import os

from langchain.callbacks.manager import Callbacks
from langchain.schema import BaseMessage, LLMResult
from langchain.chat_models import ChatOpenAI
from typing import Optional, List, Dict, Any, Union, Tuple

from pydantic import root_validator

from core.llm.wrappers.openai_wrapper import handle_openai_exceptions


class StreamableChatOpenAI(ChatOpenAI):
    request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
    """Timeout for requests to OpenAI completion API. Default is 600 seconds."""
    max_retries: int = 1
    """Maximum number of retries to make when generating."""

    @root_validator()
    def validate_environment(cls, values: Dict) -> Dict:
        """Validate that api key and python package exists in environment."""
        try:
            import openai
        except ImportError:
            raise ValueError(
                "Could not import openai python package. "
                "Please install it with `pip install openai`."
            )
        try:
            values["client"] = openai.ChatCompletion
        except AttributeError:
            raise ValueError(
                "`openai` has no `ChatCompletion` attribute, this is likely "
                "due to an old version of the openai package. Try upgrading it "
                "with `pip install --upgrade openai`."
            )
        if values["n"] < 1:
            raise ValueError("n must be at least 1.")
        if values["n"] > 1 and values["streaming"]:
            raise ValueError("n must be 1 when streaming.")
        return values

    @property
    def _default_params(self) -> Dict[str, Any]:
        """Get the default parameters for calling OpenAI API."""
        return {
            **super()._default_params,
            "api_type": 'openai',
            "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
            "api_version": None,
            "api_key": self.openai_api_key,
            "organization": self.openai_organization if self.openai_organization else None,
        }

    @handle_openai_exceptions
    def generate(
            self,
            messages: List[List[BaseMessage]],
            stop: Optional[List[str]] = None,
            callbacks: Callbacks = None,
            **kwargs: Any,
    ) -> LLMResult:
        return super().generate(messages, stop, callbacks, **kwargs)

    @classmethod
    def get_kwargs_from_model_params(cls, params: dict):
        model_kwargs = {
            'top_p': params.get('top_p', 1),
            'frequency_penalty': params.get('frequency_penalty', 0),
            'presence_penalty': params.get('presence_penalty', 0),
        }

        del params['top_p']
        del params['frequency_penalty']
        del params['presence_penalty']

        params['model_kwargs'] = model_kwargs

        return params