From c06e766d7e66bf8763fbadd49474064c01e6b1f4 Mon Sep 17 00:00:00 2001 From: zxhlyh Date: Thu, 4 Jan 2024 16:46:51 +0800 Subject: [PATCH] feat: model parameter prefefined (#1917) --- .../app/configuration/config-model/index.tsx | 423 ------------------ .../config-model/model-mode-type-label.tsx | 29 -- .../configuration/config-model/model-name.tsx | 26 -- .../configuration/config-model/param-item.tsx | 95 ---- .../config-model/provider-name.tsx | 18 - .../components/app/configuration/index.tsx | 1 - .../model-parameter-modal/index.tsx | 100 +++++ 7 files changed, 100 insertions(+), 592 deletions(-) delete mode 100644 web/app/components/app/configuration/config-model/index.tsx delete mode 100644 web/app/components/app/configuration/config-model/model-mode-type-label.tsx delete mode 100644 web/app/components/app/configuration/config-model/model-name.tsx delete mode 100644 web/app/components/app/configuration/config-model/param-item.tsx delete mode 100644 web/app/components/app/configuration/config-model/provider-name.tsx diff --git a/web/app/components/app/configuration/config-model/index.tsx b/web/app/components/app/configuration/config-model/index.tsx deleted file mode 100644 index cf3e8aa858..0000000000 --- a/web/app/components/app/configuration/config-model/index.tsx +++ /dev/null @@ -1,423 +0,0 @@ -'use client' -import type { FC } from 'react' -import React, { useEffect, useState } from 'react' -import cn from 'classnames' -import { useTranslation } from 'react-i18next' -import { useBoolean, useClickAway, useGetState } from 'ahooks' -import { InformationCircleIcon } from '@heroicons/react/24/outline' -import produce from 'immer' -import ParamItem from './param-item' -import { SlidersH } from '@/app/components/base/icons/src/vender/line/mediaAndDevices' -import Radio from '@/app/components/base/radio' -import Panel from '@/app/components/base/panel' -import type { CompletionParams } from '@/models/debug' -import { TONE_LIST } from '@/config' -import Toast from '@/app/components/base/toast' -import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback' -import { formatNumber } from '@/utils/format' -import { Brush01 } from '@/app/components/base/icons/src/vender/solid/editor' -import { Scales02 } from '@/app/components/base/icons/src/vender/solid/FinanceAndECommerce' -import { Target04 } from '@/app/components/base/icons/src/vender/solid/general' -import { Sliders02 } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' -import { fetchModelParams } from '@/service/debug' -import Loading from '@/app/components/base/loading' -import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' -import type { ModelModeType } from '@/types/app' -import ModelIcon from '@/app/components/header/account-setting/model-provider-page/model-icon' -import ModelName from '@/app/components/header/account-setting/model-provider-page/model-name' -import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' -import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' - -export type IConfigModelProps = { - isAdvancedMode: boolean - mode: string - modelId: string - provider: string - setModel: (model: { id: string; provider: string; mode: ModelModeType; features: string[] }) => void - completionParams: CompletionParams - onCompletionParamsChange: (newParams: CompletionParams) => void - disabled: boolean -} - -const ConfigModel: FC = ({ - isAdvancedMode, - modelId, - provider, - setModel, - completionParams, - onCompletionParamsChange, - disabled, -}) => { - const { t } = useTranslation() - const [isShowConfig, { setFalse: hideConfig, toggle: toogleShowConfig }] = useBoolean(false) - const [maxTokenSettingTipVisible, setMaxTokenSettingTipVisible] = useState(false) - const configContentRef = React.useRef(null) - const { - currentProvider, - currentModel: currModel, - textGenerationModelList, - } = useTextGenerationCurrentProviderAndModelAndModelList( - { provider, model: modelId }, - ) - - const media = useBreakpoints() - const isMobile = media === MediaType.mobile - - // Cache loaded model param - const [allParams, setAllParams, getAllParams] = useGetState>>({}) - const currParams = allParams[provider]?.[modelId] - const hasEnableParams = currParams && Object.keys(currParams).some(key => currParams[key].enabled) - const allSupportParams = ['temperature', 'top_p', 'presence_penalty', 'frequency_penalty', 'max_tokens'] - const currSupportParams = currParams ? allSupportParams.filter(key => currParams[key].enabled) : allSupportParams - if (isAdvancedMode) - currSupportParams.push('stop') - - useEffect(() => { - (async () => { - if (!allParams[provider]?.[modelId]) { - const res = await fetchModelParams(provider, modelId) - const newAllParams = produce(allParams, (draft) => { - if (!draft[provider]) - draft[provider] = {} - - draft[provider][modelId] = res - }) - setAllParams(newAllParams) - } - })() - }, [provider, modelId, allParams, setAllParams]) - - useClickAway(() => { - hideConfig() - }, configContentRef) - - const selectedModel = { name: modelId } // options.find(option => option.id === modelId) - - const ensureModelParamLoaded = (provider: string, modelId: string) => { - return new Promise((resolve) => { - if (getAllParams()[provider]?.[modelId]) { - resolve() - return - } - const runId = setInterval(() => { - if (getAllParams()[provider]?.[modelId]) { - resolve() - clearInterval(runId) - } - }, 500) - }) - } - - const transformValue = (value: number, fromRange: [number, number], toRange: [number, number]): number => { - const [fromStart = 0, fromEnd] = fromRange - const [toStart = 0, toEnd] = toRange - - // The following three if is to avoid precision loss - if (fromStart === toStart && fromEnd === toEnd) - return value - - if (value <= fromStart) - return toStart - - if (value >= fromEnd) - return toEnd - - const fromLength = fromEnd - fromStart - const toLength = toEnd - toStart - - let adjustedValue = (value - fromStart) * (toLength / fromLength) + toStart - adjustedValue = parseFloat(adjustedValue.toFixed(2)) - return adjustedValue - } - - const handleSelectModel = ({ id, provider: nextProvider, mode, features }: { id: string; provider: string; mode: ModelModeType; features: string[] }) => { - return async () => { - const prevParamsRule = getAllParams()[provider]?.[modelId] - - setModel({ - id, - provider: nextProvider || 'openai', - mode, - features, - }) - - await ensureModelParamLoaded(nextProvider, id) - - const nextParamsRule = getAllParams()[nextProvider]?.[id] - // debugger - const nextSelectModelMaxToken = nextParamsRule.max_tokens.max - const newConCompletionParams = produce(completionParams, (draft: any) => { - if (nextParamsRule.max_tokens.enabled) { - if (completionParams.max_tokens > nextSelectModelMaxToken) { - Toast.notify({ - type: 'warning', - message: t('common.model.params.setToCurrentModelMaxTokenTip', { maxToken: formatNumber(nextSelectModelMaxToken) }), - }) - draft.max_tokens = parseFloat((nextSelectModelMaxToken * 0.8).toFixed(2)) - } - // prev don't have max token - if (!completionParams.max_tokens) - draft.max_tokens = nextParamsRule.max_tokens.default - } - else { - delete draft.max_tokens - } - - allSupportParams.forEach((key) => { - if (key === 'max_tokens') - return - - if (!nextParamsRule[key].enabled) { - delete draft[key] - return - } - - if (draft[key] === undefined) { - draft[key] = nextParamsRule[key].default || 0 - return - } - - if (!prevParamsRule[key].enabled) { - draft[key] = nextParamsRule[key].default || 0 - return - } - - draft[key] = transformValue( - draft[key], - [prevParamsRule[key].min, prevParamsRule[key].max], - [nextParamsRule[key].min, nextParamsRule[key].max], - ) - }) - }) - onCompletionParamsChange(newConCompletionParams) - } - } - - // only openai support this - function matchToneId(completionParams: CompletionParams): number { - const remvoedCustomeTone = TONE_LIST.slice(0, -1) - const CUSTOM_TONE_ID = 4 - const tone = remvoedCustomeTone.find((tone) => { - return tone.config?.temperature === completionParams.temperature - && tone.config?.top_p === completionParams.top_p - && tone.config?.presence_penalty === completionParams.presence_penalty - && tone.config?.frequency_penalty === completionParams.frequency_penalty - }) - return tone ? tone.id : CUSTOM_TONE_ID - } - - // tone is a preset of completionParams. - const [toneId, setToneId] = React.useState(matchToneId(completionParams)) // default is Balanced - const toneTabBgClassName = ({ - 1: 'bg-[#F5F8FF]', - 2: 'bg-[#F4F3FF]', - 3: 'bg-[#F6FEFC]', - })[toneId] || '' - // set completionParams by toneId - const handleToneChange = (id: number) => { - if (id === 4) - return // custom tone - const tone = TONE_LIST.find(tone => tone.id === id) - if (tone) { - setToneId(id) - onCompletionParamsChange({ - ...tone.config, - max_tokens: completionParams.max_tokens, - } as CompletionParams) - } - } - - useEffect(() => { - setToneId(matchToneId(completionParams)) - }, [completionParams]) - - const handleParamChange = (key: string, value: number | string[]) => { - if (value === undefined) - return - if ((completionParams as any)[key] === value) - return - - if (key === 'stop') { - onCompletionParamsChange({ - ...completionParams, - [key]: value as string[], - }) - } - else { - const currParamsRule = getAllParams()[provider]?.[modelId] - let notOutRangeValue = parseFloat((value as number).toFixed(2)) - notOutRangeValue = Math.max(currParamsRule[key].min, notOutRangeValue) - notOutRangeValue = Math.min(currParamsRule[key].max, notOutRangeValue) - onCompletionParamsChange({ - ...completionParams, - [key]: notOutRangeValue, - }) - } - } - const ableStyle = 'bg-indigo-25 border-[#2A87F5] cursor-pointer' - const diabledStyle = 'bg-[#FFFCF5] border-[#F79009]' - - const getToneIcon = (toneId: number) => { - const className = 'w-[14px] h-[14px]' - const res = ({ - 1: , - 2: , - 3: , - 4: , - })[toneId] - return res - } - useEffect(() => { - if (!currParams) - return - - const max = currParams.max_tokens.max - const isSupportMaxToken = currParams.max_tokens.enabled - if (isSupportMaxToken && currentProvider?.provider !== 'anthropic' && completionParams.max_tokens > max * 2 / 3) - setMaxTokenSettingTipVisible(true) - else - setMaxTokenSettingTipVisible(false) - }, [currParams, completionParams.max_tokens, setMaxTokenSettingTipVisible, currentProvider]) - return ( -
-
!disabled && toogleShowConfig()} - > - { - currentProvider && ( - - ) - } - { - currModel && ( - - ) - } - {disabled ? : } -
- {isShowConfig && ( - - - - - - - - - - } - title={t('appDebug.modelConfig.title')} - > -
-
-
{t('appDebug.modelConfig.model')}
- { - const targetProvider = textGenerationModelList.find(modelItem => modelItem.provider === provider) - const targetModelItem = targetProvider?.models.find(modelItem => modelItem.model === model) - handleSelectModel({ - id: model, - provider, - mode: targetModelItem?.model_properties.mode as ModelModeType, - features: targetModelItem?.features || [], - })() - }} - /> -
- {hasEnableParams && ( -
- )} - - {/* Tone type */} - {['openai', 'azure_openai'].includes(provider) && ( -
-
{t('appDebug.modelConfig.setTone')}
- - <> - {TONE_LIST.slice(0, 3).map(tone => ( -
- - <> - {getToneIcon(tone.id)} - {!isMobile &&
{t(`common.model.tone.${tone.name}`) as string}
} -
- -
- {tone.id !== toneId && tone.id + 1 !== toneId && (
)} -
- ))} - - - <> - {getToneIcon(TONE_LIST[3].id)} - {!isMobile &&
{t(`common.model.tone.${TONE_LIST[3].name}`) as string}
} - -
-
-
- )} - - {/* Params */} -
- {(allParams[provider]?.[modelId]) - ? ( - currSupportParams.map(key => ()) - ) - : ( - - )} -
-
- { - maxTokenSettingTipVisible && ( -
- -
{t('common.model.params.maxTokenSettingTip')}
-
- ) - } -
- )} -
- - ) -} - -export default React.memo(ConfigModel) diff --git a/web/app/components/app/configuration/config-model/model-mode-type-label.tsx b/web/app/components/app/configuration/config-model/model-mode-type-label.tsx deleted file mode 100644 index 8f4791541d..0000000000 --- a/web/app/components/app/configuration/config-model/model-mode-type-label.tsx +++ /dev/null @@ -1,29 +0,0 @@ -'use client' -import type { FC } from 'react' -import React from 'react' -import { useTranslation } from 'react-i18next' -import cn from 'classnames' -import type { ModelModeType } from '@/types/app' - -type Props = { - className?: string - type: ModelModeType - isHighlight?: boolean -} - -const ModelModeTypeLabel: FC = ({ - className, - type, - isHighlight, -}) => { - const { t } = useTranslation() - - return ( -
- {t(`appDebug.modelConfig.modeType.${type}`)} -
- ) -} -export default React.memo(ModelModeTypeLabel) diff --git a/web/app/components/app/configuration/config-model/model-name.tsx b/web/app/components/app/configuration/config-model/model-name.tsx deleted file mode 100644 index 789aea537f..0000000000 --- a/web/app/components/app/configuration/config-model/model-name.tsx +++ /dev/null @@ -1,26 +0,0 @@ -'use client' -import type { FC } from 'react' -import React from 'react' - -export type IModelNameProps = { - modelId: string - modelDisplayName?: string -} - -export const supportI18nModelName = [ - 'gpt-3.5-turbo', 'gpt-3.5-turbo-16k', - 'gpt-4', 'gpt-4-32k', - 'text-davinci-003', 'text-embedding-ada-002', 'whisper-1', - 'claude-instant-1', 'claude-2', -] - -const ModelName: FC = ({ - modelDisplayName, -}) => { - return ( - - {modelDisplayName} - - ) -} -export default React.memo(ModelName) diff --git a/web/app/components/app/configuration/config-model/param-item.tsx b/web/app/components/app/configuration/config-model/param-item.tsx deleted file mode 100644 index 9f2863b232..0000000000 --- a/web/app/components/app/configuration/config-model/param-item.tsx +++ /dev/null @@ -1,95 +0,0 @@ -'use client' -import type { FC } from 'react' -import React, { useEffect } from 'react' -import { useTranslation } from 'react-i18next' -import Tooltip from '@/app/components/base/tooltip' -import Slider from '@/app/components/base/slider' -import TagInput from '@/app/components/base/tag-input' - -export const getFitPrecisionValue = (num: number, precision: number | null) => { - if (!precision || !(`${num}`).includes('.')) - return num - - const currNumPrecision = (`${num}`).split('.')[1].length - if (currNumPrecision > precision) - return parseFloat(num.toFixed(precision)) - - return num -} - -export type IParamIteProps = { - id: string - name: string - tip: string - value: number | string[] - step?: number - min?: number - max: number - precision: number | null - onChange: (key: string, value: number | string[]) => void - inputType?: 'inputTag' | 'slider' -} - -const TIMES_TEMPLATE = '1000000000000' -const ParamItem: FC = ({ id, name, tip, step = 0.1, min = 0, max, precision, value, inputType, onChange }) => { - const { t } = useTranslation() - - const getToIntTimes = (num: number) => { - if (precision) - return parseInt(TIMES_TEMPLATE.slice(0, precision + 1), 10) - if (num < 5) - return 10 - return 1 - } - - const times = getToIntTimes(max) - - useEffect(() => { - if (precision) - onChange(id, getFitPrecisionValue(value, precision)) - }, [value, precision]) - return ( -
-
-
- {name} - {/* Give tooltip different tip to avoiding hide bug */} - {tip}
} position='top' selector={`param-name-tooltip-${id}`}> - - - - -
- {inputType === 'inputTag' &&
{t('common.model.params.stop_sequencesPlaceholder')}
} -
-
- {inputType === 'inputTag' - ? onChange(id, newSequences)} - customizedConfirmKey='Tab' - /> - : ( - <> -
- { - onChange(id, value / times) - }} /> -
- { - let value = getFitPrecisionValue(isNaN(parseFloat(e.target.value)) ? min : parseFloat(e.target.value), precision) - if (value < min) - value = min - - if (value > max) - value = max - onChange(id, value) - }} /> - - ) - } -
- - ) -} -export default React.memo(ParamItem) diff --git a/web/app/components/app/configuration/config-model/provider-name.tsx b/web/app/components/app/configuration/config-model/provider-name.tsx deleted file mode 100644 index a9e713e8ad..0000000000 --- a/web/app/components/app/configuration/config-model/provider-name.tsx +++ /dev/null @@ -1,18 +0,0 @@ -'use client' -import type { FC } from 'react' -import React from 'react' - -export type IProviderNameProps = { - provideName: string -} - -const ProviderName: FC = ({ - provideName, -}) => { - return ( - - {provideName} - - ) -} -export default React.memo(ProviderName) diff --git a/web/app/components/app/configuration/index.tsx b/web/app/components/app/configuration/index.tsx index 4dc7059319..b96ea1ae74 100644 --- a/web/app/components/app/configuration/index.tsx +++ b/web/app/components/app/configuration/index.tsx @@ -28,7 +28,6 @@ import type { ExternalDataTool } from '@/models/common' import type { DataSet } from '@/models/datasets' import type { ModelConfig as BackendModelConfig, VisionSettings } from '@/types/app' import ConfigContext from '@/context/debug-configuration' -// import ConfigModel from '@/app/components/app/configuration/config-model' import Config from '@/app/components/app/configuration/config' import Debug from '@/app/components/app/configuration/debug' import Confirm from '@/app/components/base/confirm' diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx index fd8175e29e..d0febbc215 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/index.tsx @@ -1,6 +1,7 @@ import type { FC } from 'react' import { useEffect, useMemo, useState } from 'react' import useSWR from 'swr' +import cn from 'classnames' import { useTranslation } from 'react-i18next' import type { DefaultModel, @@ -32,6 +33,13 @@ import { fetchModelParameterRules } from '@/service/common' import Loading from '@/app/components/base/loading' import { useProviderContext } from '@/context/provider-context' import TooltipPlus from '@/app/components/base/tooltip-plus' +import Radio from '@/app/components/base/radio' +import { TONE_LIST } from '@/config' +import { Brush01 } from '@/app/components/base/icons/src/vender/solid/editor' +import { Scales02 } from '@/app/components/base/icons/src/vender/solid/FinanceAndECommerce' +import { Target04 } from '@/app/components/base/icons/src/vender/solid/general' +import { Sliders02 } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices' +import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' type ModelParameterModalProps = { isAdvancedMode: boolean @@ -71,6 +79,8 @@ const ModelParameterModal: FC = ({ const { t } = useTranslation() const language = useLanguage() const { hasSettedApiKey, modelProviders } = useProviderContext() + const media = useBreakpoints() + const isMobile = media === MediaType.mobile const [open, setOpen] = useState(false) const { data: parameterRulesData, isLoading } = useSWR(`/workspaces/current/model-providers/${provider}/models/parameter-rules?model=${modelId}`, fetchModelParameterRules) const { @@ -89,6 +99,44 @@ const ModelParameterModal: FC = ({ return parameterRulesData?.data || [] }, [parameterRulesData]) + // only openai support this + function matchToneId(completionParams: FormValue): number { + const remvoedCustomeTone = TONE_LIST.slice(0, -1) + const CUSTOM_TONE_ID = 4 + const tone = remvoedCustomeTone.find((tone) => { + return tone.config?.temperature === completionParams.temperature + && tone.config?.top_p === completionParams.top_p + && tone.config?.presence_penalty === completionParams.presence_penalty + && tone.config?.frequency_penalty === completionParams.frequency_penalty + }) + return tone ? tone.id : CUSTOM_TONE_ID + } + + // tone is a preset of completionParams. + const [toneId, setToneId] = useState(matchToneId(completionParams)) // default is Balanced + const toneTabBgClassName = ({ + 1: 'bg-[#F5F8FF]', + 2: 'bg-[#F4F3FF]', + 3: 'bg-[#F6FEFC]', + })[toneId] || '' + // set completionParams by toneId + const handleToneChange = (id: number) => { + if (id === 4) + return // custom tone + const tone = TONE_LIST.find(tone => tone.id === id) + if (tone) { + setToneId(id) + onCompletionParamsChange({ + ...tone.config, + max_tokens: completionParams.max_tokens, + }) + } + } + + useEffect(() => { + setToneId(matchToneId(completionParams)) + }, [completionParams]) + const handleParamChange = (key: string, value: ParameterValue) => { onCompletionParamsChange({ ...completionParams, @@ -138,6 +186,17 @@ const ModelParameterModal: FC = ({ handleInitialParams() }, [parameterRules]) + const getToneIcon = (toneId: number) => { + const className = 'w-[14px] h-[14px]' + const res = ({ + 1: , + 2: , + 3: , + 4: , + })[toneId] + return res + } + return ( = ({
) } + {['openai', 'azure_openai'].includes(provider) && !isLoading && !!parameterRules.length && ( +
+
{t('appDebug.modelConfig.setTone')}
+ + <> + {TONE_LIST.slice(0, 3).map(tone => ( +
+ + <> + {getToneIcon(tone.id)} + {!isMobile &&
{t(`common.model.tone.${tone.name}`) as string}
} +
+ +
+ {tone.id !== toneId && tone.id + 1 !== toneId && (
)} +
+ ))} + + + <> + {getToneIcon(TONE_LIST[3].id)} + {!isMobile &&
{t(`common.model.tone.${TONE_LIST[3].name}`) as string}
} + +
+
+
+ )} { !isLoading && !!parameterRules.length && ( [