main_chain_builder.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. from typing import Optional, List
  2. from langchain.callbacks import SharedCallbackManager, CallbackManager
  3. from langchain.chains import SequentialChain
  4. from langchain.chains.base import Chain
  5. from langchain.memory.chat_memory import BaseChatMemory
  6. from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
  7. from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
  8. from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
  9. from core.chain.chain_builder import ChainBuilder
  10. from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
  11. from core.conversation_message_task import ConversationMessageTask
  12. from extensions.ext_database import db
  13. from models.dataset import Dataset
  14. class MainChainBuilder:
  15. @classmethod
  16. def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
  17. conversation_message_task: ConversationMessageTask):
  18. first_input_key = "input"
  19. final_output_key = "output"
  20. chains = []
  21. chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task)
  22. # agent mode
  23. tool_chains, chains_output_key = cls.get_agent_chains(
  24. tenant_id=tenant_id,
  25. agent_mode=agent_mode,
  26. memory=memory,
  27. conversation_message_task=conversation_message_task
  28. )
  29. chains += tool_chains
  30. if chains_output_key:
  31. final_output_key = chains_output_key
  32. if len(chains) == 0:
  33. return None
  34. for chain in chains:
  35. # do not add handler into singleton callback manager
  36. if not isinstance(chain.callback_manager, SharedCallbackManager):
  37. chain.callback_manager.add_handler(chain_callback_handler)
  38. # build main chain
  39. overall_chain = SequentialChain(
  40. chains=chains,
  41. input_variables=[first_input_key],
  42. output_variables=[final_output_key],
  43. memory=memory, # only for use the memory prompt input key
  44. )
  45. return overall_chain
  46. @classmethod
  47. def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
  48. conversation_message_task: ConversationMessageTask):
  49. # agent mode
  50. chains = []
  51. if agent_mode and agent_mode.get('enabled'):
  52. tools = agent_mode.get('tools', [])
  53. pre_fixed_chains = []
  54. # agent_tools = []
  55. datasets = []
  56. for tool in tools:
  57. tool_type = list(tool.keys())[0]
  58. tool_config = list(tool.values())[0]
  59. if tool_type == 'sensitive-word-avoidance':
  60. chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)
  61. if chain:
  62. pre_fixed_chains.append(chain)
  63. elif tool_type == "dataset":
  64. # get dataset from dataset id
  65. dataset = db.session.query(Dataset).filter(
  66. Dataset.tenant_id == tenant_id,
  67. Dataset.id == tool_config.get("id")
  68. ).first()
  69. if dataset:
  70. datasets.append(dataset)
  71. # add pre-fixed chains
  72. chains += pre_fixed_chains
  73. if len(datasets) > 0:
  74. # tool to chain
  75. multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
  76. tenant_id=tenant_id,
  77. datasets=datasets,
  78. conversation_message_task=conversation_message_task,
  79. callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
  80. )
  81. chains.append(multi_dataset_router_chain)
  82. final_output_key = cls.get_chains_output_key(chains)
  83. return chains, final_output_key
  84. @classmethod
  85. def get_chains_output_key(cls, chains: List[Chain]):
  86. if len(chains) > 0:
  87. return chains[-1].output_keys[0]
  88. return None