multi_dataset_router_chain.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import math
  2. from typing import Mapping, List, Dict, Any, Optional
  3. from langchain import PromptTemplate
  4. from langchain.callbacks.manager import CallbackManagerForChainRun
  5. from langchain.chains.base import Chain
  6. from pydantic import Extra
  7. from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
  8. from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
  9. from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
  10. from core.conversation_message_task import ConversationMessageTask
  11. from core.llm.llm_builder import LLMBuilder
  12. from core.tool.dataset_index_tool import DatasetTool
  13. from models.dataset import Dataset, DatasetProcessRule
  14. DEFAULT_K = 2
  15. CONTEXT_TOKENS_PERCENT = 0.3
  16. MULTI_PROMPT_ROUTER_TEMPLATE = """
  17. Given a raw text input to a language model select the model prompt best suited for \
  18. the input. You will be given the names of the available prompts and a description of \
  19. what the prompt is best suited for. You may also revise the original input if you \
  20. think that revising it will ultimately lead to a better response from the language \
  21. model.
  22. << FORMATTING >>
  23. Return a markdown code snippet with a JSON object formatted to look like, \
  24. no any other string out of markdown code snippet:
  25. ```json
  26. {{{{
  27. "destination": string \\ name of the prompt to use or "DEFAULT"
  28. "next_inputs": string \\ a potentially modified version of the original input
  29. }}}}
  30. ```
  31. REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
  32. it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
  33. REMEMBER: "next_inputs" can just be the original input if you don't think any \
  34. modifications are needed.
  35. << CANDIDATE PROMPTS >>
  36. {destinations}
  37. << INPUT >>
  38. {{input}}
  39. << OUTPUT >>
  40. """
  41. class MultiDatasetRouterChain(Chain):
  42. """Use a single chain to route an input to one of multiple candidate chains."""
  43. router_chain: LLMRouterChain
  44. """Chain for deciding a destination chain and the input to it."""
  45. dataset_tools: Mapping[str, DatasetTool]
  46. """Map of name to candidate chains that inputs can be routed to."""
  47. class Config:
  48. """Configuration for this pydantic object."""
  49. extra = Extra.forbid
  50. arbitrary_types_allowed = True
  51. @property
  52. def input_keys(self) -> List[str]:
  53. """Will be whatever keys the router chain prompt expects.
  54. :meta private:
  55. """
  56. return self.router_chain.input_keys
  57. @property
  58. def output_keys(self) -> List[str]:
  59. return ["text"]
  60. @classmethod
  61. def from_datasets(
  62. cls,
  63. tenant_id: str,
  64. datasets: List[Dataset],
  65. conversation_message_task: ConversationMessageTask,
  66. rest_tokens: int,
  67. **kwargs: Any,
  68. ):
  69. """Convenience constructor for instantiating from destination prompts."""
  70. llm = LLMBuilder.to_llm(
  71. tenant_id=tenant_id,
  72. model_name='gpt-3.5-turbo',
  73. temperature=0,
  74. max_tokens=1024,
  75. callbacks=[DifyStdOutCallbackHandler()]
  76. )
  77. destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
  78. else ('useful for when you want to answer queries about the ' + d.name))
  79. for d in datasets]
  80. destinations_str = "\n".join(destinations)
  81. router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
  82. destinations=destinations_str
  83. )
  84. router_prompt = PromptTemplate(
  85. template=router_template,
  86. input_variables=["input"],
  87. output_parser=RouterOutputParser(),
  88. )
  89. router_chain = LLMRouterChain.from_llm(llm, router_prompt)
  90. dataset_tools = {}
  91. for dataset in datasets:
  92. # fulfill description when it is empty
  93. if dataset.available_document_count == 0 or dataset.available_document_count == 0:
  94. continue
  95. description = dataset.description
  96. if not description:
  97. description = 'useful for when you want to answer queries about the ' + dataset.name
  98. k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
  99. if k == 0:
  100. continue
  101. dataset_tool = DatasetTool(
  102. name=f"dataset-{dataset.id}",
  103. description=description,
  104. k=k,
  105. dataset=dataset,
  106. callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
  107. )
  108. dataset_tools[str(dataset.id)] = dataset_tool
  109. return cls(
  110. router_chain=router_chain,
  111. dataset_tools=dataset_tools,
  112. **kwargs,
  113. )
  114. @classmethod
  115. def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
  116. processing_rule = dataset.latest_process_rule
  117. if not processing_rule:
  118. return DEFAULT_K
  119. if processing_rule.mode == "custom":
  120. rules = processing_rule.rules_dict
  121. if not rules:
  122. return DEFAULT_K
  123. segmentation = rules["segmentation"]
  124. segment_max_tokens = segmentation["max_tokens"]
  125. else:
  126. segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
  127. # when rest_tokens is less than default context tokens
  128. if rest_tokens < segment_max_tokens * DEFAULT_K:
  129. return rest_tokens // segment_max_tokens
  130. context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
  131. # when context_limit_tokens is less than default context tokens, use default_k
  132. if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
  133. return DEFAULT_K
  134. # Expand the k value when there's still some room left in the 30% rest tokens space
  135. return context_limit_tokens // segment_max_tokens
  136. def _call(
  137. self,
  138. inputs: Dict[str, Any],
  139. run_manager: Optional[CallbackManagerForChainRun] = None,
  140. ) -> Dict[str, Any]:
  141. if len(self.dataset_tools) == 0:
  142. return {"text": ''}
  143. elif len(self.dataset_tools) == 1:
  144. return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}
  145. route = self.router_chain.route(inputs)
  146. if not route.destination:
  147. return {"text": ''}
  148. elif route.destination in self.dataset_tools:
  149. return {"text": self.dataset_tools[route.destination].run(
  150. route.next_inputs['input']
  151. )}
  152. else:
  153. raise ValueError(
  154. f"Received invalid destination chain name '{route.destination}'"
  155. )