from typing import Mapping, List, Dict, Any, Optional from langchain import LLMChain, PromptTemplate, ConversationChain from langchain.callbacks import CallbackManager from langchain.chains.base import Chain from langchain.schema import BaseLanguageModel from pydantic import Extra from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser from core.conversation_message_task import ConversationMessageTask from core.llm.llm_builder import LLMBuilder from core.tool.dataset_tool_builder import DatasetToolBuilder from core.tool.llama_index_tool import EnhanceLlamaIndexTool from models.dataset import Dataset MULTI_PROMPT_ROUTER_TEMPLATE = """ Given a raw text input to a language model select the model prompt best suited for \ the input. You will be given the names of the available prompts and a description of \ what the prompt is best suited for. You may also revise the original input if you \ think that revising it will ultimately lead to a better response from the language \ model. << FORMATTING >> Return a markdown code snippet with a JSON object formatted to look like: ```json {{{{ "destination": string \\ name of the prompt to use or "DEFAULT" "next_inputs": string \\ a potentially modified version of the original input }}}} ``` REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \ it can be "DEFAULT" if the input is not well suited for any of the candidate prompts. REMEMBER: "next_inputs" can just be the original input if you don't think any \ modifications are needed. << CANDIDATE PROMPTS >> {destinations} << INPUT >> {{input}} << OUTPUT >> """ class MultiDatasetRouterChain(Chain): """Use a single chain to route an input to one of multiple candidate chains.""" router_chain: LLMRouterChain """Chain for deciding a destination chain and the input to it.""" dataset_tools: Mapping[str, EnhanceLlamaIndexTool] """Map of name to candidate chains that inputs can be routed to.""" class Config: """Configuration for this pydantic object.""" extra = Extra.forbid arbitrary_types_allowed = True @property def input_keys(self) -> List[str]: """Will be whatever keys the router chain prompt expects. :meta private: """ return self.router_chain.input_keys @property def output_keys(self) -> List[str]: return ["text"] @classmethod def from_datasets( cls, tenant_id: str, datasets: List[Dataset], conversation_message_task: ConversationMessageTask, **kwargs: Any, ): """Convenience constructor for instantiating from destination prompts.""" llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()]) llm = LLMBuilder.to_llm( tenant_id=tenant_id, model_name='gpt-3.5-turbo', temperature=0, max_tokens=1024, callback_manager=llm_callback_manager ) destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description else ('useful for when you want to answer queries about the ' + d.name)) for d in datasets] destinations_str = "\n".join(destinations) router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format( destinations=destinations_str ) router_prompt = PromptTemplate( template=router_template, input_variables=["input"], output_parser=RouterOutputParser(), ) router_chain = LLMRouterChain.from_llm(llm, router_prompt) dataset_tools = {} for dataset in datasets: dataset_tool = DatasetToolBuilder.build_dataset_tool( dataset=dataset, response_mode='no_synthesizer', # "compact" callback_handler=DatasetToolCallbackHandler(conversation_message_task) ) dataset_tools[dataset.id] = dataset_tool return cls( router_chain=router_chain, dataset_tools=dataset_tools, **kwargs, ) def _call( self, inputs: Dict[str, Any] ) -> Dict[str, Any]: if len(self.dataset_tools) == 0: return {"text": ''} elif len(self.dataset_tools) == 1: return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])} route = self.router_chain.route(inputs) if not route.destination: return {"text": ''} elif route.destination in self.dataset_tools: return {"text": self.dataset_tools[route.destination].run( route.next_inputs['input'] )} else: raise ValueError( f"Received invalid destination chain name '{route.destination}'" )