| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511 | import jsonimport loggingimport uuidfrom datetime import datetime, timezonefrom typing import Optional, Union, castfrom core.agent.entities import AgentEntity, AgentToolEntityfrom core.app.app_config.features.file_upload.manager import FileUploadConfigManagerfrom core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigfrom core.app.apps.base_app_queue_manager import AppQueueManagerfrom core.app.apps.base_app_runner import AppRunnerfrom core.app.entities.app_invoke_entities import (    AgentChatAppGenerateEntity,    ModelConfigWithCredentialsEntity,)from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandlerfrom core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandlerfrom core.file.message_file_parser import MessageFileParserfrom core.memory.token_buffer_memory import TokenBufferMemoryfrom core.model_manager import ModelInstancefrom core.model_runtime.entities.llm_entities import LLMUsagefrom core.model_runtime.entities.message_entities import (    AssistantPromptMessage,    PromptMessage,    PromptMessageTool,    SystemPromptMessage,    TextPromptMessageContent,    ToolPromptMessage,    UserPromptMessage,)from core.model_runtime.entities.model_entities import ModelFeaturefrom core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModelfrom core.model_runtime.utils.encoders import jsonable_encoderfrom core.tools.entities.tool_entities import (    ToolParameter,    ToolRuntimeVariablePool,)from core.tools.tool.dataset_retriever_tool import DatasetRetrieverToolfrom core.tools.tool.tool import Toolfrom core.tools.tool_manager import ToolManagerfrom core.tools.utils.tool_parameter_converter import ToolParameterConverterfrom extensions.ext_database import dbfrom models.model import Conversation, Message, MessageAgentThoughtfrom models.tools import ToolConversationVariableslogger = logging.getLogger(__name__)class BaseAgentRunner(AppRunner):    def __init__(self, tenant_id: str,                 application_generate_entity: AgentChatAppGenerateEntity,                 conversation: Conversation,                 app_config: AgentChatAppConfig,                 model_config: ModelConfigWithCredentialsEntity,                 config: AgentEntity,                 queue_manager: AppQueueManager,                 message: Message,                 user_id: str,                 memory: Optional[TokenBufferMemory] = None,                 prompt_messages: Optional[list[PromptMessage]] = None,                 variables_pool: Optional[ToolRuntimeVariablePool] = None,                 db_variables: Optional[ToolConversationVariables] = None,                 model_instance: ModelInstance = None                 ) -> None:        """        Agent runner        :param tenant_id: tenant id        :param app_config: app generate entity        :param model_config: model config        :param config: dataset config        :param queue_manager: queue manager        :param message: message        :param user_id: user id        :param agent_llm_callback: agent llm callback        :param callback: callback        :param memory: memory        """        self.tenant_id = tenant_id        self.application_generate_entity = application_generate_entity        self.conversation = conversation        self.app_config = app_config        self.model_config = model_config        self.config = config        self.queue_manager = queue_manager        self.message = message        self.user_id = user_id        self.memory = memory        self.history_prompt_messages = self.organize_agent_history(            prompt_messages=prompt_messages or []        )        self.variables_pool = variables_pool        self.db_variables_pool = db_variables        self.model_instance = model_instance        # init callback        self.agent_callback = DifyAgentCallbackHandler()        # init dataset tools        hit_callback = DatasetIndexToolCallbackHandler(            queue_manager=queue_manager,            app_id=self.app_config.app_id,            message_id=message.id,            user_id=user_id,            invoke_from=self.application_generate_entity.invoke_from,        )        self.dataset_tools = DatasetRetrieverTool.get_dataset_tools(            tenant_id=tenant_id,            dataset_ids=app_config.dataset.dataset_ids if app_config.dataset else [],            retrieve_config=app_config.dataset.retrieve_config if app_config.dataset else None,            return_resource=app_config.additional_features.show_retrieve_source,            invoke_from=application_generate_entity.invoke_from,            hit_callback=hit_callback        )        # get how many agent thoughts have been created        self.agent_thought_count = db.session.query(MessageAgentThought).filter(            MessageAgentThought.message_id == self.message.id,        ).count()        db.session.close()        # check if model supports stream tool call        llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)        model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)        if model_schema and ModelFeature.STREAM_TOOL_CALL in (model_schema.features or []):            self.stream_tool_call = True        else:            self.stream_tool_call = False        # check if model supports vision        if model_schema and ModelFeature.VISION in (model_schema.features or []):            self.files = application_generate_entity.files        else:            self.files = []        self.query = None        self._current_thoughts: list[PromptMessage] = []    def _repack_app_generate_entity(self, app_generate_entity: AgentChatAppGenerateEntity) \            -> AgentChatAppGenerateEntity:        """        Repack app generate entity        """        if app_generate_entity.app_config.prompt_template.simple_prompt_template is None:            app_generate_entity.app_config.prompt_template.simple_prompt_template = ''        return app_generate_entity        def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]:        """            convert tool to prompt message tool        """        tool_entity = ToolManager.get_agent_tool_runtime(            tenant_id=self.tenant_id,            app_id=self.app_config.app_id,            agent_tool=tool,            invoke_from=self.application_generate_entity.invoke_from        )        tool_entity.load_variables(self.variables_pool)        message_tool = PromptMessageTool(            name=tool.tool_name,            description=tool_entity.description.llm,            parameters={                "type": "object",                "properties": {},                "required": [],            }        )        parameters = tool_entity.get_all_runtime_parameters()        for parameter in parameters:            if parameter.form != ToolParameter.ToolParameterForm.LLM:                continue            parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)            enum = []            if parameter.type == ToolParameter.ToolParameterType.SELECT:                enum = [option.value for option in parameter.options]            message_tool.parameters['properties'][parameter.name] = {                "type": parameter_type,                "description": parameter.llm_description or '',            }            if len(enum) > 0:                message_tool.parameters['properties'][parameter.name]['enum'] = enum            if parameter.required:                message_tool.parameters['required'].append(parameter.name)        return message_tool, tool_entity        def _convert_dataset_retriever_tool_to_prompt_message_tool(self, tool: DatasetRetrieverTool) -> PromptMessageTool:        """        convert dataset retriever tool to prompt message tool        """        prompt_tool = PromptMessageTool(            name=tool.identity.name,            description=tool.description.llm,            parameters={                "type": "object",                "properties": {},                "required": [],            }        )        for parameter in tool.get_runtime_parameters():            parameter_type = 'string'                    prompt_tool.parameters['properties'][parameter.name] = {                "type": parameter_type,                "description": parameter.llm_description or '',            }            if parameter.required:                if parameter.name not in prompt_tool.parameters['required']:                    prompt_tool.parameters['required'].append(parameter.name)        return prompt_tool        def _init_prompt_tools(self) -> tuple[dict[str, Tool], list[PromptMessageTool]]:        """        Init tools        """        tool_instances = {}        prompt_messages_tools = []        for tool in self.app_config.agent.tools if self.app_config.agent else []:            try:                prompt_tool, tool_entity = self._convert_tool_to_prompt_message_tool(tool)            except Exception:                # api tool may be deleted                continue            # save tool entity            tool_instances[tool.tool_name] = tool_entity            # save prompt tool            prompt_messages_tools.append(prompt_tool)        # convert dataset tools into ModelRuntime Tool format        for dataset_tool in self.dataset_tools:            prompt_tool = self._convert_dataset_retriever_tool_to_prompt_message_tool(dataset_tool)            # save prompt tool            prompt_messages_tools.append(prompt_tool)            # save tool entity            tool_instances[dataset_tool.identity.name] = dataset_tool        return tool_instances, prompt_messages_tools    def update_prompt_message_tool(self, tool: Tool, prompt_tool: PromptMessageTool) -> PromptMessageTool:        """        update prompt message tool        """        # try to get tool runtime parameters        tool_runtime_parameters = tool.get_runtime_parameters() or []        for parameter in tool_runtime_parameters:            if parameter.form != ToolParameter.ToolParameterForm.LLM:                continue            parameter_type = ToolParameterConverter.get_parameter_type(parameter.type)            enum = []            if parameter.type == ToolParameter.ToolParameterType.SELECT:                enum = [option.value for option in parameter.options]                    prompt_tool.parameters['properties'][parameter.name] = {                "type": parameter_type,                "description": parameter.llm_description or '',            }            if len(enum) > 0:                prompt_tool.parameters['properties'][parameter.name]['enum'] = enum            if parameter.required:                if parameter.name not in prompt_tool.parameters['required']:                    prompt_tool.parameters['required'].append(parameter.name)        return prompt_tool            def create_agent_thought(self, message_id: str, message: str,                              tool_name: str, tool_input: str, messages_ids: list[str]                             ) -> MessageAgentThought:        """        Create agent thought        """        thought = MessageAgentThought(            message_id=message_id,            message_chain_id=None,            thought='',            tool=tool_name,            tool_labels_str='{}',            tool_meta_str='{}',            tool_input=tool_input,            message=message,            message_token=0,            message_unit_price=0,            message_price_unit=0,            message_files=json.dumps(messages_ids) if messages_ids else '',            answer='',            observation='',            answer_token=0,            answer_unit_price=0,            answer_price_unit=0,            tokens=0,            total_price=0,            position=self.agent_thought_count + 1,            currency='USD',            latency=0,            created_by_role='account',            created_by=self.user_id,        )        db.session.add(thought)        db.session.commit()        db.session.refresh(thought)        db.session.close()        self.agent_thought_count += 1        return thought    def save_agent_thought(self,                            agent_thought: MessageAgentThought,                            tool_name: str,                           tool_input: Union[str, dict],                           thought: str,                            observation: Union[str, dict],                            tool_invoke_meta: Union[str, dict],                           answer: str,                           messages_ids: list[str],                           llm_usage: LLMUsage = None) -> MessageAgentThought:        """        Save agent thought        """        agent_thought = db.session.query(MessageAgentThought).filter(            MessageAgentThought.id == agent_thought.id        ).first()        if thought is not None:            agent_thought.thought = thought        if tool_name is not None:            agent_thought.tool = tool_name        if tool_input is not None:            if isinstance(tool_input, dict):                try:                    tool_input = json.dumps(tool_input, ensure_ascii=False)                except Exception as e:                    tool_input = json.dumps(tool_input)            agent_thought.tool_input = tool_input        if observation is not None:            if isinstance(observation, dict):                try:                    observation = json.dumps(observation, ensure_ascii=False)                except Exception as e:                    observation = json.dumps(observation)                                agent_thought.observation = observation        if answer is not None:            agent_thought.answer = answer        if messages_ids is not None and len(messages_ids) > 0:            agent_thought.message_files = json.dumps(messages_ids)                if llm_usage:            agent_thought.message_token = llm_usage.prompt_tokens            agent_thought.message_price_unit = llm_usage.prompt_price_unit            agent_thought.message_unit_price = llm_usage.prompt_unit_price            agent_thought.answer_token = llm_usage.completion_tokens            agent_thought.answer_price_unit = llm_usage.completion_price_unit            agent_thought.answer_unit_price = llm_usage.completion_unit_price            agent_thought.tokens = llm_usage.total_tokens            agent_thought.total_price = llm_usage.total_price        # check if tool labels is not empty        labels = agent_thought.tool_labels or {}        tools = agent_thought.tool.split(';') if agent_thought.tool else []        for tool in tools:            if not tool:                continue            if tool not in labels:                tool_label = ToolManager.get_tool_label(tool)                if tool_label:                    labels[tool] = tool_label.to_dict()                else:                    labels[tool] = {'en_US': tool, 'zh_Hans': tool}        agent_thought.tool_labels_str = json.dumps(labels)        if tool_invoke_meta is not None:            if isinstance(tool_invoke_meta, dict):                try:                    tool_invoke_meta = json.dumps(tool_invoke_meta, ensure_ascii=False)                except Exception as e:                    tool_invoke_meta = json.dumps(tool_invoke_meta)            agent_thought.tool_meta_str = tool_invoke_meta        db.session.commit()        db.session.close()        def update_db_variables(self, tool_variables: ToolRuntimeVariablePool, db_variables: ToolConversationVariables):        """        convert tool variables to db variables        """        db_variables = db.session.query(ToolConversationVariables).filter(            ToolConversationVariables.conversation_id == self.message.conversation_id,        ).first()        db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None)        db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool))        db.session.commit()        db.session.close()    def organize_agent_history(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:        """        Organize agent history        """        result = []        # check if there is a system message in the beginning of the conversation        for prompt_message in prompt_messages:            if isinstance(prompt_message, SystemPromptMessage):                result.append(prompt_message)        messages: list[Message] = db.session.query(Message).filter(            Message.conversation_id == self.message.conversation_id,        ).order_by(Message.created_at.asc()).all()        for message in messages:            if message.id == self.message.id:                continue            result.append(self.organize_agent_user_prompt(message))            agent_thoughts: list[MessageAgentThought] = message.agent_thoughts            if agent_thoughts:                for agent_thought in agent_thoughts:                    tools = agent_thought.tool                    if tools:                        tools = tools.split(';')                        tool_calls: list[AssistantPromptMessage.ToolCall] = []                        tool_call_response: list[ToolPromptMessage] = []                        try:                            tool_inputs = json.loads(agent_thought.tool_input)                        except Exception as e:                            tool_inputs = { tool: {} for tool in tools }                        try:                            tool_responses = json.loads(agent_thought.observation)                        except Exception as e:                            tool_responses = { tool: agent_thought.observation for tool in tools }                        for tool in tools:                            # generate a uuid for tool call                            tool_call_id = str(uuid.uuid4())                            tool_calls.append(AssistantPromptMessage.ToolCall(                                id=tool_call_id,                                type='function',                                function=AssistantPromptMessage.ToolCall.ToolCallFunction(                                    name=tool,                                    arguments=json.dumps(tool_inputs.get(tool, {})),                                )                            ))                            tool_call_response.append(ToolPromptMessage(                                content=tool_responses.get(tool, agent_thought.observation),                                name=tool,                                tool_call_id=tool_call_id,                            ))                        result.extend([                            AssistantPromptMessage(                                content=agent_thought.thought,                                tool_calls=tool_calls,                            ),                            *tool_call_response                        ])                    if not tools:                        result.append(AssistantPromptMessage(content=agent_thought.thought))            else:                if message.answer:                    result.append(AssistantPromptMessage(content=message.answer))        db.session.close()        return result    def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:        message_file_parser = MessageFileParser(            tenant_id=self.tenant_id,            app_id=self.app_config.app_id,        )        files = message.message_files        if files:            file_extra_config = FileUploadConfigManager.convert(message.app_model_config.to_dict())            if file_extra_config:                file_objs = message_file_parser.transform_message_files(                    files,                    file_extra_config                )            else:                file_objs = []            if not file_objs:                return UserPromptMessage(content=message.query)            else:                prompt_message_contents = [TextPromptMessageContent(data=message.query)]                for file_obj in file_objs:                    prompt_message_contents.append(file_obj.prompt_message_content)                return UserPromptMessage(content=prompt_message_contents)        else:            return UserPromptMessage(content=message.query)
 |