Browse Source

feat:support azure whisper model and fix:rename text-embedidng-ada-002.yaml to text-embedding-ada-002.yaml (#2732)

呆萌闷油瓶 1 year ago
parent
commit
9819ad347f

+ 17 - 0
api/core/model_runtime/model_providers/azure_openai/_constant.py

@@ -526,3 +526,20 @@ EMBEDDING_BASE_MODELS = [
         )
     )
 ]
+SPEECH2TEXT_BASE_MODELS = [
+    AzureBaseModel(
+        base_model_name='whisper-1',
+        entity=AIModelEntity(
+            model='fake-deployment-name',
+            label=I18nObject(
+                en_US='fake-deployment-name-label'
+            ),
+            fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
+            model_type=ModelType.SPEECH2TEXT,
+            model_properties={
+                ModelPropertyKey.FILE_UPLOAD_LIMIT: 25,
+                ModelPropertyKey.SUPPORTED_FILE_EXTENSIONS: 'flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm'
+            }
+        )
+    )
+]

+ 7 - 0
api/core/model_runtime/model_providers/azure_openai/azure_openai.yaml

@@ -15,6 +15,7 @@ help:
 supported_model_types:
   - llm
   - text-embedding
+  - speech2text
 configurate_methods:
   - customizable-model
 model_credential_schema:
@@ -99,6 +100,12 @@ model_credential_schema:
           show_on:
             - variable: __model_type
               value: text-embedding
+        - label:
+            en_US: whisper-1
+          value: whisper-1
+          show_on:
+            - variable: __model_type
+              value: speech2text
       placeholder:
         zh_Hans: 在此输入您的模型版本
         en_US: Enter your model version

+ 0 - 0
api/core/model_runtime/model_providers/azure_openai/speech2text/__init__.py


+ 81 - 0
api/core/model_runtime/model_providers/azure_openai/speech2text/speech2text.py

@@ -0,0 +1,81 @@
+from typing import IO, Optional
+
+from openai import AzureOpenAI
+
+from core.model_runtime.entities.model_entities import AIModelEntity
+from core.model_runtime.errors.validate import CredentialsValidateFailedError
+from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
+from core.model_runtime.model_providers.azure_openai._common import _CommonAzureOpenAI
+from core.model_runtime.model_providers.azure_openai._constant import SPEECH2TEXT_BASE_MODELS, AzureBaseModel
+
+
+class AzureOpenAISpeech2TextModel(_CommonAzureOpenAI, Speech2TextModel):
+    """
+    Model class for OpenAI Speech to text model.
+    """
+
+    def _invoke(self, model: str, credentials: dict,
+                file: IO[bytes], user: Optional[str] = None) \
+            -> str:
+        """
+        Invoke speech2text model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param file: audio file
+        :param user: unique user id
+        :return: text for given audio file
+        """
+        return self._speech2text_invoke(model, credentials, file)
+
+    def validate_credentials(self, model: str, credentials: dict) -> None:
+        """
+        Validate model credentials
+
+        :param model: model name
+        :param credentials: model credentials
+        :return:
+        """
+        try:
+            audio_file_path = self._get_demo_file_path()
+
+            with open(audio_file_path, 'rb') as audio_file:
+                self._speech2text_invoke(model, credentials, audio_file)
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    def _speech2text_invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str:
+        """
+        Invoke speech2text model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param file: audio file
+        :return: text for given audio file
+        """
+        # transform credentials to kwargs for model instance
+        credentials_kwargs = self._to_credential_kwargs(credentials)
+
+        # init model client
+        client = AzureOpenAI(**credentials_kwargs)
+
+        response = client.audio.transcriptions.create(model=model, file=file)
+
+        return response.text
+
+    def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]:
+        ai_model_entity = self._get_ai_model_entity(credentials['base_model_name'], model)
+        return ai_model_entity.entity
+
+
+    @staticmethod
+    def _get_ai_model_entity(base_model_name: str, model: str) -> AzureBaseModel:
+        for ai_model_entity in SPEECH2TEXT_BASE_MODELS:
+            if ai_model_entity.base_model_name == base_model_name:
+                ai_model_entity_copy = copy.deepcopy(ai_model_entity)
+                ai_model_entity_copy.entity.model = model
+                ai_model_entity_copy.entity.label.en_US = model
+                ai_model_entity_copy.entity.label.zh_Hans = model
+                return ai_model_entity_copy
+
+        return None

+ 1 - 1
api/core/model_runtime/model_providers/openai/speech2text/whisper-1.yaml

@@ -2,4 +2,4 @@ model: whisper-1
 model_type: speech2text
 model_properties:
   file_upload_limit: 25
-  supported_file_extensions: mp3,mp4,mpeg,mpga,m4a,wav,webm
+  supported_file_extensions: flac,mp3,mp4,mpeg,mpga,m4a,ogg,wav,webm

+ 0 - 0
api/core/model_runtime/model_providers/openai/text_embedding/text-embedidng-ada-002.yaml → api/core/model_runtime/model_providers/openai/text_embedding/text-embedding-ada-002.yaml