123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351 |
- import math
- from typing import Optional
- from flask import current_app
- from langchain import WikipediaAPIWrapper
- from langchain.callbacks.manager import Callbacks
- from langchain.memory.chat_memory import BaseChatMemory
- from langchain.tools import BaseTool, Tool, WikipediaQueryRun
- from pydantic import BaseModel, Field
- from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
- from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
- from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
- from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
- from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
- from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
- from core.conversation_message_task import ConversationMessageTask
- from core.model_providers.error import ProviderTokenNotInitError
- from core.model_providers.model_factory import ModelFactory
- from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
- from core.model_providers.models.llm.base import BaseLLM
- from core.tool.current_datetime_tool import DatetimeTool
- from core.tool.dataset_retriever_tool import DatasetRetrieverTool
- from core.tool.provider.serpapi_provider import SerpAPIToolProvider
- from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
- from core.tool.web_reader_tool import WebReaderTool
- from extensions.ext_database import db
- from models.dataset import Dataset, DatasetProcessRule
- from models.model import AppModelConfig
- from models.provider import ProviderType
- class OrchestratorRuleParser:
- """Parse the orchestrator rule to entities."""
- def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
- self.tenant_id = tenant_id
- self.app_model_config = app_model_config
- def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
- rest_tokens: int, chain_callback: MainChainGatherCallbackHandler,
- return_resource: bool = False, retriever_from: str = 'dev') -> Optional[AgentExecutor]:
- if not self.app_model_config.agent_mode_dict:
- return None
- agent_mode_config = self.app_model_config.agent_mode_dict
- model_dict = self.app_model_config.model_dict
- chain = None
- if agent_mode_config and agent_mode_config.get('enabled'):
- tool_configs = agent_mode_config.get('tools', [])
- agent_provider_name = model_dict.get('provider', 'openai')
- agent_model_name = model_dict.get('name', 'gpt-4')
- agent_model_instance = ModelFactory.get_text_generation_model(
- tenant_id=self.tenant_id,
- model_provider_name=agent_provider_name,
- model_name=agent_model_name,
- model_kwargs=ModelKwargs(
- temperature=0.2,
- top_p=0.3,
- max_tokens=1500
- )
- )
- # add agent callback to record agent thoughts
- agent_callback = AgentLoopGatherCallbackHandler(
- model_instance=agent_model_instance,
- conversation_message_task=conversation_message_task
- )
- chain_callback.agent_callback = agent_callback
- agent_model_instance.add_callbacks([agent_callback])
- planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
- # only OpenAI chat model (include Azure) support function call, use ReACT instead
- if agent_model_instance.model_mode != ModelMode.CHAT \
- or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
- if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
- planning_strategy = PlanningStrategy.REACT
- elif planning_strategy == PlanningStrategy.ROUTER:
- planning_strategy = PlanningStrategy.REACT_ROUTER
- try:
- summary_model_instance = ModelFactory.get_text_generation_model(
- tenant_id=self.tenant_id,
- model_provider_name=agent_provider_name,
- model_name=agent_model_name,
- model_kwargs=ModelKwargs(
- temperature=0,
- max_tokens=500
- ),
- deduct_quota=False
- )
- except ProviderTokenNotInitError as e:
- summary_model_instance = None
- tools = self.to_tools(
- agent_model_instance=agent_model_instance,
- tool_configs=tool_configs,
- conversation_message_task=conversation_message_task,
- rest_tokens=rest_tokens,
- callbacks=[agent_callback, DifyStdOutCallbackHandler()],
- return_resource=return_resource,
- retriever_from=retriever_from
- )
- if len(tools) == 0:
- return None
- agent_configuration = AgentConfiguration(
- strategy=planning_strategy,
- model_instance=agent_model_instance,
- tools=tools,
- summary_model_instance=summary_model_instance,
- memory=memory,
- callbacks=[chain_callback, agent_callback],
- max_iterations=10,
- max_execution_time=400.0,
- early_stopping_method="generate"
- )
- return AgentExecutor(agent_configuration)
- return chain
- def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
- -> Optional[SensitiveWordAvoidanceChain]:
- """
- Convert app sensitive word avoidance config to chain
- :param model_instance: model instance
- :param callbacks: callbacks for the chain
- :param kwargs:
- :return:
- """
- sensitive_word_avoidance_rule = None
- if self.app_model_config.sensitive_word_avoidance_dict:
- sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
- if sensitive_word_avoidance_config.get("enabled", False):
- if sensitive_word_avoidance_config.get('type') == 'moderation':
- sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
- type=SensitiveWordAvoidanceRule.Type.MODERATION,
- canned_response=sensitive_word_avoidance_config.get("canned_response")
- if sensitive_word_avoidance_config.get("canned_response")
- else 'Your content violates our usage policy. Please revise and try again.',
- )
- else:
- sensitive_words = sensitive_word_avoidance_config.get("words", "")
- if sensitive_words:
- sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
- type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
- canned_response=sensitive_word_avoidance_config.get("canned_response")
- if sensitive_word_avoidance_config.get("canned_response")
- else 'Your content violates our usage policy. Please revise and try again.',
- extra_params={
- 'sensitive_words': sensitive_words.split(','),
- }
- )
- if sensitive_word_avoidance_rule:
- return SensitiveWordAvoidanceChain(
- model_instance=model_instance,
- sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
- output_key="sensitive_word_avoidance_output",
- callbacks=callbacks,
- **kwargs
- )
- return None
- def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list,
- conversation_message_task: ConversationMessageTask,
- rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
- retriever_from: str = 'dev') -> list[BaseTool]:
- """
- Convert app agent tool configs to tools
- :param agent_model_instance:
- :param rest_tokens:
- :param tool_configs: app agent tool configs
- :param conversation_message_task:
- :param callbacks:
- :param return_resource:
- :param retriever_from:
- :return:
- """
- tools = []
- for tool_config in tool_configs:
- tool_type = list(tool_config.keys())[0]
- tool_val = list(tool_config.values())[0]
- if not tool_val.get("enabled") or tool_val.get("enabled") is not True:
- continue
- tool = None
- if tool_type == "dataset":
- tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from)
- elif tool_type == "web_reader":
- tool = self.to_web_reader_tool(agent_model_instance)
- elif tool_type == "google_search":
- tool = self.to_google_search_tool()
- elif tool_type == "wikipedia":
- tool = self.to_wikipedia_tool()
- elif tool_type == "current_datetime":
- tool = self.to_current_datetime_tool()
- if tool:
- tool.callbacks.extend(callbacks)
- tools.append(tool)
- return tools
- def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
- rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \
- -> Optional[BaseTool]:
- """
- A dataset tool is a tool that can be used to retrieve information from a dataset
- :param rest_tokens:
- :param tool_config:
- :param conversation_message_task:
- :param return_resource:
- :param retriever_from:
- :return:
- """
- # get dataset from dataset id
- dataset = db.session.query(Dataset).filter(
- Dataset.tenant_id == self.tenant_id,
- Dataset.id == tool_config.get("id")
- ).first()
- if not dataset:
- return None
- if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
- return None
- k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
- tool = DatasetRetrieverTool.from_dataset(
- dataset=dataset,
- k=k,
- callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
- conversation_message_task=conversation_message_task,
- return_resource=return_resource,
- retriever_from=retriever_from
- )
- return tool
- def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
- """
- A tool for reading web pages
- :return:
- """
- try:
- summary_model_instance = ModelFactory.get_text_generation_model(
- tenant_id=self.tenant_id,
- model_provider_name=agent_model_instance.model_provider.provider_name,
- model_name=agent_model_instance.name,
- model_kwargs=ModelKwargs(
- temperature=0,
- max_tokens=500
- ),
- deduct_quota=False
- )
- except ProviderTokenNotInitError:
- summary_model_instance = None
- tool = WebReaderTool(
- llm=summary_model_instance.client if summary_model_instance else None,
- max_chunk_length=4000,
- continue_reading=True,
- callbacks=[DifyStdOutCallbackHandler()]
- )
- return tool
- def to_google_search_tool(self) -> Optional[BaseTool]:
- tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
- func_kwargs = tool_provider.credentials_to_func_kwargs()
- if not func_kwargs:
- return None
- tool = Tool(
- name="google_search",
- description="A tool for performing a Google search and extracting snippets and webpages "
- "when you need to search for something you don't know or when your information "
- "is not up to date. "
- "Input should be a search query.",
- func=OptimizedSerpAPIWrapper(**func_kwargs).run,
- args_schema=OptimizedSerpAPIInput,
- callbacks=[DifyStdOutCallbackHandler()]
- )
- return tool
- def to_current_datetime_tool(self) -> Optional[BaseTool]:
- tool = DatetimeTool(
- callbacks=[DifyStdOutCallbackHandler()]
- )
- return tool
- def to_wikipedia_tool(self) -> Optional[BaseTool]:
- class WikipediaInput(BaseModel):
- query: str = Field(..., description="search query.")
- return WikipediaQueryRun(
- name="wikipedia",
- api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
- args_schema=WikipediaInput,
- callbacks=[DifyStdOutCallbackHandler()]
- )
- @classmethod
- def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
- DEFAULT_K = 2
- CONTEXT_TOKENS_PERCENT = 0.3
- MAX_K = 10
- if rest_tokens == -1:
- return DEFAULT_K
- processing_rule = dataset.latest_process_rule
- if not processing_rule:
- return DEFAULT_K
- if processing_rule.mode == "custom":
- rules = processing_rule.rules_dict
- if not rules:
- return DEFAULT_K
- segmentation = rules["segmentation"]
- segment_max_tokens = segmentation["max_tokens"]
- else:
- segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
- # when rest_tokens is less than default context tokens
- if rest_tokens < segment_max_tokens * DEFAULT_K:
- return rest_tokens // segment_max_tokens
- context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
- # when context_limit_tokens is less than default context tokens, use default_k
- if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
- return DEFAULT_K
- # Expand the k value when there's still some room left in the 30% rest tokens space, but less than the MAX_K
- return min(context_limit_tokens // segment_max_tokens, MAX_K)
|