From 793205afc547264151ef6ac428f485422adf6daf Mon Sep 17 00:00:00 2001 From: Yi Xiao <54782454+YIXIAO0@users.noreply.github.com> Date: Sat, 12 Oct 2024 21:24:43 +0800 Subject: [PATCH] Feat: rerank model verification in front end (#9271) --- .../params-config/config-content.tsx | 49 +++++++++++++---- .../common/retrieval-param-config/index.tsx | 49 +++++++++++++---- .../workflow/hooks/use-workflow-start-run.tsx | 55 +++++++++++++++++++ .../components/workflow/hooks/use-workflow.ts | 28 ++++++++++ web/i18n/en-US/workflow.ts | 1 + web/i18n/zh-Hans/workflow.ts | 1 + 6 files changed, 159 insertions(+), 24 deletions(-) diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 7f83a14d58..f556121518 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -1,6 +1,6 @@ 'use client' -import { memo, useEffect, useMemo } from 'react' +import { memo, useCallback, useEffect, useMemo } from 'react' import type { FC } from 'react' import { useTranslation } from 'react-i18next' import WeightedScore from './weighted-score' @@ -11,7 +11,7 @@ import type { DatasetConfigs, } from '@/models/debug' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' -import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import type { ModelConfig } from '@/app/components/workflow/types' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' import Tooltip from '@/app/components/base/tooltip' @@ -23,6 +23,7 @@ import { RerankingModeEnum } from '@/models/datasets' import cn from '@/utils/classnames' import { useSelectedDatasetsMode } from '@/app/components/workflow/nodes/knowledge-retrieval/hooks' import Switch from '@/app/components/base/switch' +import Toast from '@/app/components/base/toast' type Props = { datasetConfigs: DatasetConfigs @@ -60,6 +61,24 @@ const ConfigContent: FC = ({ modelList: rerankModelList, defaultModel: rerankDefaultModel, } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + + const { + currentModel, + } = useCurrentProviderAndModel( + rerankModelList, + rerankDefaultModel + ? { + ...rerankDefaultModel, + provider: rerankDefaultModel.provider.provider, + } + : undefined, + ) + + const handleDisabledSwitchClick = useCallback(() => { + if (!currentModel) + Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) + }, [currentModel, rerankDefaultModel, t]) + const rerankModel = (() => { if (datasetConfigs.reranking_model?.reranking_provider_name) { return { @@ -231,16 +250,22 @@ const ConfigContent: FC = ({
{ selectedDatasetsMode.allEconomic && ( - { - onChange({ - ...datasetConfigs, - reranking_enable: v, - }) - }} - /> +
+ { + onChange({ + ...datasetConfigs, + reranking_enable: v, + }) + }} + /> +
) }
{t('common.modelProvider.rerankModel.key')}
diff --git a/web/app/components/datasets/common/retrieval-param-config/index.tsx b/web/app/components/datasets/common/retrieval-param-config/index.tsx index 323e47f3b4..9d48d56a8d 100644 --- a/web/app/components/datasets/common/retrieval-param-config/index.tsx +++ b/web/app/components/datasets/common/retrieval-param-config/index.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC } from 'react' -import React from 'react' +import React, { useCallback } from 'react' import { useTranslation } from 'react-i18next' import cn from '@/utils/classnames' @@ -11,7 +11,7 @@ import Switch from '@/app/components/base/switch' import Tooltip from '@/app/components/base/tooltip' import type { RetrievalConfig } from '@/types/app' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' -import { useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' +import { useCurrentProviderAndModel, useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { DEFAULT_WEIGHTED_SCORE, @@ -19,6 +19,7 @@ import { WeightedScoreEnum, } from '@/models/datasets' import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score' +import Toast from '@/app/components/base/toast' type Props = { type: RETRIEVE_METHOD @@ -38,6 +39,24 @@ const RetrievalParamConfig: FC = ({ defaultModel: rerankDefaultModel, modelList: rerankModelList, } = useModelListAndDefaultModel(ModelTypeEnum.rerank) + + const { + currentModel, + } = useCurrentProviderAndModel( + rerankModelList, + rerankDefaultModel + ? { + ...rerankDefaultModel, + provider: rerankDefaultModel.provider.provider, + } + : undefined, + ) + + const handleDisabledSwitchClick = useCallback(() => { + if (!currentModel) + Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') }) + }, [currentModel, rerankDefaultModel, t]) + const isHybridSearch = type === RETRIEVE_METHOD.hybrid const rerankModel = (() => { @@ -99,16 +118,22 @@ const RetrievalParamConfig: FC = ({
{canToggleRerankModalEnable && ( - { - onChange({ - ...value, - reranking_enable: v, - }) - }} - /> +
+ { + onChange({ + ...value, + reranking_enable: v, + }) + }} + disabled={!currentModel} + /> +
)}
{t('common.modelProvider.rerankModel.key')} diff --git a/web/app/components/workflow/hooks/use-workflow-start-run.tsx b/web/app/components/workflow/hooks/use-workflow-start-run.tsx index b2b1c69975..77e959b573 100644 --- a/web/app/components/workflow/hooks/use-workflow-start-run.tsx +++ b/web/app/components/workflow/hooks/use-workflow-start-run.tsx @@ -1,17 +1,25 @@ import { useCallback } from 'react' import { useStoreApi } from 'reactflow' +import { useTranslation } from 'react-i18next' import { useWorkflowStore } from '../store' import { BlockEnum, WorkflowRunningStatus, } from '../types' +import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types' +import type { Node } from '../types' +import { useWorkflow } from './use-workflow' import { useIsChatMode, useNodesSyncDraft, useWorkflowInteractions, useWorkflowRun, } from './index' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useFeaturesStore } from '@/app/components/base/features/hooks' +import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default' +import Toast from '@/app/components/base/toast' export const useWorkflowStartRun = () => { const store = useStoreApi() @@ -20,7 +28,26 @@ export const useWorkflowStartRun = () => { const isChatMode = useIsChatMode() const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions() const { handleRun } = useWorkflowRun() + const { isFromStartNode } = useWorkflow() const { doSyncWorkflowDraft } = useNodesSyncDraft() + const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault + const { t } = useTranslation() + const { + modelList: rerankModelList, + defaultModel: rerankDefaultModel, + } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) + + const { + currentModel, + } = useCurrentProviderAndModel( + rerankModelList, + rerankDefaultModel + ? { + ...rerankDefaultModel, + provider: rerankDefaultModel.provider.provider, + } + : undefined, + ) const handleWorkflowStartRunInWorkflow = useCallback(async () => { const { @@ -33,6 +60,9 @@ export const useWorkflowStartRun = () => { const { getNodes } = store.getState() const nodes = getNodes() const startNode = nodes.find(node => node.data.type === BlockEnum.Start) + const knowledgeRetrievalNodes = nodes.filter((node: Node) => + node.data.type === BlockEnum.KnowledgeRetrieval, + ) const startVariables = startNode?.data.variables || [] const fileSettings = featuresStore!.getState().features.file const { @@ -42,6 +72,31 @@ export const useWorkflowStartRun = () => { setShowEnvPanel, } = workflowStore.getState() + if (knowledgeRetrievalNodes.length > 0) { + for (const node of knowledgeRetrievalNodes) { + if (isFromStartNode(node.id)) { + const res = checkKnowledgeRetrievalValid(node.data, t) + if (!res.isValid || !currentModel || !rerankDefaultModel) { + const errorMessage = res.errorMessage + if (errorMessage) { + Toast.notify({ + type: 'error', + message: errorMessage, + }) + return false + } + else { + Toast.notify({ + type: 'error', + message: t('appDebug.datasetConfig.rerankModelRequired'), + }) + return false + } + } + } + } + } + setShowEnvPanel(false) if (showDebugAndPreviewPanel) { diff --git a/web/app/components/workflow/hooks/use-workflow.ts b/web/app/components/workflow/hooks/use-workflow.ts index b201b28b88..ec7ce66e5f 100644 --- a/web/app/components/workflow/hooks/use-workflow.ts +++ b/web/app/components/workflow/hooks/use-workflow.ts @@ -235,6 +235,33 @@ export const useWorkflow = () => { return nodes.filter(node => node.parentId === nodeId) }, [store]) + const isFromStartNode = useCallback((nodeId: string) => { + const { getNodes } = store.getState() + const nodes = getNodes() + const currentNode = nodes.find(node => node.id === nodeId) + + if (!currentNode) + return false + + if (currentNode.data.type === BlockEnum.Start) + return true + + const checkPreviousNodes = (node: Node) => { + const previousNodes = getBeforeNodeById(node.id) + + for (const prevNode of previousNodes) { + if (prevNode.data.type === BlockEnum.Start) + return true + if (checkPreviousNodes(prevNode)) + return true + } + + return false + } + + return checkPreviousNodes(currentNode) + }, [store, getBeforeNodeById]) + const handleOutVarRenameChange = useCallback((nodeId: string, oldValeSelector: ValueSelector, newVarSelector: ValueSelector) => { const { getNodes, setNodes } = store.getState() const afterNodes = getAfterNodesInSameBranch(nodeId) @@ -389,6 +416,7 @@ export const useWorkflow = () => { checkParallelLimit, checkNestedParallelLimit, isValidConnection, + isFromStartNode, formatTimeFromNow, getNode, getBeforeNodeById, diff --git a/web/i18n/en-US/workflow.ts b/web/i18n/en-US/workflow.ts index a7e768911f..d5ab6eb728 100644 --- a/web/i18n/en-US/workflow.ts +++ b/web/i18n/en-US/workflow.ts @@ -172,6 +172,7 @@ const translation = { }, errorMsg: { fieldRequired: '{{field}} is required', + rerankModelRequired: 'Before turning on the Rerank Model, please confirm that the model has been successfully configured in the settings.', authRequired: 'Authorization is required', invalidJson: '{{field}} is invalid JSON', fields: { diff --git a/web/i18n/zh-Hans/workflow.ts b/web/i18n/zh-Hans/workflow.ts index 3579ec5df3..4959a87be7 100644 --- a/web/i18n/zh-Hans/workflow.ts +++ b/web/i18n/zh-Hans/workflow.ts @@ -172,6 +172,7 @@ const translation = { }, errorMsg: { fieldRequired: '{{field}} 不能为空', + rerankModelRequired: '开启 Rerank 模型前,请务必确认模型已在设置中成功配置。', authRequired: '请先授权', invalidJson: '{{field}} 是非法的 JSON', fields: {