diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index f556121518..04ae146645 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -63,7 +63,7 @@ const ConfigContent: FC = ({ } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) const { - currentModel, + currentModel: currentRerankModel, } = useCurrentProviderAndModel( rerankModelList, rerankDefaultModel @@ -74,11 +74,6 @@ const ConfigContent: FC = ({ : undefined, ) - const handleDisabledSwitchClick = useCallback(() => { - if (!currentModel) - Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) - }, [currentModel, rerankDefaultModel, t]) - const rerankModel = (() => { if (datasetConfigs.reranking_model?.reranking_provider_name) { return { @@ -164,12 +159,33 @@ const ConfigContent: FC = ({ const showWeightedScorePanel = showWeightedScore && datasetConfigs.reranking_mode === RerankingModeEnum.WeightedScore && datasetConfigs.weights const selectedRerankMode = datasetConfigs.reranking_mode || RerankingModeEnum.RerankingModel + const canManuallyToggleRerank = useMemo(() => { + return !( + (selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic) + || selectedDatasetsMode.allExternal + ) + }, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal]) + const showRerankModel = useMemo(() => { - if (datasetConfigs.reranking_enable === false && selectedDatasetsMode.allEconomic) + if (!canManuallyToggleRerank) return false - return true - }, [datasetConfigs.reranking_enable, selectedDatasetsMode.allEconomic]) + return datasetConfigs.reranking_enable + }, [canManuallyToggleRerank, datasetConfigs.reranking_enable]) + + const handleDisabledSwitchClick = useCallback(() => { + if (!currentRerankModel && !showRerankModel) + Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) + }, [currentRerankModel, showRerankModel, t]) + + useEffect(() => { + if (!canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) { + onChange({ + ...datasetConfigs, + reranking_enable: showRerankModel, + }) + } + }, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange]) return (
@@ -256,13 +272,15 @@ const ConfigContent: FC = ({ > { - onChange({ - ...datasetConfigs, - reranking_enable: v, - }) + if (canManuallyToggleRerank) { + onChange({ + ...datasetConfigs, + reranking_enable: v, + }) + } }} />
diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.tsx index 2d3df0b039..91d0e4e590 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.tsx @@ -42,6 +42,7 @@ const ParamsConfig = ({ allHighQuality, allHighQualityFullTextSearch, allHighQualityVectorSearch, + allInternal, allExternal, mixtureHighQualityAndEconomic, inconsistentEmbeddingModel, @@ -50,7 +51,7 @@ const ParamsConfig = ({ const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs let rerankEnable = restConfigs.reranking_enable - if ((allEconomic && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined) || allExternal) + if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined) rerankEnable = false if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1)) diff --git a/web/app/components/workflow/hooks/use-workflow-start-run.tsx b/web/app/components/workflow/hooks/use-workflow-start-run.tsx index 77e959b573..b2b1c69975 100644 --- a/web/app/components/workflow/hooks/use-workflow-start-run.tsx +++ b/web/app/components/workflow/hooks/use-workflow-start-run.tsx @@ -1,25 +1,17 @@ import { useCallback } from 'react' import { useStoreApi } from 'reactflow' -import { useTranslation } from 'react-i18next' import { useWorkflowStore } from '../store' import { BlockEnum, WorkflowRunningStatus, } from '../types' -import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types' -import type { Node } from '../types' -import { useWorkflow } from './use-workflow' import { useIsChatMode, useNodesSyncDraft, useWorkflowInteractions, useWorkflowRun, } from './index' -import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' -import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useFeaturesStore } from '@/app/components/base/features/hooks' -import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default' -import Toast from '@/app/components/base/toast' export const useWorkflowStartRun = () => { const store = useStoreApi() @@ -28,26 +20,7 @@ export const useWorkflowStartRun = () => { const isChatMode = useIsChatMode() const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions() const { handleRun } = useWorkflowRun() - const { isFromStartNode } = useWorkflow() const { doSyncWorkflowDraft } = useNodesSyncDraft() - const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault - const { t } = useTranslation() - const { - modelList: rerankModelList, - defaultModel: rerankDefaultModel, - } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) - - const { - currentModel, - } = useCurrentProviderAndModel( - rerankModelList, - rerankDefaultModel - ? { - ...rerankDefaultModel, - provider: rerankDefaultModel.provider.provider, - } - : undefined, - ) const handleWorkflowStartRunInWorkflow = useCallback(async () => { const { @@ -60,9 +33,6 @@ export const useWorkflowStartRun = () => { const { getNodes } = store.getState() const nodes = getNodes() const startNode = nodes.find(node => node.data.type === BlockEnum.Start) - const knowledgeRetrievalNodes = nodes.filter((node: Node) => - node.data.type === BlockEnum.KnowledgeRetrieval, - ) const startVariables = startNode?.data.variables || [] const fileSettings = featuresStore!.getState().features.file const { @@ -72,31 +42,6 @@ export const useWorkflowStartRun = () => { setShowEnvPanel, } = workflowStore.getState() - if (knowledgeRetrievalNodes.length > 0) { - for (const node of knowledgeRetrievalNodes) { - if (isFromStartNode(node.id)) { - const res = checkKnowledgeRetrievalValid(node.data, t) - if (!res.isValid || !currentModel || !rerankDefaultModel) { - const errorMessage = res.errorMessage - if (errorMessage) { - Toast.notify({ - type: 'error', - message: errorMessage, - }) - return false - } - else { - Toast.notify({ - type: 'error', - message: t('appDebug.datasetConfig.rerankModelRequired'), - }) - return false - } - } - } - } - } - setShowEnvPanel(false) if (showDebugAndPreviewPanel) { diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts index e83a5d97b5..01c1e31ccc 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/use-config.ts @@ -23,7 +23,7 @@ import type { DataSet } from '@/models/datasets' import { fetchDatasets } from '@/service/datasets' import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud' import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run' -import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { @@ -34,6 +34,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { const startNodeId = startNode?.id const { inputs, setInputs: doSetInputs } = useNodeCrud(id, payload) + const inputRef = useRef(inputs) + const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => { const newInputs = produce(s, (draft) => { if (s.retrieval_mode === RETRIEVE_TYPE.multiWay) @@ -43,13 +45,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { }) // not work in pass to draft... doSetInputs(newInputs) + inputRef.current = newInputs }, [doSetInputs]) - const inputRef = useRef(inputs) - useEffect(() => { - inputRef.current = inputs - }, [inputs]) - const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => { const newInputs = produce(inputs, (draft) => { draft.query_variable_selector = newVar as ValueSelector @@ -63,9 +61,22 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration) const { + modelList: rerankModelList, defaultModel: rerankDefaultModel, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + const { + currentModel: currentRerankModel, + } = useCurrentProviderAndModel( + rerankModelList, + rerankDefaultModel + ? { + ...rerankDefaultModel, + provider: rerankDefaultModel.provider.provider, + } + : undefined, + ) + const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => { const newInputs = produce(inputRef.current, (draft) => { if (!draft.single_retrieval_config) { @@ -110,7 +121,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { // set defaults models useEffect(() => { const inputs = inputRef.current - if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider) + if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel) return if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider) @@ -130,7 +141,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { } } } - const multipleRetrievalConfig = draft.multiple_retrieval_config draft.multiple_retrieval_config = { top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k, @@ -138,6 +148,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { reranking_model: multipleRetrievalConfig?.reranking_model, reranking_mode: multipleRetrievalConfig?.reranking_mode, weights: multipleRetrievalConfig?.weights, + reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined + ? multipleRetrievalConfig.reranking_enable + : Boolean(currentRerankModel && rerankDefaultModel), } }) setInputs(newInput) @@ -194,14 +207,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => { }, []) useEffect(() => { + const inputs = inputRef.current let query_variable_selector: ValueSelector = inputs.query_variable_selector if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId) query_variable_selector = [startNodeId, 'sys.query'] - setInputs({ - ...inputs, - query_variable_selector, - }) + setInputs(produce(inputs, (draft) => { + draft.query_variable_selector = query_variable_selector + })) // eslint-disable-next-line react-hooks/exhaustive-deps }, []) diff --git a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts index 85ae6c4c96..e48777d948 100644 --- a/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts +++ b/web/app/components/workflow/nodes/knowledge-retrieval/utils.ts @@ -113,7 +113,7 @@ export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetr reranking_mode, reranking_model, weights, - reranking_enable: allEconomic ? reranking_enable : true, + reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true, } if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal)