mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
feat: tool credentials cache and introduce _position.yaml (#2386)
This commit is contained in:
parent
6278ff0f30
commit
5010706d8b
49
api/core/helper/tool_provider_cache.py
Normal file
49
api/core/helper/tool_provider_cache.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
import json
|
||||
from enum import Enum
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ToolProviderCredentialsCacheType(Enum):
|
||||
PROVIDER = "tool_provider"
|
||||
|
||||
class ToolProviderCredentialsCache:
|
||||
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
|
||||
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""
|
||||
Get cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
cached_provider_credentials = redis_client.get(self.cache_key)
|
||||
if cached_provider_credentials:
|
||||
try:
|
||||
cached_provider_credentials = cached_provider_credentials.decode('utf-8')
|
||||
cached_provider_credentials = json.loads(cached_provider_credentials)
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
|
||||
return cached_provider_credentials
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, credentials: dict) -> None:
|
||||
"""
|
||||
Cache model provider credentials.
|
||||
|
||||
:param credentials: provider credentials
|
||||
:return:
|
||||
"""
|
||||
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
|
||||
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
redis_client.delete(self.cache_key)
|
15
api/core/tools/provider/_position.yaml
Normal file
15
api/core/tools/provider/_position.yaml
Normal file
|
@ -0,0 +1,15 @@
|
|||
- google
|
||||
- bing
|
||||
- wikipedia
|
||||
- dalle
|
||||
- azuredalle
|
||||
- webscraper
|
||||
- wolframalpha
|
||||
- github
|
||||
- chart
|
||||
- time
|
||||
- yahoo
|
||||
- stablediffusion
|
||||
- vectorizer
|
||||
- youtube
|
||||
- gaode
|
|
@ -1,31 +1,29 @@
|
|||
from typing import List
|
||||
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from typing import List
|
||||
from yaml import load, FullLoader
|
||||
|
||||
position = {
|
||||
'google': 1,
|
||||
'bing': 2,
|
||||
'wikipedia': 2,
|
||||
'dalle': 3,
|
||||
'webscraper': 4,
|
||||
'wolframalpha': 5,
|
||||
'chart': 6,
|
||||
'time': 7,
|
||||
'yahoo': 8,
|
||||
'stablediffusion': 9,
|
||||
'vectorizer': 10,
|
||||
'youtube': 11,
|
||||
'github': 12,
|
||||
'gaode': 13
|
||||
}
|
||||
import os.path
|
||||
|
||||
position = {}
|
||||
|
||||
class BuiltinToolProviderSort:
|
||||
@staticmethod
|
||||
def sort(providers: List[UserToolProvider]) -> List[UserToolProvider]:
|
||||
global position
|
||||
if not position:
|
||||
tmp_position = {}
|
||||
file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
|
||||
with open(file_path, 'r') as f:
|
||||
for pos, val in enumerate(load(f, Loader=FullLoader)):
|
||||
tmp_position[val] = pos
|
||||
position = tmp_position
|
||||
|
||||
def sort_compare(provider: UserToolProvider) -> int:
|
||||
# if provider.type == UserToolProvider.ProviderType.MODEL:
|
||||
# return position.get(f'model_provider.{provider.name}', 10000)
|
||||
return position.get(provider.name, 10000)
|
||||
|
||||
sorted_providers = sorted(providers, key=sort_compare)
|
||||
|
||||
return sorted_providers
|
||||
return sorted_providers
|
|
@ -1,10 +1,10 @@
|
|||
from typing import Any, Dict
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentials
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from typing import Dict, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentials
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.helper import encrypter
|
||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCacheType, ToolProviderCredentialsCache
|
||||
|
||||
class ToolConfiguration(BaseModel):
|
||||
tenant_id: str
|
||||
|
@ -63,8 +63,15 @@ class ToolConfiguration(BaseModel):
|
|||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
credentials = self._deep_copy(credentials)
|
||||
|
||||
# get fields need to be decrypted
|
||||
fields = self.provider_controller.get_credentials_schema()
|
||||
for field_name, field in fields.items():
|
||||
|
@ -74,5 +81,6 @@ class ToolConfiguration(BaseModel):
|
|||
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
cache.set(credentials)
|
||||
return credentials
|
Loading…
Reference in New Issue
Block a user