| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135 | """Base classes for LLM-powered router chains."""from __future__ import annotationsimport jsonfrom typing import Any, Dict, List, Optional, Type, cast, NamedTuplefrom langchain.chains.base import Chainfrom pydantic import root_validatorfrom langchain.chains import LLMChainfrom langchain.prompts import BasePromptTemplatefrom langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModelclass Route(NamedTuple):    destination: Optional[str]    next_inputs: Dict[str, Any]class LLMRouterChain(Chain):    """A router chain that uses an LLM chain to perform routing."""    llm_chain: LLMChain    """LLM chain used to perform routing"""    @root_validator()    def validate_prompt(cls, values: dict) -> dict:        prompt = values["llm_chain"].prompt        if prompt.output_parser is None:            raise ValueError(                "LLMRouterChain requires base llm_chain prompt to have an output"                " parser that converts LLM text output to a dictionary with keys"                " 'destination' and 'next_inputs'. Received a prompt with no output"                " parser."            )        return values    @property    def input_keys(self) -> List[str]:        """Will be whatever keys the LLM chain prompt expects.        :meta private:        """        return self.llm_chain.input_keys    def _validate_outputs(self, outputs: Dict[str, Any]) -> None:        super()._validate_outputs(outputs)        if not isinstance(outputs["next_inputs"], dict):            raise ValueError    def _call(        self,        inputs: Dict[str, Any]    ) -> Dict[str, Any]:        output = cast(            Dict[str, Any],            self.llm_chain.predict_and_parse(**inputs),        )        return output    @classmethod    def from_llm(        cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any    ) -> LLMRouterChain:        """Convenience constructor."""        llm_chain = LLMChain(llm=llm, prompt=prompt)        return cls(llm_chain=llm_chain, **kwargs)    @property    def output_keys(self) -> List[str]:        return ["destination", "next_inputs"]    def route(self, inputs: Dict[str, Any]) -> Route:        result = self(inputs)        return Route(result["destination"], result["next_inputs"])class RouterOutputParser(BaseOutputParser[Dict[str, str]]):    """Parser for output of router chain int he multi-prompt chain."""    default_destination: str = "DEFAULT"    next_inputs_type: Type = str    next_inputs_inner_key: str = "input"    def parse_json_markdown(self, json_string: str) -> dict:        # Remove the triple backticks if present        start_index = json_string.find("```json")        end_index = json_string.find("```", start_index + len("```json"))        if start_index != -1 and end_index != -1:            extracted_content = json_string[start_index + len("```json"):end_index].strip()            # Parse the JSON string into a Python dictionary            parsed = json.loads(extracted_content)        else:            raise Exception("Could not find JSON block in the output.")        return parsed    def parse_and_check_json_markdown(self, text: str, expected_keys: List[str]) -> dict:        try:            json_obj = self.parse_json_markdown(text)        except json.JSONDecodeError as e:            raise OutputParserException(f"Got invalid JSON object. Error: {e}")        for key in expected_keys:            if key not in json_obj:                raise OutputParserException(                    f"Got invalid return object. Expected key `{key}` "                    f"to be present, but got {json_obj}"                )        return json_obj    def parse(self, text: str) -> Dict[str, Any]:        try:            expected_keys = ["destination", "next_inputs"]            parsed = self.parse_and_check_json_markdown(text, expected_keys)            if not isinstance(parsed["destination"], str):                raise ValueError("Expected 'destination' to be a string.")            if not isinstance(parsed["next_inputs"], self.next_inputs_type):                raise ValueError(                    f"Expected 'next_inputs' to be {self.next_inputs_type}."                )            parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}            if (                parsed["destination"].strip().lower()                == self.default_destination.lower()            ):                parsed["destination"] = None            else:                parsed["destination"] = parsed["destination"].strip()            return parsed        except Exception as e:            raise OutputParserException(                f"Parsing text\n{text}\n raised following error:\n{e}"            )
 |