import math
from typing import Mapping, List, Dict, Any, Optional

from langchain import PromptTemplate
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import Extra

from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.llm_router_chain import LLMRouterChain, RouterOutputParser
from core.conversation_message_task import ConversationMessageTask
from core.llm.llm_builder import LLMBuilder
from core.tool.dataset_index_tool import DatasetTool
from models.dataset import Dataset, DatasetProcessRule

DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MULTI_PROMPT_ROUTER_TEMPLATE = """
Given a raw text input to a language model select the model prompt best suited for \
the input. You will be given the names of the available prompts and a description of \
what the prompt is best suited for. You may also revise the original input if you \
think that revising it will ultimately lead to a better response from the language \
model.

<< FORMATTING >>
Return a markdown code snippet with a JSON object formatted to look like, \
no any other string out of markdown code snippet:
```json
{{{{
    "destination": string \\ name of the prompt to use or "DEFAULT"
    "next_inputs": string \\ a potentially modified version of the original input
}}}}
```

REMEMBER: "destination" MUST be one of the candidate prompt names specified below OR \
it can be "DEFAULT" if the input is not well suited for any of the candidate prompts.
REMEMBER: "next_inputs" can just be the original input if you don't think any \
modifications are needed.

<< CANDIDATE PROMPTS >>
{destinations}

<< INPUT >>
{{input}}

<< OUTPUT >>
"""


class MultiDatasetRouterChain(Chain):
    """Use a single chain to route an input to one of multiple candidate chains."""

    router_chain: LLMRouterChain
    """Chain for deciding a destination chain and the input to it."""
    dataset_tools: Mapping[str, DatasetTool]
    """Map of name to candidate chains that inputs can be routed to."""

    class Config:
        """Configuration for this pydantic object."""

        extra = Extra.forbid
        arbitrary_types_allowed = True

    @property
    def input_keys(self) -> List[str]:
        """Will be whatever keys the router chain prompt expects.

        :meta private:
        """
        return self.router_chain.input_keys

    @property
    def output_keys(self) -> List[str]:
        return ["text"]

    @classmethod
    def from_datasets(
            cls,
            tenant_id: str,
            datasets: List[Dataset],
            conversation_message_task: ConversationMessageTask,
            rest_tokens: int,
            **kwargs: Any,
    ):
        """Convenience constructor for instantiating from destination prompts."""
        llm = LLMBuilder.to_llm(
            tenant_id=tenant_id,
            model_name='gpt-3.5-turbo',
            temperature=0,
            max_tokens=1024,
            callbacks=[DifyStdOutCallbackHandler()]
        )

        destinations = ["[[{}]]: {}".format(d.id, d.description.replace('\n', ' ') if d.description
                        else ('useful for when you want to answer queries about the ' + d.name))
                        for d in datasets]
        destinations_str = "\n".join(destinations)
        router_template = MULTI_PROMPT_ROUTER_TEMPLATE.format(
            destinations=destinations_str
        )

        router_prompt = PromptTemplate(
            template=router_template,
            input_variables=["input"],
            output_parser=RouterOutputParser(),
        )

        router_chain = LLMRouterChain.from_llm(llm, router_prompt)
        dataset_tools = {}
        for dataset in datasets:
            # fulfill description when it is empty
            if dataset.available_document_count == 0 or dataset.available_document_count == 0:
                continue

            description = dataset.description
            if not description:
                description = 'useful for when you want to answer queries about the ' + dataset.name

            k = cls._dynamic_calc_retrieve_k(dataset, rest_tokens)
            if k == 0:
                continue

            dataset_tool = DatasetTool(
                name=f"dataset-{dataset.id}",
                description=description,
                k=k,
                dataset=dataset,
                callbacks=[DatasetToolCallbackHandler(conversation_message_task), DifyStdOutCallbackHandler()]
            )

            dataset_tools[str(dataset.id)] = dataset_tool

        return cls(
            router_chain=router_chain,
            dataset_tools=dataset_tools,
            **kwargs,
        )

    @classmethod
    def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
        processing_rule = dataset.latest_process_rule
        if not processing_rule:
            return DEFAULT_K

        if processing_rule.mode == "custom":
            rules = processing_rule.rules_dict
            if not rules:
                return DEFAULT_K

            segmentation = rules["segmentation"]
            segment_max_tokens = segmentation["max_tokens"]
        else:
            segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']

        # when rest_tokens is less than default context tokens
        if rest_tokens < segment_max_tokens * DEFAULT_K:
            return rest_tokens // segment_max_tokens

        context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)

        # when context_limit_tokens is less than default context tokens, use default_k
        if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
            return DEFAULT_K

        # Expand the k value when there's still some room left in the 30% rest tokens space
        return context_limit_tokens // segment_max_tokens

    def _call(
        self,
        inputs: Dict[str, Any],
        run_manager: Optional[CallbackManagerForChainRun] = None,
    ) -> Dict[str, Any]:
        if len(self.dataset_tools) == 0:
            return {"text": ''}
        elif len(self.dataset_tools) == 1:
            return {"text": next(iter(self.dataset_tools.values())).run(inputs['input'])}

        route = self.router_chain.route(inputs)

        if not route.destination:
            return {"text": ''}
        elif route.destination in self.dataset_tools:
            return {"text": self.dataset_tools[route.destination].run(
                route.next_inputs['input']
            )}
        else:
            raise ValueError(
                f"Received invalid destination chain name '{route.destination}'"
            )