fix: retrieval setting validate (#10454)

This commit is contained in:
zxhlyh 2024-11-12 14:38:24 +08:00 committed by GitHub
parent 16b9665033
commit e4d175780e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 130 additions and 49 deletions

View File

@ -47,12 +47,16 @@ const DatasetConfig: FC = () => {
const { const {
currentModel: currentRerankModel, currentModel: currentRerankModel,
currentProvider: currentRerankProvider,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const onRemove = (id: string) => { const onRemove = (id: string) => {
const filteredDataSets = dataSet.filter(item => item.id !== id) const filteredDataSets = dataSet.filter(item => item.id !== id)
setDataSet(filteredDataSets) setDataSet(filteredDataSets)
const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, !!currentRerankModel) const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, {
provider: currentRerankProvider?.provider,
model: currentRerankModel?.model,
})
setDatasetConfigs({ setDatasetConfigs({
...(datasetConfigs as any), ...(datasetConfigs as any),
...retrievalConfig, ...retrievalConfig,

View File

@ -172,7 +172,7 @@ const ConfigContent: FC<Props> = ({
return false return false
return datasetConfigs.reranking_enable return datasetConfigs.reranking_enable
}, [canManuallyToggleRerank, datasetConfigs.reranking_enable]) }, [canManuallyToggleRerank, datasetConfigs.reranking_enable, isRerankDefaultModelValid])
const handleDisabledSwitchClick = useCallback(() => { const handleDisabledSwitchClick = useCallback(() => {
if (!currentRerankModel && !showRerankModel) if (!currentRerankModel && !showRerankModel)

View File

@ -43,6 +43,7 @@ const ParamsConfig = ({
const { const {
defaultModel: rerankDefaultModel, defaultModel: rerankDefaultModel,
currentModel: isRerankDefaultModelValid, currentModel: isRerankDefaultModelValid,
currentProvider: rerankDefaultProvider,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const isValid = () => { const isValid = () => {
@ -91,7 +92,10 @@ const ParamsConfig = ({
reranking_mode: restConfigs.reranking_mode, reranking_mode: restConfigs.reranking_mode,
weights: restConfigs.weights, weights: restConfigs.weights,
reranking_enable: restConfigs.reranking_enable, reranking_enable: restConfigs.reranking_enable,
}, selectedDatasets, selectedDatasets, !!isRerankDefaultModelValid) }, selectedDatasets, selectedDatasets, {
provider: rerankDefaultProvider?.provider,
model: isRerankDefaultModelValid?.model,
})
setTempDataSetConfigs({ setTempDataSetConfigs({
...retrievalConfig, ...retrievalConfig,

View File

@ -226,6 +226,7 @@ const Configuration: FC = () => {
const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false) const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false)
const { const {
currentModel: currentRerankModel, currentModel: currentRerankModel,
currentProvider: currentRerankProvider,
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank) } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
const handleSelect = (data: DataSet[]) => { const handleSelect = (data: DataSet[]) => {
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) { if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
@ -279,7 +280,10 @@ const Configuration: FC = () => {
reranking_mode: restConfigs.reranking_mode, reranking_mode: restConfigs.reranking_mode,
weights: restConfigs.weights, weights: restConfigs.weights,
reranking_enable: restConfigs.reranking_enable, reranking_enable: restConfigs.reranking_enable,
}, newDatasets, dataSets, !!currentRerankModel) }, newDatasets, dataSets, {
provider: currentRerankProvider?.provider,
model: currentRerankModel?.model,
})
setDatasetConfigs({ setDatasetConfigs({
...retrievalConfig, ...retrievalConfig,
@ -620,7 +624,10 @@ const Configuration: FC = () => {
syncToPublishedConfig(config) syncToPublishedConfig(config)
setPublishedConfig(config) setPublishedConfig(config)
const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, !!currentRerankModel) const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, {
provider: currentRerankProvider?.provider,
model: currentRerankModel?.model,
})
setDatasetConfigs({ setDatasetConfigs({
retrieval_model: RETRIEVE_TYPE.multiWay, retrieval_model: RETRIEVE_TYPE.multiWay,
...modelConfig.dataset_configs, ...modelConfig.dataset_configs,

View File

@ -1,7 +1,7 @@
import { BlockEnum } from '../../types' import { BlockEnum } from '../../types'
import type { NodeDefault } from '../../types' import type { NodeDefault } from '../../types'
import type { KnowledgeRetrievalNodeType } from './types' import type { KnowledgeRetrievalNodeType } from './types'
import { RerankingModeEnum } from '@/models/datasets' import { checkoutRerankModelConfigedInRetrievalSettings } from './utils'
import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants' import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants'
import { DATASET_DEFAULT } from '@/config' import { DATASET_DEFAULT } from '@/config'
import { RETRIEVE_TYPE } from '@/types/app' import { RETRIEVE_TYPE } from '@/types/app'
@ -36,12 +36,17 @@ const nodeDefault: NodeDefault<KnowledgeRetrievalNodeType> = {
if (!errorMessages && (!payload.dataset_ids || payload.dataset_ids.length === 0)) if (!errorMessages && (!payload.dataset_ids || payload.dataset_ids.length === 0))
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.knowledge`) }) errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.knowledge`) })
if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.multiWay && payload.multiple_retrieval_config?.reranking_mode === RerankingModeEnum.RerankingModel && !payload.multiple_retrieval_config?.reranking_model?.provider && payload.multiple_retrieval_config?.reranking_enable)
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })
if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.oneWay && !payload.single_retrieval_config?.model?.provider) if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.oneWay && !payload.single_retrieval_config?.model?.provider)
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t('common.modelProvider.systemReasoningModel.key') }) errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t('common.modelProvider.systemReasoningModel.key') })
const { _datasets, multiple_retrieval_config, retrieval_mode } = payload
if (retrieval_mode === RETRIEVE_TYPE.multiWay) {
const checked = checkoutRerankModelConfigedInRetrievalSettings(_datasets || [], multiple_retrieval_config)
if (!errorMessages && !checked)
errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })
}
return { return {
isValid: !errorMessages, isValid: !errorMessages,
errorMessage: errorMessages, errorMessage: errorMessages,

View File

@ -1,6 +1,7 @@
import type { CommonNodeType, ModelConfig, ValueSelector } from '@/app/components/workflow/types' import type { CommonNodeType, ModelConfig, ValueSelector } from '@/app/components/workflow/types'
import type { RETRIEVE_TYPE } from '@/types/app' import type { RETRIEVE_TYPE } from '@/types/app'
import type { import type {
DataSet,
RerankingModeEnum, RerankingModeEnum,
} from '@/models/datasets' } from '@/models/datasets'
@ -35,4 +36,5 @@ export type KnowledgeRetrievalNodeType = CommonNodeType & {
retrieval_mode: RETRIEVE_TYPE retrieval_mode: RETRIEVE_TYPE
multiple_retrieval_config?: MultipleRetrievalConfig multiple_retrieval_config?: MultipleRetrievalConfig
single_retrieval_config?: SingleRetrievalConfig single_retrieval_config?: SingleRetrievalConfig
_datasets?: DataSet[]
} }

View File

@ -67,6 +67,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
const { const {
currentModel: currentRerankModel, currentModel: currentRerankModel,
currentProvider: currentRerankProvider,
} = useCurrentProviderAndModel( } = useCurrentProviderAndModel(
rerankModelList, rerankModelList,
rerankDefaultModel rerankDefaultModel
@ -163,7 +164,10 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
draft.retrieval_mode = newMode draft.retrieval_mode = newMode
if (newMode === RETRIEVE_TYPE.multiWay) { if (newMode === RETRIEVE_TYPE.multiWay) {
const multipleRetrievalConfig = draft.multiple_retrieval_config const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, {
provider: currentRerankProvider?.provider,
model: currentRerankModel?.model,
})
} }
else { else {
const hasSetModel = draft.single_retrieval_config?.model?.provider const hasSetModel = draft.single_retrieval_config?.model?.provider
@ -180,14 +184,17 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
} }
}) })
setInputs(newInputs) setInputs(newInputs)
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel]) }, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider])
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => { const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
const newInputs = produce(inputs, (draft) => { const newInputs = produce(inputs, (draft) => {
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel) draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, {
provider: currentRerankProvider?.provider,
model: currentRerankModel?.model,
})
}) })
setInputs(newInputs) setInputs(newInputs)
}, [inputs, setInputs, selectedDatasets, currentRerankModel]) }, [inputs, setInputs, selectedDatasets, currentRerankModel, currentRerankProvider])
// datasets // datasets
useEffect(() => { useEffect(() => {
@ -200,6 +207,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
} }
const newInputs = produce(inputs, (draft) => { const newInputs = produce(inputs, (draft) => {
draft.dataset_ids = datasetIds draft.dataset_ids = datasetIds
draft._datasets = selectedDatasets
}) })
setInputs(newInputs) setInputs(newInputs)
})() })()
@ -228,10 +236,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
} = getSelectedDatasetsMode(newDatasets) } = getSelectedDatasetsMode(newDatasets)
const newInputs = produce(inputs, (draft) => { const newInputs = produce(inputs, (draft) => {
draft.dataset_ids = newDatasets.map(d => d.id) draft.dataset_ids = newDatasets.map(d => d.id)
draft._datasets = newDatasets
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) { if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
const multipleRetrievalConfig = draft.multiple_retrieval_config const multipleRetrievalConfig = draft.multiple_retrieval_config
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel) draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, {
provider: currentRerankProvider?.provider,
model: currentRerankModel?.model,
})
} }
}) })
setInputs(newInputs) setInputs(newInputs)
@ -243,7 +255,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|| allExternal || allExternal
) )
setRerankModelOpen(true) setRerankModelOpen(true)
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel]) }, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel, currentRerankProvider])
const filterVar = useCallback((varPayload: Var) => { const filterVar = useCallback((varPayload: Var) => {
return varPayload.type === VarType.string return varPayload.type === VarType.string

View File

@ -94,9 +94,10 @@ export const getMultipleRetrievalConfig = (
multipleRetrievalConfig: MultipleRetrievalConfig, multipleRetrievalConfig: MultipleRetrievalConfig,
selectedDatasets: DataSet[], selectedDatasets: DataSet[],
originalDatasets: DataSet[], originalDatasets: DataSet[],
isValidRerankModel?: boolean, validRerankModel?: { provider?: string; model?: string },
) => { ) => {
const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0 const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0
const rerankModelIsValid = validRerankModel?.provider && validRerankModel?.model
const { const {
allHighQuality, allHighQuality,
@ -128,18 +129,10 @@ export const getMultipleRetrievalConfig = (
reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true, reranking_enable: ((allInternal && allEconomic) || allExternal) ? reranking_enable : true,
} }
if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) if (!rerankModelIsValid)
result.reranking_mode = RerankingModeEnum.RerankingModel result.reranking_model = undefined
if (allHighQuality && !inconsistentEmbeddingModel && reranking_mode === undefined && allInternal)
result.reranking_mode = RerankingModeEnum.WeightedScore
if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) {
if (!isValidRerankModel)
result.reranking_mode = RerankingModeEnum.WeightedScore
else
result.reranking_mode = RerankingModeEnum.RerankingModel
const setDefaultWeights = () => {
result.weights = { result.weights = {
vector_setting: { vector_setting: {
vector_weight: allHighQualityVectorSearch vector_weight: allHighQualityVectorSearch
@ -160,31 +153,85 @@ export const getMultipleRetrievalConfig = (
} }
} }
if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && weights) { if (allEconomic || mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || allExternal || mixtureInternalAndExternal) {
if (!isValidRerankModel) result.reranking_mode = RerankingModeEnum.RerankingModel
result.reranking_mode = RerankingModeEnum.WeightedScore
else
result.reranking_mode = RerankingModeEnum.RerankingModel
result.weights = { if (rerankModelIsValid) {
vector_setting: { result.reranking_mode = RerankingModeEnum.RerankingModel
vector_weight: allHighQualityVectorSearch result.reranking_model = {
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic provider: validRerankModel?.provider || '',
: allHighQualityFullTextSearch model: validRerankModel?.model || '',
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic }
: DEFAULT_WEIGHTED_SCORE.other.semantic, }
embedding_provider_name: selectedDatasets[0].embedding_model_provider, else {
embedding_model_name: selectedDatasets[0].embedding_model, result.reranking_model = undefined
}, }
keyword_setting: { }
keyword_weight: allHighQualityVectorSearch
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword if (allHighQuality && !inconsistentEmbeddingModel && allInternal) {
: allHighQualityFullTextSearch if (!reranking_mode) {
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword if (validRerankModel?.provider && validRerankModel?.model) {
: DEFAULT_WEIGHTED_SCORE.other.keyword, result.reranking_mode = RerankingModeEnum.RerankingModel
}, result.reranking_model = {
provider: validRerankModel.provider,
model: validRerankModel.model,
}
}
else {
result.reranking_mode = RerankingModeEnum.WeightedScore
setDefaultWeights()
}
}
if (reranking_mode === RerankingModeEnum.WeightedScore && !weights)
setDefaultWeights()
if (reranking_mode === RerankingModeEnum.WeightedScore && weights && shouldSetWeightDefaultValue) {
if (rerankModelIsValid) {
result.reranking_mode = RerankingModeEnum.RerankingModel
result.reranking_model = {
provider: validRerankModel.provider || '',
model: validRerankModel.model || '',
}
}
else {
setDefaultWeights()
}
}
if (reranking_mode === RerankingModeEnum.RerankingModel && !rerankModelIsValid && shouldSetWeightDefaultValue) {
result.reranking_mode = RerankingModeEnum.WeightedScore
setDefaultWeights()
} }
} }
return result return result
} }
export const checkoutRerankModelConfigedInRetrievalSettings = (
datasets: DataSet[],
multipleRetrievalConfig?: MultipleRetrievalConfig,
) => {
if (!multipleRetrievalConfig)
return true
const {
allEconomic,
allExternal,
} = getSelectedDatasetsMode(datasets)
const {
reranking_enable,
reranking_mode,
reranking_model,
} = multipleRetrievalConfig
if (reranking_mode === RerankingModeEnum.RerankingModel && (!reranking_model?.provider || !reranking_model?.model)) {
if ((allEconomic || allExternal) && !reranking_enable)
return true
return false
}
return true
}