|
@@ -0,0 +1,81 @@
|
|
|
+from typing import IO, Optional
|
|
|
+
|
|
|
+from openai import AzureOpenAI
|
|
|
+
|
|
|
+from core.model_runtime.entities.model_entities import AIModelEntity
|
|
|
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
|
|
+from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
|
|
+from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
|
|
|
+from core.model_runtime.model_providers.azure_openai._constant import SPEECH2TEXT_BASE_MODELS, AzureBaseModel
|
|
|
+
|
|
|
+
|
|
|
+class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
|
|
|
+ """
|
|
|
+ Model class for OpenAI Speech to text model.
|
|
|
+ """
|
|
|
+
|
|
|
+ def _invoke(self, model: str, credentials: dict,
|
|
|
+ file: IO[bytes], user: Optional[str] = None) \
|
|
|
+ -> str:
|
|
|
+ """
|
|
|
+ Invoke speech2text model
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param file: audio file
|
|
|
+ :param user: unique user id
|
|
|
+ :return: text for given audio file
|
|
|
+ """
|
|
|
+ return self._speech2text_invoke(model, credentials, file)
|
|
|
+
|
|
|
+ def validate_credentials(self, model: str, credentials: dict) -> None:
|
|
|
+ """
|
|
|
+ Validate model credentials
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :return:
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ audio_file_path = self._get_demo_file_path()
|
|
|
+
|
|
|
+ with open(audio_file_path, 'rb') as audio_file:
|
|
|
+ self._speech2text_invoke(model, credentials, audio_file)
|
|
|
+ except Exception as ex:
|
|
|
+ raise CredentialsValidateFailedError(str(ex))
|
|
|
+
|
|
|
+ def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
|
|
|
+ """
|
|
|
+ Invoke speech2text model
|
|
|
+
|
|
|
+ :param model: model name
|
|
|
+ :param credentials: model credentials
|
|
|
+ :param file: audio file
|
|
|
+ :return: text for given audio file
|
|
|
+ """
|
|
|
+ # transform credentials to kwargs for model instance
|
|
|
+ credentials_kwargs = self._to_credential_kwargs(credentials)
|
|
|
+
|
|
|
+ # init model client
|
|
|
+ client = AzureOpenAI(**credentials_kwargs)
|
|
|
+
|
|
|
+ response = client.audio.transcriptions.create(model=model, file=file)
|
|
|
+
|
|
|
+ return response.text
|
|
|
+
|
|
|
+ def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
|
|
|
+ ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
|
|
|
+ return ai_model_entity.entity
|
|
|
+
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
|
|
|
+ for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
|
|
|
+ if ai_model_entity.base_model_name == base_model_name:
|
|
|
+ ai_model_entity_copy = copy.deepcopy(ai_model_entity)
|
|
|
+ ai_model_entity_copy.entity.model = model
|
|
|
+ ai_model_entity_copy.entity.label.en_US = model
|
|
|
+ ai_model_entity_copy.entity.label.zh_Hans = model
|
|
|
+ return ai_model_entity_copy
|
|
|
+
|
|
|
+ return None
|