123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- from typing import Mapping, List, Dict, Any, Optional
- from langchain import LLMChain, PromptTemplate, ConversationChain
- from langchain.callbacks import CallbackManager
- from langchain.chains.base import Chain
- from langchain.schema import BaseLanguageModel
- from pydantic import Extra
- from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
- from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
- from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
- from core.conversation_message_task import ConversationMessageTask
- from core.llm.llm_builder import LLMBuilder
- from core.tool.dataset_tool_builder import DatasetToolBuilder
- from core.tool.llama_index_tool import EnhanceLlamaIndexTool
- from models.dataset import Dataset
- MULTI_PROMPT_ROUTER_TEMPLATE = """
- Given a raw text input to a language model select the model prompt best suited for \
- the input. You will be given the names of the available prompts and a description of \
- what the prompt is best suited for. You may also revise the original input if you \
- think that revising it will ultimately lead to a better response from the language \
- model.
- << FORMATTING >>
- Return a markdown code snippet with a JSON object formatted to look like:
- ```json
- {{{{
- "destination": string \\ name of the prompt to use or "DEFAULT"
- "next_inputs": string \\ a potentially modified version of the original input
- }}}}
- ```
- REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
- it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
- REMEMBER: "next_inputs" can just be the original input if you don't think any \
- modifications are needed.
- << CANDIDATE PROMPTS >>
- {destinations}
- << INPUT >>
- {{input}}
- << OUTPUT >>
- """
- class MultiDatasetRouterChain(Chain):
- """Use a single chain to route an input to one of multiple candidate chains."""
- router_chain: LLMRouterChain
- """Chain for deciding a destination chain and the input to it."""
- dataset_tools: Mapping[str, EnhanceLlamaIndexTool]
- """Map of name to candidate chains that inputs can be routed to."""
- class Config:
- """Configuration for this pydantic object."""
- extra = Extra.forbid
- arbitrary_types_allowed = True
- @property
- def input_keys(self) -> List[str]:
- """Will be whatever keys the router chain prompt expects.
- :meta private:
- """
- return self.router_chain.input_keys
- @property
- def output_keys(self) -> List[str]:
- return ["text"]
- @classmethod
- def from_datasets(
- cls,
- tenant_id: str,
- datasets: List[Dataset],
- conversation_message_task: ConversationMessageTask,
- **kwargs: Any,
- ):
- """Convenience constructor for instantiating from destination prompts."""
- llm_callback_manager = CallbackManager([DifyStdOutCallbackHandler()])
- llm = LLMBuilder.to_llm(
- tenant_id=tenant_id,
- model_name='gpt-3.5-turbo',
- temperature=0,
- max_tokens=1024,
- callback_manager=llm_callback_manager
- )
- destinations = ["{}: {}".format(d.id, d.description.replace('\n', ' ') if d.description
- else ('useful for when you want to answer queries about the ' + d.name))
- for d in datasets]
- destinations_str = "\n".join(destinations)
- router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
- destinations=destinations_str
- )
- router_prompt = PromptTemplate(
- template=router_template,
- input_variables=["input"],
- output_parser=RouterOutputParser(),
- )
- router_chain = LLMRouterChain.from_llm(llm, router_prompt)
- dataset_tools = {}
- for dataset in datasets:
- dataset_tool = DatasetToolBuilder.build_dataset_tool(
- dataset=dataset,
- response_mode='no_synthesizer', # "compact"
- callback_handler=DatasetToolCallbackHandler(conversation_message_task)
- )
- dataset_tools[dataset.id] = dataset_tool
- return cls(
- router_chain=router_chain,
- dataset_tools=dataset_tools,
- **kwargs,
- )
- def _call(
- self,
- inputs: Dict[str, Any]
- ) -> Dict[str, Any]:
- if len(self.dataset_tools) == 0:
- return {"text": ''}
- elif len(self.dataset_tools) == 1:
- return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}
- route = self.router_chain.route(inputs)
- if not route.destination:
- return {"text": ''}
- elif route.destination in self.dataset_tools:
- return {"text": self.dataset_tools[route.destination].run(
- route.next_inputs['input']
- )}
- else:
- raise ValueError(
- f"Received invalid destination chain name '{route.destination}'"
- )
|