orchestrator_rule_parser.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. import math
  2. from typing import Optional
  3. from flask import current_app
  4. from langchain import WikipediaAPIWrapper
  5. from langchain.callbacks.manager import Callbacks
  6. from langchain.memory.chat_memory import BaseChatMemory
  7. from langchain.tools import BaseTool, Tool, WikipediaQueryRun
  8. from pydantic import BaseModel, Field
  9. from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
  10. from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
  11. from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
  12. from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
  13. from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
  14. from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
  15. from core.conversation_message_task import ConversationMessageTask
  16. from core.model_providers.error import ProviderTokenNotInitError
  17. from core.model_providers.model_factory import ModelFactory
  18. from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
  19. from core.model_providers.models.llm.base import BaseLLM
  20. from core.tool.current_datetime_tool import DatetimeTool
  21. from core.tool.dataset_retriever_tool import DatasetRetrieverTool
  22. from core.tool.provider.serpapi_provider import SerpAPIToolProvider
  23. from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
  24. from core.tool.web_reader_tool import WebReaderTool
  25. from extensions.ext_database import db
  26. from models.dataset import Dataset, DatasetProcessRule
  27. from models.model import AppModelConfig
  28. from models.provider import ProviderType
  29. class OrchestratorRuleParser:
  30. """Parse the orchestrator rule to entities."""
  31. def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
  32. self.tenant_id = tenant_id
  33. self.app_model_config = app_model_config
  34. def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
  35. rest_tokens: int, chain_callback: MainChainGatherCallbackHandler,
  36. return_resource: bool = False, retriever_from: str = 'dev') -> Optional[AgentExecutor]:
  37. if not self.app_model_config.agent_mode_dict:
  38. return None
  39. agent_mode_config = self.app_model_config.agent_mode_dict
  40. model_dict = self.app_model_config.model_dict
  41. chain = None
  42. if agent_mode_config and agent_mode_config.get('enabled'):
  43. tool_configs = agent_mode_config.get('tools', [])
  44. agent_provider_name = model_dict.get('provider', 'openai')
  45. agent_model_name = model_dict.get('name', 'gpt-4')
  46. agent_model_instance = ModelFactory.get_text_generation_model(
  47. tenant_id=self.tenant_id,
  48. model_provider_name=agent_provider_name,
  49. model_name=agent_model_name,
  50. model_kwargs=ModelKwargs(
  51. temperature=0.2,
  52. top_p=0.3,
  53. max_tokens=1500
  54. )
  55. )
  56. # add agent callback to record agent thoughts
  57. agent_callback = AgentLoopGatherCallbackHandler(
  58. model_instance=agent_model_instance,
  59. conversation_message_task=conversation_message_task
  60. )
  61. chain_callback.agent_callback = agent_callback
  62. agent_model_instance.add_callbacks([agent_callback])
  63. planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
  64. # only OpenAI chat model (include Azure) support function call, use ReACT instead
  65. if agent_model_instance.model_mode != ModelMode.CHAT \
  66. or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
  67. if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
  68. planning_strategy = PlanningStrategy.REACT
  69. elif planning_strategy == PlanningStrategy.ROUTER:
  70. planning_strategy = PlanningStrategy.REACT_ROUTER
  71. try:
  72. summary_model_instance = ModelFactory.get_text_generation_model(
  73. tenant_id=self.tenant_id,
  74. model_provider_name=agent_provider_name,
  75. model_name=agent_model_name,
  76. model_kwargs=ModelKwargs(
  77. temperature=0,
  78. max_tokens=500
  79. ),
  80. deduct_quota=False
  81. )
  82. except ProviderTokenNotInitError as e:
  83. summary_model_instance = None
  84. tools = self.to_tools(
  85. agent_model_instance=agent_model_instance,
  86. tool_configs=tool_configs,
  87. conversation_message_task=conversation_message_task,
  88. rest_tokens=rest_tokens,
  89. callbacks=[agent_callback, DifyStdOutCallbackHandler()],
  90. return_resource=return_resource,
  91. retriever_from=retriever_from
  92. )
  93. if len(tools) == 0:
  94. return None
  95. agent_configuration = AgentConfiguration(
  96. strategy=planning_strategy,
  97. model_instance=agent_model_instance,
  98. tools=tools,
  99. summary_model_instance=summary_model_instance,
  100. memory=memory,
  101. callbacks=[chain_callback, agent_callback],
  102. max_iterations=10,
  103. max_execution_time=400.0,
  104. early_stopping_method="generate"
  105. )
  106. return AgentExecutor(agent_configuration)
  107. return chain
  108. def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
  109. -> Optional[SensitiveWordAvoidanceChain]:
  110. """
  111. Convert app sensitive word avoidance config to chain
  112. :param model_instance: model instance
  113. :param callbacks: callbacks for the chain
  114. :param kwargs:
  115. :return:
  116. """
  117. sensitive_word_avoidance_rule = None
  118. if self.app_model_config.sensitive_word_avoidance_dict:
  119. sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
  120. if sensitive_word_avoidance_config.get("enabled", False):
  121. if sensitive_word_avoidance_config.get('type') == 'moderation':
  122. sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
  123. type=SensitiveWordAvoidanceRule.Type.MODERATION,
  124. canned_response=sensitive_word_avoidance_config.get("canned_response")
  125. if sensitive_word_avoidance_config.get("canned_response")
  126. else 'Your content violates our usage policy. Please revise and try again.',
  127. )
  128. else:
  129. sensitive_words = sensitive_word_avoidance_config.get("words", "")
  130. if sensitive_words:
  131. sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
  132. type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
  133. canned_response=sensitive_word_avoidance_config.get("canned_response")
  134. if sensitive_word_avoidance_config.get("canned_response")
  135. else 'Your content violates our usage policy. Please revise and try again.',
  136. extra_params={
  137. 'sensitive_words': sensitive_words.split(','),
  138. }
  139. )
  140. if sensitive_word_avoidance_rule:
  141. return SensitiveWordAvoidanceChain(
  142. model_instance=model_instance,
  143. sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
  144. output_key="sensitive_word_avoidance_output",
  145. callbacks=callbacks,
  146. **kwargs
  147. )
  148. return None
  149. def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list,
  150. conversation_message_task: ConversationMessageTask,
  151. rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
  152. retriever_from: str = 'dev') -> list[BaseTool]:
  153. """
  154. Convert app agent tool configs to tools
  155. :param agent_model_instance:
  156. :param rest_tokens:
  157. :param tool_configs: app agent tool configs
  158. :param conversation_message_task:
  159. :param callbacks:
  160. :param return_resource:
  161. :param retriever_from:
  162. :return:
  163. """
  164. tools = []
  165. for tool_config in tool_configs:
  166. tool_type = list(tool_config.keys())[0]
  167. tool_val = list(tool_config.values())[0]
  168. if not tool_val.get("enabled") or tool_val.get("enabled") is not True:
  169. continue
  170. tool = None
  171. if tool_type == "dataset":
  172. tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from)
  173. elif tool_type == "web_reader":
  174. tool = self.to_web_reader_tool(agent_model_instance)
  175. elif tool_type == "google_search":
  176. tool = self.to_google_search_tool()
  177. elif tool_type == "wikipedia":
  178. tool = self.to_wikipedia_tool()
  179. elif tool_type == "current_datetime":
  180. tool = self.to_current_datetime_tool()
  181. if tool:
  182. tool.callbacks.extend(callbacks)
  183. tools.append(tool)
  184. return tools
  185. def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
  186. rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \
  187. -> Optional[BaseTool]:
  188. """
  189. A dataset tool is a tool that can be used to retrieve information from a dataset
  190. :param rest_tokens:
  191. :param tool_config:
  192. :param conversation_message_task:
  193. :param return_resource:
  194. :param retriever_from:
  195. :return:
  196. """
  197. # get dataset from dataset id
  198. dataset = db.session.query(Dataset).filter(
  199. Dataset.tenant_id == self.tenant_id,
  200. Dataset.id == tool_config.get("id")
  201. ).first()
  202. if not dataset:
  203. return None
  204. if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
  205. return None
  206. k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
  207. tool = DatasetRetrieverTool.from_dataset(
  208. dataset=dataset,
  209. k=k,
  210. callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
  211. conversation_message_task=conversation_message_task,
  212. return_resource=return_resource,
  213. retriever_from=retriever_from
  214. )
  215. return tool
  216. def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
  217. """
  218. A tool for reading web pages
  219. :return:
  220. """
  221. try:
  222. summary_model_instance = ModelFactory.get_text_generation_model(
  223. tenant_id=self.tenant_id,
  224. model_provider_name=agent_model_instance.model_provider.provider_name,
  225. model_name=agent_model_instance.name,
  226. model_kwargs=ModelKwargs(
  227. temperature=0,
  228. max_tokens=500
  229. ),
  230. deduct_quota=False
  231. )
  232. except ProviderTokenNotInitError:
  233. summary_model_instance = None
  234. tool = WebReaderTool(
  235. llm=summary_model_instance.client if summary_model_instance else None,
  236. max_chunk_length=4000,
  237. continue_reading=True,
  238. callbacks=[DifyStdOutCallbackHandler()]
  239. )
  240. return tool
  241. def to_google_search_tool(self) -> Optional[BaseTool]:
  242. tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
  243. func_kwargs = tool_provider.credentials_to_func_kwargs()
  244. if not func_kwargs:
  245. return None
  246. tool = Tool(
  247. name="google_search",
  248. description="A tool for performing a Google search and extracting snippets and webpages "
  249. "when you need to search for something you don't know or when your information "
  250. "is not up to date. "
  251. "Input should be a search query.",
  252. func=OptimizedSerpAPIWrapper(**func_kwargs).run,
  253. args_schema=OptimizedSerpAPIInput,
  254. callbacks=[DifyStdOutCallbackHandler()]
  255. )
  256. return tool
  257. def to_current_datetime_tool(self) -> Optional[BaseTool]:
  258. tool = DatetimeTool(
  259. callbacks=[DifyStdOutCallbackHandler()]
  260. )
  261. return tool
  262. def to_wikipedia_tool(self) -> Optional[BaseTool]:
  263. class WikipediaInput(BaseModel):
  264. query: str = Field(..., description="search query.")
  265. return WikipediaQueryRun(
  266. name="wikipedia",
  267. api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
  268. args_schema=WikipediaInput,
  269. callbacks=[DifyStdOutCallbackHandler()]
  270. )
  271. @classmethod
  272. def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
  273. DEFAULT_K = 2
  274. CONTEXT_TOKENS_PERCENT = 0.3
  275. MAX_K = 10
  276. if rest_tokens == -1:
  277. return DEFAULT_K
  278. processing_rule = dataset.latest_process_rule
  279. if not processing_rule:
  280. return DEFAULT_K
  281. if processing_rule.mode == "custom":
  282. rules = processing_rule.rules_dict
  283. if not rules:
  284. return DEFAULT_K
  285. segmentation = rules["segmentation"]
  286. segment_max_tokens = segmentation["max_tokens"]
  287. else:
  288. segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
  289. # when rest_tokens is less than default context tokens
  290. if rest_tokens < segment_max_tokens * DEFAULT_K:
  291. return rest_tokens // segment_max_tokens
  292. context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
  293. # when context_limit_tokens is less than default context tokens, use default_k
  294. if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
  295. return DEFAULT_K
  296. # Expand the k value when there's still some room left in the 30% rest tokens space, but less than the MAX_K
  297. return min(context_limit_tokens // segment_max_tokens, MAX_K)