| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 | import refrom collections.abc import Generatorfrom time import time# import monkeypatchfrom typing import Any, Literal, Optional, Unionfrom openai import AzureOpenAI, BadRequestError, OpenAIfrom openai._types import NOT_GIVEN, NotGivenfrom openai.resources.completions import Completionsfrom openai.types import Completion as CompletionMessagefrom openai.types.completion import CompletionChoicefrom openai.types.completion_usage import CompletionUsagefrom core.model_runtime.errors.invoke import InvokeAuthorizationErrorclass MockCompletionsClass:    @staticmethod    def mocked_openai_completion_create_sync(model: str) -> CompletionMessage:        return CompletionMessage(            id="cmpl-3QJQa5jXJ5Z5X",            object="text_completion",            created=int(time()),            model=model,            system_fingerprint="",            choices=[                CompletionChoice(                    text="mock",                    index=0,                    logprobs=None,                    finish_reason="stop",                )            ],            usage=CompletionUsage(                prompt_tokens=2,                completion_tokens=1,                total_tokens=3,            ),        )    @staticmethod    def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]:        full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"        for i in range(0, len(full_text) + 1):            if i == len(full_text):                yield CompletionMessage(                    id="cmpl-3QJQa5jXJ5Z5X",                    object="text_completion",                    created=int(time()),                    model=model,                    system_fingerprint="",                    choices=[                        CompletionChoice(                            text="",                            index=0,                            logprobs=None,                            finish_reason="stop",                        )                    ],                    usage=CompletionUsage(                        prompt_tokens=2,                        completion_tokens=17,                        total_tokens=19,                    ),                )            else:                yield CompletionMessage(                    id="cmpl-3QJQa5jXJ5Z5X",                    object="text_completion",                    created=int(time()),                    model=model,                    system_fingerprint="",                    choices=[                        CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter")                    ],                )    def completion_create(        self: Completions,        *,        model: Union[            str,            Literal[                "babbage-002",                "davinci-002",                "gpt-3.5-turbo-instruct",                "text-davinci-003",                "text-davinci-002",                "text-davinci-001",                "code-davinci-002",                "text-curie-001",                "text-babbage-001",                "text-ada-001",            ],        ],        prompt: Union[str, list[str], list[int], list[list[int]], None],        stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,        **kwargs: Any,    ):        openai_models = [            "babbage-002",            "davinci-002",            "gpt-3.5-turbo-instruct",            "text-davinci-003",            "text-davinci-002",            "text-davinci-001",            "code-davinci-002",            "text-curie-001",            "text-babbage-001",            "text-ada-001",        ]        azure_openai_models = ["gpt-35-turbo-instruct"]        if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)):            raise InvokeAuthorizationError("Invalid base url")        if model in openai_models + azure_openai_models:            if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:                # sometime, provider use OpenAI compatible API will not have api key or have different api key format                # so we only check if model is in openai_models                raise InvokeAuthorizationError("Invalid api key")            if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:                raise InvokeAuthorizationError("Invalid api key")        if not prompt:            raise BadRequestError("Invalid prompt")        if stream:            return MockCompletionsClass.mocked_openai_completion_create_stream(model=model)        return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)
 |