Browse Source

feat: [backend] vision support (#1510)

Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
takatost 1 year ago
parent
commit
41d0a8b295
61 changed files with 1551 additions and 288 deletions
  1. 11 0
      api/.env.example
  2. 7 0
      api/app.py
  3. 120 83
      api/config.py
  4. 4 0
      api/controllers/console/app/completion.py
  5. 2 4
      api/controllers/console/app/conversation.py
  6. 0 3
      api/controllers/console/datasets/data_source.py
  7. 5 6
      api/controllers/console/datasets/file.py
  8. 12 0
      api/controllers/console/explore/completion.py
  9. 11 3
      api/controllers/console/explore/conversation.py
  10. 3 2
      api/controllers/console/explore/installed_app.py
  11. 13 2
      api/controllers/console/explore/parameter.py
  12. 2 0
      api/controllers/console/explore/saved_message.py
  13. 3 0
      api/controllers/console/universal_chat/chat.py
  14. 9 2
      api/controllers/console/universal_chat/conversation.py
  15. 10 0
      api/controllers/files/__init__.py
  16. 40 0
      api/controllers/files/image_preview.py
  17. 1 1
      api/controllers/service_api/__init__.py
  18. 13 2
      api/controllers/service_api/app/app.py
  19. 6 1
      api/controllers/service_api/app/completion.py
  20. 9 2
      api/controllers/service_api/app/conversation.py
  21. 23 0
      api/controllers/service_api/app/error.py
  22. 42 0
      api/controllers/service_api/app/file.py
  23. 2 1
      api/controllers/service_api/app/message.py
  24. 3 2
      api/controllers/service_api/dataset/document.py
  25. 1 1
      api/controllers/web/__init__.py
  26. 13 2
      api/controllers/web/app.py
  27. 4 0
      api/controllers/web/completion.py
  28. 9 2
      api/controllers/web/conversation.py
  29. 25 1
      api/controllers/web/error.py
  30. 36 0
      api/controllers/web/file.py
  31. 2 0
      api/controllers/web/message.py
  32. 3 0
      api/controllers/web/saved_message.py
  33. 8 2
      api/core/callback_handler/llm_callback_handler.py
  34. 24 9
      api/core/completion.py
  35. 24 6
      api/core/conversation_message_task.py
  36. 0 0
      api/core/file/__init__.py
  37. 79 0
      api/core/file/file_obj.py
  38. 180 0
      api/core/file/message_file_parser.py
  39. 79 0
      api/core/file/upload_file_parser.py
  40. 6 2
      api/core/generator/llm_generator.py
  41. 19 1
      api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py
  42. 46 2
      api/core/model_providers/models/entity/message.py
  43. 0 2
      api/core/model_providers/models/llm/openai_model.py
  44. 153 76
      api/core/prompt/prompt_transform.py
  45. 103 1
      api/core/third_party/langchain/llms/chat_open_ai.py
  46. 4 10
      api/events/event_handlers/generate_conversation_name_when_first_message_created.py
  47. 36 1
      api/extensions/ext_storage.py
  48. 3 2
      api/fields/app_fields.py
  49. 9 7
      api/fields/conversation_fields.py
  50. 2 1
      api/fields/file_fields.py
  51. 2 0
      api/fields/message_fields.py
  52. 59 0
      api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py
  53. 63 4
      api/models/model.py
  54. 33 1
      api/services/app_model_config_service.py
  55. 40 10
      api/services/completion_service.py
  56. 36 4
      api/services/conversation_service.py
  57. 54 21
      api/services/file_service.py
  58. 4 2
      api/services/web_conversation_service.py
  59. 29 1
      api/tests/integration_tests/models/llm/test_openai_model.py
  60. 7 3
      docker/docker-compose.yaml
  61. 5 0
      docker/nginx/conf.d/default.conf

+ 11 - 0
api/.env.example

@@ -18,6 +18,9 @@ SERVICE_API_URL=http://127.0.0.1:5001
 APP_API_URL=http://127.0.0.1:5001
 APP_WEB_URL=http://127.0.0.1:3000
 
+# Files URL
+FILES_URL=http://127.0.0.1:5001
+
 # celery configuration
 CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
 
@@ -70,6 +73,14 @@ MILVUS_USER=root
 MILVUS_PASSWORD=Milvus
 MILVUS_SECURE=false
 
+# Upload configuration
+UPLOAD_FILE_SIZE_LIMIT=15
+UPLOAD_FILE_BATCH_LIMIT=5
+UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
+
+# Model Configuration
+MULTIMODAL_SEND_IMAGE_FORMAT=base64
+
 # Mail configuration, support: resend
 MAIL_TYPE=
 MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>

+ 7 - 0
api/app.py

@@ -126,6 +126,7 @@ def register_blueprints(app):
     from controllers.service_api import bp as service_api_bp
     from controllers.web import bp as web_bp
     from controllers.console import bp as console_app_bp
+    from controllers.files import bp as files_bp
 
     CORS(service_api_bp,
          allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
@@ -155,6 +156,12 @@ def register_blueprints(app):
 
     app.register_blueprint(console_app_bp)
 
+    CORS(files_bp,
+         allow_headers=['Content-Type'],
+         methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
+         )
+    app.register_blueprint(files_bp)
+
 
 # create app
 app = create_app()

+ 120 - 83
api/config.py

@@ -26,6 +26,7 @@ DEFAULTS = {
     'SERVICE_API_URL': 'https://api.dify.ai',
     'APP_WEB_URL': 'https://udify.app',
     'APP_API_URL': 'https://udify.app',
+    'FILES_URL': '',
     'STORAGE_TYPE': 'local',
     'STORAGE_LOCAL_PATH': 'storage',
     'CHECK_UPDATE_URL': 'https://updates.dify.ai',
@@ -57,7 +58,9 @@ DEFAULTS = {
     'CLEAN_DAY_SETTING': 30,
     'UPLOAD_FILE_SIZE_LIMIT': 15,
     'UPLOAD_FILE_BATCH_LIMIT': 5,
-    'OUTPUT_MODERATION_BUFFER_SIZE': 300
+    'UPLOAD_IMAGE_FILE_SIZE_LIMIT': 10,
+    'OUTPUT_MODERATION_BUFFER_SIZE': 300,
+    'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64'
 }
 
 
@@ -84,15 +87,9 @@ class Config:
     """Application configuration class."""
 
     def __init__(self):
-        # app settings
-        self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
-        self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
-        self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
-        self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
-        self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
-        self.CONSOLE_URL = get_env('CONSOLE_URL')
-        self.API_URL = get_env('API_URL')
-        self.APP_URL = get_env('APP_URL')
+        # ------------------------
+        # General Configurations.
+        # ------------------------
         self.CURRENT_VERSION = "0.3.29"
         self.COMMIT_SHA = get_env('COMMIT_SHA')
         self.EDITION = "SELF_HOSTED"
@@ -100,13 +97,71 @@ class Config:
         self.TESTING = False
         self.LOG_LEVEL = get_env('LOG_LEVEL')
 
+        # The backend URL prefix of the console API.
+        # used to concatenate the login authorization callback or notion integration callback.
+        self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
+
+        # The front-end URL prefix of the console web.
+        # used to concatenate some front-end addresses and for CORS configuration use.
+        self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
+
+        # WebApp API backend Url prefix.
+        # used to declare the back-end URL for the front-end API.
+        self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
+
+        # WebApp Url prefix.
+        # used to display WebAPP API Base Url to the front-end.
+        self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
+
+        # Service API Url prefix.
+        # used to display Service API Base Url to the front-end.
+        self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
+
+        # File preview or download Url prefix.
+        # used to display File preview or download Url to the front-end or as Multi-model inputs;
+        # Url is signed and has expiration time.
+        self.FILES_URL = get_env('FILES_URL') if get_env('FILES_URL') else self.CONSOLE_API_URL
+
+        # Fallback Url prefix.
+        # Will be deprecated in the future.
+        self.CONSOLE_URL = get_env('CONSOLE_URL')
+        self.API_URL = get_env('API_URL')
+        self.APP_URL = get_env('APP_URL')
+
         # Your App secret key will be used for securely signing the session cookie
         # Make sure you are changing this key for your deployment with a strong key.
         # You can generate a strong key using `openssl rand -base64 42`.
         # Alternatively you can set it with `SECRET_KEY` environment variable.
         self.SECRET_KEY = get_env('SECRET_KEY')
 
-        # redis settings
+        # cors settings
+        self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
+            'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
+        self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
+            'WEB_API_CORS_ALLOW_ORIGINS', '*')
+
+        # check update url
+        self.CHECK_UPDATE_URL = get_env('CHECK_UPDATE_URL')
+
+        # ------------------------
+        # Database Configurations.
+        # ------------------------
+        db_credentials = {
+            key: get_env(key) for key in
+            ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE']
+        }
+
+        self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}"
+        self.SQLALCHEMY_ENGINE_OPTIONS = {
+            'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
+            'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE'))
+        }
+
+        self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
+
+        # ------------------------
+        # Redis Configurations.
+        # ------------------------
         self.REDIS_HOST = get_env('REDIS_HOST')
         self.REDIS_PORT = get_env('REDIS_PORT')
         self.REDIS_USERNAME = get_env('REDIS_USERNAME')
@@ -114,7 +169,18 @@ class Config:
         self.REDIS_DB = get_env('REDIS_DB')
         self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
 
-        # storage settings
+        # ------------------------
+        # Celery worker Configurations.
+        # ------------------------
+        self.CELERY_BROKER_URL = get_env('CELERY_BROKER_URL')
+        self.CELERY_BACKEND = get_env('CELERY_BACKEND')
+        self.CELERY_RESULT_BACKEND = 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
+            if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
+        self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
+
+        # ------------------------
+        # File Storage Configurations.
+        # ------------------------
         self.STORAGE_TYPE = get_env('STORAGE_TYPE')
         self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
         self.S3_ENDPOINT = get_env('S3_ENDPOINT')
@@ -123,68 +189,72 @@ class Config:
         self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
         self.S3_REGION = get_env('S3_REGION')
 
-        # vector store settings, only support weaviate, qdrant
+        # ------------------------
+        # Vector Store Configurations.
+        # Currently, only support: qdrant, milvus, zilliz, weaviate
+        # ------------------------
         self.VECTOR_STORE = get_env('VECTOR_STORE')
 
-        # weaviate settings
-        self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
-        self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
-        self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
-        self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
-
         # qdrant settings
         self.QDRANT_URL = get_env('QDRANT_URL')
         self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
 
-        # milvus setting
+        # milvus / zilliz setting
         self.MILVUS_HOST = get_env('MILVUS_HOST')
         self.MILVUS_PORT = get_env('MILVUS_PORT')
         self.MILVUS_USER = get_env('MILVUS_USER')
         self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
         self.MILVUS_SECURE = get_env('MILVUS_SECURE')
 
+        # weaviate settings
+        self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
+        self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
+        self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
+        self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
 
-        # cors settings
-        self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
-            'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
-        self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
-            'WEB_API_CORS_ALLOW_ORIGINS', '*')
-
-        # mail settings
+        # ------------------------
+        # Mail Configurations.
+        # ------------------------
         self.MAIL_TYPE = get_env('MAIL_TYPE')
         self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
         self.RESEND_API_KEY = get_env('RESEND_API_KEY')
 
-        # sentry settings
+        # ------------------------
+        # Sentry Configurations.
+        # ------------------------
         self.SENTRY_DSN = get_env('SENTRY_DSN')
         self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
         self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE'))
 
-        # check update url
-        self.CHECK_UPDATE_URL = get_env('CHECK_UPDATE_URL')
+        # ------------------------
+        # Business Configurations.
+        # ------------------------
 
-        # database settings
-        db_credentials = {
-            key: get_env(key) for key in
-            ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE']
-        }
+        # multi model send image format, support base64, url, default is base64
+        self.MULTIMODAL_SEND_IMAGE_FORMAT = get_env('MULTIMODAL_SEND_IMAGE_FORMAT')
 
-        self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/{db_credentials['DB_DATABASE']}"
-        self.SQLALCHEMY_ENGINE_OPTIONS = {
-            'pool_size': int(get_env('SQLALCHEMY_POOL_SIZE')),
-            'pool_recycle': int(get_env('SQLALCHEMY_POOL_RECYCLE'))
-        }
+        # Dataset Configurations.
+        self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
+        self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
 
-        self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
+        # File upload Configurations.
+        self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
+        self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
+        self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT'))
 
-        # celery settings
-        self.CELERY_BROKER_URL = get_env('CELERY_BROKER_URL')
-        self.CELERY_BACKEND = get_env('CELERY_BACKEND')
-        self.CELERY_RESULT_BACKEND = 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
-            if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
-        self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
+        # Moderation in app Configurations.
+        self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
+
+        # Notion integration setting
+        self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
+        self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
+        self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
+        self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
+        self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
 
-        # hosted provider credentials
+        # ------------------------
+        # Platform Configurations.
+        # ------------------------
         self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
         self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
         self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
@@ -212,26 +282,6 @@ class Config:
         self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
         self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
 
-        self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
-        self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
-
-        # notion import setting
-        self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
-        self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
-        self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
-        self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
-        self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
-
-        self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
-        self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
-
-        # uploading settings
-        self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
-        self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
-
-        # moderation settings
-        self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
-
 
 class CloudEditionConfig(Config):
 
