import os

from langchain.schema import BaseMessage, ChatResult, LLMResult
from langchain.chat_models import ChatOpenAI
from typing import Optional, List, Dict, Any

from pydantic import root_validator

from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async


class StreamableChatOpenAI(ChatOpenAI):

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that api key and python package exists in environment."""
        try:
            import openai
        except ImportError:
            raise ValueError(
                "Could not import openai python package. "
                "Please install it with `pip install openai`."
            )
        try:
            values["client"] = openai.ChatCompletion
        except AttributeError:
            raise ValueError(
                "`openai` has no `ChatCompletion` attribute, this is likely "
                "due to an old version of the openai package. Try upgrading it "
                "with `pip install --upgrade openai`."
            )
        if values["n"] < 1:
            raise ValueError("n must be at least 1.")
        if values["n"] > 1 and values["streaming"]:
            raise ValueError("n must be 1 when streaming.")
        return values

    @property
    def _default_params(self) -> Dict[str, Any]:
        """Get the default parameters for calling OpenAI API."""
        return {
            **super()._default_params,
            "api_type": 'openai',
            "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
            "api_version": None,
            "api_key": self.openai_api_key,
            "organization": self.openai_organization if self.openai_organization else None,
        }

    def get_messages_tokens(self, messages: List[BaseMessage]) -> int:
        """Get the number of tokens in a list of messages.

        Args:
            messages: The messages to count the tokens of.

        Returns:
            The number of tokens in the messages.
        """
        tokens_per_message = 5
        tokens_per_request = 3

        message_tokens = tokens_per_request
        message_strs = ''
        for message in messages:
            message_strs += message.content
            message_tokens += tokens_per_message

        # calc once
        message_tokens += self.get_num_tokens(message_strs)

        return message_tokens

    def _generate(
        self, messages: List[BaseMessage], stop: Optional[List[str]] = None
    ) -> ChatResult:
        self.callback_manager.on_llm_start(
            {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
        )

        chat_result = super()._generate(messages, stop)

        result = LLMResult(
            generations=[chat_result.generations],
            llm_output=chat_result.llm_output
        )
        self.callback_manager.on_llm_end(result, verbose=self.verbose)

        return chat_result

    async def _agenerate(
        self, messages: List[BaseMessage], stop: Optional[List[str]] = None
    ) -> ChatResult:
        if self.callback_manager.is_async:
            await self.callback_manager.on_llm_start(
                {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
            )
        else:
            self.callback_manager.on_llm_start(
                {"name": self.__class__.__name__}, [(message.type + ": " + message.content) for message in messages], verbose=self.verbose
            )

        chat_result = super()._generate(messages, stop)

        result = LLMResult(
            generations=[chat_result.generations],
            llm_output=chat_result.llm_output
        )

        if self.callback_manager.is_async:
            await self.callback_manager.on_llm_end(result, verbose=self.verbose)
        else:
            self.callback_manager.on_llm_end(result, verbose=self.verbose)

        return chat_result

    @handle_llm_exceptions
    def generate(
            self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
    ) -> LLMResult:
        return super().generate(messages, stop)

    @handle_llm_exceptions_async
    async def agenerate(
            self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
    ) -> LLMResult:
        return await super().agenerate(messages, stop)