mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 11:42:29 +08:00
fix: rerank validation fix from frontend
This commit is contained in:
parent
cd7ab6231f
commit
47c9595a0e
|
@ -63,7 +63,7 @@ const ConfigContent: FC<Props> = ({
|
|||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const {
|
||||
currentModel,
|
||||
currentModel: currentRerankModel,
|
||||
} = useCurrentProviderAndModel(
|
||||
rerankModelList,
|
||||
rerankDefaultModel
|
||||
|
@ -74,11 +74,6 @@ const ConfigContent: FC<Props> = ({
|
|||
: undefined,
|
||||
)
|
||||
|
||||
const handleDisabledSwitchClick = useCallback(() => {
|
||||
if (!currentModel)
|
||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||
}, [currentModel, rerankDefaultModel, t])
|
||||
|
||||
const rerankModel = (() => {
|
||||
if (datasetConfigs.reranking_model?.reranking_provider_name) {
|
||||
return {
|
||||
|
@ -164,12 +159,33 @@ const ConfigContent: FC<Props> = ({
|
|||
const showWeightedScorePanel = showWeightedScore && datasetConfigs.reranking_mode === RerankingModeEnum.WeightedScore && datasetConfigs.weights
|
||||
const selectedRerankMode = datasetConfigs.reranking_mode || RerankingModeEnum.RerankingModel
|
||||
|
||||
const canManuallyToggleRerank = useMemo(() => {
|
||||
return !(
|
||||
(selectedDatasetsMode.allInternal && selectedDatasetsMode.allEconomic)
|
||||
|| selectedDatasetsMode.allExternal
|
||||
)
|
||||
}, [selectedDatasetsMode.allEconomic, selectedDatasetsMode.allExternal, selectedDatasetsMode.allInternal])
|
||||
|
||||
const showRerankModel = useMemo(() => {
|
||||
if (datasetConfigs.reranking_enable === false && selectedDatasetsMode.allEconomic)
|
||||
if (!canManuallyToggleRerank)
|
||||
return false
|
||||
|
||||
return true
|
||||
}, [datasetConfigs.reranking_enable, selectedDatasetsMode.allEconomic])
|
||||
return datasetConfigs.reranking_enable
|
||||
}, [canManuallyToggleRerank, datasetConfigs.reranking_enable])
|
||||
|
||||
const handleDisabledSwitchClick = useCallback(() => {
|
||||
if (!currentRerankModel && !showRerankModel)
|
||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||
}, [currentRerankModel, showRerankModel, t])
|
||||
|
||||
useEffect(() => {
|
||||
if (!canManuallyToggleRerank && showRerankModel !== datasetConfigs.reranking_enable) {
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_enable: showRerankModel,
|
||||
})
|
||||
}
|
||||
}, [canManuallyToggleRerank, showRerankModel, datasetConfigs, onChange])
|
||||
|
||||
return (
|
||||
<div>
|
||||
|
@ -256,13 +272,15 @@ const ConfigContent: FC<Props> = ({
|
|||
>
|
||||
<Switch
|
||||
size='md'
|
||||
defaultValue={currentModel ? showRerankModel : false}
|
||||
disabled={!currentModel}
|
||||
defaultValue={showRerankModel}
|
||||
disabled={!currentRerankModel || !canManuallyToggleRerank}
|
||||
onChange={(v) => {
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_enable: v,
|
||||
})
|
||||
if (canManuallyToggleRerank) {
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_enable: v,
|
||||
})
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
|
|
@ -42,6 +42,7 @@ const ParamsConfig = ({
|
|||
allHighQuality,
|
||||
allHighQualityFullTextSearch,
|
||||
allHighQualityVectorSearch,
|
||||
allInternal,
|
||||
allExternal,
|
||||
mixtureHighQualityAndEconomic,
|
||||
inconsistentEmbeddingModel,
|
||||
|
@ -50,7 +51,7 @@ const ParamsConfig = ({
|
|||
const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
|
||||
let rerankEnable = restConfigs.reranking_enable
|
||||
|
||||
if ((allEconomic && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined) || allExternal)
|
||||
if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined)
|
||||
rerankEnable = false
|
||||
|
||||
if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1))
|
||||
|
|
|
@ -1,25 +1,17 @@
|
|||
import { useCallback } from 'react'
|
||||
import { useStoreApi } from 'reactflow'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useWorkflowStore } from '../store'
|
||||
import {
|
||||
BlockEnum,
|
||||
WorkflowRunningStatus,
|
||||
} from '../types'
|
||||
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
|
||||
import type { Node } from '../types'
|
||||
import { useWorkflow } from './use-workflow'
|
||||
import {
|
||||
useIsChatMode,
|
||||
useNodesSyncDraft,
|
||||
useWorkflowInteractions,
|
||||
useWorkflowRun,
|
||||
} from './index'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { useFeaturesStore } from '@/app/components/base/features/hooks'
|
||||
import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
|
||||
export const useWorkflowStartRun = () => {
|
||||
const store = useStoreApi()
|
||||
|
@ -28,26 +20,7 @@ export const useWorkflowStartRun = () => {
|
|||
const isChatMode = useIsChatMode()
|
||||
const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions()
|
||||
const { handleRun } = useWorkflowRun()
|
||||
const { isFromStartNode } = useWorkflow()
|
||||
const { doSyncWorkflowDraft } = useNodesSyncDraft()
|
||||
const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault
|
||||
const { t } = useTranslation()
|
||||
const {
|
||||
modelList: rerankModelList,
|
||||
defaultModel: rerankDefaultModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const {
|
||||
currentModel,
|
||||
} = useCurrentProviderAndModel(
|
||||
rerankModelList,
|
||||
rerankDefaultModel
|
||||
? {
|
||||
...rerankDefaultModel,
|
||||
provider: rerankDefaultModel.provider.provider,
|
||||
}
|
||||
: undefined,
|
||||
)
|
||||
|
||||
const handleWorkflowStartRunInWorkflow = useCallback(async () => {
|
||||
const {
|
||||
|
@ -60,9 +33,6 @@ export const useWorkflowStartRun = () => {
|
|||
const { getNodes } = store.getState()
|
||||
const nodes = getNodes()
|
||||
const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
|
||||
const knowledgeRetrievalNodes = nodes.filter((node: Node<KnowledgeRetrievalNodeType>) =>
|
||||
node.data.type === BlockEnum.KnowledgeRetrieval,
|
||||
)
|
||||
const startVariables = startNode?.data.variables || []
|
||||
const fileSettings = featuresStore!.getState().features.file
|
||||
const {
|
||||
|
@ -72,31 +42,6 @@ export const useWorkflowStartRun = () => {
|
|||
setShowEnvPanel,
|
||||
} = workflowStore.getState()
|
||||
|
||||
if (knowledgeRetrievalNodes.length > 0) {
|
||||
for (const node of knowledgeRetrievalNodes) {
|
||||
if (isFromStartNode(node.id)) {
|
||||
const res = checkKnowledgeRetrievalValid(node.data, t)
|
||||
if (!res.isValid || !currentModel || !rerankDefaultModel) {
|
||||
const errorMessage = res.errorMessage
|
||||
if (errorMessage) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: errorMessage,
|
||||
})
|
||||
return false
|
||||
}
|
||||
else {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('appDebug.datasetConfig.rerankModelRequired'),
|
||||
})
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setShowEnvPanel(false)
|
||||
|
||||
if (showDebugAndPreviewPanel) {
|
||||
|
|
|
@ -23,7 +23,7 @@ import type { DataSet } from '@/models/datasets'
|
|||
import { fetchDatasets } from '@/service/datasets'
|
||||
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
|
||||
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
|
||||
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
||||
|
@ -34,6 +34,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
const startNodeId = startNode?.id
|
||||
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
|
||||
|
||||
const inputRef = useRef(inputs)
|
||||
|
||||
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
|
||||
const newInputs = produce(s, (draft) => {
|
||||
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
|
||||
|
@ -43,13 +45,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
})
|
||||
// not work in pass to draft...
|
||||
doSetInputs(newInputs)
|
||||
inputRef.current = newInputs
|
||||
}, [doSetInputs])
|
||||
|
||||
const inputRef = useRef(inputs)
|
||||
useEffect(() => {
|
||||
inputRef.current = inputs
|
||||
}, [inputs])
|
||||
|
||||
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.query_variable_selector = newVar as ValueSelector
|
||||
|
@ -63,9 +61,22 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
|
||||
|
||||
const {
|
||||
modelList: rerankModelList,
|
||||
defaultModel: rerankDefaultModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const {
|
||||
currentModel: currentRerankModel,
|
||||
} = useCurrentProviderAndModel(
|
||||
rerankModelList,
|
||||
rerankDefaultModel
|
||||
? {
|
||||
...rerankDefaultModel,
|
||||
provider: rerankDefaultModel.provider.provider,
|
||||
}
|
||||
: undefined,
|
||||
)
|
||||
|
||||
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
|
||||
const newInputs = produce(inputRef.current, (draft) => {
|
||||
if (!draft.single_retrieval_config) {
|
||||
|
@ -110,7 +121,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
// set defaults models
|
||||
useEffect(() => {
|
||||
const inputs = inputRef.current
|
||||
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider)
|
||||
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel)
|
||||
return
|
||||
|
||||
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
|
||||
|
@ -130,7 +141,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
draft.multiple_retrieval_config = {
|
||||
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
|
||||
|
@ -138,6 +148,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
reranking_model: multipleRetrievalConfig?.reranking_model,
|
||||
reranking_mode: multipleRetrievalConfig?.reranking_mode,
|
||||
weights: multipleRetrievalConfig?.weights,
|
||||
reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined
|
||||
? multipleRetrievalConfig.reranking_enable
|
||||
: Boolean(currentRerankModel && rerankDefaultModel),
|
||||
}
|
||||
})
|
||||
setInputs(newInput)
|
||||
|
@ -194,14 +207,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
const inputs = inputRef.current
|
||||
let query_variable_selector: ValueSelector = inputs.query_variable_selector
|
||||
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
|
||||
query_variable_selector = [startNodeId, 'sys.query']
|
||||
|
||||
setInputs({
|
||||
...inputs,
|
||||
query_variable_selector,
|
||||
})
|
||||
setInputs(produce(inputs, (draft) => {
|
||||
draft.query_variable_selector = query_variable_selector
|
||||
}))
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
|
|
|
@ -113,7 +113,7 @@ export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetr
|
|||
reranking_mode,
|
||||
reranking_model,
|
||||
weights,
|
||||
reranking_enable: allEconomic ? reranking_enable : true,
|
||||
reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true,
|
||||
}
|
||||
|
||||
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal)
|
||||
|
|
Loading…
Reference in New Issue
Block a user