multi_dataset_router_chain.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. from typing import Mapping, List, Dict, Any, Optional
  2. from langchain import LLMChain, PromptTemplate, ConversationChain
  3. from langchain.callbacks import CallbackManager
  4. from langchain.chains.base import Chain
  5. from langchain.schema import BaseLanguageModel
  6. from pydantic import Extra
  7. from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
  8. from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
  9. from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
  10. from core.conversation_message_task import ConversationMessageTask
  11. from core.llm.llm_builder import LLMBuilder
  12. from core.tool.dataset_tool_builder import DatasetToolBuilder
  13. from core.tool.llama_index_tool import EnhanceLlamaIndexTool
  14. from models.dataset import Dataset
  15. MULTI_PROMPT_ROUTER_TEMPLATE = """
  16. Given a raw text input to a language model select the model prompt best suited for \
  17. the input. You will be given the names of the available prompts and a description of \
  18. what the prompt is best suited for. You may also revise the original input if you \
  19. think that revising it will ultimately lead to a better response from the language \
  20. model.
  21. << FORMATTING >>
  22. Return a markdown code snippet with a JSON object formatted to look like:
  23. ```json
  24. {{{{
  25. "destination": string \\ name of the prompt to use or "DEFAULT"
  26. "next_inputs": string \\ a potentially modified version of the original input
  27. }}}}
  28. ```
  29. REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
  30. it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
  31. REMEMBER: "next_inputs" can just be the original input if you don't think any \
  32. modifications are needed.
  33. << CANDIDATE PROMPTS >>
  34. {destinations}
  35. << INPUT >>
  36. {{input}}
  37. << OUTPUT >>
  38. """
  39. class MultiDatasetRouterChain(Chain):
  40. """Use a single chain to route an input to one of multiple candidate chains."""
  41. router_chain: LLMRouterChain
  42. """Chain for deciding a destination chain and the input to it."""
  43. dataset_tools: Mapping[str, EnhanceLlamaIndexTool]
  44. """Map of name to candidate chains that inputs can be routed to."""
  45. class Config:
  46. """Configuration for this pydantic object."""
  47. extra = Extra.forbid
  48. arbitrary_types_allowed = True
  49. @property
  50. def input_keys(self) -> List[str]:
  51. """Will be whatever keys the router chain prompt expects.
  52. :meta private:
  53. """
  54. return self.router_chain.input_keys
  55. @property
  56. def output_keys(self) -> List[str]:
  57. return ["text"]
  58. @classmethod
  59. def from_datasets(
  60. cls,
  61. tenant_id: str,
  62. datasets: List[Dataset],
  63. conversation_message_task: ConversationMessageTask,
  64. **kwargs: Any,
  65. ):
  66. """Convenience constructor for instantiating from destination prompts."""
  67. llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
  68. llm = LLMBuilder.to_llm(
  69. tenant_id=tenant_id,
  70. model_name='gpt-3.5-turbo',
  71. temperature=0,
  72. max_tokens=1024,
  73. callback_manager=llm_callback_manager
  74. )
  75. destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ')) for d in datasets]
  76. destinations_str = "\n".join(destinations)
  77. router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
  78. destinations=destinations_str
  79. )
  80. router_prompt = PromptTemplate(
  81. template=router_template,
  82. input_variables=["input"],
  83. output_parser=RouterOutputParser(),
  84. )
  85. router_chain = LLMRouterChain.from_llm(llm, router_prompt)
  86. dataset_tools = {}
  87. for dataset in datasets:
  88. dataset_tool = DatasetToolBuilder.build_dataset_tool(
  89. dataset=dataset,
  90. response_mode='no_synthesizer', # "compact"
  91. callback_handler=DatasetToolCallbackHandler(conversation_message_task)
  92. )
  93. dataset_tools[dataset.id] = dataset_tool
  94. return cls(
  95. router_chain=router_chain,
  96. dataset_tools=dataset_tools,
  97. **kwargs,
  98. )
  99. def _call(
  100. self,
  101. inputs: Dict[str, Any]
  102. ) -> Dict[str, Any]:
  103. if len(self.dataset_tools) == 0:
  104. return {"text": ''}
  105. elif len(self.dataset_tools) == 1:
  106. return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}
  107. route = self.router_chain.route(inputs)
  108. if not route.destination:
  109. return {"text": ''}
  110. elif route.destination in self.dataset_tools:
  111. return {"text": self.dataset_tools[route.destination].run(
  112. route.next_inputs['input']
  113. )}
  114. else:
  115. raise ValueError(
  116. f"Received invalid destination chain name '{route.destination}'"
  117. )