"""Base classes for LLM-powered router chains."""
from __future__ import annotations

import json
from typing import Any, Dict, List, Optional, Type, cast, NamedTuple

from langchain.chains.base import Chain
from pydantic import root_validator

from langchain.chains import LLMChain
from langchain.prompts import BasePromptTemplate
from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel


class Route(NamedTuple):
    destination: Optional[str]
    next_inputs: Dict[str, Any]


class LLMRouterChain(Chain):
    """A router chain that uses an LLM chain to perform routing."""

    llm_chain: LLMChain
    """LLM chain used to perform routing"""

    @root_validator()
    def validate_prompt(cls, values: dict) -> dict:
        prompt = values["llm_chain"].prompt
        if prompt.output_parser is None:
            raise ValueError(
                "LLMRouterChain requires base llm_chain prompt to have an output"
                " parser that converts LLM text output to a dictionary with keys"
                " 'destination' and 'next_inputs'. Received a prompt with no output"
                " parser."
            )
        return values

    @property
    def input_keys(self) -> List[str]:
        """Will be whatever keys the LLM chain prompt expects.

        :meta private:
        """
        return self.llm_chain.input_keys

    def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
        super()._validate_outputs(outputs)
        if not isinstance(outputs["next_inputs"], dict):
            raise ValueError

    def _call(
        self,
        inputs: Dict[str, Any]
    ) -> Dict[str, Any]:
        output = cast(
            Dict[str, Any],
            self.llm_chain.predict_and_parse(**inputs),
        )
        return output

    @classmethod
    def from_llm(
        cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
    ) -> LLMRouterChain:
        """Convenience constructor."""
        llm_chain = LLMChain(llm=llm, prompt=prompt)
        return cls(llm_chain=llm_chain, **kwargs)

    @property
    def output_keys(self) -> List[str]:
        return ["destination", "next_inputs"]

    def route(self, inputs: Dict[str, Any]) -> Route:
        result = self(inputs)
        return Route(result["destination"], result["next_inputs"])


class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
    """Parser for output of router chain int he multi-prompt chain."""

    default_destination: str = "DEFAULT"
    next_inputs_type: Type = str
    next_inputs_inner_key: str = "input"

    def parse_json_markdown(self, json_string: str) -> dict:
        # Remove the triple backticks if present
        start_index = json_string.find("```json")
        end_index = json_string.find("```", start_index + len("```json"))

        if start_index != -1 and end_index != -1:
            extracted_content = json_string[start_index + len("```json"):end_index].strip()

            # Parse the JSON string into a Python dictionary
            parsed = json.loads(extracted_content)
        else:
            raise Exception("Could not find JSON block in the output.")

        return parsed

    def parse_and_check_json_markdown(self, text: str, expected_keys: List[str]) -> dict:
        try:
            json_obj = self.parse_json_markdown(text)
        except json.JSONDecodeError as e:
            raise OutputParserException(f"Got invalid JSON object. Error: {e}")
        for key in expected_keys:
            if key not in json_obj:
                raise OutputParserException(
                    f"Got invalid return object. Expected key `{key}` "
                    f"to be present, but got {json_obj}"
                )
        return json_obj

    def parse(self, text: str) -> Dict[str, Any]:
        try:
            expected_keys = ["destination", "next_inputs"]
            parsed = self.parse_and_check_json_markdown(text, expected_keys)
            if not isinstance(parsed["destination"], str):
                raise ValueError("Expected 'destination' to be a string.")
            if not isinstance(parsed["next_inputs"], self.next_inputs_type):
                raise ValueError(
                    f"Expected 'next_inputs' to be {self.next_inputs_type}."
                )
            parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
            if (
                parsed["destination"].strip().lower()
                == self.default_destination.lower()
            ):
                parsed["destination"] = None
            else:
                parsed["destination"] = parsed["destination"].strip()
            return parsed
        except Exception as e:
            raise OutputParserException(
                f"Parsing text\n{text}\n raised following error:\n{e}"
            )