Browse Source

feat: support LLM understand video (#9828)

非法操作 5 months ago
parent
commit
033ab5490b

+ 2 - 1
api/.env.example

@@ -285,8 +285,9 @@ UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
 UPLOAD_VIDEO_FILE_SIZE_LIMIT=100
 UPLOAD_AUDIO_FILE_SIZE_LIMIT=50
 
-# Model Configuration
+# Model configuration
 MULTIMODAL_SEND_IMAGE_FORMAT=base64
+MULTIMODAL_SEND_VIDEO_FORMAT=base64
 PROMPT_GENERATION_MAX_TOKENS=512
 CODE_GENERATION_MAX_TOKENS=1024
 

+ 7 - 2
api/configs/feature/__init__.py

@@ -634,12 +634,17 @@ class IndexingConfig(BaseSettings):
     )
 
 
-class ImageFormatConfig(BaseSettings):
+class VisionFormatConfig(BaseSettings):
     MULTIMODAL_SEND_IMAGE_FORMAT: Literal["base64", "url"] = Field(
         description="Format for sending images in multimodal contexts ('base64' or 'url'), default is base64",
         default="base64",
     )
 
+    MULTIMODAL_SEND_VIDEO_FORMAT: Literal["base64", "url"] = Field(
+        description="Format for sending videos in multimodal contexts ('base64' or 'url'), default is base64",
+        default="base64",
+    )
+
 
 class CeleryBeatConfig(BaseSettings):
     CELERY_BEAT_SCHEDULER_TIME: int = Field(
@@ -742,7 +747,7 @@ class FeatureConfig(
     FileAccessConfig,
     FileUploadConfig,
     HttpConfig,
-    ImageFormatConfig,
+    VisionFormatConfig,
     InnerAPIConfig,
     IndexingConfig,
     LoggingConfig,

+ 10 - 2
api/core/file/file_manager.py

@@ -3,7 +3,7 @@ import base64
 from configs import dify_config
 from core.file import file_repository
 from core.helper import ssrf_proxy
-from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent
+from core.model_runtime.entities import AudioPromptMessageContent, ImagePromptMessageContent, VideoPromptMessageContent
 from extensions.ext_database import db
 from extensions.ext_storage import storage
 
@@ -71,6 +71,12 @@ def to_prompt_message_content(f: File, /):
             if f.extension is None:
                 raise ValueError("Missing file extension")
             return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
+        case FileType.VIDEO:
+            if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
+                data = _to_url(f)
+            else:
+                data = _to_base64_data_string(f)
+            return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
         case _:
             raise ValueError(f"file type {f.type} is not supported")
 
@@ -112,7 +118,7 @@ def _download_file_content(path: str, /):
 def _get_encoded_string(f: File, /):
     match f.transfer_method:
         case FileTransferMethod.REMOTE_URL:
-            response = ssrf_proxy.get(f.remote_url)
+            response = ssrf_proxy.get(f.remote_url, follow_redirects=True)
             response.raise_for_status()
             content = response.content
             encoded_string = base64.b64encode(content).decode("utf-8")
@@ -140,6 +146,8 @@ def _file_to_encoded_string(f: File, /):
     match f.type:
         case FileType.IMAGE:
             return _to_base64_data_string(f)
+        case FileType.VIDEO:
+            return _to_base64_data_string(f)
         case FileType.AUDIO:
             return _get_encoded_string(f)
         case _:

+ 2 - 0
api/core/model_runtime/entities/__init__.py

@@ -12,11 +12,13 @@ from .message_entities import (
     TextPromptMessageContent,
     ToolPromptMessage,
     UserPromptMessage,
+    VideoPromptMessageContent,
 )
 from .model_entities import ModelPropertyKey
 
 __all__ = [
     "ImagePromptMessageContent",
+    "VideoPromptMessageContent",
     "PromptMessage",
     "PromptMessageRole",
     "LLMUsage",

+ 7 - 0
api/core/model_runtime/entities/message_entities.py

@@ -56,6 +56,7 @@ class PromptMessageContentType(Enum):
     TEXT = "text"
     IMAGE = "image"
     AUDIO = "audio"
+    VIDEO = "video"
 
 
 class PromptMessageContent(BaseModel):
@@ -75,6 +76,12 @@ class TextPromptMessageContent(PromptMessageContent):
     type: PromptMessageContentType = PromptMessageContentType.TEXT
 
 
+class VideoPromptMessageContent(PromptMessageContent):
+    type: PromptMessageContentType = PromptMessageContentType.VIDEO
+    data: str = Field(..., description="Base64 encoded video data")
+    format: str = Field(..., description="Video format")
+
+
 class AudioPromptMessageContent(PromptMessageContent):
     type: PromptMessageContentType = PromptMessageContentType.AUDIO
     data: str = Field(..., description="Base64 encoded audio data")

+ 9 - 0
api/core/model_runtime/model_providers/tongyi/llm/llm.py

@@ -29,6 +29,7 @@ from core.model_runtime.entities.message_entities import (
     TextPromptMessageContent,
     ToolPromptMessage,
     UserPromptMessage,
+    VideoPromptMessageContent,
 )
 from core.model_runtime.entities.model_entities import (
     AIModelEntity,
@@ -431,6 +432,14 @@ class TongyiLargeLanguageModel(LargeLanguageModel):
 
                             sub_message_dict = {"image": image_url}
                             sub_messages.append(sub_message_dict)
+                        elif message_content.type == PromptMessageContentType.VIDEO:
+                            message_content = cast(VideoPromptMessageContent, message_content)
+                            video_url = message_content.data
+                            if message_content.data.startswith("data:"):
+                                raise InvokeError("not support base64, please set MULTIMODAL_SEND_VIDEO_FORMAT to url")
+
+                            sub_message_dict = {"video": video_url}
+                            sub_messages.append(sub_message_dict)
 
                     # resort sub_messages to ensure text is always at last
                     sub_messages = sorted(sub_messages, key=lambda x: "text" in x)

+ 26 - 12
api/core/model_runtime/model_providers/zhipuai/llm/llm.py

@@ -313,21 +313,35 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
         return params
 
     def _construct_glm_4v_messages(self, prompt_message: Union[str, list[PromptMessageContent]]) -> list[dict]:
-        if isinstance(prompt_message, str):
+        if isinstance(prompt_message, list):
+            sub_messages = []
+            for item in prompt_message:
+                if item.type == PromptMessageContentType.IMAGE:
+                    sub_messages.append(
+                        {
+                            "type": "image_url",
+                            "image_url": {"url": self._remove_base64_header(item.data)},
+                        }
+                    )
+                elif item.type == PromptMessageContentType.VIDEO:
+                    sub_messages.append(
+                        {
+                            "type": "video_url",
+                            "video_url": {"url": self._remove_base64_header(item.data)},
+                        }
+                    )
+                else:
+                    sub_messages.append({"type": "text", "text": item.data})
+            return sub_messages
+        else:
             return [{"type": "text", "text": prompt_message}]
 
-        return [
-            {"type": "image_url", "image_url": {"url": self._remove_image_header(item.data)}}
-            if item.type == PromptMessageContentType.IMAGE
-            else {"type": "text", "text": item.data}
-            for item in prompt_message
-        ]
-
-    def _remove_image_header(self, image: str) -> str:
-        if image.startswith("data:image"):
-            return image.split(",")[1]
+    def _remove_base64_header(self, file_content: str) -> str:
+        if file_content.startswith("data:"):
+            data_split = file_content.split(";base64,")
+            return data_split[1]
 
-        return image
+        return file_content
 
     def _handle_generate_response(
         self,

+ 4 - 1
api/core/workflow/nodes/llm/node.py

@@ -14,6 +14,7 @@ from core.model_runtime.entities import (
     PromptMessage,
     PromptMessageContentType,
     TextPromptMessageContent,
+    VideoPromptMessageContent,
 )
 from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
 from core.model_runtime.entities.model_entities import ModelType
@@ -560,7 +561,9 @@ class LLMNode(BaseNode[LLMNodeData]):
                         # cuz vision detail is related to the configuration from FileUpload feature.
                         content_item.detail = vision_detail
                         prompt_message_content.append(content_item)
-                    elif isinstance(content_item, TextPromptMessageContent | AudioPromptMessageContent):
+                    elif isinstance(
+                        content_item, TextPromptMessageContent | AudioPromptMessageContent | VideoPromptMessageContent
+                    ):
                         prompt_message_content.append(content_item)
 
                 if len(prompt_message_content) > 1:

+ 2 - 2
web/app/components/app/configuration/index.tsx

@@ -468,8 +468,8 @@ const Configuration: FC = () => {
           transfer_methods: modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
         },
         enabled: !!(modelConfig.file_upload?.enabled || modelConfig.file_upload?.image?.enabled),
-        allowed_file_types: modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image],
-        allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || FILE_EXTS[SupportUploadFileTypes.image].map(ext => `.${ext}`),
+        allowed_file_types: modelConfig.file_upload?.allowed_file_types || [SupportUploadFileTypes.image, SupportUploadFileTypes.video],
+        allowed_file_extensions: modelConfig.file_upload?.allowed_file_extensions || [...FILE_EXTS[SupportUploadFileTypes.image], ...FILE_EXTS[SupportUploadFileTypes.video]].map(ext => `.${ext}`),
         allowed_file_upload_methods: modelConfig.file_upload?.allowed_file_upload_methods || modelConfig.file_upload?.image?.transfer_methods || ['local_file', 'remote_url'],
         number_limits: modelConfig.file_upload?.number_limits || modelConfig.file_upload?.image?.number_limits || 3,
         fileUploadConfig: fileUploadConfigResponse,