@@ -246,18 +296,5 @@ class CloudEditionConfig(Config):
         self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
         self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
 
-
-class TestConfig(Config):
-
-    def __init__(self):
-        super().__init__()
-
-        self.EDITION = "SELF_HOSTED"
-        self.TESTING = True
-
-        db_credentials = {
-            key: get_env(key) for key in ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT']
-        }
-
-        # use a different database for testing: dify_test
-        self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/dify_test"
+        self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
+        self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')

+ 4 - 0
api/controllers/console/app/completion.py

@@ -40,12 +40,14 @@ class CompletionMessageApi(Resource):
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, location='json', default='')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
         args = parser.parse_args()
 
         streaming = args['response_mode'] != 'blocking'
+        args['auto_generate_name'] = False
 
         account = flask_login.current_user
 
@@ -113,6 +115,7 @@ class ChatMessageApi(Resource):
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('model_config', type=dict, required=True, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
@@ -120,6 +123,7 @@ class ChatMessageApi(Resource):
         args = parser.parse_args()
 
         streaming = args['response_mode'] != 'blocking'
+        args['auto_generate_name'] = False
 
         account = flask_login.current_user
 

+ 2 - 4
api/controllers/console/app/conversation.py

@@ -108,7 +108,7 @@ class CompletionConversationDetailApi(Resource):
         conversation_id = str(conversation_id)
 
         return _get_conversation(app_id, conversation_id, 'completion')
-    
+
     @setup_required
     @login_required
     @account_initialization_required
@@ -230,7 +230,7 @@ class ChatConversationDetailApi(Resource):
         conversation_id = str(conversation_id)
 
         return _get_conversation(app_id, conversation_id, 'chat')
-    
+
     @setup_required
     @login_required
     @account_initialization_required
@@ -253,8 +253,6 @@ class ChatConversationDetailApi(Resource):
         return {'result': 'success'}, 204
 
 
-
-
 api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations')
 api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>')
 api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')

+ 0 - 3
api/controllers/console/datasets/data_source.py

@@ -1,7 +1,6 @@
 import datetime
 import json
 
-from cachetools import TTLCache
 from flask import request
 from flask_login import current_user
 from libs.login import login_required
@@ -20,8 +19,6 @@ from models.source import DataSourceBinding
 from services.dataset_service import DatasetService, DocumentService
 from tasks.document_indexing_sync_task import document_indexing_sync_task
 
-cache = TTLCache(maxsize=None, ttl=30)
-
 
 class DataSourceApi(Resource):
 

+ 5 - 6
api/controllers/console/datasets/file.py

@@ -1,5 +1,5 @@
-from cachetools import TTLCache
 from flask import request, current_app
+from flask_login import current_user
 
 import services
 from libs.login import login_required
@@ -15,9 +15,6 @@ from fields.file_fields import upload_config_fields, file_fields
 
 from services.file_service import FileService
 
-cache = TTLCache(maxsize=None, ttl=30)
-
-ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
 PREVIEW_WORDS_LIMIT = 3000
 
 
@@ -30,9 +27,11 @@ class FileApi(Resource):
     def get(self):
         file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT")
         batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT")
+        image_file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT")
         return {
             'file_size_limit': file_size_limit,
-            'batch_count_limit': batch_count_limit
+            'batch_count_limit': batch_count_limit,
+            'image_file_size_limit': image_file_size_limit
         }, 200
 
     @setup_required
@@ -51,7 +50,7 @@ class FileApi(Resource):
         if len(request.files) > 1:
             raise TooManyFilesError()
         try:
-            upload_file = FileService.upload_file(file)
+            upload_file = FileService.upload_file(file, current_user)
         except services.errors.file.FileTooLargeError as file_too_large_error:
             raise FileTooLargeError(file_too_large_error.description)
         except services.errors.file.UnsupportedFileTypeError:

+ 12 - 0
api/controllers/console/explore/completion.py

@@ -1,6 +1,7 @@
 # -*- coding:utf-8 -*-
 import json
 import logging
+from datetime import datetime
 from typing import Generator, Union
 
 from flask import Response, stream_with_context
@@ -17,6 +18,7 @@ from controllers.console.explore.wraps import InstalledAppResource
 from core.conversation_message_task import PubHandler
 from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from extensions.ext_database import db
 from libs.helper import uuid_value
 from services.completion_service import CompletionService
 
@@ -32,11 +34,16 @@ class CompletionApi(InstalledAppResource):
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, location='json', default='')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
         args = parser.parse_args()
 
         streaming = args['response_mode'] == 'streaming'
+        args['auto_generate_name'] = False
+
+        installed_app.last_used_at = datetime.utcnow()
+        db.session.commit()
 
         try:
             response = CompletionService.completion(
@@ -91,12 +98,17 @@ class ChatApi(InstalledAppResource):
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
         args = parser.parse_args()
 
         streaming = args['response_mode'] == 'streaming'
+        args['auto_generate_name'] = False
+
+        installed_app.last_used_at = datetime.utcnow()
+        db.session.commit()
 
         try:
             response = CompletionService.completion(

+ 11 - 3
api/controllers/console/explore/conversation.py

@@ -38,7 +38,8 @@ class ConversationListApi(InstalledAppResource):
                 user=current_user,
                 last_id=args['last_id'],
                 limit=args['limit'],
-                pinned=pinned
+                pinned=pinned,
+                exclude_debug_conversation=True
             )
         except LastConversationNotExistsError:
             raise NotFound("Last Conversation Not Exists.")
@@ -71,11 +72,18 @@ class ConversationRenameApi(InstalledAppResource):
         conversation_id = str(c_id)
 
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, location='json')
+        parser.add_argument('name', type=str, required=False, location='json')
+        parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
         args = parser.parse_args()
 
         try:
-            return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
+            return ConversationService.rename(
+                app_model,
+                conversation_id,
+                current_user,
+                args['name'],
+                args['auto_generate']
+            )
         except ConversationNotExistsError:
             raise NotFound("Conversation Not Exists.")
 

+ 3 - 2
api/controllers/console/explore/installed_app.py

@@ -39,8 +39,9 @@ class InstalledAppsListApi(Resource):
             }
             for installed_app in installed_apps
         ]
-        installed_apps.sort(key=lambda app: (-app['is_pinned'], app['last_used_at']
-                            if app['last_used_at'] is not None else datetime.min))
+        installed_apps.sort(key=lambda app: (-app['is_pinned'],
+                                             app['last_used_at'] is None,
+                                             -app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0))
 
         return {'installed_apps': installed_apps}
 

+ 13 - 2
api/controllers/console/explore/parameter.py

@@ -1,5 +1,6 @@
 # -*- coding:utf-8 -*-
 from flask_restful import marshal_with, fields
+from flask import current_app
 
 from controllers.console import api
 from controllers.console.explore.wraps import InstalledAppResource
@@ -19,6 +20,10 @@ class AppParameterApi(InstalledAppResource):
         'options': fields.List(fields.String)
     }
 
+    system_parameters_fields = {
+        'image_file_size_limit': fields.String
+    }
+
     parameters_fields = {
         'opening_statement': fields.String,
         'suggested_questions': fields.Raw,
@@ -27,7 +32,9 @@ class AppParameterApi(InstalledAppResource):
         'retriever_resource': fields.Raw,
         'more_like_this': fields.Raw,
         'user_input_form': fields.Raw,
-        'sensitive_word_avoidance': fields.Raw
+        'sensitive_word_avoidance': fields.Raw,
+        'file_upload': fields.Raw,
+        'system_parameters': fields.Nested(system_parameters_fields)
     }
 
     @marshal_with(parameters_fields)
@@ -44,7 +51,11 @@ class AppParameterApi(InstalledAppResource):
             'retriever_resource': app_model_config.retriever_resource_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list,
-            'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
+            'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
+            'file_upload': app_model_config.file_upload_dict,
+            'system_parameters': {
+                'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
+            }
         }
 
 

+ 2 - 0
api/controllers/console/explore/saved_message.py

@@ -9,6 +9,7 @@ from controllers.console.explore.wraps import InstalledAppResource
 from libs.helper import uuid_value, TimestampField
 from services.errors.message import MessageNotExistsError
 from services.saved_message_service import SavedMessageService
+from fields.conversation_fields import message_file_fields
 
 feedback_fields = {
     'rating': fields.String
@@ -19,6 +20,7 @@ message_fields = {
     'inputs': fields.Raw,
     'query': fields.String,
     'answer': fields.String,
+    'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
     'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
     'created_at': TimestampField
 }

+ 3 - 0
api/controllers/console/universal_chat/chat.py

@@ -25,6 +25,7 @@ class UniversalChatApi(UniversalChatResource):
 
         parser = reqparse.RequestParser()
         parser.add_argument('query', type=str, required=True, location='json')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('provider', type=str, required=True, location='json')
         parser.add_argument('model', type=str, required=True, location='json')
@@ -60,6 +61,8 @@ class UniversalChatApi(UniversalChatResource):
         del args['model']
         del args['tools']
 
+        args['auto_generate_name'] = False
+
         try:
             response = CompletionService.completion(
                 app_model=app_model,

+ 9 - 2
api/controllers/console/universal_chat/conversation.py

@@ -65,11 +65,18 @@ class UniversalChatConversationRenameApi(UniversalChatResource):
         conversation_id = str(c_id)
 
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, location='json')
+        parser.add_argument('name', type=str, required=False, location='json')
+        parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
         args = parser.parse_args()
 
         try:
-            return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
+            return ConversationService.rename(
+                app_model,
+                conversation_id,
+                current_user,
+                args['name'],
+                args['auto_generate']
+            )
         except ConversationNotExistsError:
             raise NotFound("Conversation Not Exists.")
 

+ 10 - 0
api/controllers/files/__init__.py

@@ -0,0 +1,10 @@
+# -*- coding:utf-8 -*-
+from flask import Blueprint
+
+from libs.external_api import ExternalApi
+
+bp = Blueprint('files', __name__)
+api = ExternalApi(bp)
+
+
+from . import image_preview

+ 40 - 0
api/controllers/files/image_preview.py

@@ -0,0 +1,40 @@
+from flask import request, Response
+from flask_restful import Resource
+
+import services
+from controllers.files import api
+from libs.exception import BaseHTTPException
+from services.file_service import FileService
+
+
+class ImagePreviewApi(Resource):
+    def get(self, file_id):
+        file_id = str(file_id)
+
+        timestamp = request.args.get('timestamp')
+        nonce = request.args.get('nonce')
+        sign = request.args.get('sign')
+
+        if not timestamp or not nonce or not sign:
+            return {'content': 'Invalid request.'}, 400
+
+        try:
+            generator, mimetype = FileService.get_image_preview(
+                file_id,
+                timestamp,
+                nonce,
+                sign
+            )
+        except services.errors.file.UnsupportedFileTypeError:
+            raise UnsupportedFileTypeError()
+
+        return Response(generator, mimetype=mimetype)
+
+
+api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview')
+
+
+class UnsupportedFileTypeError(BaseHTTPException):
+    error_code = 'unsupported_file_type'
+    description = "File type not allowed."
+    code = 415

+ 1 - 1
api/controllers/service_api/__init__.py

@@ -7,6 +7,6 @@ bp = Blueprint('service_api', __name__, url_prefix='/v1')
 api = ExternalApi(bp)
 
 
-from .app import completion, app, conversation, message, audio
+from .app import completion, app, conversation, message, audio, file
 
 from .dataset import document, segment, dataset

+ 13 - 2
api/controllers/service_api/app/app.py

@@ -1,5 +1,6 @@
 # -*- coding:utf-8 -*-
 from flask_restful import fields, marshal_with
+from flask import current_app
 
 from controllers.service_api import api
 from controllers.service_api.wraps import AppApiResource
@@ -20,6 +21,10 @@ class AppParameterApi(AppApiResource):
         'options': fields.List(fields.String)
     }
 
+    system_parameters_fields = {
+        'image_file_size_limit': fields.String
+    }
+
     parameters_fields = {
         'opening_statement': fields.String,
         'suggested_questions': fields.Raw,
@@ -28,7 +33,9 @@ class AppParameterApi(AppApiResource):
         'retriever_resource': fields.Raw,
         'more_like_this': fields.Raw,
         'user_input_form': fields.Raw,
-        'sensitive_word_avoidance': fields.Raw
+        'sensitive_word_avoidance': fields.Raw,
+        'file_upload': fields.Raw,
+        'system_parameters': fields.Nested(system_parameters_fields)
     }
 
     @marshal_with(parameters_fields)
@@ -44,7 +51,11 @@ class AppParameterApi(AppApiResource):
             'retriever_resource': app_model_config.retriever_resource_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list,
-            'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
+            'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
+            'file_upload': app_model_config.file_upload_dict,
+            'system_parameters': {
+                'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
+            }
         }
 
 

+ 6 - 1
api/controllers/service_api/app/completion.py

@@ -28,6 +28,7 @@ class CompletionApi(AppApiResource):
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, location='json', default='')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('user', type=str, location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
@@ -39,13 +40,15 @@ class CompletionApi(AppApiResource):
         if end_user is None and args['user'] is not None:
             end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
 
+        args['auto_generate_name'] = False
+
         try:
             response = CompletionService.completion(
                 app_model=app_model,
                 user=end_user,
                 args=args,
                 from_source='api',
-                streaming=streaming
+                streaming=streaming,
             )
 
             return compact_response(response)
@@ -90,10 +93,12 @@ class ChatApi(AppApiResource):
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('user', type=str, location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
+        parser.add_argument('auto_generate_name', type=bool, required=False, default='True', location='json')
 
         args = parser.parse_args()
 

+ 9 - 2
api/controllers/service_api/app/conversation.py

@@ -65,15 +65,22 @@ class ConversationRenameApi(AppApiResource):
         conversation_id = str(c_id)
 
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, location='json')
+        parser.add_argument('name', type=str, required=False, location='json')
         parser.add_argument('user', type=str, location='json')
+        parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
         args = parser.parse_args()
 
         if end_user is None and args['user'] is not None:
             end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
 
         try:
-            return ConversationService.rename(app_model, conversation_id, end_user, args['name'])
+            return ConversationService.rename(
+                app_model,
+                conversation_id,
+                end_user,
+                args['name'],
+                args['auto_generate']
+            )
         except services.errors.conversation.ConversationNotExistsError:
             raise NotFound("Conversation Not Exists.")
 

+ 23 - 0
api/controllers/service_api/app/error.py

@@ -75,3 +75,26 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
     description = "Provider not support speech to text."
     code = 400
 
+
+class NoFileUploadedError(BaseHTTPException):
+    error_code = 'no_file_uploaded'
+    description = "Please upload your file."
+    code = 400
+
+
+class TooManyFilesError(BaseHTTPException):
+    error_code = 'too_many_files'
+    description = "Only one file is allowed."
+    code = 400
+
+
+class FileTooLargeError(BaseHTTPException):
+    error_code = 'file_too_large'
+    description = "File size exceeded. {message}"
+    code = 413
+
+
+class UnsupportedFileTypeError(BaseHTTPException):
+    error_code = 'unsupported_file_type'
+    description = "File type not allowed."
+    code = 415

+ 42 - 0
api/controllers/service_api/app/file.py

@@ -0,0 +1,42 @@
+from flask import request
+from flask_restful import marshal_with
+
+from controllers.service_api import api
+from controllers.service_api.wraps import AppApiResource
+from controllers.service_api.app import create_or_update_end_user_for_user_id
+from controllers.service_api.app.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
+    UnsupportedFileTypeError
+import services
+from services.file_service import FileService
+from fields.file_fields import file_fields
+
+
+class FileApi(AppApiResource):
+
+    @marshal_with(file_fields)
+    def post(self, app_model, end_user):
+
+        file = request.files['file']
+        user_args = request.form.get('user')
+
+        if end_user is None and user_args is not None:
+            end_user = create_or_update_end_user_for_user_id(app_model, user_args)
+
+        # check file
+        if 'file' not in request.files:
+            raise NoFileUploadedError()
+
+        if len(request.files) > 1:
+            raise TooManyFilesError()
+
+        try:
+            upload_file = FileService.upload_file(file, end_user)
+        except services.errors.file.FileTooLargeError as file_too_large_error:
+            raise FileTooLargeError(file_too_large_error.description)
+        except services.errors.file.UnsupportedFileTypeError:
+            raise UnsupportedFileTypeError()
+
+        return upload_file, 201
+
+
+api.add_resource(FileApi, '/files/upload')

+ 2 - 1
api/controllers/service_api/app/message.py

@@ -12,7 +12,7 @@ from libs.helper import TimestampField, uuid_value
 from services.message_service import MessageService
 from extensions.ext_database import db
 from models.model import Message, EndUser
-
+from fields.conversation_fields import message_file_fields
 
 class MessageListApi(AppApiResource):
     feedback_fields = {
@@ -43,6 +43,7 @@ class MessageListApi(AppApiResource):
         'inputs': fields.Raw,
         'query': fields.String,
         'answer': fields.String,
+        'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
         'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
         'created_at': TimestampField

+ 3 - 2
api/controllers/service_api/dataset/document.py

@@ -2,6 +2,7 @@ import json
 
 from flask import request
 from flask_restful import reqparse, marshal
+from flask_login import current_user
 from sqlalchemy import desc
 from werkzeug.exceptions import NotFound
 
@@ -173,7 +174,7 @@ class DocumentAddByFileApi(DatasetApiResource):
         if len(request.files) > 1:
             raise TooManyFilesError()
 
-        upload_file = FileService.upload_file(file)
+        upload_file = FileService.upload_file(file, current_user)
         data_source = {
             'type': 'upload_file',
             'info_list': {
@@ -235,7 +236,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
             if len(request.files) > 1:
                 raise TooManyFilesError()
 
-            upload_file = FileService.upload_file(file)
+            upload_file = FileService.upload_file(file, current_user)
             data_source = {
                 'type': 'upload_file',
                 'info_list': {

+ 1 - 1
api/controllers/web/__init__.py

@@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
 api = ExternalApi(bp)
 
 
-from . import completion, app, conversation, message, site, saved_message, audio, passport
+from . import completion, app, conversation, message, site, saved_message, audio, passport, file

+ 13 - 2
api/controllers/web/app.py

@@ -1,5 +1,6 @@
 # -*- coding:utf-8 -*-
 from flask_restful import marshal_with, fields
+from flask import current_app
 
 from controllers.web import api
 from controllers.web.wraps import WebApiResource
@@ -19,6 +20,10 @@ class AppParameterApi(WebApiResource):
         'options': fields.List(fields.String)
     }
 
+    system_parameters_fields = {
+        'image_file_size_limit': fields.String
+    }
+
     parameters_fields = {
         'opening_statement': fields.String,
         'suggested_questions': fields.Raw,
@@ -27,7 +32,9 @@ class AppParameterApi(WebApiResource):
         'retriever_resource': fields.Raw,
         'more_like_this': fields.Raw,
         'user_input_form': fields.Raw,
-        'sensitive_word_avoidance': fields.Raw
+        'sensitive_word_avoidance': fields.Raw,
+        'file_upload': fields.Raw,
+        'system_parameters': fields.Nested(system_parameters_fields)
     }
 
     @marshal_with(parameters_fields)
@@ -43,7 +50,11 @@ class AppParameterApi(WebApiResource):
             'retriever_resource': app_model_config.retriever_resource_dict,
             'more_like_this': app_model_config.more_like_this_dict,
             'user_input_form': app_model_config.user_input_form_list,
-            'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
+            'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
+            'file_upload': app_model_config.file_upload_dict,
+            'system_parameters': {
+                'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
+            }
         }
 
 

+ 4 - 0
api/controllers/web/completion.py

@@ -30,12 +30,14 @@ class CompletionApi(WebApiResource):
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, location='json', default='')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
 
         args = parser.parse_args()
 
         streaming = args['response_mode'] == 'streaming'
+        args['auto_generate_name'] = False
 
         try:
             response = CompletionService.completion(
@@ -88,6 +90,7 @@ class ChatApi(WebApiResource):
         parser = reqparse.RequestParser()
         parser.add_argument('inputs', type=dict, required=True, location='json')
         parser.add_argument('query', type=str, required=True, location='json')
+        parser.add_argument('files', type=list, required=False, location='json')
         parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
         parser.add_argument('conversation_id', type=uuid_value, location='json')
         parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
@@ -95,6 +98,7 @@ class ChatApi(WebApiResource):
         args = parser.parse_args()
 
         streaming = args['response_mode'] == 'streaming'
+        args['auto_generate_name'] = False
 
         try:
             response = CompletionService.completion(

+ 9 - 2
api/controllers/web/conversation.py

@@ -67,11 +67,18 @@ class ConversationRenameApi(WebApiResource):
         conversation_id = str(c_id)
 
         parser = reqparse.RequestParser()
-        parser.add_argument('name', type=str, required=True, location='json')
+        parser.add_argument('name', type=str, required=False, location='json')
+        parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
         args = parser.parse_args()
 
         try:
-            return ConversationService.rename(app_model, conversation_id, end_user, args['name'])
+            return ConversationService.rename(
+                app_model,
+                conversation_id,
+                end_user,
+                args['name'],
+                args['auto_generate']
+            )
         except ConversationNotExistsError:
             raise NotFound("Conversation Not Exists.")
 

+ 25 - 1
api/controllers/web/error.py

@@ -85,4 +85,28 @@ class UnsupportedAudioTypeError(BaseHTTPException):
 class ProviderNotSupportSpeechToTextError(BaseHTTPException):
     error_code = 'provider_not_support_speech_to_text'
     description = "Provider not support speech to text."
-    code = 400
+    code = 400
+
+
+class NoFileUploadedError(BaseHTTPException):
+    error_code = 'no_file_uploaded'
+    description = "Please upload your file."
+    code = 400
+
+
+class TooManyFilesError(BaseHTTPException):
+    error_code = 'too_many_files'
+    description = "Only one file is allowed."
+    code = 400
+
+
+class FileTooLargeError(BaseHTTPException):
+    error_code = 'file_too_large'
+    description = "File size exceeded. {message}"
+    code = 413
+
+
+class UnsupportedFileTypeError(BaseHTTPException):
+    error_code = 'unsupported_file_type'
+    description = "File type not allowed."
+    code = 415

+ 36 - 0
api/controllers/web/file.py

@@ -0,0 +1,36 @@
+from flask import request
+from flask_restful import marshal_with
+
+from controllers.web import api
+from controllers.web.wraps import WebApiResource
+from controllers.web.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
+    UnsupportedFileTypeError
+import services
+from services.file_service import FileService
+from fields.file_fields import file_fields
+
+
+class FileApi(WebApiResource):
+
+    @marshal_with(file_fields)
+    def post(self, app_model, end_user):
+        # get file from request
+        file = request.files['file']
+
+        # check file
+        if 'file' not in request.files:
+            raise NoFileUploadedError()
+
+        if len(request.files) > 1:
+            raise TooManyFilesError()
+        try:
+            upload_file = FileService.upload_file(file, end_user)
+        except services.errors.file.FileTooLargeError as file_too_large_error:
+            raise FileTooLargeError(file_too_large_error.description)
+        except services.errors.file.UnsupportedFileTypeError:
+            raise UnsupportedFileTypeError()
+
+        return upload_file, 201
+
+
+api.add_resource(FileApi, '/files/upload')

+ 2 - 0
api/controllers/web/message.py

@@ -22,6 +22,7 @@ from services.errors.app import MoreLikeThisDisabledError
 from services.errors.conversation import ConversationNotExistsError
 from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
 from services.message_service import MessageService
+from fields.conversation_fields import message_file_fields
 
 
 class MessageListApi(WebApiResource):
@@ -54,6 +55,7 @@ class MessageListApi(WebApiResource):
         'inputs': fields.Raw,
         'query': fields.String,
         'answer': fields.String,
+        'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
         'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
         'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
         'created_at': TimestampField

+ 3 - 0
api/controllers/web/saved_message.py

@@ -8,6 +8,8 @@ from controllers.web.wraps import WebApiResource
 from libs.helper import uuid_value, TimestampField
 from services.errors.message import MessageNotExistsError
 from services.saved_message_service import SavedMessageService
+from fields.conversation_fields import message_file_fields
+
 
 feedback_fields = {
     'rating': fields.String
@@ -18,6 +20,7 @@ message_fields = {
     'inputs': fields.Raw,
     'query': fields.String,
     'answer': fields.String,
+    'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
     'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
     'created_at': TimestampField
 }

+ 8 - 2
api/core/callback_handler/llm_callback_handler.py

@@ -11,7 +11,8 @@ from pydantic import BaseModel
 from core.callback_handler.entity.llm_message import LLMMessage
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
     ConversationTaskInterruptException
-from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
+from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
+    ImagePromptMessageFile
 from core.model_providers.models.llm.base import BaseLLM
 from core.moderation.base import ModerationOutputsResult, ModerationAction
 from core.moderation.factory import ModerationFactory
@@ -72,7 +73,12 @@ class LLMCallbackHandler(BaseCallbackHandler):
 
             real_prompts.append({
                 "role": role,
-                "text": message.content
+                "text": message.content,
+                "files": [{
+                    "type": file.type.value,
+                    "data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
+                    "detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
+                } for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
             })
 
         self.llm_message.prompt = real_prompts

+ 24 - 9
api/core/completion.py

@@ -13,11 +13,12 @@ from core.callback_handler.llm_callback_handler import LLMCallbackHandler
 from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
     ConversationTaskInterruptException
 from core.external_data_tool.factory import ExternalDataToolFactory
+from core.file.file_obj import FileObj
 from core.model_providers.error import LLMBadRequestError
 from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
     ReadOnlyConversationTokenDBBufferSharedMemory
 from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.entity.message import PromptMessage
+from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
 from core.model_providers.models.llm.base import BaseLLM
 from core.orchestrator_rule_parser import OrchestratorRuleParser
 from core.prompt.prompt_template import PromptTemplateParser
@@ -30,8 +31,9 @@ from core.moderation.factory import ModerationFactory
 class Completion:
     @classmethod
     def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
-                 user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
-                 is_override: bool = False, retriever_from: str = 'dev'):
+                 files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
+                 streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
+                 auto_generate_name: bool = True):
         """
         errors: ProviderTokenNotInitError
         """
@@ -64,16 +66,21 @@ class Completion:
             is_override=is_override,
             inputs=inputs,
             query=query,
+            files=files,
             streaming=streaming,
-            model_instance=final_model_instance
+            model_instance=final_model_instance,
+            auto_generate_name=auto_generate_name
         )
 
+        prompt_message_files = [file.prompt_message_file for file in files]
+
         rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
             mode=app.mode,
             model_instance=final_model_instance,
             app_model_config=app_model_config,
             query=query,
-            inputs=inputs
+            inputs=inputs,
+            files=prompt_message_files
         )
 
         # init orchestrator rule parser
@@ -95,6 +102,7 @@ class Completion:
                     app_model_config=app_model_config,
                     query=query,
                     inputs=inputs,
+                    files=prompt_message_files,
                     agent_execute_result=None,
                     conversation_message_task=conversation_message_task,
                     memory=memory,
@@ -146,6 +154,7 @@ class Completion:
                 app_model_config=app_model_config,
                 query=query,
                 inputs=inputs,
+                files=prompt_message_files,
                 agent_execute_result=agent_execute_result,
                 conversation_message_task=conversation_message_task,
                 memory=memory,
@@ -257,6 +266,7 @@ class Completion:
     @classmethod
     def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
                       inputs: dict,
+                      files: List[PromptMessageFile],
                       agent_execute_result: Optional[AgentExecuteResult],
                       conversation_message_task: ConversationMessageTask,
                       memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
@@ -266,10 +276,12 @@ class Completion:
         # get llm prompt
         if app_model_config.prompt_type == 'simple':
             prompt_messages, stop_words = prompt_transform.get_prompt(
-                mode=mode,
+                app_mode=mode,
+                app_model_config=app_model_config,
                 pre_prompt=app_model_config.pre_prompt,
                 inputs=inputs,
                 query=query,
+                files=files,
                 context=agent_execute_result.output if agent_execute_result else None,
                 memory=memory,
                 model_instance=model_instance
@@ -280,6 +292,7 @@ class Completion:
                 app_model_config=app_model_config,
                 inputs=inputs,
                 query=query,
+                files=files,
                 context=agent_execute_result.output if agent_execute_result else None,
                 memory=memory,
                 model_instance=model_instance
@@ -337,7 +350,7 @@ class Completion:
 
     @classmethod
     def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
-                                 query: str, inputs: dict) -> int:
+                                 query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
         model_limited_tokens = model_instance.model_rules.max_tokens.max
         max_tokens = model_instance.get_model_kwargs().max_tokens
 
@@ -348,15 +361,16 @@ class Completion:
             max_tokens = 0
 
         prompt_transform = PromptTransform()
-        prompt_messages = []
 
         # get prompt without memory and context
         if app_model_config.prompt_type == 'simple':
             prompt_messages, _ = prompt_transform.get_prompt(
-                mode=mode,
+                app_mode=mode,
+                app_model_config=app_model_config,
                 pre_prompt=app_model_config.pre_prompt,
                 inputs=inputs,
                 query=query,
+                files=files,
                 context=None,
                 memory=None,
                 model_instance=model_instance
@@ -367,6 +381,7 @@ class Completion:
                 app_model_config=app_model_config,
                 inputs=inputs,
                 query=query,
+                files=files,
                 context=None,
                 memory=None,
                 model_instance=model_instance

+ 24 - 6
api/core/conversation_message_task.py

@@ -6,8 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop
 from core.callback_handler.entity.dataset_query import DatasetQueryObj
 from core.callback_handler.entity.llm_message import LLMMessage
 from core.callback_handler.entity.chain_result import ChainResult
+from core.file.file_obj import FileObj
 from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.entity.message import to_prompt_messages, MessageType
+from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile
 from core.model_providers.models.llm.base import BaseLLM
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_template import PromptTemplateParser
@@ -16,13 +17,14 @@ from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from models.dataset import DatasetQuery
 from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
-    MessageChain, DatasetRetrieverResource
+    MessageChain, DatasetRetrieverResource, MessageFile
 
 
 class ConversationMessageTask:
     def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
-                 inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
-                 conversation: Optional[Conversation] = None, is_override: bool = False):
+                 inputs: dict, query: str, files: List[FileObj], streaming: bool,
+                 model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False,
+                 auto_generate_name: bool = True):
         self.start_at = time.perf_counter()
 
         self.task_id = task_id
@@ -35,6 +37,7 @@ class ConversationMessageTask:
         self.user = user
         self.inputs = inputs
         self.query = query
+        self.files = files
         self.streaming = streaming
 
         self.conversation = conversation
@@ -45,6 +48,7 @@ class ConversationMessageTask:
         self.message = None
 
         self.retriever_resource = None
+        self.auto_generate_name = auto_generate_name
 
         self.model_dict = self.app_model_config.model_dict
         self.provider_name = self.model_dict.get('provider')
@@ -100,7 +104,7 @@ class ConversationMessageTask:
                 model_id=self.model_name,
                 override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
                 mode=self.mode,
-                name='',
+                name='New conversation',
                 inputs=self.inputs,
                 introduction=introduction,
                 system_instruction=system_instruction,
@@ -142,6 +146,19 @@ class ConversationMessageTask:
         db.session.add(self.message)
         db.session.commit()
 
+        for file in self.files:
+            message_file = MessageFile(
+                message_id=self.message.id,
+                type=file.type.value,
+                transfer_method=file.transfer_method.value,
+                url=file.url,
+                upload_file_id=file.upload_file_id,
+                created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
+                created_by=self.user.id
+            )
+            db.session.add(message_file)
+            db.session.commit()
+
     def append_message_text(self, text: str):
         if text is not None:
             self._pub_handler.pub_text(text)
@@ -176,7 +193,8 @@ class ConversationMessageTask:
         message_was_created.send(
             self.message,
             conversation=self.conversation,
-            is_first_message=self.is_new_conversation
+            is_first_message=self.is_new_conversation,
+            auto_generate_name=self.auto_generate_name
         )
 
         if not by_stopped:

+ 0 - 0
api/core/file/__init__.py


+ 79 - 0
api/core/file/file_obj.py

@@ -0,0 +1,79 @@
+import enum
+from typing import Optional
+
+from pydantic import BaseModel
+
+from core.file.upload_file_parser import UploadFileParser
+from core.model_providers.models.entity.message import PromptMessageFile, ImagePromptMessageFile
+from extensions.ext_database import db
+from models.model import UploadFile
+
+
+class FileType(enum.Enum):
+    IMAGE = 'image'
+
+    @staticmethod
+    def value_of(value):
+        for member in FileType:
+            if member.value == value:
+                return member
+        raise ValueError(f"No matching enum found for value '{value}'")
+
+
+class FileTransferMethod(enum.Enum):
+    REMOTE_URL = 'remote_url'
+    LOCAL_FILE = 'local_file'
+
+    @staticmethod
+    def value_of(value):
+        for member in FileTransferMethod:
+            if member.value == value:
+                return member
+        raise ValueError(f"No matching enum found for value '{value}'")
+
+
+class FileObj(BaseModel):
+    id: Optional[str]
+    tenant_id: str
+    type: FileType
+    transfer_method: FileTransferMethod
+    url: Optional[str]
+    upload_file_id: Optional[str]
+    file_config: dict
+
+    @property
+    def data(self) -> Optional[str]:
+        return self._get_data()
+
+    @property
+    def preview_url(self) -> Optional[str]:
+        return self._get_data(force_url=True)
+
+    @property
+    def prompt_message_file(self) -> PromptMessageFile:
+        if self.type == FileType.IMAGE:
+            image_config = self.file_config.get('image')
+
+            return ImagePromptMessageFile(
+                data=self.data,
+                detail=ImagePromptMessageFile.DETAIL.HIGH
+                if image_config.get("detail") == "high" else ImagePromptMessageFile.DETAIL.LOW
+            )
+
+    def _get_data(self, force_url: bool = False) -> Optional[str]:
+        if self.type == FileType.IMAGE:
+            if self.transfer_method == FileTransferMethod.REMOTE_URL:
+                return self.url
+            elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
+                upload_file = (db.session.query(UploadFile)
+                               .filter(
+                    UploadFile.id == self.upload_file_id,
+                    UploadFile.tenant_id == self.tenant_id
+                ).first())
+
+                return UploadFileParser.get_image_data(
+                    upload_file=upload_file,
+                    force_url=force_url
+                )
+
+        return None

+ 180 - 0
api/core/file/message_file_parser.py

@@ -0,0 +1,180 @@
+from typing import List, Union, Optional, Dict
+
+import requests
+
+from core.file.file_obj import FileObj, FileType, FileTransferMethod
+from core.file.upload_file_parser import SUPPORT_EXTENSIONS
+from extensions.ext_database import db
+from models.account import Account
+from models.model import MessageFile, EndUser, AppModelConfig, UploadFile
+
+
+class MessageFileParser:
+
+    def __init__(self, tenant_id: str, app_id: str) -> None:
+        self.tenant_id = tenant_id
+        self.app_id = app_id
+
+    def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig,
+                                         user: Union[Account, EndUser]) -> List[FileObj]:
+        """
+        validate and transform files arg
+
+        :param files:
+        :param app_model_config:
+        :param user:
+        :return:
+        """
+        file_upload_config = app_model_config.file_upload_dict
+
+        for file in files:
+            if not isinstance(file, dict):
+                raise ValueError('Invalid file format, must be dict')
+            if not file.get('type'):
+                raise ValueError('Missing file type')
+            FileType.value_of(file.get('type'))
+            if not file.get('transfer_method'):
+                raise ValueError('Missing file transfer method')
+            FileTransferMethod.value_of(file.get('transfer_method'))
+            if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value:
+                if not file.get('url'):
+                    raise ValueError('Missing file url')
+                if not file.get('url').startswith('http'):
+                    raise ValueError('Invalid file url')
+            if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
+                raise ValueError('Missing file upload_file_id')
+
+        # transform files to file objs
+        type_file_objs = self._to_file_objs(files, file_upload_config)
+
+        # validate files
+        new_files = []
+        for file_type, file_objs in type_file_objs.items():
+            if file_type == FileType.IMAGE:
+                # parse and validate files
+                image_config = file_upload_config.get('image')
+
+                # check if image file feature is enabled
+                if not image_config['enabled']:
+                    continue
+
+                # Validate number of files
+                if len(files) > image_config['number_limits']:
+                    raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
+
+                for file_obj in file_objs:
+                    # Validate transfer method
+                    if file_obj.transfer_method.value not in image_config['transfer_methods']:
+                        raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}')
+
+                    # Validate file type
+                    if file_obj.type != FileType.IMAGE:
+                        raise ValueError(f'Invalid file type: {file_obj.type}')
+
+                    if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
+                        # check remote url valid and is image
+                        result, error = self._check_image_remote_url(file_obj.url)
+                        if result is False:
+                            raise ValueError(error)
+                    elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
+                        # get upload file from upload_file_id
+                        upload_file = (db.session.query(UploadFile)
+                                       .filter(
+                            UploadFile.id == file_obj.upload_file_id,
+                            UploadFile.tenant_id == self.tenant_id,
+                            UploadFile.created_by == user.id,
+                            UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
+                            UploadFile.extension.in_(SUPPORT_EXTENSIONS)
+                        ).first())
+
+                        # check upload file is belong to tenant and user
+                        if not upload_file:
+                            raise ValueError('Invalid upload file')
+
+                    new_files.append(file_obj)
+
+        # return all file objs
+        return new_files
+
+    def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]:
+        """
+        transform message files
+
+        :param files:
+        :param app_model_config:
+        :return:
+        """
+        # transform files to file objs
+        type_file_objs = self._to_file_objs(files, app_model_config.file_upload_dict)
+
+        # return all file objs
+        return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
+
+    def _to_file_objs(self, files: List[Union[Dict, MessageFile]],
+                      file_upload_config: dict) -> Dict[FileType, List[FileObj]]:
+        """
+        transform files to file objs
+
+        :param files:
+        :param file_upload_config:
+        :return:
+        """
+        type_file_objs: Dict[FileType, List[FileObj]] = {
+            # Currently only support image
+            FileType.IMAGE: []
+        }
+
+        if not files:
+            return type_file_objs
+
+        # group by file type and convert file args or message files to FileObj
+        for file in files:
+            file_obj = self._to_file_obj(file, file_upload_config)
+            if file_obj.type not in type_file_objs:
+                continue
+
+            type_file_objs[file_obj.type].append(file_obj)
+
+        return type_file_objs
+
+    def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj:
+        """
+        transform file to file obj
+
+        :param file:
+        :return:
+        """
+        if isinstance(file, dict):
+            transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
+            return FileObj(
+                tenant_id=self.tenant_id,
+                type=FileType.value_of(file.get('type')),
+                transfer_method=transfer_method,
+                url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
+                upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
+                file_config=file_upload_config
+            )
+        else:
+            return FileObj(
+                id=file.id,
+                tenant_id=self.tenant_id,
+                type=FileType.value_of(file.type),
+                transfer_method=FileTransferMethod.value_of(file.transfer_method),
+                url=file.url,
+                upload_file_id=file.upload_file_id or None,
+                file_config=file_upload_config
+            )
+
+    def _check_image_remote_url(self, url):
+        try:
+            headers = {
+                "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
+            }
+
+            response = requests.head(url, headers=headers, allow_redirects=True)
+            if response.status_code == 200:
+                return True, ""
+            else:
+                return False, "URL does not exist."
+        except requests.RequestException as e:
+            return False, f"Error checking URL: {e}"

+ 79 - 0
api/core/file/upload_file_parser.py

@@ -0,0 +1,79 @@
+import base64
+import hashlib
+import hmac
+import logging
+import os
+import time
+from typing import Optional
+
+from flask import current_app
+
+from extensions.ext_storage import storage
+
+SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
+
+
+class UploadFileParser:
+    @classmethod
+    def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
+        if not upload_file:
+            return None
+
+        if upload_file.extension not in SUPPORT_EXTENSIONS:
+            return None
+
+        if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url:
+            return cls.get_signed_temp_image_url(upload_file)
+        else:
+            # get image file base64
+            try:
+                data = storage.load(upload_file.key)
+            except FileNotFoundError:
+                logging.error(f'File not found: {upload_file.key}')
+                return None
+
+            encoded_string = base64.b64encode(data).decode('utf-8')
+            return f'data:{upload_file.mime_type};base64,{encoded_string}'
+
+    @classmethod
+    def get_signed_temp_image_url(cls, upload_file) -> str:
+        """
+        get signed url from upload file
+
+        :param upload_file: UploadFile object
+        :return:
+        """
+        base_url = current_app.config.get('FILES_URL')
+        image_preview_url = f'{base_url}/files/{upload_file.id}/image-preview'
+
+        timestamp = str(int(time.time()))
+        nonce = os.urandom(16).hex()
+        data_to_sign = f"image-preview|{upload_file.id}|{timestamp}|{nonce}"
+        secret_key = current_app.config['SECRET_KEY'].encode()
+        sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
+        encoded_sign = base64.urlsafe_b64encode(sign).decode()
+
+        return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
+
+    @classmethod
+    def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
+        """
+        verify signature
+
+        :param upload_file_id: file id
+        :param timestamp: timestamp
+        :param nonce: nonce
+        :param sign: signature
+        :return:
+        """
+        data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
+        secret_key = current_app.config['SECRET_KEY'].encode()
+        recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
+        recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
+
+        # verify signature
+        if sign != recalculated_encoded_sign:
+            return False
+
+        current_time = int(time.time())
+        return current_time - int(timestamp) <= 300  # expired after 5 minutes

+ 6 - 2
api/core/generator/llm_generator.py

@@ -16,7 +16,7 @@ from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
 
 class LLMGenerator:
     @classmethod
-    def generate_conversation_name(cls, tenant_id: str, query, answer):
+    def generate_conversation_name(cls, tenant_id: str, query):
         prompt = CONVERSATION_TITLE_PROMPT
 
         if len(query) > 2000:
@@ -40,8 +40,12 @@ class LLMGenerator:
 
         result_dict = json.loads(answer)
         answer = result_dict['Your Output']
+        name = answer.strip()
 
-        return answer.strip()
+        if len(name) > 75:
+            name = name[:75] + '...'
+
+        return name
 
     @classmethod
     def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):

+ 19 - 1
api/core/memory/read_only_conversation_token_db_buffer_shared_memory.py

@@ -3,6 +3,7 @@ from typing import Any, List, Dict
 from langchain.memory.chat_memory import BaseChatMemory
 from langchain.schema import get_buffer_string, BaseMessage
 
+from core.file.message_file_parser import MessageFileParser
 from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
 from core.model_providers.models.llm.base import BaseLLM
 from extensions.ext_database import db
@@ -21,6 +22,8 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
     @property
     def buffer(self) -> List[BaseMessage]:
         """String buffer of memory."""
+        app_model = self.conversation.app
+
         # fetch limited messages desc, and return reversed
         messages = db.session.query(Message).filter(
             Message.conversation_id == self.conversation.id,
@@ -28,10 +31,25 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
         ).order_by(Message.created_at.desc()).limit(self.message_limit).all()
 
         messages = list(reversed(messages))
+        message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=self.conversation.app_id)
 
         chat_messages: List[PromptMessage] = []
         for message in messages:
-            chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
+            files = message.message_files
+            if files:
+                file_objs = message_file_parser.transform_message_files(
+                    files, message.app_model_config
+                )
+
+                prompt_message_files = [file_obj.prompt_message_file for file_obj in file_objs]
+                chat_messages.append(PromptMessage(
+                    content=message.query,
+                    type=MessageType.USER,
+                    files=prompt_message_files
+                ))
+            else:
+                chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
+
             chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
 
         if not chat_messages:

+ 46 - 2
api/core/model_providers/models/entity/message.py

@@ -1,4 +1,5 @@
 import enum
+from typing import Any, cast, Union, List, Dict
 
 from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
 from pydantic import BaseModel
@@ -18,17 +19,53 @@ class MessageType(enum.Enum):
     SYSTEM = 'system'
 
 
+class PromptMessageFileType(enum.Enum):
+    IMAGE = 'image'
+
+    @staticmethod
+    def value_of(value):
+        for member in PromptMessageFileType:
+            if member.value == value:
+                return member
+        raise ValueError(f"No matching enum found for value '{value}'")
+
+
+
+class PromptMessageFile(BaseModel):
+    type: PromptMessageFileType
+    data: Any
+
+
+class ImagePromptMessageFile(PromptMessageFile):
+    class DETAIL(enum.Enum):
+        LOW = 'low'
+        HIGH = 'high'
+
+    type: PromptMessageFileType = PromptMessageFileType.IMAGE
+    detail: DETAIL = DETAIL.LOW
+
+
 class PromptMessage(BaseModel):
     type: MessageType = MessageType.USER
     content: str = ''
+    files: list[PromptMessageFile] = []
     function_call: dict = None
 
 
+class LCHumanMessageWithFiles(HumanMessage):
+    # content: Union[str, List[Union[str, Dict]]]
+    content: str
+    files: list[PromptMessageFile]
+
+
 def to_lc_messages(messages: list[PromptMessage]):
     lc_messages = []
     for message in messages:
         if message.type == MessageType.USER:
-            lc_messages.append(HumanMessage(content=message.content))
+            if not message.files:
+                lc_messages.append(HumanMessage(content=message.content))
+            else:
+                lc_messages.append(LCHumanMessageWithFiles(content=message.content, files=message.files))
         elif message.type == MessageType.ASSISTANT:
             additional_kwargs = {}
             if message.function_call:
@@ -44,7 +81,14 @@ def to_prompt_messages(messages: list[BaseMessage]):
     prompt_messages = []
     for message in messages:
         if isinstance(message, HumanMessage):
-            prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
+            if isinstance(message, LCHumanMessageWithFiles):
+                prompt_messages.append(PromptMessage(
+                    content=message.content,
+                    type=MessageType.USER,
+                    files=message.files
+                ))
+            else:
+                prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
         elif isinstance(message, AIMessage):
             message_kwargs = {
                 'content': message.content,

+ 0 - 2
api/core/model_providers/models/llm/openai_model.py

@@ -1,11 +1,9 @@
-import decimal
 import logging
 from typing import List, Optional, Any
 
 import openai
 from langchain.callbacks.manager import Callbacks
 from langchain.schema import LLMResult
-from openai import api_requestor
 
 from core.model_providers.providers.base import BaseModelProvider
 from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI

+ 153 - 76
api/core/prompt/prompt_transform.py

@@ -8,7 +8,7 @@ from langchain.memory.chat_memory import BaseChatMemory
 from langchain.schema import BaseMessage
 
 from core.model_providers.models.entity.model_params import ModelMode
-from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages
+from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages, PromptMessageFile
 from core.model_providers.models.llm.base import BaseLLM
 from core.model_providers.models.llm.baichuan_model import BaichuanModel
 from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
@@ -16,32 +16,59 @@ from core.model_providers.models.llm.openllm_model import OpenLLMModel
 from core.model_providers.models.llm.xinference_model import XinferenceModel
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_template import PromptTemplateParser
+from models.model import AppModelConfig
+
 
 class AppMode(enum.Enum):
     COMPLETION = 'completion'
     CHAT = 'chat'
 
+
 class PromptTransform:
-    def get_prompt(self, mode: str,
-                   pre_prompt: str, inputs: dict,
+    def get_prompt(self,
+                   app_mode: str,
+                   app_model_config: AppModelConfig,
+                   pre_prompt: str,
+                   inputs: dict,
                    query: str,
+                   files: List[PromptMessageFile],
                    context: Optional[str],
                    memory: Optional[BaseChatMemory],
                    model_instance: BaseLLM) -> \
             Tuple[List[PromptMessage], Optional[List[str]]]:
-        prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(mode, model_instance))
-        prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory, model_instance)
-        return [PromptMessage(content=prompt)], stops
-
-    def get_advanced_prompt(self, 
-            app_mode: str,
-            app_model_config: str, 
-            inputs: dict,
-            query: str,
-            context: Optional[str],
-            memory: Optional[BaseChatMemory],
-            model_instance: BaseLLM) -> List[PromptMessage]:
-        
+        model_mode = app_model_config.model_dict['mode']
+
+        app_mode_enum = AppMode(app_mode)
+        model_mode_enum = ModelMode(model_mode)
+
+        prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(app_mode, model_instance))
+
+        if app_mode_enum == AppMode.CHAT and model_mode_enum == ModelMode.CHAT:
+            stops = None
+
+            prompt_messages = self._get_simple_chat_app_chat_model_prompt_messages(prompt_rules, pre_prompt, inputs,
+                                                                                   query, context, memory,
+                                                                                   model_instance, files)
+        else:
+            stops = prompt_rules.get('stops')
+            if stops is not None and len(stops) == 0:
+                stops = None
+
+            prompt_messages = self._get_simple_others_prompt_messages(prompt_rules, pre_prompt, inputs, query, context,
+                                                                      memory,
+                                                                      model_instance, files)
+        return prompt_messages, stops
+
+    def get_advanced_prompt(self,
+                            app_mode: str,
+                            app_model_config: AppModelConfig,
+                            inputs: dict,
+                            query: str,
+                            files: List[PromptMessageFile],
+                            context: Optional[str],
+                            memory: Optional[BaseChatMemory],
+                            model_instance: BaseLLM) -> List[PromptMessage]:
+
         model_mode = app_model_config.model_dict['mode']
 
         app_mode_enum = AppMode(app_mode)
@@ -51,15 +78,20 @@ class PromptTransform:
 
         if app_mode_enum == AppMode.CHAT:
             if model_mode_enum == ModelMode.COMPLETION:
-                prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
+                prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query,
+                                                                                      files, context, memory,
+                                                                                      model_instance)
             elif model_mode_enum == ModelMode.CHAT:
-                prompt_messages =  self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
+                prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, files,
+                                                                                context, memory, model_instance)
         elif app_mode_enum == AppMode.COMPLETION:
             if model_mode_enum == ModelMode.CHAT:
-                prompt_messages =  self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, context)
+                prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs,
+                                                                                      files, context)
             elif model_mode_enum == ModelMode.COMPLETION:
-                prompt_messages =  self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs, context)
-            
+                prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs,
+                                                                                            files, context)
+
         return prompt_messages
 
     def _get_history_messages_from_memory(self, memory: BaseChatMemory,
@@ -71,7 +103,7 @@ class PromptTransform:
         return external_context[memory_key]
 
     def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
-                                          max_token_limit: int) -> List[PromptMessage]:
+                                               max_token_limit: int) -> List[PromptMessage]:
         """Get memory messages."""
         memory.max_token_limit = max_token_limit
         memory.return_messages = True
@@ -79,7 +111,7 @@ class PromptTransform:
         external_context = memory.load_memory_variables({})
         memory.return_messages = False
         return to_prompt_messages(external_context[memory_key])
-    
+
     def _prompt_file_name(self, mode: str, model_instance: BaseLLM) -> str:
         # baichuan
         if isinstance(model_instance, BaichuanModel):
@@ -94,13 +126,13 @@ class PromptTransform:
             return 'common_completion'
         else:
             return 'common_chat'
-        
+
     def _prompt_file_name_for_baichuan(self, mode: str) -> str:
         if mode == 'completion':
             return 'baichuan_completion'
         else:
             return 'baichuan_chat'
-    
+
     def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
         # Get the absolute path of the subdirectory
         prompt_path = os.path.join(
@@ -111,12 +143,53 @@ class PromptTransform:
         # Open the JSON file and read its content
         with open(json_file_path, 'r') as json_file:
             return json.load(json_file)
-        
-    def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
-                             query: str,
-                             context: Optional[str],
-                             memory: Optional[BaseChatMemory],
-                             model_instance: BaseLLM) -> Tuple[str, Optional[list]]:
+
+    def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
+                                                        query: str,
+                                                        context: Optional[str],
+                                                        memory: Optional[BaseChatMemory],
+                                                        model_instance: BaseLLM,
+                                                        files: List[PromptMessageFile]) -> List[PromptMessage]:
+        prompt_messages = []
+
+        context_prompt_content = ''
+        if context and 'context_prompt' in prompt_rules:
+            prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
+            context_prompt_content = prompt_template.format(
+                {'context': context}
+            )
+
+        pre_prompt_content = ''
+        if pre_prompt:
+            prompt_template = PromptTemplateParser(template=pre_prompt)
+            prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
+            pre_prompt_content = prompt_template.format(
+                prompt_inputs
+            )
+
+        prompt = ''
+        for order in prompt_rules['system_prompt_orders']:
+            if order == 'context_prompt':
+                prompt += context_prompt_content
+            elif order == 'pre_prompt':
+                prompt += pre_prompt_content
+
+        prompt = re.sub(r'<\|.*?\|>', '', prompt)
+
+        prompt_messages.append(PromptMessage(type=MessageType.SYSTEM, content=prompt))
+
+        self._append_chat_histories(memory, prompt_messages, model_instance)
+
+        prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files))
+
+        return prompt_messages
+
+    def _get_simple_others_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
+                                           query: str,
+                                           context: Optional[str],
+                                           memory: Optional[BaseChatMemory],
+                                           model_instance: BaseLLM,
+                                           files: List[PromptMessageFile]) -> List[PromptMessage]:
         context_prompt_content = ''
         if context and 'context_prompt' in prompt_rules:
             prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
@@ -175,16 +248,12 @@ class PromptTransform:
 
         prompt = re.sub(r'<\|.*?\|>', '', prompt)
 
-        stops = prompt_rules.get('stops')
-        if stops is not None and len(stops) == 0:
-            stops = None
+        return [PromptMessage(content=prompt, files=files)]
 
-        return prompt, stops
-    
     def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
         if '#context#' in prompt_template.variable_keys:
             if context:
-                prompt_inputs['#context#'] = context    
+                prompt_inputs['#context#'] = context
             else:
                 prompt_inputs['#context#'] = ''
 
@@ -195,17 +264,18 @@ class PromptTransform:
             else:
                 prompt_inputs['#query#'] = ''
 
-    def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict, 
-                                prompt_template: PromptTemplateParser, prompt_inputs: dict, model_instance: BaseLLM) -> None:
+    def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict,
+                                prompt_template: PromptTemplateParser, prompt_inputs: dict,
+                                model_instance: BaseLLM) -> None:
         if '#histories#' in prompt_template.variable_keys:
             if memory:
                 tmp_human_message = PromptBuilder.to_human_message(
                     prompt_content=raw_prompt,
-                    inputs={ '#histories#': '', **prompt_inputs }
+                    inputs={'#histories#': '', **prompt_inputs}
                 )
 
                 rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)
-                
+
                 memory.human_prefix = conversation_histories_role['user_prefix']
                 memory.ai_prefix = conversation_histories_role['assistant_prefix']
                 histories = self._get_history_messages_from_memory(memory, rest_tokens)
@@ -213,7 +283,8 @@ class PromptTransform:
             else:
                 prompt_inputs['#histories#'] = ''
 
-    def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage], model_instance: BaseLLM) -> None:
+    def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage],
+                               model_instance: BaseLLM) -> None:
         if memory:
             rest_tokens = self._calculate_rest_token(prompt_messages, model_instance)
 
@@ -242,19 +313,19 @@ class PromptTransform:
         return prompt
 
     def _get_chat_app_completion_model_prompt_messages(self,
-            app_model_config: str,
-            inputs: dict,
-            query: str,
-            context: Optional[str],
-            memory: Optional[BaseChatMemory],
-            model_instance: BaseLLM) -> List[PromptMessage]:
-        
+                                                       app_model_config: AppModelConfig,
+                                                       inputs: dict,
+                                                       query: str,
+                                                       files: List[PromptMessageFile],
+                                                       context: Optional[str],
+                                                       memory: Optional[BaseChatMemory],
+                                                       model_instance: BaseLLM) -> List[PromptMessage]:
+
         raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
         conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
 
         prompt_messages = []
-        prompt = ''
-        
+
         prompt_template = PromptTemplateParser(template=raw_prompt)
         prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
 
@@ -262,28 +333,29 @@ class PromptTransform:
 
         self._set_query_variable(query, prompt_template, prompt_inputs)
 
-        self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs, model_instance)
+        self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs,
+                                     model_instance)
 
         prompt = self._format_prompt(prompt_template, prompt_inputs)
 
-        prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
+        prompt_messages.append(PromptMessage(type=MessageType.USER, content=prompt, files=files))
 
         return prompt_messages
 
     def _get_chat_app_chat_model_prompt_messages(self,
-            app_model_config: str,
-            inputs: dict,
-            query: str,
-            context: Optional[str],
-            memory: Optional[BaseChatMemory],
-            model_instance: BaseLLM) -> List[PromptMessage]:
+                                                 app_model_config: AppModelConfig,
+                                                 inputs: dict,
+                                                 query: str,
+                                                 files: List[PromptMessageFile],
+                                                 context: Optional[str],
+                                                 memory: Optional[BaseChatMemory],
+                                                 model_instance: BaseLLM) -> List[PromptMessage]:
         raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
 
         prompt_messages = []
 
         for prompt_item in raw_prompt_list:
             raw_prompt = prompt_item['text']
-            prompt = ''
 
             prompt_template = PromptTemplateParser(template=raw_prompt)
             prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
@@ -292,23 +364,23 @@ class PromptTransform:
 
             prompt = self._format_prompt(prompt_template, prompt_inputs)
 
-            prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
-        
+            prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
+
         self._append_chat_histories(memory, prompt_messages, model_instance)
 
-        prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
+        prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files))
 
         return prompt_messages
 
     def _get_completion_app_completion_model_prompt_messages(self,
-                   app_model_config: str,
-                   inputs: dict,
-                   context: Optional[str]) -> List[PromptMessage]:
+                                                             app_model_config: AppModelConfig,
+                                                             inputs: dict,
+                                                             files: List[PromptMessageFile],
+                                                             context: Optional[str]) -> List[PromptMessage]:
         raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
 
         prompt_messages = []
-        prompt = ''
-        
+
         prompt_template = PromptTemplateParser(template=raw_prompt)
         prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
 
@@ -316,21 +388,21 @@ class PromptTransform:
 
         prompt = self._format_prompt(prompt_template, prompt_inputs)
 
-        prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
+        prompt_messages.append(PromptMessage(type=MessageType(MessageType.USER), content=prompt, files=files))
 
         return prompt_messages
 
     def _get_completion_app_chat_model_prompt_messages(self,
-                   app_model_config: str,
-                   inputs: dict,
-                   context: Optional[str]) -> List[PromptMessage]:
+                                                       app_model_config: AppModelConfig,
+                                                       inputs: dict,
+                                                       files: List[PromptMessageFile],
+                                                       context: Optional[str]) -> List[PromptMessage]:
         raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
 
         prompt_messages = []
 
         for prompt_item in raw_prompt_list:
             raw_prompt = prompt_item['text']
-            prompt = ''
 
             prompt_template = PromptTemplateParser(template=raw_prompt)
             prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
@@ -339,6 +411,11 @@ class PromptTransform:
 
             prompt = self._format_prompt(prompt_template, prompt_inputs)
 
-            prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
-        
-        return prompt_messages
+            prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
+
+        for prompt_message in prompt_messages[::-1]:
+            if prompt_message.type == MessageType.USER:
+                prompt_message.files = files
+                break
+
+        return prompt_messages

+ 103 - 1
api/core/third_party/langchain/llms/chat_open_ai.py

@@ -1,10 +1,13 @@
 import os
 
-from typing import Dict, Any, Optional, Union, Tuple
+from typing import Dict, Any, Optional, Union, Tuple, List, cast
 
 from langchain.chat_models import ChatOpenAI
+from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage
 from pydantic import root_validator
 
+from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile
+
 
 class EnhanceChatOpenAI(ChatOpenAI):
     request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
@@ -48,3 +51,102 @@ class EnhanceChatOpenAI(ChatOpenAI):
             "api_key": self.openai_api_key,
             "organization": self.openai_organization if self.openai_organization else None,
         }
+
+    def _create_message_dicts(
+        self, messages: List[BaseMessage], stop: Optional[List[str]]
+    ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
+        params = self._client_params
+        if stop is not None:
+            if "stop" in params:
+                raise ValueError("`stop` found in both the input and default params.")
+            params["stop"] = stop
+        message_dicts = [self._convert_message_to_dict(m) for m in messages]
+        return message_dicts, params
+
+    def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
+        """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
+
+        Official documentation: https://github.com/openai/openai-cookbook/blob/
+        main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
+        model, encoding = self._get_encoding_model()
+        if model.startswith("gpt-3.5-turbo-0301"):
+            # every message follows <im_start>{role/name}\n{content}<im_end>\n
+            tokens_per_message = 4
+            # if there's a name, the role is omitted
+            tokens_per_name = -1
+        elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
+            tokens_per_message = 3
+            tokens_per_name = 1
+        else:
+            raise NotImplementedError(
+                f"get_num_tokens_from_messages() is not presently implemented "
+                f"for model {model}."
+                "See https://github.com/openai/openai-python/blob/main/chatml.md for "
+                "information on how messages are converted to tokens."
+            )
+        num_tokens = 0
+        messages_dict = [self._convert_message_to_dict(m) for m in messages]
+        for message in messages_dict:
+            num_tokens += tokens_per_message
+            for key, value in message.items():
+                # Cast str(value) in case the message value is not a string
+                # This occurs with function messages
+                # TODO: The current token calculation method for the image type is not implemented,
+                #  which need to download the image and then get the resolution for calculation,
+                #  and will increase the request delay
+                if isinstance(value, list):
+                    text = ''
+                    for item in value:
+                        if isinstance(item, dict) and item['type'] == 'text':
+                            text += item['text']
+
+                    value = text
+                num_tokens += len(encoding.encode(str(value)))
+                if key == "name":
+                    num_tokens += tokens_per_name
+        # every reply is primed with <im_start>assistant
+        num_tokens += 3
+        return num_tokens
+
+    def _convert_message_to_dict(self, message: BaseMessage) -> dict:
+        if isinstance(message, ChatMessage):
+            message_dict = {"role": message.role, "content": message.content}
+        elif isinstance(message, LCHumanMessageWithFiles):
+            content = [
+                {
+                    "type": "text",
+                    "text": message.content
+                }
+            ]
+
+            for file in message.files:
+                if file.type == PromptMessageFileType.IMAGE:
+                    file = cast(ImagePromptMessageFile, file)
+                    content.append({
+                        "type": "image_url",
+                        "image_url": {
+                            "url": file.data,
+                            "detail": file.detail.value
+                        }
+                    })
+
+            message_dict = {"role": "user", "content": content}
+        elif isinstance(message, HumanMessage):
+            message_dict = {"role": "user", "content": message.content}
+        elif isinstance(message, AIMessage):
+            message_dict = {"role": "assistant", "content": message.content}
+            if "function_call" in message.additional_kwargs:
+                message_dict["function_call"] = message.additional_kwargs["function_call"]
+        elif isinstance(message, SystemMessage):
+            message_dict = {"role": "system", "content": message.content}
+        elif isinstance(message, FunctionMessage):
+            message_dict = {
+                "role": "function",
+                "content": message.content,
+                "name": message.name,
+            }
+        else:
+            raise ValueError(f"Got unknown type {message}")
+        if "name" in message.additional_kwargs:
+            message_dict["name"] = message.additional_kwargs["name"]
+        return message_dict

+ 4 - 10
api/events/event_handlers/generate_conversation_name_when_first_message_created.py

@@ -1,5 +1,3 @@
-import logging
-
 from core.generator.llm_generator import LLMGenerator
 from events.message_event import message_was_created
 from extensions.ext_database import db
@@ -10,8 +8,9 @@ def handle(sender, **kwargs):
     message = sender
     conversation = kwargs.get('conversation')
     is_first_message = kwargs.get('is_first_message')
+    auto_generate_name = kwargs.get('auto_generate_name', True)
 
-    if is_first_message:
+    if auto_generate_name and is_first_message:
         if conversation.mode == 'chat':
             app_model = conversation.app
             if not app_model:
@@ -19,14 +18,9 @@ def handle(sender, **kwargs):
 
             # generate conversation name
             try:
-                name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query, message.answer)
-
-                if len(name) > 75:
-                    name = name[:75] + '...'
-
+                name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query)
                 conversation.name = name
             except:
-                conversation.name = 'New conversation'
+                pass
 
-            db.session.add(conversation)
             db.session.commit()

+ 36 - 1
api/extensions/ext_storage.py

@@ -1,6 +1,7 @@
 import os
 import shutil
 from contextlib import closing
+from typing import Union, Generator
 
 import boto3
 from botocore.exceptions import ClientError
@@ -45,7 +46,13 @@ class Storage:
             with open(os.path.join(os.getcwd(), filename), "wb") as f:
                 f.write(data)
 
-    def load(self, filename):
+    def load(self, filename: str, stream: bool = False) -> Union[bytes, Generator]:
+        if stream:
+            return self.load_stream(filename)
+        else:
+            return self.load_once(filename)
+
+    def load_once(self, filename: str) -> bytes:
         if self.storage_type == 's3':
             try:
                 with closing(self.client) as client:
@@ -69,6 +76,34 @@ class Storage:
 
         return data
 
+    def load_stream(self, filename: str) -> Generator:
+        def generate(filename: str = filename) -> Generator:
+            if self.storage_type == 's3':
+                try:
+                    with closing(self.client) as client:
+                        response = client.get_object(Bucket=self.bucket_name, Key=filename)
+                        for chunk in response['Body'].iter_chunks():
+                            yield chunk
+                except ClientError as ex:
+                    if ex.response['Error']['Code'] == 'NoSuchKey':
+                        raise FileNotFoundError("File not found")
+                    else:
+                        raise
+            else:
+                if not self.folder or self.folder.endswith('/'):
+                    filename = self.folder + filename
+                else:
+                    filename = self.folder + '/' + filename
+
+                if not os.path.exists(filename):
+                    raise FileNotFoundError("File not found")
+
+                with open(filename, "rb") as f:
+                    while chunk := f.read(4096):  # Read in chunks of 4KB
+                        yield chunk
+
+        return generate()
+
     def download(self, filename, target_filepath):
         if self.storage_type == 's3':
             with closing(self.client) as client:

+ 3 - 2
api/fields/app_fields.py

@@ -32,7 +32,8 @@ model_config_fields = {
     'prompt_type': fields.String,
     'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'),
     'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'),
-    'dataset_configs': fields.Raw(attribute='dataset_configs_dict')
+    'dataset_configs': fields.Raw(attribute='dataset_configs_dict'),
+    'file_upload': fields.Raw(attribute='file_upload_dict'),
 }
 
 app_detail_fields = {
@@ -140,4 +141,4 @@ app_site_fields = {
     'privacy_policy': fields.String,
     'customize_token_strategy': fields.String,
     'prompt_public': fields.Boolean
-}
+}

+ 9 - 7
api/fields/conversation_fields.py

@@ -28,6 +28,12 @@ annotation_fields = {
     'created_at': TimestampField
 }
 
+message_file_fields = {
+    'id': fields.String,
+    'type': fields.String,
+    'url': fields.String,
+}
+
 message_detail_fields = {
     'id': fields.String,
     'conversation_id': fields.String,
@@ -43,7 +49,8 @@ message_detail_fields = {
     'from_account_id': fields.String,
     'feedbacks': fields.List(fields.Nested(feedback_fields)),
     'annotation': fields.Nested(annotation_fields, allow_null=True),
-    'created_at': TimestampField
+    'created_at': TimestampField,
+    'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
 }
 
 feedback_stat_fields = {
@@ -111,11 +118,6 @@ conversation_message_detail_fields = {
     'message': fields.Nested(message_detail_fields, attribute='first_message'),
 }
 
-simple_model_config_fields = {
-    'model': fields.Raw(attribute='model_dict'),
-    'pre_prompt': fields.String,
-}
-
 conversation_with_summary_fields = {
     'id': fields.String,
     'status': fields.String,
@@ -180,4 +182,4 @@ conversation_with_model_config_infinite_scroll_pagination_fields = {
     'limit': fields.Integer,
     'has_more': fields.Boolean,
     'data': fields.List(fields.Nested(conversation_with_model_config_fields))
-}
+}

+ 2 - 1
api/fields/file_fields.py

@@ -4,7 +4,8 @@ from libs.helper import TimestampField
 
 upload_config_fields = {
     'file_size_limit': fields.Integer,
-    'batch_count_limit': fields.Integer
+    'batch_count_limit': fields.Integer,
+    'image_file_size_limit': fields.Integer,
 }
 
 file_fields = {

+ 2 - 0
api/fields/message_fields.py

@@ -1,6 +1,7 @@
 from flask_restful import fields
 
 from libs.helper import TimestampField
+from fields.conversation_fields import message_file_fields
 
 feedback_fields = {
     'rating': fields.String
@@ -31,6 +32,7 @@ message_fields = {
     'inputs': fields.Raw,
     'query': fields.String,
     'answer': fields.String,
+    'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
     'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
     'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
     'created_at': TimestampField

+ 59 - 0
api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py

@@ -0,0 +1,59 @@
+"""add gpt4v supports
+
+Revision ID: 8fe468ba0ca5
+Revises: a9836e3baeee
+Create Date: 2023-11-09 11:39:00.006432
+
+"""
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '8fe468ba0ca5'
+down_revision = 'a9836e3baeee'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.create_table('message_files',
+    sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+    sa.Column('message_id', postgresql.UUID(), nullable=False),
+    sa.Column('type', sa.String(length=255), nullable=False),
+    sa.Column('transfer_method', sa.String(length=255), nullable=False),
+    sa.Column('url', sa.Text(), nullable=True),
+    sa.Column('upload_file_id', postgresql.UUID(), nullable=True),
+    sa.Column('created_by_role', sa.String(length=255), nullable=False),
+    sa.Column('created_by', postgresql.UUID(), nullable=False),
+    sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.PrimaryKeyConstraint('id', name='message_file_pkey')
+    )
+    with op.batch_alter_table('message_files', schema=None) as batch_op:
+        batch_op.create_index('message_file_created_by_idx', ['created_by'], unique=False)
+        batch_op.create_index('message_file_message_idx', ['message_id'], unique=False)
+
+    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('file_upload', sa.Text(), nullable=True))
+
+    with op.batch_alter_table('upload_files', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'account'::character varying"), nullable=False))
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('upload_files', schema=None) as batch_op:
+        batch_op.drop_column('created_by_role')
+
+    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+        batch_op.drop_column('file_upload')
+
+    with op.batch_alter_table('message_files', schema=None) as batch_op:
+        batch_op.drop_index('message_file_message_idx')
+        batch_op.drop_index('message_file_created_by_idx')
+
+    op.drop_table('message_files')
+    # ### end Alembic commands ###

+ 63 - 4
api/models/model.py

@@ -1,10 +1,10 @@
 import json
-from json import JSONDecodeError
 
 from flask import current_app, request
 from flask_login import UserMixin
 from sqlalchemy.dialects.postgresql import UUID
 
+from core.file.upload_file_parser import UploadFileParser
 from libs.helper import generate_string
 from extensions.ext_database import db
 from .account import Account, Tenant
@@ -98,6 +98,7 @@ class AppModelConfig(db.Model):
     completion_prompt_config = db.Column(db.Text)
     dataset_configs = db.Column(db.Text)
     external_data_tools = db.Column(db.Text)
+    file_upload = db.Column(db.Text)
 
     @property
     def app(self):
@@ -161,6 +162,10 @@ class AppModelConfig(db.Model):
     def dataset_configs_dict(self) -> dict:
         return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
 
+    @property
+    def file_upload_dict(self) -> dict:
+        return json.loads(self.file_upload) if self.file_upload else {"image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}}
+
     def to_dict(self) -> dict:
         return {
             "provider": "",
@@ -182,7 +187,8 @@ class AppModelConfig(db.Model):
             "prompt_type": self.prompt_type,
             "chat_prompt_config": self.chat_prompt_config_dict,
             "completion_prompt_config": self.completion_prompt_config_dict,
-            "dataset_configs": self.dataset_configs_dict
+            "dataset_configs": self.dataset_configs_dict,
+            "file_upload": self.file_upload_dict
         }
 
     def from_model_config_dict(self, model_config: dict):
@@ -213,6 +219,8 @@ class AppModelConfig(db.Model):
             if model_config.get('completion_prompt_config') else None
         self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \
             if model_config.get('dataset_configs') else None
+        self.file_upload = json.dumps(model_config.get('file_upload')) \
+            if model_config.get('file_upload') else None
         return self
 
     def copy(self):
@@ -238,7 +246,8 @@ class AppModelConfig(db.Model):
             prompt_type=self.prompt_type,
             chat_prompt_config=self.chat_prompt_config,
             completion_prompt_config=self.completion_prompt_config,
-            dataset_configs=self.dataset_configs
+            dataset_configs=self.dataset_configs,
+            file_upload=self.file_upload
         )
 
         return new_app_model_config
@@ -512,6 +521,37 @@ class Message(db.Model):
         return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \
             .order_by(DatasetRetrieverResource.position.asc()).all()
 
+    @property
+    def message_files(self):
+        return db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all()
+
+    @property
+    def files(self):
+        message_files = self.message_files
+
+        files = []
+        for message_file in message_files:
+            url = message_file.url
+            if message_file.type == 'image':
+                if message_file.transfer_method == 'local_file':
+                    upload_file = (db.session.query(UploadFile)
+                                   .filter(
+                        UploadFile.id == message_file.upload_file_id
+                    ).first())
+
+                    url = UploadFileParser.get_image_data(
+                        upload_file=upload_file,
+                        force_url=True
+                    )
+
+            files.append({
+                'id': message_file.id,
+                'type': message_file.type,
+                'url': url
+            })
+
+        return files
+
 
 class MessageFeedback(db.Model):
     __tablename__ = 'message_feedbacks'
@@ -540,6 +580,25 @@ class MessageFeedback(db.Model):
         return account
 
 
+class MessageFile(db.Model):
+    __tablename__ = 'message_files'
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='message_file_pkey'),
+        db.Index('message_file_message_idx', 'message_id'),
+        db.Index('message_file_created_by_idx', 'created_by')
+    )
+
+    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
+    message_id = db.Column(UUID, nullable=False)
+    type = db.Column(db.String(255), nullable=False)
+    transfer_method = db.Column(db.String(255), nullable=False)
+    url = db.Column(db.Text, nullable=True)
+    upload_file_id = db.Column(UUID, nullable=True)
+    created_by_role = db.Column(db.String(255), nullable=False)
+    created_by = db.Column(UUID, nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
+
+
 class MessageAnnotation(db.Model):
     __tablename__ = 'message_annotations'
     __table_args__ = (
@@ -683,6 +742,7 @@ class UploadFile(db.Model):
     size = db.Column(db.Integer, nullable=False)
     extension = db.Column(db.String(255), nullable=False)
     mime_type = db.Column(db.String(255), nullable=True)
+    created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
     created_by = db.Column(UUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     used = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
@@ -783,4 +843,3 @@ class DatasetRetrieverResource(db.Model):
     retriever_from = db.Column(db.Text, nullable=False)
     created_by = db.Column(UUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
-

+ 33 - 1
api/services/app_model_config_service.py

@@ -315,6 +315,9 @@ class AppModelConfigService:
         # moderation validation
         cls.is_moderation_valid(tenant_id, config)
 
+        # file upload validation
+        cls.is_file_upload_valid(config)
+
         # Filter out extra parameters
         filtered_config = {
             "opening_statement": config["opening_statement"],
@@ -338,7 +341,8 @@ class AppModelConfigService:
             "prompt_type": config["prompt_type"],
             "chat_prompt_config": config["chat_prompt_config"],
             "completion_prompt_config": config["completion_prompt_config"],
-            "dataset_configs": config["dataset_configs"]
+            "dataset_configs": config["dataset_configs"],
+            "file_upload": config["file_upload"]
         }
 
         return filtered_config
@@ -371,6 +375,34 @@ class AppModelConfigService:
             config=config
         )
 
+    @classmethod
+    def is_file_upload_valid(cls, config: dict):
+        if 'file_upload' not in config or not config["file_upload"]:
+            config["file_upload"] = {}
+
+        if not isinstance(config["file_upload"], dict):
+            raise ValueError("file_upload must be of dict type")
+
+        # check image config
+        if 'image' not in config["file_upload"] or not config["file_upload"]["image"]:
+            config["file_upload"]["image"] = {"enabled": False}
+
+        if config['file_upload']['image']['enabled']:
+            number_limits = config['file_upload']['image']['number_limits']
+            if number_limits < 1 or number_limits > 6:
+                raise ValueError("number_limits must be in [1, 6]")
+
+            detail = config['file_upload']['image']['detail']
+            if detail not in ['high', 'low']:
+                raise ValueError("detail must be in ['high', 'low']")
+
+            transfer_methods = config['file_upload']['image']['transfer_methods']
+            if not isinstance(transfer_methods, list):
+                raise ValueError("transfer_methods must be of list type")
+            for method in transfer_methods:
+                if method not in ['remote_url', 'local_file']:
+                    raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
+
     @classmethod
     def is_external_data_tools_valid(cls, tenant_id: str, config: dict):
         if 'external_data_tools' not in config or not config["external_data_tools"]:

+ 40 - 10
api/services/completion_service.py

@@ -3,7 +3,7 @@ import logging
 import threading
 import time
 import uuid
-from typing import Generator, Union, Any, Optional
+from typing import Generator, Union, Any, Optional, List
 
 from flask import current_app, Flask
 from redis.client import PubSub
@@ -12,9 +12,11 @@ from sqlalchemy import and_
 from core.completion import Completion
 from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
     ConversationTaskInterruptException
+from core.file.message_file_parser import MessageFileParser
 from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
     LLMRateLimitError, \
     LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
+from core.model_providers.models.entity.message import PromptMessageFile
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
@@ -35,6 +37,9 @@ class CompletionService:
         # is streaming mode
         inputs = args['inputs']
         query = args['query']
+        files = args['files'] if 'files' in args and args['files'] else []
+        auto_generate_name = args['auto_generate_name'] \
+            if 'auto_generate_name' in args else True
 
         if app_model.mode != 'completion' and not query:
             raise ValueError('query is required')
@@ -132,6 +137,14 @@ class CompletionService:
         # clean input by app_model_config form rules
         inputs = cls.get_cleaned_inputs(inputs, app_model_config)
 
+        # parse files
+        message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
+        file_objs = message_file_parser.validate_and_transform_files_arg(
+            files,
+            app_model_config,
+            user
+        )
+
         generate_task_id = str(uuid.uuid4())
 
         pubsub = redis_client.pubsub()
@@ -146,17 +159,20 @@ class CompletionService:
             'app_model_config': app_model_config.copy(),
             'query': query,
             'inputs': inputs,
+            'files': file_objs,
             'detached_user': user,
             'detached_conversation': conversation,
             'streaming': streaming,
             'is_model_config_override': is_model_config_override,
-            'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
+            'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
+            'auto_generate_name': auto_generate_name
         })
 
         generate_worker_thread.start()
 
         # wait for 10 minutes to close the thread
-        cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
+        cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
+                                generate_task_id)
 
         return cls.compact_response(pubsub, streaming)
 
@@ -172,10 +188,12 @@ class CompletionService:
         return user
 
     @classmethod
-    def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig,
-                        query: str, inputs: dict, detached_user: Union[Account, EndUser],
+    def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
+                        app_model_config: AppModelConfig,
+                        query: str, inputs: dict, files: List[PromptMessageFile],
+                        detached_user: Union[Account, EndUser],
                         detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
-                        retriever_from: str = 'dev'):
+                        retriever_from: str = 'dev', auto_generate_name: bool = True):
         with flask_app.app_context():
             # fixed the state of the model object when it detached from the original session
             user = db.session.merge(detached_user)
@@ -195,10 +213,12 @@ class CompletionService:
                     query=query,
                     inputs=inputs,
                     user=user,
+                    files=files,
                     conversation=conversation,
                     streaming=streaming,
                     is_override=is_model_config_override,
-                    retriever_from=retriever_from
+                    retriever_from=retriever_from,
+                    auto_generate_name=auto_generate_name
                 )
             except (ConversationTaskInterruptException, ConversationTaskStoppedException):
                 pass
@@ -215,7 +235,8 @@ class CompletionService:
                 db.session.commit()
 
     @classmethod
-    def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, generate_task_id) -> threading.Thread:
+    def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user,
+                            generate_task_id) -> threading.Thread:
         # wait for 10 minutes to close the thread
         timeout = 600
 
@@ -274,6 +295,12 @@ class CompletionService:
         model_dict['completion_params'] = completion_params
         app_model_config.model = json.dumps(model_dict)
 
+        # parse files
+        message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
+        file_objs = message_file_parser.transform_message_files(
+            message.files, app_model_config
+        )
+
         generate_task_id = str(uuid.uuid4())
 
         pubsub = redis_client.pubsub()
@@ -288,11 +315,13 @@ class CompletionService:
             'app_model_config': app_model_config.copy(),
             'query': message.query,
             'inputs': message.inputs,
+            'files': file_objs,
             'detached_user': user,
             'detached_conversation': None,
             'streaming': streaming,
             'is_model_config_override': True,
-            'retriever_from': retriever_from
+            'retriever_from': retriever_from,
+            'auto_generate_name': False
         })
 
         generate_worker_thread.start()
@@ -388,7 +417,8 @@ class CompletionService:
                             if event == 'message':
                                 yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
                             elif event == 'message_replace':
-                                yield "data: " + json.dumps(cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
+                                yield "data: " + json.dumps(
+                                    cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
                             elif event == 'chain':
                                 yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
                             elif event == 'agent_thought':

+ 36 - 4
api/services/conversation_service.py

@@ -1,17 +1,20 @@
 from typing import Union, Optional
 
+from core.generator.llm_generator import LLMGenerator
 from libs.infinite_scroll_pagination import InfiniteScrollPagination
 from extensions.ext_database import db
 from models.account import Account
-from models.model import Conversation, App, EndUser
+from models.model import Conversation, App, EndUser, Message
 from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
+from services.errors.message import MessageNotExistsError
 
 
 class ConversationService:
     @classmethod
     def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
                               last_id: Optional[str], limit: int,
-                              include_ids: Optional[list] = None, exclude_ids: Optional[list] = None) -> InfiniteScrollPagination:
+                              include_ids: Optional[list] = None, exclude_ids: Optional[list] = None,
+                              exclude_debug_conversation: bool = False) -> InfiniteScrollPagination:
         if not user:
             return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
 
@@ -29,6 +32,9 @@ class ConversationService:
         if exclude_ids is not None:
             base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
 
+        if exclude_debug_conversation:
+            base_query = base_query.filter(Conversation.override_model_configs == None)
+
         if last_id:
             last_conversation = base_query.filter(
                 Conversation.id == last_id,
@@ -63,10 +69,36 @@ class ConversationService:
 
     @classmethod
     def rename(cls, app_model: App, conversation_id: str,
-               user: Optional[Union[Account | EndUser]], name: str):
+               user: Optional[Union[Account | EndUser]], name: str, auto_generate: bool):
         conversation = cls.get_conversation(app_model, conversation_id, user)
 
-        conversation.name = name
+        if auto_generate:
+            return cls.auto_generate_name(app_model, conversation)
+        else:
+            conversation.name = name
+            db.session.commit()
+
+        return conversation
+
+    @classmethod
+    def auto_generate_name(cls, app_model: App, conversation: Conversation):
+        # get conversation first message
+        message = db.session.query(Message) \
+            .filter(
+                Message.app_id == app_model.id,
+                Message.conversation_id == conversation.id
+            ).order_by(Message.created_at.asc()).first()
+
+        if not message:
+            raise MessageNotExistsError()
+
+        # generate conversation name
+        try:
+            name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query)
+            conversation.name = name
+        except:
+            pass
+
         db.session.commit()
 
         return conversation

+ 54 - 21
api/services/file_service.py

@@ -1,46 +1,62 @@
 import datetime
 import hashlib
-import time
 import uuid
+from typing import Generator, Tuple, Union
 
-from cachetools import TTLCache
-from flask import request, current_app
+from flask import current_app
 from flask_login import current_user
 from werkzeug.datastructures import FileStorage
 from werkzeug.exceptions import NotFound
 
 from core.data_loader.file_extractor import FileExtractor
+from core.file.upload_file_parser import UploadFileParser
 from extensions.ext_storage import storage
 from extensions.ext_database import db
-from models.model import UploadFile
+from models.account import Account
+from models.model import UploadFile, EndUser
 from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
 
-ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
+ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv',
+                      'jpg', 'jpeg', 'png', 'webp', 'gif']
+IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
 PREVIEW_WORDS_LIMIT = 3000
-cache = TTLCache(maxsize=None, ttl=30)
 
 
 class FileService:
 
     @staticmethod
-    def upload_file(file: FileStorage) -> UploadFile:
+    def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
+        extension = file.filename.split('.')[-1]
+        if extension.lower() not in ALLOWED_EXTENSIONS:
+            raise UnsupportedFileTypeError()
+        elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
+            raise UnsupportedFileTypeError()
+
         # read file content
         file_content = file.read()
+
         # get file size
         file_size = len(file_content)
 
-        file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
+        if extension.lower() in IMAGE_EXTENSIONS:
+            file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT") * 1024 * 1024
+        else:
+            file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
+
         if file_size > file_size_limit:
             message = f'File size exceeded. {file_size} > {file_size_limit}'
             raise FileTooLargeError(message)
 
-        extension = file.filename.split('.')[-1]
-        if extension.lower() not in ALLOWED_EXTENSIONS:
-            raise UnsupportedFileTypeError()
-
         # user uuid as file name
         file_uuid = str(uuid.uuid4())
-        file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
+
+        if isinstance(user, Account):
+            current_tenant_id = user.current_tenant_id
+        else:
+            # end_user
+            current_tenant_id = user.tenant_id
+
+        file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension
 
         # save file to storage
         storage.save(file_key, file_content)
@@ -48,14 +64,15 @@ class FileService:
         # save file to db
         config = current_app.config
         upload_file = UploadFile(
-            tenant_id=current_user.current_tenant_id,
+            tenant_id=current_tenant_id,
             storage_type=config['STORAGE_TYPE'],
             key=file_key,
             name=file.filename,
             size=file_size,
             extension=extension,
             mime_type=file.mimetype,
-            created_by=current_user.id,
+            created_by_role=('account' if isinstance(user, Account) else 'end_user'),
+            created_by=user.id,
             created_at=datetime.datetime.utcnow(),
             used=False,
             hash=hashlib.sha3_256(file_content).hexdigest()
@@ -99,12 +116,6 @@ class FileService:
 
     @staticmethod
     def get_file_preview(file_id: str) -> str:
-        # get file storage key
-        key = file_id + request.path
-        cached_response = cache.get(key)
-        if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
-            return cached_response['response']
-
         upload_file = db.session.query(UploadFile) \
             .filter(UploadFile.id == file_id) \
             .first()
@@ -121,3 +132,25 @@ class FileService:
         text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
 
         return text
+
+    @staticmethod
+    def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str) -> Tuple[Generator, str]:
+        result = UploadFileParser.verify_image_file_signature(file_id, timestamp, nonce, sign)
+        if not result:
+            raise NotFound("File not found or signature is invalid")
+
+        upload_file = db.session.query(UploadFile) \
+            .filter(UploadFile.id == file_id) \
+            .first()
+
+        if not upload_file:
+            raise NotFound("File not found or signature is invalid")
+
+        # extract text from file
+        extension = upload_file.extension
+        if extension.lower() not in IMAGE_EXTENSIONS:
+            raise UnsupportedFileTypeError()
+
+        generator = storage.load(upload_file.key, stream=True)
+
+        return generator, upload_file.mime_type

+ 4 - 2
api/services/web_conversation_service.py

@@ -11,7 +11,8 @@ from services.conversation_service import ConversationService
 class WebConversationService:
     @classmethod
     def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
-                              last_id: Optional[str], limit: int, pinned: Optional[bool] = None) -> InfiniteScrollPagination:
+                              last_id: Optional[str], limit: int, pinned: Optional[bool] = None,
+                              exclude_debug_conversation: bool = False) -> InfiniteScrollPagination:
         include_ids = None
         exclude_ids = None
         if pinned is not None:
@@ -32,7 +33,8 @@ class WebConversationService:
             last_id=last_id,
             limit=limit,
             include_ids=include_ids,
-            exclude_ids=exclude_ids
+            exclude_ids=exclude_ids,
+            exclude_debug_conversation=exclude_debug_conversation
         )
 
     @classmethod

+ 29 - 1
api/tests/integration_tests/models/llm/test_openai_model.py

@@ -5,7 +5,7 @@ from unittest.mock import patch
 from langchain.schema import Generation, ChatGeneration, AIMessage
 
 from core.model_providers.providers.openai_provider import OpenAIProvider
-from core.model_providers.models.entity.message import PromptMessage, MessageType
+from core.model_providers.models.entity.message import PromptMessage, MessageType, ImageMessageFile
 from core.model_providers.models.entity.model_params import ModelKwargs
 from core.model_providers.models.llm.openai_model import OpenAIModel
 from models.provider import Provider, ProviderType
@@ -57,6 +57,18 @@ def test_chat_get_num_tokens(mock_decrypt):
     assert rst == 22
 
 
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_vision_chat_get_num_tokens(mock_decrypt):
+    openai_model = get_mock_openai_model('gpt-4-vision-preview')
+    messages = [
+        PromptMessage(content='What’s in first image?', files=[
+            ImageMessageFile(
+                data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
+        ])
+    ]
+    rst = openai_model.get_num_tokens(messages)
+    assert rst == 77
+
 @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
 def test_run(mock_decrypt, mocker):
     mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
@@ -80,4 +92,20 @@ def test_chat_run(mock_decrypt, mocker):
         messages,
         stop=['\nHuman:'],
     )
+    assert (len(rst.content) > 0)
+
+
+@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
+def test_vision_run(mock_decrypt, mocker):
+    mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
+
+    openai_model = get_mock_openai_model('gpt-4-vision-preview')
+    messages = [
+        PromptMessage(content='What’s in first image?', files=[
+            ImageMessageFile(data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
+        ])
+    ]
+    rst = openai_model.run(
+        messages,
+    )
     assert len(rst.content) > 0

+ 7 - 3
docker/docker-compose.yaml

@@ -19,18 +19,22 @@ services:
       # different from api or web app domain.
       # example: http://cloud.dify.ai
       CONSOLE_API_URL: ''
-      # The URL for Service API endpoints, refers to the base URL of the current API service if api domain is
+      # The URL prefix for Service API endpoints, refers to the base URL of the current API service if api domain is
       # different from console domain.
       # example: http://api.dify.ai
       SERVICE_API_URL: ''
-      # The URL for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from
+      # The URL prefix for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from
       # console or api domain.
       # example: http://udify.app
       APP_API_URL: ''
-      # The URL for Web APP frontend, refers to the Web App base URL of WEB service if web app domain is different from
+      # The URL prefix for Web APP frontend, refers to the Web App base URL of WEB service if web app domain is different from
       # console or api domain.
       # example: http://udify.app
       APP_WEB_URL: ''
+      # File preview or download Url prefix.
+      # used to display File preview or download Url to the front-end or as Multi-model inputs;
+      # Url is signed and has expiration time.
+      FILES_URL: ''
       # When enabled, migrations will be executed prior to application startup and the application will start after the migrations have completed.
       MIGRATION_ENABLED: 'true'
       # The configurations of postgres database connection.

+ 5 - 0
docker/nginx/conf.d/default.conf

@@ -17,6 +17,11 @@ server {
       include proxy.conf;
     }
 
+    location /files {
+      proxy_pass http://api:5001;
+      include proxy.conf;
+    }
+
     location / {
       proxy_pass http://web:3000;
       include proxy.conf;