dify/web/context/provider-context.tsx
2023-08-22 14:55:20 +08:00

73 lines
3.0 KiB
TypeScript

'use client'
import { createContext, useContext } from 'use-context-selector'
import useSWR from 'swr'
import { fetchDefaultModal, fetchModelList } from '@/service/common'
import { ModelFeature, ModelType } from '@/app/components/header/account-setting/model-page/declarations'
import type { BackendModel } from '@/app/components/header/account-setting/model-page/declarations'
const ProviderContext = createContext<{
textGenerationModelList: BackendModel[]
embeddingsModelList: BackendModel[]
speech2textModelList: BackendModel[]
agentThoughtModelList: BackendModel[]
updateModelList: (type: ModelType) => void
embeddingsDefaultModel?: BackendModel
mutateEmbeddingsDefaultModel: () => void
speech2textDefaultModel?: BackendModel
mutateSpeech2textDefaultModel: () => void
}>({
textGenerationModelList: [],
embeddingsModelList: [],
speech2textModelList: [],
agentThoughtModelList: [],
updateModelList: () => {},
speech2textDefaultModel: undefined,
mutateSpeech2textDefaultModel: () => {},
embeddingsDefaultModel: undefined,
mutateEmbeddingsDefaultModel: () => {},
})
export const useProviderContext = () => useContext(ProviderContext)
type ProviderContextProviderProps = {
children: React.ReactNode
}
export const ProviderContextProvider = ({
children,
}: ProviderContextProviderProps) => {
const { data: embeddingsDefaultModel, mutate: mutateEmbeddingsDefaultModel } = useSWR('/workspaces/current/default-model?model_type=embeddings', fetchDefaultModal)
const { data: speech2textDefaultModel, mutate: mutateSpeech2textDefaultModel } = useSWR('/workspaces/current/default-model?model_type=speech2text', fetchDefaultModal)
const fetchModelListUrlPrefix = '/workspaces/current/models/model-type/'
const { data: textGenerationModelList, mutate: mutateTextGenerationModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.textGeneration}`, fetchModelList)
const { data: embeddingsModelList, mutate: mutateEmbeddingsModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.embeddings}`, fetchModelList)
const { data: speech2textModelList } = useSWR(`${fetchModelListUrlPrefix}${ModelType.speech2text}`, fetchModelList)
const agentThoughtModelList = textGenerationModelList?.filter((item) => {
return item.features?.includes(ModelFeature.agentThought)
})
const updateModelList = (type: ModelType) => {
if (type === ModelType.textGeneration)
mutateTextGenerationModelList()
if (type === ModelType.embeddings)
mutateEmbeddingsModelList()
}
return (
<ProviderContext.Provider value={{
textGenerationModelList: textGenerationModelList || [],
embeddingsModelList: embeddingsModelList || [],
speech2textModelList: speech2textModelList || [],
agentThoughtModelList: agentThoughtModelList || [],
updateModelList,
embeddingsDefaultModel,
mutateEmbeddingsDefaultModel,
speech2textDefaultModel,
mutateSpeech2textDefaultModel,
}}>
{children}
</ProviderContext.Provider>
)
}
export default ProviderContext