diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 1cd4823e13..95c8f36271 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -113,7 +113,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): try: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) - if model.startswith("o1"): + if "o1" in model: client.chat.completions.create( messages=[{"role": "user", "content": "ping"}], model=model, @@ -311,7 +311,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): prompt_messages = self._clear_illegal_prompt_messages(model, prompt_messages) block_as_stream = False - if model.startswith("o1"): + if "o1" in model: if stream: block_as_stream = True stream = False @@ -404,7 +404,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): ] ) - if model.startswith("o1"): + if "o1" in model: system_message_count = len([m for m in prompt_messages if isinstance(m, SystemPromptMessage)]) if system_message_count > 0: new_prompt_messages = [] @@ -653,7 +653,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): tokens_per_message = 4 # if there's a name, the role is omitted tokens_per_name = -1 - elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or model.startswith("o1"): + elif model.startswith("gpt-35-turbo") or model.startswith("gpt-4") or "o1" in model: tokens_per_message = 3 tokens_per_name = 1 else: