dify/api/core/llm/llm_builder.py

149 lines
7.2 KiB
Python

from typing import Union, Optional, List
from langchain.callbacks.base import BaseCallbackHandler
from core.constant import llm_constant
from core.llm.error import ProviderTokenNotInitError
from core.llm.provider.base import BaseProvider
from core.llm.provider.llm_provider_service import LLMProviderService
from core.llm.streamable_azure_chat_open_ai import StreamableAzureChatOpenAI
from core.llm.streamable_azure_open_ai import StreamableAzureOpenAI
from core.llm.streamable_chat_anthropic import StreamableChatAnthropic
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI
from models.provider import ProviderType, ProviderName
class LLMBuilder:
"""
This class handles the following logic:
1. For providers with the name 'OpenAI', the OPENAI_API_KEY value is stored directly in encrypted_config.
2. For providers with the name 'Azure OpenAI', encrypted_config stores the serialized values of four fields, as shown below:
OPENAI_API_TYPE=azure
OPENAI_API_VERSION=2022-12-01
OPENAI_API_BASE=https://your-resource-name.openai.azure.com
OPENAI_API_KEY=<your Azure OpenAI API key>
3. For providers with the name 'Anthropic', the ANTHROPIC_API_KEY value is stored directly in encrypted_config.
4. For providers with the name 'Cohere', the COHERE_API_KEY value is stored directly in encrypted_config.
5. For providers with the name 'HUGGINGFACEHUB', the HUGGINGFACEHUB_API_KEY value is stored directly in encrypted_config.
6. Providers with the provider_type 'CUSTOM' can be created through the admin interface, while 'System' providers cannot be created through the admin interface.
7. If both CUSTOM and System providers exist in the records, the CUSTOM provider is preferred by default, but this preference can be changed via an input parameter.
8. For providers with the provider_type 'System', the quota_used must not exceed quota_limit. If the quota is exceeded, the provider cannot be used. Currently, only the TRIAL quota_type is supported, which is permanently non-resetting.
"""
@classmethod
def to_llm(cls, tenant_id: str, model_name: str, **kwargs) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
provider = cls.get_default_provider(tenant_id, model_name)
model_credentials = cls.get_model_credentials(tenant_id, provider, model_name)
llm_cls = None
mode = cls.get_mode_by_model(model_name)
if mode == 'chat':
if provider == ProviderName.OPENAI.value:
llm_cls = StreamableChatOpenAI
elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureChatOpenAI
elif provider == ProviderName.ANTHROPIC.value:
llm_cls = StreamableChatAnthropic
elif mode == 'completion':
if provider == ProviderName.OPENAI.value:
llm_cls = StreamableOpenAI
elif provider == ProviderName.AZURE_OPENAI.value:
llm_cls = StreamableAzureOpenAI
if not llm_cls:
raise ValueError(f"model name {model_name} is not supported.")
model_kwargs = {
'model_name': model_name,
'temperature': kwargs.get('temperature', 0),
'max_tokens': kwargs.get('max_tokens', 256),
'top_p': kwargs.get('top_p', 1),
'frequency_penalty': kwargs.get('frequency_penalty', 0),
'presence_penalty': kwargs.get('presence_penalty', 0),
'callbacks': kwargs.get('callbacks', None),
'streaming': kwargs.get('streaming', False),
}
model_kwargs.update(model_credentials)
model_kwargs = llm_cls.get_kwargs_from_model_params(model_kwargs)
return llm_cls(**model_kwargs)
@classmethod
def to_llm_from_model(cls, tenant_id: str, model: dict, streaming: bool = False,
callbacks: Optional[List[BaseCallbackHandler]] = None) -> Union[StreamableOpenAI, StreamableChatOpenAI]:
model_name = model.get("name")
completion_params = model.get("completion_params", {})
return cls.to_llm(
tenant_id=tenant_id,
model_name=model_name,
temperature=completion_params.get('temperature', 0),
max_tokens=completion_params.get('max_tokens', 256),
top_p=completion_params.get('top_p', 0),
frequency_penalty=completion_params.get('frequency_penalty', 0.1),
presence_penalty=completion_params.get('presence_penalty', 0.1),
streaming=streaming,
callbacks=callbacks
)
@classmethod
def get_mode_by_model(cls, model_name: str) -> str:
if not model_name:
raise ValueError(f"empty model name is not supported.")
if model_name in llm_constant.models_by_mode['chat']:
return "chat"
elif model_name in llm_constant.models_by_mode['completion']:
return "completion"
else:
raise ValueError(f"model name {model_name} is not supported.")
@classmethod
def get_model_credentials(cls, tenant_id: str, model_provider: str, model_name: str) -> dict:
"""
Returns the API credentials for the given tenant_id and model_name, based on the model's provider.
Raises an exception if the model_name is not found or if the provider is not found.
"""
if not model_name:
raise Exception('model name not found')
#
# if model_name not in llm_constant.models:
# raise Exception('model {} not found'.format(model_name))
# model_provider = llm_constant.models[model_name]
provider_service = LLMProviderService(tenant_id=tenant_id, provider_name=model_provider)
return provider_service.get_credentials(model_name)
@classmethod
def get_default_provider(cls, tenant_id: str, model_name: str) -> str:
provider_name = llm_constant.models[model_name]
if provider_name == 'openai':
# get the default provider (openai / azure_openai) for the tenant
openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.OPENAI.value)
azure_openai_provider = BaseProvider.get_valid_provider(tenant_id, ProviderName.AZURE_OPENAI.value)
provider = None
if openai_provider and openai_provider.provider_type == ProviderType.CUSTOM.value:
provider = openai_provider
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.CUSTOM.value:
provider = azure_openai_provider
elif openai_provider and openai_provider.provider_type == ProviderType.SYSTEM.value:
provider = openai_provider
elif azure_openai_provider and azure_openai_provider.provider_type == ProviderType.SYSTEM.value:
provider = azure_openai_provider
if not provider:
raise ProviderTokenNotInitError(
f"No valid {provider_name} model provider credentials found. "
f"Please go to Settings -> Model Provider to complete your provider credentials."
)
provider_name = provider.provider_name
return provider_name