From 071e7800a0bc8b0fdca454c689a088ec07e26b5a Mon Sep 17 00:00:00 2001 From: takatost Date: Wed, 23 Aug 2023 19:48:31 +0800 Subject: [PATCH] fix: add hf task field (#976) Co-authored-by: StyleZhang --- .../model-page/configs/huggingface_hub.tsx | 53 +++++++++++++++++++ .../model-page/declarations.ts | 1 + .../account-setting/model-page/index.tsx | 5 +- .../model-page/model-modal/Form.tsx | 2 +- 4 files changed, 58 insertions(+), 3 deletions(-) diff --git a/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx b/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx index 420ee2b567..efcc578592 100644 --- a/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx +++ b/web/app/components/header/account-setting/model-page/configs/huggingface_hub.tsx @@ -38,6 +38,7 @@ const config: ProviderConfig = { defaultValue: { model_type: 'text-generation', huggingfacehub_api_type: 'hosted_inference_api', + task_type: 'text-generation', }, validateKeys: (v?: FormValue) => { if (v?.huggingfacehub_api_type === 'hosted_inference_api') { @@ -51,10 +52,36 @@ const config: ProviderConfig = { 'huggingfacehub_api_token', 'model_name', 'huggingfacehub_endpoint_url', + 'task_type', ] } return [] }, + filterValue: (v?: FormValue) => { + let filteredKeys: string[] = [] + if (v?.huggingfacehub_api_type === 'hosted_inference_api') { + filteredKeys = [ + 'huggingfacehub_api_type', + 'huggingfacehub_api_token', + 'model_name', + 'model_type', + ] + } + if (v?.huggingfacehub_api_type === 'inference_endpoints') { + filteredKeys = [ + 'huggingfacehub_api_type', + 'huggingfacehub_api_token', + 'model_name', + 'huggingfacehub_endpoint_url', + 'task_type', + 'model_type', + ] + } + return filteredKeys.reduce((prev: FormValue, next: string) => { + prev[next] = v?.[next] || '' + return prev + }, {}) + }, fields: [ { type: 'radio', @@ -120,6 +147,32 @@ const config: ProviderConfig = { 'zh-Hans': '在此输入您的端点 URL', }, }, + { + hidden: (value?: FormValue) => value?.huggingfacehub_api_type === 'hosted_inference_api', + type: 'radio', + key: 'task_type', + required: true, + label: { + 'en': 'Task', + 'zh-Hans': 'Task', + }, + options: [ + { + key: 'text2text-generation', + label: { + 'en': 'Text-to-Text Generation', + 'zh-Hans': 'Text-to-Text Generation', + }, + }, + { + key: 'text-generation', + label: { + 'en': 'Text Generation', + 'zh-Hans': 'Text Generation', + }, + }, + ], + }, ], }, } diff --git a/web/app/components/header/account-setting/model-page/declarations.ts b/web/app/components/header/account-setting/model-page/declarations.ts index ea6f8cb81d..0835dd837c 100644 --- a/web/app/components/header/account-setting/model-page/declarations.ts +++ b/web/app/components/header/account-setting/model-page/declarations.ts @@ -91,6 +91,7 @@ export type ProviderConfigModal = { icon: ReactElement defaultValue?: FormValue validateKeys?: string[] | ((v?: FormValue) => string[]) + filterValue?: (v?: FormValue) => FormValue fields: Field[] link: { href: string diff --git a/web/app/components/header/account-setting/model-page/index.tsx b/web/app/components/header/account-setting/model-page/index.tsx index 281ada2be5..22ccfd23b3 100644 --- a/web/app/components/header/account-setting/model-page/index.tsx +++ b/web/app/components/header/account-setting/model-page/index.tsx @@ -124,8 +124,9 @@ const ModelPage = () => { updateModelList(ModelType.embeddings) mutateProviders() } - const handleSave = async (v?: FormValue) => { - if (v && modelModalConfig) { + const handleSave = async (originValue?: FormValue) => { + if (originValue && modelModalConfig) { + const v = modelModalConfig.filterValue ? modelModalConfig.filterValue(originValue) : originValue let body, url if (ConfigurableProviders.includes(modelModalConfig.key)) { const { model_name, model_type, ...config } = v diff --git a/web/app/components/header/account-setting/model-page/model-modal/Form.tsx b/web/app/components/header/account-setting/model-page/model-modal/Form.tsx index 08b54a1162..6f3838c860 100644 --- a/web/app/components/header/account-setting/model-page/model-modal/Form.tsx +++ b/web/app/components/header/account-setting/model-page/model-modal/Form.tsx @@ -68,7 +68,7 @@ const Form: FC = ({ return true }, run: () => { - return validateModelProviderFn(modelModal!.key, v) + return validateModelProviderFn(modelModal!.key, modelModal?.filterValue ? modelModal?.filterValue(v) : v) }, }) }