import json
import threading
from typing import Optional, List

from flask import Flask
from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.tools import BaseTool, Tool, WikipediaQueryRun
from pydantic import BaseModel, Field

from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
from core.conversation_message_task import ConversationMessageTask
from core.model_providers.error import ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
from core.model_providers.models.llm.base import BaseLLM
from core.tool.current_datetime_tool import DatetimeTool
from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tool.provider.serpapi_provider import SerpAPIToolProvider
from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db
from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig

default_retrieval_model = {
    'search_method': 'semantic_search',
    'reranking_enable': False,
    'reranking_model': {
        'reranking_provider_name': '',
        'reranking_model_name': ''
    },
    'top_k': 2,
    'score_threshold_enabled': False
}

class OrchestratorRuleParser:
    """Parse the orchestrator rule to entities."""

    def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
        self.tenant_id = tenant_id
        self.app_model_config = app_model_config

    def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
                          rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, tenant_id: str,
                          retriever_from: str = 'dev') -> Optional[AgentExecutor]:
        if not self.app_model_config.agent_mode_dict:
            return None

        agent_mode_config = self.app_model_config.agent_mode_dict
        model_dict = self.app_model_config.model_dict
        return_resource = self.app_model_config.retriever_resource_dict.get('enabled', False)

        chain = None
        if agent_mode_config and agent_mode_config.get('enabled'):
            tool_configs = agent_mode_config.get('tools', [])
            agent_provider_name = model_dict.get('provider', 'openai')
            agent_model_name = model_dict.get('name', 'gpt-4')
            dataset_configs = self.app_model_config.dataset_configs_dict

            agent_model_instance = ModelFactory.get_text_generation_model(
                tenant_id=self.tenant_id,
                model_provider_name=agent_provider_name,
                model_name=agent_model_name,
                model_kwargs=ModelKwargs(
                    temperature=0.2,
                    top_p=0.3,
                    max_tokens=1500
                )
            )

            # add agent callback to record agent thoughts
            agent_callback = AgentLoopGatherCallbackHandler(
                model_instance=agent_model_instance,
                conversation_message_task=conversation_message_task
            )

            chain_callback.agent_callback = agent_callback
            agent_model_instance.add_callbacks([agent_callback])

            planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))

            # only OpenAI chat model (include Azure) support function call, use ReACT instead
            if not agent_model_instance.support_function_call:
                if planning_strategy == PlanningStrategy.FUNCTION_CALL:
                    planning_strategy = PlanningStrategy.REACT
                elif planning_strategy == PlanningStrategy.ROUTER:
                    planning_strategy = PlanningStrategy.REACT_ROUTER

            try:
                summary_model_instance = ModelFactory.get_text_generation_model(
                    tenant_id=self.tenant_id,
                    model_provider_name=agent_provider_name,
                    model_name=agent_model_name,
                    model_kwargs=ModelKwargs(
                        temperature=0,
                        max_tokens=500
                    ),
                    deduct_quota=False
                )
            except ProviderTokenNotInitError as e:
                summary_model_instance = None

            tools = self.to_tools(
                tool_configs=tool_configs,
                callbacks=[agent_callback, DifyStdOutCallbackHandler()],
                agent_model_instance=agent_model_instance,
                conversation_message_task=conversation_message_task,
                rest_tokens=rest_tokens,
                return_resource=return_resource,
                retriever_from=retriever_from,
                dataset_configs=dataset_configs,
                tenant_id=tenant_id
            )

            if len(tools) == 0:
                return None

            agent_configuration = AgentConfiguration(
                strategy=planning_strategy,
                model_instance=agent_model_instance,
                tools=tools,
                summary_model_instance=summary_model_instance,
                memory=memory,
                callbacks=[chain_callback, agent_callback],
                max_iterations=10,
                max_execution_time=400.0,
                early_stopping_method="generate"
            )

            return AgentExecutor(agent_configuration)

        return chain

    def to_tools(self, tool_configs: list, callbacks: Callbacks = None,  **kwargs) -> list[BaseTool]:
        """
        Convert app agent tool configs to tools

        :param tool_configs: app agent tool configs
        :param callbacks:
        :return:
        """
        tools = []
        dataset_tools = []
        for tool_config in tool_configs:
            tool_type = list(tool_config.keys())[0]
            tool_val = list(tool_config.values())[0]
            if not tool_val.get("enabled") or tool_val.get("enabled") is not True:
                continue

            tool = None
            if tool_type == "dataset":
                dataset_tools.append(tool_config)
            elif tool_type == "web_reader":
                tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
            elif tool_type == "google_search":
                tool = self.to_google_search_tool(tool_config=tool_val, **kwargs)
            elif tool_type == "wikipedia":
                tool = self.to_wikipedia_tool(tool_config=tool_val, **kwargs)
            elif tool_type == "current_datetime":
                tool = self.to_current_datetime_tool(tool_config=tool_val, **kwargs)

            if tool:
                if tool.callbacks is not None:
                    tool.callbacks.extend(callbacks)
                else:
                    tool.callbacks = callbacks
                tools.append(tool)
        # format dataset tool
        if len(dataset_tools) > 0:
            dataset_retriever_tools = self.to_dataset_retriever_tool(tool_configs=dataset_tools, **kwargs)
            if dataset_retriever_tools:
                tools.extend(dataset_retriever_tools)
        return tools

    def to_dataset_retriever_tool(self, tool_configs: List, conversation_message_task: ConversationMessageTask,
                                  return_resource: bool = False, retriever_from: str = 'dev',
                                  **kwargs) \
            -> Optional[List[BaseTool]]:
        """
        A dataset tool is a tool that can be used to retrieve information from a dataset
        :param tool_configs:
        :param conversation_message_task:
        :param return_resource:
        :param retriever_from:
        :return:
        """
        dataset_configs = kwargs['dataset_configs']
        retrieval_model = dataset_configs.get('retrieval_model', 'single')
        tools = []
        dataset_ids = []
        tenant_id = None
        for tool_config in tool_configs:
            # get dataset from dataset id
            dataset = db.session.query(Dataset).filter(
                Dataset.tenant_id == self.tenant_id,
                Dataset.id == tool_config.get('dataset').get("id")
            ).first()

            if not dataset:
                continue

            if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
                continue
            dataset_ids.append(dataset.id)
            if retrieval_model == 'single':
                retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
                top_k = retrieval_model_config['top_k']

                # dynamically adjust top_k when the remaining token number is not enough to support top_k
                # top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)

                score_threshold = None
                score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
                if score_threshold_enabled:
                    score_threshold = retrieval_model_config.get("score_threshold")

                tool = DatasetRetrieverTool.from_dataset(
                    dataset=dataset,
                    top_k=top_k,
                    score_threshold=score_threshold,
                    callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
                    conversation_message_task=conversation_message_task,
                    return_resource=return_resource,
                    retriever_from=retriever_from
                )
                tools.append(tool)
        if retrieval_model == 'multiple':
            tool = DatasetMultiRetrieverTool.from_dataset(
                dataset_ids=dataset_ids,
                tenant_id=kwargs['tenant_id'],
                top_k=dataset_configs.get('top_k', 2),
                score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enabled', False) else None,
                callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
                conversation_message_task=conversation_message_task,
                return_resource=return_resource,
                retriever_from=retriever_from,
                reranking_provider_name=dataset_configs.get('reranking_model').get('reranking_provider_name'),
                reranking_model_name=dataset_configs.get('reranking_model').get('reranking_model_name')
            )
            tools.append(tool)

        return tools

    def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
        """
        A tool for reading web pages

        :return:
        """
        try:
            summary_model_instance = ModelFactory.get_text_generation_model(
                tenant_id=self.tenant_id,
                model_provider_name=agent_model_instance.model_provider.provider_name,
                model_name=agent_model_instance.name,
                model_kwargs=ModelKwargs(
                    temperature=0,
                    max_tokens=500
                ),
                deduct_quota=False
            )
        except ProviderTokenNotInitError:
            summary_model_instance = None

        tool = WebReaderTool(
            model_instance=summary_model_instance if summary_model_instance else None,
            max_chunk_length=4000,
            continue_reading=True
        )

        return tool

    def to_google_search_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
        tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
        func_kwargs = tool_provider.credentials_to_func_kwargs()
        if not func_kwargs:
            return None

        tool = Tool(
            name="google_search",
            description="A tool for performing a Google search and extracting snippets and webpages "
                        "when you need to search for something you don't know or when your information "
                        "is not up to date. "
                        "Input should be a search query.",
            func=OptimizedSerpAPIWrapper(**func_kwargs).run,
            args_schema=OptimizedSerpAPIInput
        )

        return tool

    def to_current_datetime_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
        tool = DatetimeTool()

        return tool

    def to_wikipedia_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
        class WikipediaInput(BaseModel):
            query: str = Field(..., description="search query.")

        return WikipediaQueryRun(
            name="wikipedia",
            api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
            args_schema=WikipediaInput
        )

    @classmethod
    def _dynamic_calc_retrieve_k(cls, dataset: Dataset, top_k: int, rest_tokens: int) -> int:
        if rest_tokens == -1:
            return top_k

        processing_rule = dataset.latest_process_rule
        if not processing_rule:
            return top_k

        if processing_rule.mode == "custom":
            rules = processing_rule.rules_dict
            if not rules:
                return top_k

            segmentation = rules["segmentation"]
            segment_max_tokens = segmentation["max_tokens"]
        else:
            segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']

        # when rest_tokens is less than default context tokens
        if rest_tokens < segment_max_tokens * top_k:
            return rest_tokens // segment_max_tokens

        return min(top_k, 10)