from typing import Optional, List

from langchain.callbacks import SharedCallbackManager, CallbackManager
from langchain.chains import SequentialChain
from langchain.chains.base import Chain
from langchain.memory.chat_memory import BaseChatMemory

from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.chain.chain_builder import ChainBuilder
from core.chain.multi_dataset_router_chain import MultiDatasetRouterChain
from core.conversation_message_task import ConversationMessageTask
from extensions.ext_database import db
from models.dataset import Dataset


class MainChainBuilder:
    @classmethod
    def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
                                conversation_message_task: ConversationMessageTask):
        first_input_key = "input"
        final_output_key = "output"

        chains = []

        chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task)

        # agent mode
        tool_chains, chains_output_key = cls.get_agent_chains(
            tenant_id=tenant_id,
            agent_mode=agent_mode,
            memory=memory,
            conversation_message_task=conversation_message_task
        )
        chains += tool_chains

        if chains_output_key:
            final_output_key = chains_output_key

        if len(chains) == 0:
            return None

        for chain in chains:
            # do not add handler into singleton callback manager
            if not isinstance(chain.callback_manager, SharedCallbackManager):
                chain.callback_manager.add_handler(chain_callback_handler)

        # build main chain
        overall_chain = SequentialChain(
            chains=chains,
            input_variables=[first_input_key],
            output_variables=[final_output_key],
            memory=memory,  # only for use the memory prompt input key
        )

        return overall_chain

    @classmethod
    def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],
                         conversation_message_task: ConversationMessageTask):
        # agent mode
        chains = []
        if agent_mode and agent_mode.get('enabled'):
            tools = agent_mode.get('tools', [])

            pre_fixed_chains = []
            # agent_tools = []
            datasets = []
            for tool in tools:
                tool_type = list(tool.keys())[0]
                tool_config = list(tool.values())[0]
                if tool_type == 'sensitive-word-avoidance':
                    chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)
                    if chain:
                        pre_fixed_chains.append(chain)
                elif tool_type == "dataset":
                    # get dataset from dataset id
                    dataset = db.session.query(Dataset).filter(
                        Dataset.tenant_id == tenant_id,
                        Dataset.id == tool_config.get("id")
                    ).first()

                    if dataset:
                        datasets.append(dataset)

            # add pre-fixed chains
            chains += pre_fixed_chains

            if len(datasets) > 0:
                # tool to chain
                multi_dataset_router_chain = MultiDatasetRouterChain.from_datasets(
                    tenant_id=tenant_id,
                    datasets=datasets,
                    conversation_message_task=conversation_message_task,
                    callback_manager=CallbackManager([DifyStdOutCallbackHandler()])
                )
                chains.append(multi_dataset_router_chain)

        final_output_key = cls.get_chains_output_key(chains)

        return chains, final_output_key

    @classmethod
    def get_chains_output_key(cls, chains: List[Chain]):
        if len(chains) > 0:
            return chains[-1].output_keys[0]
        return None