mirror of
https://github.com/langgenius/dify.git
synced 2024-11-16 03:32:23 +08:00
Feat: rerank model verification in front end (#9271)
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
This commit is contained in:
parent
c6b74daa0a
commit
793205afc5
|
@ -1,6 +1,6 @@
|
|||
'use client'
|
||||
|
||||
import { memo, useEffect, useMemo } from 'react'
|
||||
import { memo, useCallback, useEffect, useMemo } from 'react'
|
||||
import type { FC } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import WeightedScore from './weighted-score'
|
||||
|
@ -11,7 +11,7 @@ import type {
|
|||
DatasetConfigs,
|
||||
} from '@/models/debug'
|
||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import type { ModelConfig } from '@/app/components/workflow/types'
|
||||
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
|
@ -23,6 +23,7 @@ import { RerankingModeEnum } from '@/models/datasets'
|
|||
import cn from '@/utils/classnames'
|
||||
import { useSelectedDatasetsMode } from '@/app/components/workflow/nodes/knowledge-retrieval/hooks'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
|
||||
type Props = {
|
||||
datasetConfigs: DatasetConfigs
|
||||
|
@ -60,6 +61,24 @@ const ConfigContent: FC<Props> = ({
|
|||
modelList: rerankModelList,
|
||||
defaultModel: rerankDefaultModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const {
|
||||
currentModel,
|
||||
} = useCurrentProviderAndModel(
|
||||
rerankModelList,
|
||||
rerankDefaultModel
|
||||
? {
|
||||
...rerankDefaultModel,
|
||||
provider: rerankDefaultModel.provider.provider,
|
||||
}
|
||||
: undefined,
|
||||
)
|
||||
|
||||
const handleDisabledSwitchClick = useCallback(() => {
|
||||
if (!currentModel)
|
||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||
}, [currentModel, rerankDefaultModel, t])
|
||||
|
||||
const rerankModel = (() => {
|
||||
if (datasetConfigs.reranking_model?.reranking_provider_name) {
|
||||
return {
|
||||
|
@ -231,16 +250,22 @@ const ConfigContent: FC<Props> = ({
|
|||
<div className='flex items-center'>
|
||||
{
|
||||
selectedDatasetsMode.allEconomic && (
|
||||
<Switch
|
||||
size='md'
|
||||
defaultValue={showRerankModel}
|
||||
onChange={(v) => {
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_enable: v,
|
||||
})
|
||||
}}
|
||||
/>
|
||||
<div
|
||||
className='flex items-center'
|
||||
onClick={handleDisabledSwitchClick}
|
||||
>
|
||||
<Switch
|
||||
size='md'
|
||||
defaultValue={currentModel ? showRerankModel : false}
|
||||
disabled={!currentModel}
|
||||
onChange={(v) => {
|
||||
onChange({
|
||||
...datasetConfigs,
|
||||
reranking_enable: v,
|
||||
})
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<div className='leading-[32px] ml-1 text-text-secondary system-sm-semibold'>{t('common.modelProvider.rerankModel.key')}</div>
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import React, { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import cn from '@/utils/classnames'
|
||||
|
@ -11,7 +11,7 @@ import Switch from '@/app/components/base/switch'
|
|||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import type { RetrievalConfig } from '@/types/app'
|
||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||
import { useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { useCurrentProviderAndModel, useModelListAndDefaultModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import {
|
||||
DEFAULT_WEIGHTED_SCORE,
|
||||
|
@ -19,6 +19,7 @@ import {
|
|||
WeightedScoreEnum,
|
||||
} from '@/models/datasets'
|
||||
import WeightedScore from '@/app/components/app/configuration/dataset-config/params-config/weighted-score'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
|
||||
type Props = {
|
||||
type: RETRIEVE_METHOD
|
||||
|
@ -38,6 +39,24 @@ const RetrievalParamConfig: FC<Props> = ({
|
|||
defaultModel: rerankDefaultModel,
|
||||
modelList: rerankModelList,
|
||||
} = useModelListAndDefaultModel(ModelTypeEnum.rerank)
|
||||
|
||||
const {
|
||||
currentModel,
|
||||
} = useCurrentProviderAndModel(
|
||||
rerankModelList,
|
||||
rerankDefaultModel
|
||||
? {
|
||||
...rerankDefaultModel,
|
||||
provider: rerankDefaultModel.provider.provider,
|
||||
}
|
||||
: undefined,
|
||||
)
|
||||
|
||||
const handleDisabledSwitchClick = useCallback(() => {
|
||||
if (!currentModel)
|
||||
Toast.notify({ type: 'error', message: t('workflow.errorMsg.rerankModelRequired') })
|
||||
}, [currentModel, rerankDefaultModel, t])
|
||||
|
||||
const isHybridSearch = type === RETRIEVE_METHOD.hybrid
|
||||
|
||||
const rerankModel = (() => {
|
||||
|
@ -99,16 +118,22 @@ const RetrievalParamConfig: FC<Props> = ({
|
|||
<div>
|
||||
<div className='flex h-8 items-center text-[13px] font-medium text-gray-900 space-x-2'>
|
||||
{canToggleRerankModalEnable && (
|
||||
<Switch
|
||||
size='md'
|
||||
defaultValue={value.reranking_enable}
|
||||
onChange={(v) => {
|
||||
onChange({
|
||||
...value,
|
||||
reranking_enable: v,
|
||||
})
|
||||
}}
|
||||
/>
|
||||
<div
|
||||
className='flex items-center'
|
||||
onClick={handleDisabledSwitchClick}
|
||||
>
|
||||
<Switch
|
||||
size='md'
|
||||
defaultValue={currentModel ? value.reranking_enable : false}
|
||||
onChange={(v) => {
|
||||
onChange({
|
||||
...value,
|
||||
reranking_enable: v,
|
||||
})
|
||||
}}
|
||||
disabled={!currentModel}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div className='flex items-center'>
|
||||
<span className='mr-0.5'>{t('common.modelProvider.rerankModel.key')}</span>
|
||||
|
|
|
@ -1,17 +1,25 @@
|
|||
import { useCallback } from 'react'
|
||||
import { useStoreApi } from 'reactflow'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useWorkflowStore } from '../store'
|
||||
import {
|
||||
BlockEnum,
|
||||
WorkflowRunningStatus,
|
||||
} from '../types'
|
||||
import type { KnowledgeRetrievalNodeType } from '../nodes/knowledge-retrieval/types'
|
||||
import type { Node } from '../types'
|
||||
import { useWorkflow } from './use-workflow'
|
||||
import {
|
||||
useIsChatMode,
|
||||
useNodesSyncDraft,
|
||||
useWorkflowInteractions,
|
||||
useWorkflowRun,
|
||||
} from './index'
|
||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { useFeaturesStore } from '@/app/components/base/features/hooks'
|
||||
import KnowledgeRetrievalDefault from '@/app/components/workflow/nodes/knowledge-retrieval/default'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
|
||||
export const useWorkflowStartRun = () => {
|
||||
const store = useStoreApi()
|
||||
|
@ -20,7 +28,26 @@ export const useWorkflowStartRun = () => {
|
|||
const isChatMode = useIsChatMode()
|
||||
const { handleCancelDebugAndPreviewPanel } = useWorkflowInteractions()
|
||||
const { handleRun } = useWorkflowRun()
|
||||
const { isFromStartNode } = useWorkflow()
|
||||
const { doSyncWorkflowDraft } = useNodesSyncDraft()
|
||||
const { checkValid: checkKnowledgeRetrievalValid } = KnowledgeRetrievalDefault
|
||||
const { t } = useTranslation()
|
||||
const {
|
||||
modelList: rerankModelList,
|
||||
defaultModel: rerankDefaultModel,
|
||||
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
||||
|
||||
const {
|
||||
currentModel,
|
||||
} = useCurrentProviderAndModel(
|
||||
rerankModelList,
|
||||
rerankDefaultModel
|
||||
? {
|
||||
...rerankDefaultModel,
|
||||
provider: rerankDefaultModel.provider.provider,
|
||||
}
|
||||
: undefined,
|
||||
)
|
||||
|
||||
const handleWorkflowStartRunInWorkflow = useCallback(async () => {
|
||||
const {
|
||||
|
@ -33,6 +60,9 @@ export const useWorkflowStartRun = () => {
|
|||
const { getNodes } = store.getState()
|
||||
const nodes = getNodes()
|
||||
const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
|
||||
const knowledgeRetrievalNodes = nodes.filter((node: Node<KnowledgeRetrievalNodeType>) =>
|
||||
node.data.type === BlockEnum.KnowledgeRetrieval,
|
||||
)
|
||||
const startVariables = startNode?.data.variables || []
|
||||
const fileSettings = featuresStore!.getState().features.file
|
||||
const {
|
||||
|
@ -42,6 +72,31 @@ export const useWorkflowStartRun = () => {
|
|||
setShowEnvPanel,
|
||||
} = workflowStore.getState()
|
||||
|
||||
if (knowledgeRetrievalNodes.length > 0) {
|
||||
for (const node of knowledgeRetrievalNodes) {
|
||||
if (isFromStartNode(node.id)) {
|
||||
const res = checkKnowledgeRetrievalValid(node.data, t)
|
||||
if (!res.isValid || !currentModel || !rerankDefaultModel) {
|
||||
const errorMessage = res.errorMessage
|
||||
if (errorMessage) {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: errorMessage,
|
||||
})
|
||||
return false
|
||||
}
|
||||
else {
|
||||
Toast.notify({
|
||||
type: 'error',
|
||||
message: t('appDebug.datasetConfig.rerankModelRequired'),
|
||||
})
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setShowEnvPanel(false)
|
||||
|
||||
if (showDebugAndPreviewPanel) {
|
||||
|
|
|
@ -235,6 +235,33 @@ export const useWorkflow = () => {
|
|||
return nodes.filter(node => node.parentId === nodeId)
|
||||
}, [store])
|
||||
|
||||
const isFromStartNode = useCallback((nodeId: string) => {
|
||||
const { getNodes } = store.getState()
|
||||
const nodes = getNodes()
|
||||
const currentNode = nodes.find(node => node.id === nodeId)
|
||||
|
||||
if (!currentNode)
|
||||
return false
|
||||
|
||||
if (currentNode.data.type === BlockEnum.Start)
|
||||
return true
|
||||
|
||||
const checkPreviousNodes = (node: Node) => {
|
||||
const previousNodes = getBeforeNodeById(node.id)
|
||||
|
||||
for (const prevNode of previousNodes) {
|
||||
if (prevNode.data.type === BlockEnum.Start)
|
||||
return true
|
||||
if (checkPreviousNodes(prevNode))
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
return checkPreviousNodes(currentNode)
|
||||
}, [store, getBeforeNodeById])
|
||||
|
||||
const handleOutVarRenameChange = useCallback((nodeId: string, oldValeSelector: ValueSelector, newVarSelector: ValueSelector) => {
|
||||
const { getNodes, setNodes } = store.getState()
|
||||
const afterNodes = getAfterNodesInSameBranch(nodeId)
|
||||
|
@ -389,6 +416,7 @@ export const useWorkflow = () => {
|
|||
checkParallelLimit,
|
||||
checkNestedParallelLimit,
|
||||
isValidConnection,
|
||||
isFromStartNode,
|
||||
formatTimeFromNow,
|
||||
getNode,
|
||||
getBeforeNodeById,
|
||||
|
|
|
@ -172,6 +172,7 @@ const translation = {
|
|||
},
|
||||
errorMsg: {
|
||||
fieldRequired: '{{field}} is required',
|
||||
rerankModelRequired: 'Before turning on the Rerank Model, please confirm that the model has been successfully configured in the settings.',
|
||||
authRequired: 'Authorization is required',
|
||||
invalidJson: '{{field}} is invalid JSON',
|
||||
fields: {
|
||||
|
|
|
@ -172,6 +172,7 @@ const translation = {
|
|||
},
|
||||
errorMsg: {
|
||||
fieldRequired: '{{field}} 不能为空',
|
||||
rerankModelRequired: '开启 Rerank 模型前,请务必确认模型已在设置中成功配置。',
|
||||
authRequired: '请先授权',
|
||||
invalidJson: '{{field}} 是非法的 JSON',
|
||||
fields: {
|
||||
|
|
Loading…
Reference in New Issue
Block a user