"""
    For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.

    Therefore, a model manager is needed to list/invoke/validate models.
"""

import json
from typing import cast

from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import (
    InvokeAuthorizationError,
    InvokeBadRequestError,
    InvokeConnectionError,
    InvokeRateLimitError,
    InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel, ModelPropertyKey
from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from models.tools import ToolModelInvoke


class InvokeModelError(Exception):
    pass

class ModelInvocationUtils:
    @staticmethod
    def get_max_llm_context_tokens(
        tenant_id: str,
    ) -> int:
        """
            get max llm context tokens of the model
        """
        model_manager = ModelManager()
        model_instance = model_manager.get_default_model_instance(
            tenant_id=tenant_id, model_type=ModelType.LLM,
        )

        if not model_instance:
            raise InvokeModelError('Model not found')
        
        llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
        schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)

        if not schema:
            raise InvokeModelError('No model schema found')

        max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
        if max_tokens is None:
            return 2048
        
        return max_tokens

    @staticmethod
    def calculate_tokens(
        tenant_id: str,
        prompt_messages: list[PromptMessage]
    ) -> int:
        """
            calculate tokens from prompt messages and model parameters
        """

        # get model instance
        model_manager = ModelManager()
        model_instance = model_manager.get_default_model_instance(
            tenant_id=tenant_id, model_type=ModelType.LLM
        )

        if not model_instance:
            raise InvokeModelError('Model not found')
        
        # get tokens
        tokens = model_instance.get_llm_num_tokens(prompt_messages)

        return tokens

    @staticmethod
    def invoke(
        user_id: str, tenant_id: str,
        tool_type: str, tool_name: str,
        prompt_messages: list[PromptMessage]
    ) -> LLMResult:
        """
        invoke model with parameters in user's own context

        :param user_id: user id
        :param tenant_id: tenant id, the tenant id of the creator of the tool
        :param tool_provider: tool provider
        :param tool_id: tool id
        :param tool_name: tool name
        :param provider: model provider
        :param model: model name
        :param model_parameters: model parameters
        :param prompt_messages: prompt messages
        :return: AssistantPromptMessage
        """

        # get model manager
        model_manager = ModelManager()
        # get model instance
        model_instance = model_manager.get_default_model_instance(
            tenant_id=tenant_id, model_type=ModelType.LLM,
        )

        # get prompt tokens
        prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)

        model_parameters = {
            'temperature': 0.8,
            'top_p': 0.8,
        }

        # create tool model invoke
        tool_model_invoke = ToolModelInvoke(
            user_id=user_id,
            tenant_id=tenant_id,
            provider=model_instance.provider,
            tool_type=tool_type,
            tool_name=tool_name,
            model_parameters=json.dumps(model_parameters),
            prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
            model_response='',
            prompt_tokens=prompt_tokens,
            answer_tokens=0,
            answer_unit_price=0,
            answer_price_unit=0,
            provider_response_latency=0,
            total_price=0,
            currency='USD',
        )

        db.session.add(tool_model_invoke)
        db.session.commit()

        try:
            response: LLMResult = model_instance.invoke_llm(
                prompt_messages=prompt_messages,
                model_parameters=model_parameters,
                tools=[], stop=[], stream=False, user=user_id, callbacks=[]
            )
        except InvokeRateLimitError as e:
            raise InvokeModelError(f'Invoke rate limit error: {e}')
        except InvokeBadRequestError as e:
            raise InvokeModelError(f'Invoke bad request error: {e}')
        except InvokeConnectionError as e:
            raise InvokeModelError(f'Invoke connection error: {e}')
        except InvokeAuthorizationError as e:
            raise InvokeModelError('Invoke authorization error')
        except InvokeServerUnavailableError as e:
            raise InvokeModelError(f'Invoke server unavailable error: {e}')
        except Exception as e:
            raise InvokeModelError(f'Invoke error: {e}')

        # update tool model invoke
        tool_model_invoke.model_response = response.message.content
        if response.usage:
            tool_model_invoke.answer_tokens = response.usage.completion_tokens
            tool_model_invoke.answer_unit_price = response.usage.completion_unit_price
            tool_model_invoke.answer_price_unit = response.usage.completion_price_unit
            tool_model_invoke.provider_response_latency = response.usage.latency
            tool_model_invoke.total_price = response.usage.total_price
            tool_model_invoke.currency = response.usage.currency

        db.session.commit()

        return response