feat: tool credentials cache and introduce _position.yaml (#2386)

This commit is contained in:
Yeuoly 2024-02-05 12:39:42 +08:00 committed by GitHub
parent 6278ff0f30
commit 5010706d8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 96 additions and 26 deletions

View File

@ -0,0 +1,49 @@
import json
from enum import Enum
from json import JSONDecodeError
from typing import Optional
from extensions.ext_redis import redis_client
class ToolProviderCredentialsCacheType(Enum):
PROVIDER = "tool_provider"
class ToolProviderCredentialsCache:
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
def get(self) -> Optional[dict]:
"""
Get cached model provider credentials.
:return:
"""
cached_provider_credentials = redis_client.get(self.cache_key)
if cached_provider_credentials:
try:
cached_provider_credentials = cached_provider_credentials.decode('utf-8')
cached_provider_credentials = json.loads(cached_provider_credentials)
except JSONDecodeError:
return None
return cached_provider_credentials
else:
return None
def set(self, credentials: dict) -> None:
"""
Cache model provider credentials.
:param credentials: provider credentials
:return:
"""
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
def delete(self) -> None:
"""
Delete cached model provider credentials.
:return:
"""
redis_client.delete(self.cache_key)

View File

@ -0,0 +1,15 @@
- google
- bing
- wikipedia
- dalle
- azuredalle
- webscraper
- wolframalpha
- github
- chart
- time
- yahoo
- stablediffusion
- vectorizer
- youtube
- gaode

View File

@ -1,31 +1,29 @@
from typing import List
from core.tools.entities.user_entities import UserToolProvider
from core.tools.entities.tool_entities import ToolProviderType
from typing import List
from yaml import load, FullLoader
position = {
'google': 1,
'bing': 2,
'wikipedia': 2,
'dalle': 3,
'webscraper': 4,
'wolframalpha': 5,
'chart': 6,
'time': 7,
'yahoo': 8,
'stablediffusion': 9,
'vectorizer': 10,
'youtube': 11,
'github': 12,
'gaode': 13
}
import os.path
position = {}
class BuiltinToolProviderSort:
@staticmethod
def sort(providers: List[UserToolProvider]) -> List[UserToolProvider]:
global position
if not position:
tmp_position = {}
file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
with open(file_path, 'r') as f:
for pos, val in enumerate(load(f, Loader=FullLoader)):
tmp_position[val] = pos
position = tmp_position
def sort_compare(provider: UserToolProvider) -> int:
# if provider.type == UserToolProvider.ProviderType.MODEL:
# return position.get(f'model_provider.{provider.name}', 10000)
return position.get(provider.name, 10000)
sorted_providers = sorted(providers, key=sort_compare)
return sorted_providers
return sorted_providers

View File

@ -1,10 +1,10 @@
from typing import Any, Dict
from core.helper import encrypter
from core.tools.entities.tool_entities import ToolProviderCredentials
from core.tools.provider.tool_provider import ToolProviderController
from typing import Dict, Any
from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolProviderCredentials
from core.tools.provider.tool_provider import ToolProviderController
from core.helper import encrypter
from core.helper.tool_provider_cache import ToolProviderCredentialsCacheType, ToolProviderCredentialsCache
class ToolConfiguration(BaseModel):
tenant_id: str
@ -63,8 +63,15 @@ class ToolConfiguration(BaseModel):
return a deep copy of credentials with decrypted values
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
)
cached_credentials = cache.get()
if cached_credentials:
return cached_credentials
credentials = self._deep_copy(credentials)
# get fields need to be decrypted
fields = self.provider_controller.get_credentials_schema()
for field_name, field in fields.items():
@ -74,5 +81,6 @@ class ToolConfiguration(BaseModel):
credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
except:
pass
cache.set(credentials)
return credentials