multi_dataset_router_chain.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import math
  2. import re
  3. from typing import Mapping, List, Dict, Any, Optional
  4. from langchain import PromptTemplate
  5. from langchain.callbacks.manager import CallbackManagerForChainRun
  6. from langchain.chains.base import Chain
  7. from pydantic import Extra
  8. from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
  9. from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
  10. from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
  11. from core.conversation_message_task import ConversationMessageTask
  12. from core.llm.llm_builder import LLMBuilder
  13. from core.tool.dataset_index_tool import DatasetTool
  14. from models.dataset import Dataset, DatasetProcessRule
  15. DEFAULT_K = 2
  16. CONTEXT_TOKENS_PERCENT = 0.3
  17. MULTI_PROMPT_ROUTER_TEMPLATE = """
  18. Given a raw text input to a language model select the model prompt best suited for \
  19. the input. You will be given the names of the available prompts and a description of \
  20. what the prompt is best suited for. You may also revise the original input if you \
  21. think that revising it will ultimately lead to a better response from the language \
  22. model.
  23. << FORMATTING >>
  24. Return a markdown code snippet with a JSON object formatted to look like, \
  25. no any other string out of markdown code snippet:
  26. ```json
  27. {{{{
  28. "destination": string \\ name of the prompt to use or "DEFAULT"
  29. "next_inputs": string \\ a potentially modified version of the original input
  30. }}}}
  31. ```
  32. REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
  33. it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
  34. REMEMBER: "next_inputs" can just be the original input if you don't think any \
  35. modifications are needed.
  36. << CANDIDATE PROMPTS >>
  37. {destinations}
  38. << INPUT >>
  39. {{input}}
  40. << OUTPUT >>
  41. """
  42. class MultiDatasetRouterChain(Chain):
  43. """Use a single chain to route an input to one of multiple candidate chains."""
  44. router_chain: LLMRouterChain
  45. """Chain for deciding a destination chain and the input to it."""
  46. dataset_tools: Mapping[str, DatasetTool]
  47. """Map of name to candidate chains that inputs can be routed to."""
  48. class Config:
  49. """Configuration for this pydantic object."""
  50. extra = Extra.forbid
  51. arbitrary_types_allowed = True
  52. @property
  53. def input_keys(self) -> List[str]:
  54. """Will be whatever keys the router chain prompt expects.
  55. :meta private:
  56. """
  57. return self.router_chain.input_keys
  58. @property
  59. def output_keys(self) -> List[str]:
  60. return ["text"]
  61. @classmethod
  62. def from_datasets(
  63. cls,
  64. tenant_id: str,
  65. datasets: List[Dataset],
  66. conversation_message_task: ConversationMessageTask,
  67. rest_tokens: int,
  68. **kwargs: Any,
  69. ):
  70. """Convenience constructor for instantiating from destination prompts."""
  71. llm = LLMBuilder.to_llm(
  72. tenant_id=tenant_id,
  73. model_name='gpt-3.5-turbo',
  74. temperature=0,
  75. max_tokens=1024,
  76. callbacks=[DifyStdOutCallbackHandler()]
  77. )
  78. destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
  79. else ('useful for when you want to answer queries about the ' + d.name))
  80. for d in datasets]
  81. destinations_str = "\n".join(destinations)
  82. router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
  83. destinations=destinations_str
  84. )
  85. router_prompt = PromptTemplate(
  86. template=router_template,
  87. input_variables=["input"],
  88. output_parser=RouterOutputParser(),
  89. )
  90. router_chain = LLMRouterChain.from_llm(llm, router_prompt)
  91. dataset_tools = {}
  92. for dataset in datasets:
  93. # fulfill description when it is empty
  94. if dataset.available_document_count == 0 or dataset.available_document_count == 0:
  95. continue
  96. description = dataset.description
  97. if not description:
  98. description = 'useful for when you want to answer queries about the ' + dataset.name
  99. k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
  100. if k == 0:
  101. continue
  102. dataset_tool = DatasetTool(
  103. name=f"dataset-{dataset.id}",
  104. description=description,
  105. k=k,
  106. dataset=dataset,
  107. callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
  108. )
  109. dataset_tools[str(dataset.id)] = dataset_tool
  110. return cls(
  111. router_chain=router_chain,
  112. dataset_tools=dataset_tools,
  113. **kwargs,
  114. )
  115. @classmethod
  116. def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
  117. processing_rule = dataset.latest_process_rule
  118. if not processing_rule:
  119. return DEFAULT_K
  120. if processing_rule.mode == "custom":
  121. rules = processing_rule.rules_dict
  122. if not rules:
  123. return DEFAULT_K
  124. segmentation = rules["segmentation"]
  125. segment_max_tokens = segmentation["max_tokens"]
  126. else:
  127. segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
  128. # when rest_tokens is less than default context tokens
  129. if rest_tokens < segment_max_tokens * DEFAULT_K:
  130. return rest_tokens // segment_max_tokens
  131. context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
  132. # when context_limit_tokens is less than default context tokens, use default_k
  133. if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
  134. return DEFAULT_K
  135. # Expand the k value when there's still some room left in the 30% rest tokens space
  136. return context_limit_tokens // segment_max_tokens
  137. def _call(
  138. self,
  139. inputs: Dict[str, Any],
  140. run_manager: Optional[CallbackManagerForChainRun] = None,
  141. ) -> Dict[str, Any]:
  142. if len(self.dataset_tools) == 0:
  143. return {"text": ''}
  144. elif len(self.dataset_tools) == 1:
  145. return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}
  146. route = self.router_chain.route(inputs)
  147. destination = ''
  148. if route.destination:
  149. pattern = r'\b[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\b'
  150. match = re.search(pattern, route.destination, re.IGNORECASE)
  151. if match:
  152. destination = match.group()
  153. if not destination:
  154. return {"text": ''}
  155. elif destination in self.dataset_tools:
  156. return {"text": self.dataset_tools[destination].run(
  157. route.next_inputs['input']
  158. )}
  159. else:
  160. raise ValueError(
  161. f"Received invalid destination chain name '{destination}'"
  162. )