orchestrator_rule_parser.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. import math
  2. from typing import Optional
  3. from langchain import WikipediaAPIWrapper
  4. from langchain.callbacks.manager import Callbacks
  5. from langchain.chat_models import ChatOpenAI
  6. from langchain.memory.chat_memory import BaseChatMemory
  7. from langchain.tools import BaseTool, Tool, WikipediaQueryRun
  8. from pydantic import BaseModel, Field
  9. from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
  10. from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
  11. from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
  12. from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
  13. from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
  14. from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
  15. from core.conversation_message_task import ConversationMessageTask
  16. from core.llm.llm_builder import LLMBuilder
  17. from core.tool.dataset_retriever_tool import DatasetRetrieverTool
  18. from core.tool.provider.serpapi_provider import SerpAPIToolProvider
  19. from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
  20. from core.tool.web_reader_tool import WebReaderTool
  21. from extensions.ext_database import db
  22. from models.dataset import Dataset, DatasetProcessRule
  23. from models.model import AppModelConfig
  24. class OrchestratorRuleParser:
  25. """Parse the orchestrator rule to entities."""
  26. def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
  27. self.tenant_id = tenant_id
  28. self.app_model_config = app_model_config
  29. self.agent_summary_model_name = "gpt-3.5-turbo-16k"
  30. def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
  31. rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
  32. -> Optional[AgentExecutor]:
  33. if not self.app_model_config.agent_mode_dict:
  34. return None
  35. agent_mode_config = self.app_model_config.agent_mode_dict
  36. model_dict = self.app_model_config.model_dict
  37. chain = None
  38. if agent_mode_config and agent_mode_config.get('enabled'):
  39. tool_configs = agent_mode_config.get('tools', [])
  40. agent_model_name = model_dict.get('name', 'gpt-4')
  41. # add agent callback to record agent thoughts
  42. agent_callback = AgentLoopGatherCallbackHandler(
  43. model_name=agent_model_name,
  44. conversation_message_task=conversation_message_task
  45. )
  46. chain_callback.agent_callback = agent_callback
  47. agent_llm = LLMBuilder.to_llm(
  48. tenant_id=self.tenant_id,
  49. model_name=agent_model_name,
  50. temperature=0,
  51. max_tokens=1500,
  52. callbacks=[agent_callback, DifyStdOutCallbackHandler()]
  53. )
  54. planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
  55. # only OpenAI chat model (include Azure) support function call, use ReACT instead
  56. if not isinstance(agent_llm, ChatOpenAI) \
  57. and planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
  58. planning_strategy = PlanningStrategy.REACT
  59. summary_llm = LLMBuilder.to_llm(
  60. tenant_id=self.tenant_id,
  61. model_name=self.agent_summary_model_name,
  62. temperature=0,
  63. max_tokens=500,
  64. callbacks=[DifyStdOutCallbackHandler()]
  65. )
  66. tools = self.to_tools(
  67. tool_configs=tool_configs,
  68. conversation_message_task=conversation_message_task,
  69. model_name=self.agent_summary_model_name,
  70. rest_tokens=rest_tokens,
  71. callbacks=[agent_callback, DifyStdOutCallbackHandler()]
  72. )
  73. if len(tools) == 0:
  74. return None
  75. agent_configuration = AgentConfiguration(
  76. strategy=planning_strategy,
  77. llm=agent_llm,
  78. tools=tools,
  79. summary_llm=summary_llm,
  80. memory=memory,
  81. callbacks=[chain_callback, agent_callback],
  82. max_iterations=10,
  83. max_execution_time=400.0,
  84. early_stopping_method="generate"
  85. )
  86. return AgentExecutor(agent_configuration)
  87. return chain
  88. def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
  89. -> Optional[SensitiveWordAvoidanceChain]:
  90. """
  91. Convert app sensitive word avoidance config to chain
  92. :param kwargs:
  93. :return:
  94. """
  95. if not self.app_model_config.sensitive_word_avoidance_dict:
  96. return None
  97. sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
  98. sensitive_words = sensitive_word_avoidance_config.get("words", "")
  99. if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
  100. return SensitiveWordAvoidanceChain(
  101. sensitive_words=sensitive_words.split(","),
  102. canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
  103. output_key="sensitive_word_avoidance_output",
  104. callbacks=callbacks,
  105. **kwargs
  106. )
  107. return None
  108. def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
  109. model_name: str, rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
  110. """
  111. Convert app agent tool configs to tools
  112. :param rest_tokens:
  113. :param tool_configs: app agent tool configs
  114. :param model_name:
  115. :param conversation_message_task:
  116. :param callbacks:
  117. :return:
  118. """
  119. tools = []
  120. for tool_config in tool_configs:
  121. tool_type = list(tool_config.keys())[0]
  122. tool_val = list(tool_config.values())[0]
  123. if not tool_val.get("enabled") or tool_val.get("enabled") is not True:
  124. continue
  125. tool = None
  126. if tool_type == "dataset":
  127. tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
  128. elif tool_type == "web_reader":
  129. tool = self.to_web_reader_tool(model_name)
  130. elif tool_type == "google_search":
  131. tool = self.to_google_search_tool()
  132. elif tool_type == "wikipedia":
  133. tool = self.to_wikipedia_tool()
  134. if tool:
  135. tool.callbacks.extend(callbacks)
  136. tools.append(tool)
  137. return tools
  138. def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
  139. rest_tokens: int) \
  140. -> Optional[BaseTool]:
  141. """
  142. A dataset tool is a tool that can be used to retrieve information from a dataset
  143. :param rest_tokens:
  144. :param tool_config:
  145. :param conversation_message_task:
  146. :return:
  147. """
  148. # get dataset from dataset id
  149. dataset = db.session.query(Dataset).filter(
  150. Dataset.tenant_id == self.tenant_id,
  151. Dataset.id == tool_config.get("id")
  152. ).first()
  153. if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
  154. return None
  155. k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
  156. tool = DatasetRetrieverTool.from_dataset(
  157. dataset=dataset,
  158. k=k,
  159. callbacks=[DatasetToolCallbackHandler(conversation_message_task)]
  160. )
  161. return tool
  162. def to_web_reader_tool(self, model_name: str) -> Optional[BaseTool]:
  163. """
  164. A tool for reading web pages
  165. :return:
  166. """
  167. summary_llm = LLMBuilder.to_llm(
  168. tenant_id=self.tenant_id,
  169. model_name=model_name,
  170. temperature=0,
  171. max_tokens=500,
  172. callbacks=[DifyStdOutCallbackHandler()]
  173. )
  174. tool = WebReaderTool(
  175. llm=summary_llm,
  176. max_chunk_length=4000,
  177. continue_reading=True,
  178. callbacks=[DifyStdOutCallbackHandler()]
  179. )
  180. return tool
  181. def to_google_search_tool(self) -> Optional[BaseTool]:
  182. tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
  183. func_kwargs = tool_provider.credentials_to_func_kwargs()
  184. if not func_kwargs:
  185. return None
  186. tool = Tool(
  187. name="google_search",
  188. description="A tool for performing a Google search and extracting snippets and webpages "
  189. "when you need to search for something you don't know or when your information "
  190. "is not up to date."
  191. "Input should be a search query.",
  192. func=OptimizedSerpAPIWrapper(**func_kwargs).run,
  193. args_schema=OptimizedSerpAPIInput,
  194. callbacks=[DifyStdOutCallbackHandler()]
  195. )
  196. return tool
  197. def to_wikipedia_tool(self) -> Optional[BaseTool]:
  198. class WikipediaInput(BaseModel):
  199. query: str = Field(..., description="search query.")
  200. return WikipediaQueryRun(
  201. name="wikipedia",
  202. api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
  203. args_schema=WikipediaInput,
  204. callbacks=[DifyStdOutCallbackHandler()]
  205. )
  206. @classmethod
  207. def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
  208. DEFAULT_K = 2
  209. CONTEXT_TOKENS_PERCENT = 0.3
  210. processing_rule = dataset.latest_process_rule
  211. if not processing_rule:
  212. return DEFAULT_K
  213. if processing_rule.mode == "custom":
  214. rules = processing_rule.rules_dict
  215. if not rules:
  216. return DEFAULT_K
  217. segmentation = rules["segmentation"]
  218. segment_max_tokens = segmentation["max_tokens"]
  219. else:
  220. segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
  221. # when rest_tokens is less than default context tokens
  222. if rest_tokens < segment_max_tokens * DEFAULT_K:
  223. return rest_tokens // segment_max_tokens
  224. context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
  225. # when context_limit_tokens is less than default context tokens, use default_k
  226. if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
  227. return DEFAULT_K
  228. # Expand the k value when there's still some room left in the 30% rest tokens space
  229. return context_limit_tokens // segment_max_tokens