default.ts 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import { BlockEnum } from '../../types'
  2. import type { NodeDefault } from '../../types'
  3. import type { KnowledgeRetrievalNodeType } from './types'
  4. import { ALL_CHAT_AVAILABLE_BLOCKS, ALL_COMPLETION_AVAILABLE_BLOCKS } from '@/app/components/workflow/constants'
  5. import { DATASET_DEFAULT } from '@/config'
  6. import { RETRIEVE_TYPE } from '@/types/app'
  7. const i18nPrefix = 'workflow'
  8. const nodeDefault: NodeDefault<KnowledgeRetrievalNodeType> = {
  9. defaultValue: {
  10. query_variable_selector: [],
  11. dataset_ids: [],
  12. retrieval_mode: RETRIEVE_TYPE.multiWay,
  13. multiple_retrieval_config: {
  14. top_k: DATASET_DEFAULT.top_k,
  15. score_threshold: undefined,
  16. reranking_enable: false,
  17. },
  18. },
  19. getAvailablePrevNodes(isChatMode: boolean) {
  20. const nodes = isChatMode
  21. ? ALL_CHAT_AVAILABLE_BLOCKS
  22. : ALL_COMPLETION_AVAILABLE_BLOCKS.filter(type => type !== BlockEnum.End)
  23. return nodes
  24. },
  25. getAvailableNextNodes(isChatMode: boolean) {
  26. const nodes = isChatMode ? ALL_CHAT_AVAILABLE_BLOCKS : ALL_COMPLETION_AVAILABLE_BLOCKS
  27. return nodes
  28. },
  29. checkValid(payload: KnowledgeRetrievalNodeType, t: any) {
  30. let errorMessages = ''
  31. if (!errorMessages && (!payload.query_variable_selector || payload.query_variable_selector.length === 0))
  32. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.queryVariable`) })
  33. if (!errorMessages && (!payload.dataset_ids || payload.dataset_ids.length === 0))
  34. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.nodes.knowledgeRetrieval.knowledge`) })
  35. if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.multiWay && !payload.multiple_retrieval_config?.reranking_model?.provider)
  36. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t(`${i18nPrefix}.errorMsg.fields.rerankModel`) })
  37. if (!errorMessages && payload.retrieval_mode === RETRIEVE_TYPE.oneWay && !payload.single_retrieval_config?.model?.provider)
  38. errorMessages = t(`${i18nPrefix}.errorMsg.fieldRequired`, { field: t('common.modelProvider.systemReasoningModel.key') })
  39. return {
  40. isValid: !errorMessages,
  41. errorMessage: errorMessages,
  42. }
  43. },
  44. }
  45. export default nodeDefault