import logging
import time
from typing import Any, Dict, List, Union, Optional

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, LLMResult, HumanMessage, AIMessage, SystemMessage

from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.llm.streamable_chat_open_ai import StreamableChatOpenAI
from core.llm.streamable_open_ai import StreamableOpenAI


class LLMCallbackHandler(BaseCallbackHandler):

    def __init__(self, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
                 conversation_message_task: ConversationMessageTask):
        self.llm = llm
        self.llm_message = LLMMessage()
        self.start_at = None
        self.conversation_message_task = conversation_message_task

    @property
    def always_verbose(self) -> bool:
        """Whether to call verbose callbacks even if verbose is False."""
        return True

    def on_llm_start(
        self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
    ) -> None:
        self.start_at = time.perf_counter()

        if 'Chat' in serialized['name']:
            real_prompts = []
            messages = []
            for prompt in prompts:
                role, content = prompt.split(': ', maxsplit=1)
                if role == 'human':
                    role = 'user'
                    message = HumanMessage(content=content)
                elif role == 'ai':
                    role = 'assistant'
                    message = AIMessage(content=content)
                else:
                    message = SystemMessage(content=content)

                real_prompt = {
                    "role": role,
                    "text": content
                }
                real_prompts.append(real_prompt)
                messages.append(message)

            self.llm_message.prompt = real_prompts
            self.llm_message.prompt_tokens = self.llm.get_messages_tokens(messages)
        else:
            self.llm_message.prompt = [{
                "role": 'user',
                "text": prompts[0]
            }]

            self.llm_message.prompt_tokens = self.llm.get_num_tokens(prompts[0])

    def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
        end_at = time.perf_counter()
        self.llm_message.latency = end_at - self.start_at

        if not self.conversation_message_task.streaming:
            self.conversation_message_task.append_message_text(response.generations[0][0].text)
            self.llm_message.completion = response.generations[0][0].text
            self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
        else:
            self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)

        self.conversation_message_task.save_message(self.llm_message)

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        self.conversation_message_task.append_message_text(token)
        self.llm_message.completion += token

    def on_llm_error(
            self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> None:
        """Do nothing."""
        if isinstance(error, ConversationTaskStoppedException):
            if self.conversation_message_task.streaming:
                end_at = time.perf_counter()
                self.llm_message.latency = end_at - self.start_at
                self.llm_message.completion_tokens = self.llm.get_num_tokens(self.llm_message.completion)
                self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
        else:
            logging.error(error)

    def on_chain_start(
            self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
    ) -> None:
        pass

    def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
        pass

    def on_chain_error(
            self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> None:
        pass

    def on_tool_start(
            self,
            serialized: Dict[str, Any],
            input_str: str,
            **kwargs: Any,
    ) -> None:
        pass

    def on_agent_action(
            self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
    ) -> Any:
        pass

    def on_tool_end(
            self,
            output: str,
            color: Optional[str] = None,
            observation_prefix: Optional[str] = None,
            llm_prefix: Optional[str] = None,
            **kwargs: Any,
    ) -> None:
        pass

    def on_tool_error(
            self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
    ) -> None:
        pass

    def on_text(
            self,
            text: str,
            color: Optional[str] = None,
            end: str = "",
            **kwargs: Optional[str],
    ) -> None:
        pass

    def on_agent_finish(
            self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
    ) -> None:
        pass