mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
Fix/retrieval setting weight default value (#9622)
This commit is contained in:
parent
7d7e0f9800
commit
ff956cb546
|
@ -13,6 +13,11 @@ import ContextVar from './context-var'
|
|||
import ConfigContext from '@/context/debug-configuration'
|
||||
import { AppType } from '@/types/app'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import {
|
||||
getMultipleRetrievalConfig,
|
||||
} from '@/app/components/workflow/nodes/knowledge-retrieval/utils'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
|
||||
const Icon = (
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
|
@ -31,13 +36,25 @@ const DatasetConfig: FC = () => {
|
|||
setModelConfig,
|
||||
showSelectDataSet,
|
||||
isAgent,
|
||||
datasetConfigs,
|
||||
setDatasetConfigs,
|
||||
} = useContext(ConfigContext)
|
||||
const formattingChangedDispatcher = useFormattingChangedDispatcher()
|
||||
|
||||
const hasData = dataSet.length > 0
|
||||
|
||||
const {
|
||||
currentModel: currentRerankModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const onRemove = (id: string) => {
|
||||
setDataSet(dataSet.filter(item => item.id !== id))
|
||||
const filteredDataSets = dataSet.filter(item => item.id !== id)
|
||||
setDataSet(filteredDataSets)
|
||||
const retrievalConfig = getMultipleRetrievalConfig(datasetConfigs as any, filteredDataSets, dataSet, !!currentRerankModel)
|
||||
setDatasetConfigs({
|
||||
...(datasetConfigs as any),
|
||||
...retrievalConfig,
|
||||
})
|
||||
formattingChangedDispatcher()
|
||||
}
|
||||
|
||||
|
|
|
@ -55,7 +55,7 @@ const ConfigContent: FC<Props> = ({
|
|||
retrieval_model: RETRIEVE_TYPE.multiWay,
|
||||
}, isInWorkflow)
|
||||
}
|
||||
}, [type])
|
||||
}, [type, datasetConfigs, isInWorkflow, onChange])
|
||||
|
||||
const {
|
||||
modelList: rerankModelList,
|
||||
|
|
|
@ -16,7 +16,6 @@ import type { DataSet } from '@/models/datasets'
|
|||
import type { DatasetConfigs } from '@/models/debug'
|
||||
import {
|
||||
getMultipleRetrievalConfig,
|
||||
getSelectedDatasetsMode,
|
||||
} from '@/app/components/workflow/nodes/knowledge-retrieval/utils'
|
||||
|
||||
type ParamsConfigProps = {
|
||||
|
@ -37,57 +36,8 @@ const ParamsConfig = ({
|
|||
const [tempDataSetConfigs, setTempDataSetConfigs] = useState(datasetConfigs)
|
||||
|
||||
useEffect(() => {
|
||||
const {
|
||||
allEconomic,
|
||||
allHighQuality,
|
||||
allHighQualityFullTextSearch,
|
||||
allHighQualityVectorSearch,
|
||||
allExternal,
|
||||
mixtureHighQualityAndEconomic,
|
||||
inconsistentEmbeddingModel,
|
||||
mixtureInternalAndExternal,
|
||||
} = getSelectedDatasetsMode(selectedDatasets)
|
||||
|
||||
if (allEconomic || allHighQuality || allHighQualityFullTextSearch || allHighQualityVectorSearch || (allExternal && selectedDatasets.length === 1))
|
||||
setRerankSettingModalOpen(false)
|
||||
|
||||
if (mixtureHighQualityAndEconomic || inconsistentEmbeddingModel || mixtureInternalAndExternal || (allExternal && selectedDatasets.length > 1))
|
||||
setRerankSettingModalOpen(true)
|
||||
}, [selectedDatasets])
|
||||
|
||||
useEffect(() => {
|
||||
const {
|
||||
allEconomic,
|
||||
allInternal,
|
||||
allExternal,
|
||||
} = getSelectedDatasetsMode(selectedDatasets)
|
||||
const { datasets, retrieval_model, score_threshold_enabled, ...restConfigs } = datasetConfigs
|
||||
let rerankEnable = restConfigs.reranking_enable
|
||||
|
||||
if (((allInternal && allEconomic) || allExternal) && !restConfigs.reranking_model?.reranking_provider_name && rerankEnable === undefined)
|
||||
rerankEnable = false
|
||||
|
||||
setTempDataSetConfigs({
|
||||
...getMultipleRetrievalConfig({
|
||||
top_k: restConfigs.top_k,
|
||||
score_threshold: restConfigs.score_threshold,
|
||||
reranking_model: restConfigs.reranking_model && {
|
||||
provider: restConfigs.reranking_model.reranking_provider_name,
|
||||
model: restConfigs.reranking_model.reranking_model_name,
|
||||
},
|
||||
reranking_mode: restConfigs.reranking_mode,
|
||||
weights: restConfigs.weights,
|
||||
reranking_enable: rerankEnable,
|
||||
}, selectedDatasets),
|
||||
reranking_model: restConfigs.reranking_model && {
|
||||
reranking_provider_name: restConfigs.reranking_model.reranking_provider_name,
|
||||
reranking_model_name: restConfigs.reranking_model.reranking_model_name,
|
||||
},
|
||||
retrieval_model,
|
||||
score_threshold_enabled,
|
||||
datasets,
|
||||
})
|
||||
}, [selectedDatasets, datasetConfigs])
|
||||
setTempDataSetConfigs(datasetConfigs)
|
||||
}, [datasetConfigs])
|
||||
|
||||
const {
|
||||
defaultModel: rerankDefaultModel,
|
||||
|
@ -135,7 +85,7 @@ const ParamsConfig = ({
|
|||
reranking_mode: restConfigs.reranking_mode,
|
||||
weights: restConfigs.weights,
|
||||
reranking_enable: restConfigs.reranking_enable,
|
||||
}, selectedDatasets)
|
||||
}, selectedDatasets, selectedDatasets, !!isRerankDefaultModelValid)
|
||||
|
||||
setTempDataSetConfigs({
|
||||
...retrievalConfig,
|
||||
|
@ -180,6 +130,7 @@ const ParamsConfig = ({
|
|||
|
||||
<div className='mt-6 flex justify-end'>
|
||||
<Button className='mr-2 flex-shrink-0' onClick={() => {
|
||||
setTempDataSetConfigs(datasetConfigs)
|
||||
setRerankSettingModalOpen(false)
|
||||
}}>{t('common.operation.cancel')}</Button>
|
||||
<Button variant='primary' className='flex-shrink-0' onClick={handleSave} >{t('common.operation.save')}</Button>
|
||||
|
|
|
@ -38,7 +38,7 @@ import ConfigContext from '@/context/debug-configuration'
|
|||
import Config from '@/app/components/app/configuration/config'
|
||||
import Debug from '@/app/components/app/configuration/debug'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { ModelFeatureEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { fetchAppDetail, updateAppModelConfig } from '@/service/apps'
|
||||
import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config'
|
||||
|
@ -53,7 +53,10 @@ import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
|||
import Drawer from '@/app/components/base/drawer'
|
||||
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
|
||||
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useTextGenerationCurrentProviderAndModelAndModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import {
|
||||
useModelListAndDefaultModelAndCurrentProviderAndModel,
|
||||
useTextGenerationCurrentProviderAndModelAndModelList,
|
||||
} from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { fetchCollectionList } from '@/service/tools'
|
||||
import { type Collection } from '@/app/components/tools/types'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
|
@ -217,6 +220,9 @@ const Configuration: FC = () => {
|
|||
const [isShowSelectDataSet, { setTrue: showSelectDataSet, setFalse: hideSelectDataSet }] = useBoolean(false)
|
||||
const selectedIds = dataSets.map(item => item.id)
|
||||
const [rerankSettingModalOpen, setRerankSettingModalOpen] = useState(false)
|
||||
const {
|
||||
currentModel: currentRerankModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
const handleSelect = (data: DataSet[]) => {
|
||||
if (isEqual(data.map(item => item.id), dataSets.map(item => item.id))) {
|
||||
hideSelectDataSet()
|
||||
|
@ -263,7 +269,7 @@ const Configuration: FC = () => {
|
|||
reranking_mode: restConfigs.reranking_mode,
|
||||
weights: restConfigs.weights,
|
||||
reranking_enable: restConfigs.reranking_enable,
|
||||
}, newDatasets)
|
||||
}, newDatasets, dataSets, !!currentRerankModel)
|
||||
|
||||
setDatasetConfigs({
|
||||
...retrievalConfig,
|
||||
|
@ -603,9 +609,11 @@ const Configuration: FC = () => {
|
|||
|
||||
syncToPublishedConfig(config)
|
||||
setPublishedConfig(config)
|
||||
const retrievalConfig = getMultipleRetrievalConfig(modelConfig.dataset_configs, datasets, datasets, !!currentRerankModel)
|
||||
setDatasetConfigs({
|
||||
retrieval_model: RETRIEVE_TYPE.multiWay,
|
||||
...modelConfig.dataset_configs,
|
||||
...retrievalConfig,
|
||||
})
|
||||
setHasFetchedDetail(true)
|
||||
})
|
||||
|
|
|
@ -163,7 +163,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
draft.retrieval_mode = newMode
|
||||
if (newMode === RETRIEVE_TYPE.multiWay) {
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets)
|
||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel)
|
||||
}
|
||||
else {
|
||||
const hasSetModel = draft.single_retrieval_config?.model?.provider
|
||||
|
@ -180,14 +180,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
}
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets])
|
||||
}, [currentModel?.model, currentModel?.model_properties?.mode, currentProvider?.provider, inputs, setInputs, selectedDatasets, currentRerankModel])
|
||||
|
||||
const handleMultipleRetrievalConfigChange = useCallback((newConfig: MultipleRetrievalConfig) => {
|
||||
const newInputs = produce(inputs, (draft) => {
|
||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets)
|
||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(newConfig!, selectedDatasets, selectedDatasets, !!currentRerankModel)
|
||||
})
|
||||
setInputs(newInputs)
|
||||
}, [inputs, setInputs, selectedDatasets])
|
||||
}, [inputs, setInputs, selectedDatasets, currentRerankModel])
|
||||
|
||||
// datasets
|
||||
useEffect(() => {
|
||||
|
@ -231,7 +231,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
|
||||
if (payload.retrieval_mode === RETRIEVE_TYPE.multiWay && newDatasets.length > 0) {
|
||||
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets)
|
||||
draft.multiple_retrieval_config = getMultipleRetrievalConfig(multipleRetrievalConfig!, newDatasets, selectedDatasets, !!currentRerankModel)
|
||||
}
|
||||
})
|
||||
setInputs(newInputs)
|
||||
|
@ -243,7 +243,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|||
|| (allExternal && newDatasets.length > 1)
|
||||
)
|
||||
setRerankModelOpen(true)
|
||||
}, [inputs, setInputs, payload.retrieval_mode])
|
||||
}, [inputs, setInputs, payload.retrieval_mode, selectedDatasets, currentRerankModel])
|
||||
|
||||
const filterVar = useCallback((varPayload: Var) => {
|
||||
return varPayload.type === VarType.string
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
import { uniq } from 'lodash-es'
|
||||
import {
|
||||
uniq,
|
||||
xorBy,
|
||||
} from 'lodash-es'
|
||||
import type { MultipleRetrievalConfig } from './types'
|
||||
import type {
|
||||
DataSet,
|
||||
|
@ -15,7 +18,9 @@ export const checkNodeValid = () => {
|
|||
return true
|
||||
}
|
||||
|
||||
export const getSelectedDatasetsMode = (datasets: DataSet[]) => {
|
||||
export const getSelectedDatasetsMode = (datasets: DataSet[] = []) => {
|
||||
if (datasets === null)
|
||||
datasets = []
|
||||
let allHighQuality = true
|
||||
let allHighQualityVectorSearch = true
|
||||
let allHighQualityFullTextSearch = true
|
||||
|
@ -85,7 +90,14 @@ export const getSelectedDatasetsMode = (datasets: DataSet[]) => {
|
|||
} as SelectedDatasetsMode
|
||||
}
|
||||
|
||||
export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetrievalConfig, selectedDatasets: DataSet[]) => {
|
||||
export const getMultipleRetrievalConfig = (
|
||||
multipleRetrievalConfig: MultipleRetrievalConfig,
|
||||
selectedDatasets: DataSet[],
|
||||
originalDatasets: DataSet[],
|
||||
isValidRerankModel?: boolean,
|
||||
) => {
|
||||
const shouldSetWeightDefaultValue = xorBy(selectedDatasets, originalDatasets, 'id').length > 0
|
||||
|
||||
const {
|
||||
allHighQuality,
|
||||
allHighQualityVectorSearch,
|
||||
|
@ -123,6 +135,37 @@ export const getMultipleRetrievalConfig = (multipleRetrievalConfig: MultipleRetr
|
|||
result.reranking_mode = RerankingModeEnum.WeightedScore
|
||||
|
||||
if (allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined) && allInternal && !weights) {
|
||||
if (!isValidRerankModel)
|
||||
result.reranking_mode = RerankingModeEnum.WeightedScore
|
||||
else
|
||||
result.reranking_mode = RerankingModeEnum.RerankingModel
|
||||
|
||||
result.weights = {
|
||||
vector_setting: {
|
||||
vector_weight: allHighQualityVectorSearch
|
||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.semantic
|
||||
: allHighQualityFullTextSearch
|
||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.semantic
|
||||
: DEFAULT_WEIGHTED_SCORE.other.semantic,
|
||||
embedding_provider_name: selectedDatasets[0].embedding_model_provider,
|
||||
embedding_model_name: selectedDatasets[0].embedding_model,
|
||||
},
|
||||
keyword_setting: {
|
||||
keyword_weight: allHighQualityVectorSearch
|
||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityVectorSearch.keyword
|
||||
: allHighQualityFullTextSearch
|
||||
? DEFAULT_WEIGHTED_SCORE.allHighQualityFullTextSearch.keyword
|
||||
: DEFAULT_WEIGHTED_SCORE.other.keyword,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if (shouldSetWeightDefaultValue && allHighQuality && !inconsistentEmbeddingModel && (reranking_mode === RerankingModeEnum.WeightedScore || reranking_mode === undefined || !isValidRerankModel) && allInternal && weights) {
|
||||
if (!isValidRerankModel)
|
||||
result.reranking_mode = RerankingModeEnum.WeightedScore
|
||||
else
|
||||
result.reranking_mode = RerankingModeEnum.RerankingModel
|
||||
|
||||
result.weights = {
|
||||
vector_setting: {
|
||||
vector_weight: allHighQualityVectorSearch
|
||||
|
|
|
@ -566,14 +566,6 @@ export const DEFAULT_WEIGHTED_SCORE = {
|
|||
semantic: 0,
|
||||
keyword: 1.0,
|
||||
},
|
||||
semanticFirst: {
|
||||
semantic: 0.7,
|
||||
keyword: 0.3,
|
||||
},
|
||||
keywordFirst: {
|
||||
semantic: 0.3,
|
||||
keyword: 0.7,
|
||||
},
|
||||
other: {
|
||||
semantic: 0.7,
|
||||
keyword: 0.3,
|
||||
|
|
Loading…
Reference in New Issue
Block a user