Browse Source

feat: add api-based extension & external data tool & moderation backend (#1403)

Co-authored-by: takatost <takatost@gmail.com>
Garfield Dai 1 year ago
parent
commit
db43ed6f41
50 changed files with 1622 additions and 271 deletions
  1. 1 1
      api/.vscode/launch.json
  2. 2 1
      api/app.py
  3. 4 0
      api/config.py
  4. 1 1
      api/controllers/console/__init__.py
  5. 3 1
      api/controllers/console/explore/parameter.py
  6. 114 0
      api/controllers/console/extension.py
  7. 3 1
      api/controllers/service_api/app/app.py
  8. 0 1
      api/controllers/service_api/app/completion.py
  9. 3 1
      api/controllers/web/app.py
  10. 1 1
      api/controllers/web/completion.py
  11. 1 0
      api/core/__init__.py
  12. 183 6
      api/core/callback_handler/llm_callback_handler.py
  13. 0 92
      api/core/chain/sensitive_word_avoidance_chain.py
  14. 132 25
      api/core/completion.py
  15. 26 0
      api/core/conversation_message_task.py
  16. 0 0
      api/core/extension/__init__.py
  17. 62 0
      api/core/extension/api_based_extension_requestor.py
  18. 111 0
      api/core/extension/extensible.py
  19. 47 0
      api/core/extension/extension.py
  20. 0 0
      api/core/external_data_tool/__init__.py
  21. 1 0
      api/core/external_data_tool/api/__builtin__
  22. 0 0
      api/core/external_data_tool/api/__init__.py
  23. 92 0
      api/core/external_data_tool/api/api.py
  24. 45 0
      api/core/external_data_tool/base.py
  25. 40 0
      api/core/external_data_tool/factory.py
  26. 0 0
      api/core/moderation/__init__.py
  27. 1 0
      api/core/moderation/api/__builtin__
  28. 0 0
      api/core/moderation/api/__init__.py
  29. 88 0
      api/core/moderation/api/api.py
  30. 113 0
      api/core/moderation/base.py
  31. 48 0
      api/core/moderation/factory.py
  32. 1 0
      api/core/moderation/keywords/__builtin__
  33. 0 0
      api/core/moderation/keywords/__init__.py
  34. 60 0
      api/core/moderation/keywords/keywords.py
  35. 1 0
      api/core/moderation/openai_moderation/__builtin__
  36. 0 0
      api/core/moderation/openai_moderation/__init__.py
  37. 46 0
      api/core/moderation/openai_moderation/openai_moderation.py
  38. 0 47
      api/core/orchestrator_rule_parser.py
  39. 8 0
      api/extensions/ext_code_based_extension.py
  40. 17 0
      api/fields/api_based_extension_fields.py
  41. 1 0
      api/fields/app_fields.py
  42. 45 0
      api/migrations/versions/968fff4c0ab9_add_api_based_extension.py
  43. 32 0
      api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py
  44. 27 0
      api/models/api_based_extension.py
  45. 14 30
      api/models/model.py
  46. 98 0
      api/services/api_based_extension_service.py
  47. 89 54
      api/services/app_model_config_service.py
  48. 13 0
      api/services/code_based_extension_service.py
  49. 28 9
      api/services/completion_service.py
  50. 20 0
      api/services/moderation_service.py

+ 1 - 1
.vscode/launch.json → api/.vscode/launch.json

@@ -10,7 +10,7 @@
             "request": "launch",
             "module": "flask",
             "env": {
-                "FLASK_APP": "api/app.py",
+                "FLASK_APP": "app.py",
                 "FLASK_DEBUG": "1",
                 "GEVENT_SUPPORT": "True"
             },

+ 2 - 1
api/app.py

@@ -19,7 +19,7 @@ from flask_cors import CORS
 
 from core.model_providers.providers import hosted
 from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
-    ext_database, ext_storage, ext_mail, ext_stripe
+    ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension
 from extensions.ext_database import db
 from extensions.ext_login import login_manager
 
@@ -79,6 +79,7 @@ def create_app(test_config=None) -> Flask:
 def initialize_extensions(app):
     # Since the application instance is now created, pass it to each Flask
     # extension instance to bind it to the Flask application instance (app)
+    ext_code_based_extension.init()
     ext_database.init_app(app)
     ext_migrate.init(app, db)
     ext_redis.init_app(app)

+ 4 - 0
api/config.py

@@ -57,6 +57,7 @@ DEFAULTS = {
     'CLEAN_DAY_SETTING': 30,
     'UPLOAD_FILE_SIZE_LIMIT': 15,
     'UPLOAD_FILE_BATCH_LIMIT': 5,
+    'OUTPUT_MODERATION_BUFFER_SIZE': 300
 }
 
 
@@ -228,6 +229,9 @@ class Config:
         self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
         self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
 
+        # moderation settings
+        self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
+
 
 class CloudEditionConfig(Config):
 

+ 1 - 1
api/controllers/console/__init__.py

@@ -6,7 +6,7 @@ bp = Blueprint('console', __name__, url_prefix='/console/api')
 api = ExternalApi(bp)
 
 # Import other controllers
-from . import setup, version, apikey, admin
+from . import extension, setup, version, apikey, admin
 
 # Import app controllers
 from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio

+ 3 - 1
api/controllers/console/explore/parameter.py

@@ -27,6 +27,7 @@ class AppParameterApi(InstalledAppResource):
         'retriever_resource': fields.Raw,
         'more_like_this': fields.Raw,
         'user_input_form': fields.Raw,
+        'sensitive_word_avoidance': fields.Raw
     }
 
     @marshal_with(parameters_fields)
@@ -42,7 +43,8 @@ class AppParameterApi(InstalledAppResource):
             'speech_to_text': app_model_config.speech_to_text_dict,
             'retriever_resource': app_model_config.retriever_resource_dict,
             'more_like_this': app_model_config.more_like_this_dict,
-            'user_input_form': app_model_config.user_input_form_list
+            'user_input_form': app_model_config.user_input_form_list,
+            'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
         }
 
 

+ 114 - 0
api/controllers/console/extension.py

@@ -0,0 +1,114 @@
+from flask_restful import Resource, reqparse, marshal_with
+from flask_login import current_user
+
+from controllers.console import api
+from controllers.console.setup import setup_required
+from controllers.console.wraps import account_initialization_required
+from libs.login import login_required
+from models.api_based_extension import APIBasedExtension
+from fields.api_based_extension_fields import api_based_extension_fields
+from services.code_based_extension_service import CodeBasedExtensionService
+from services.api_based_extension_service import APIBasedExtensionService
+
+
+class CodeBasedExtensionAPI(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument('module', type=str, required=True, location='args')
+        args = parser.parse_args()
+
+        return {
+            'module': args['module'],
+            'data': CodeBasedExtensionService.get_code_based_extension(args['module'])
+        }
+
+
+class APIBasedExtensionAPI(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @marshal_with(api_based_extension_fields)
+    def get(self):
+        tenant_id = current_user.current_tenant_id
+        return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @marshal_with(api_based_extension_fields)
+    def post(self):
+        parser = reqparse.RequestParser()
+        parser.add_argument('name', type=str, required=True, location='json')
+        parser.add_argument('api_endpoint', type=str, required=True, location='json')
+        parser.add_argument('api_key', type=str, required=True, location='json')
+        args = parser.parse_args()
+
+        extension_data = APIBasedExtension(
+            tenant_id=current_user.current_tenant_id,
+            name=args['name'],
+            api_endpoint=args['api_endpoint'],
+            api_key=args['api_key']
+        )
+
+        return APIBasedExtensionService.save(extension_data)
+
+
+class APIBasedExtensionDetailAPI(Resource):
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @marshal_with(api_based_extension_fields)
+    def get(self, id):
+        api_based_extension_id = str(id)
+        tenant_id = current_user.current_tenant_id
+
+        return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @marshal_with(api_based_extension_fields)
+    def post(self, id):
+        api_based_extension_id = str(id)
+        tenant_id = current_user.current_tenant_id
+
+        extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
+
+        parser = reqparse.RequestParser()
+        parser.add_argument('name', type=str, required=True, location='json')
+        parser.add_argument('api_endpoint', type=str, required=True, location='json')
+        parser.add_argument('api_key', type=str, required=True, location='json')
+        args = parser.parse_args()
+
+        extension_data_from_db.name = args['name']
+        extension_data_from_db.api_endpoint = args['api_endpoint']
+
+        if args['api_key'] != '[__HIDDEN__]':
+            extension_data_from_db.api_key = args['api_key']
+
+        return APIBasedExtensionService.save(extension_data_from_db)
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def delete(self, id):
+        api_based_extension_id = str(id)
+        tenant_id = current_user.current_tenant_id
+
+        extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
+
+        APIBasedExtensionService.delete(extension_data_from_db)
+
+        return {'result': 'success'}
+
+
+api.add_resource(CodeBasedExtensionAPI, '/code-based-extension')
+
+api.add_resource(APIBasedExtensionAPI, '/api-based-extension')
+api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/<uuid:id>')

+ 3 - 1
api/controllers/service_api/app/app.py

@@ -28,6 +28,7 @@ class AppParameterApi(AppApiResource):
         'retriever_resource': fields.Raw,
         'more_like_this': fields.Raw,
         'user_input_form': fields.Raw,
+        'sensitive_word_avoidance': fields.Raw
     }
 
     @marshal_with(parameters_fields)
@@ -42,7 +43,8 @@ class AppParameterApi(AppApiResource):
             'speech_to_text': app_model_config.speech_to_text_dict,
             'retriever_resource': app_model_config.retriever_resource_dict,
             'more_like_this': app_model_config.more_like_this_dict,
-            'user_input_form': app_model_config.user_input_form_list
+            'user_input_form': app_model_config.user_input_form_list,
+            'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
         }
 
 

+ 0 - 1
api/controllers/service_api/app/completion.py

@@ -183,4 +183,3 @@ api.add_resource(CompletionApi, '/completion-messages')
 api.add_resource(CompletionStopApi, '/completion-messages/<string:task_id>/stop')
 api.add_resource(ChatApi, '/chat-messages')
 api.add_resource(ChatStopApi, '/chat-messages/<string:task_id>/stop')
-

+ 3 - 1
api/controllers/web/app.py

@@ -27,6 +27,7 @@ class AppParameterApi(WebApiResource):
         'retriever_resource': fields.Raw,
         'more_like_this': fields.Raw,
         'user_input_form': fields.Raw,
+        'sensitive_word_avoidance': fields.Raw
     }
 
     @marshal_with(parameters_fields)
@@ -41,7 +42,8 @@ class AppParameterApi(WebApiResource):
             'speech_to_text': app_model_config.speech_to_text_dict,
             'retriever_resource': app_model_config.retriever_resource_dict,
             'more_like_this': app_model_config.more_like_this_dict,
-            'user_input_form': app_model_config.user_input_form_list
+            'user_input_form': app_model_config.user_input_form_list,
+            'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
         }
 
 

+ 1 - 1
api/controllers/web/completion.py

@@ -139,7 +139,7 @@ class ChatStopApi(WebApiResource):
         return {'result': 'success'}, 200
 
 
-def compact_response(response: Union[dict | Generator]) -> Response:
+def compact_response(response: Union[dict, Generator]) -> Response:
     if isinstance(response, dict):
         return Response(response=json.dumps(response), status=200, mimetype='application/json')
     else:

+ 1 - 0
api/core/__init__.py

@@ -0,0 +1 @@
+import core.moderation.base

+ 183 - 6
api/core/callback_handler/llm_callback_handler.py

@@ -1,13 +1,25 @@
 import logging
-from typing import Any, Dict, List, Union
+import threading
+import time
+from typing import Any, Dict, List, Union, Optional
 
+from flask import Flask, current_app
 from langchain.callbacks.base import BaseCallbackHandler
 from langchain.schema import LLMResult, BaseMessage
+from pydantic import BaseModel
 
 from core.callback_handler.entity.llm_message import LLMMessage
-from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
+from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
+    ConversationTaskInterruptException
 from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
 from core.model_providers.models.llm.base import BaseLLM
+from core.moderation.base import ModerationOutputsResult, ModerationAction
+from core.moderation.factory import ModerationFactory
+
+
+class ModerationRule(BaseModel):
+    type: str
+    config: Dict[str, Any]
 
 
 class LLMCallbackHandler(BaseCallbackHandler):
@@ -20,6 +32,24 @@ class LLMCallbackHandler(BaseCallbackHandler):
         self.start_at = None
         self.conversation_message_task = conversation_message_task
 
+        self.output_moderation_handler = None
+        self.init_output_moderation()
+
+    def init_output_moderation(self):
+        app_model_config = self.conversation_message_task.app_model_config
+        sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
+
+        if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
+            self.output_moderation_handler = OutputModerationHandler(
+                tenant_id=self.conversation_message_task.tenant_id,
+                app_id=self.conversation_message_task.app.id,
+                rule=ModerationRule(
+                    type=sensitive_word_avoidance_dict.get("type"),
+                    config=sensitive_word_avoidance_dict.get("config")
+                ),
+                on_message_replace_func=self.conversation_message_task.on_message_replace
+            )
+
     @property
     def always_verbose(self) -> bool:
         """Whether to call verbose callbacks even if verbose is False."""
@@ -59,10 +89,19 @@ class LLMCallbackHandler(BaseCallbackHandler):
         self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
 
     def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
-        if not self.conversation_message_task.streaming:
-            self.conversation_message_task.append_message_text(response.generations[0][0].text)
+        if self.output_moderation_handler:
+            self.output_moderation_handler.stop_thread()
+
+            self.llm_message.completion = self.output_moderation_handler.moderation_completion(
+                completion=response.generations[0][0].text,
+                public_event=True if self.conversation_message_task.streaming else False
+            )
+        else:
             self.llm_message.completion = response.generations[0][0].text
 
+        if not self.conversation_message_task.streaming:
+            self.conversation_message_task.append_message_text(self.llm_message.completion)
+
         if response.llm_output and 'token_usage' in response.llm_output:
             if 'prompt_tokens' in response.llm_output['token_usage']:
                 self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
@@ -79,23 +118,161 @@ class LLMCallbackHandler(BaseCallbackHandler):
         self.conversation_message_task.save_message(self.llm_message)
 
     def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
+        if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
+            # stop subscribe new token when output moderation should direct output
+            ex = ConversationTaskInterruptException()
+            self.on_llm_error(error=ex)
+            raise ex
+
         try:
             self.conversation_message_task.append_message_text(token)
+            self.llm_message.completion += token
+
+            if self.output_moderation_handler:
+                self.output_moderation_handler.append_new_token(token)
         except ConversationTaskStoppedException as ex:
             self.on_llm_error(error=ex)
             raise ex
 
-        self.llm_message.completion += token
-
     def on_llm_error(
             self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
     ) -> None:
         """Do nothing."""
+        if self.output_moderation_handler:
+            self.output_moderation_handler.stop_thread()
+
         if isinstance(error, ConversationTaskStoppedException):
             if self.conversation_message_task.streaming:
                 self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
                     [PromptMessage(content=self.llm_message.completion)]
                 )
                 self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
+        if isinstance(error, ConversationTaskInterruptException):
+            self.llm_message.completion = self.output_moderation_handler.get_final_output()
+            self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
+                [PromptMessage(content=self.llm_message.completion)]
+            )
+            self.conversation_message_task.save_message(llm_message=self.llm_message)
         else:
             logging.debug("on_llm_error: %s", error)
+
+
+class OutputModerationHandler(BaseModel):
+    DEFAULT_BUFFER_SIZE: int = 300
+
+    tenant_id: str
+    app_id: str
+
+    rule: ModerationRule
+    on_message_replace_func: Any
+
+    thread: Optional[threading.Thread] = None
+    thread_running: bool = True
+    buffer: str = ''
+    is_final_chunk: bool = False
+    final_output: Optional[str] = None
+
+    class Config:
+        arbitrary_types_allowed = True
+
+    def should_direct_output(self):
+        return self.final_output is not None
+
+    def get_final_output(self):
+        return self.final_output
+
+    def append_new_token(self, token: str):
+        self.buffer += token
+
+        if not self.thread:
+            self.thread = self.start_thread()
+
+    def moderation_completion(self, completion: str, public_event: bool = False) -> str:
+        self.buffer = completion
+        self.is_final_chunk = True
+
+        result = self.moderation(
+            tenant_id=self.tenant_id,
+            app_id=self.app_id,
+            moderation_buffer=completion
+        )
+
+        if not result or not result.flagged:
+            return completion
+
+        if result.action == ModerationAction.DIRECT_OUTPUT:
+            final_output = result.preset_response
+        else:
+            final_output = result.text
+
+        if public_event:
+            self.on_message_replace_func(final_output)
+
+        return final_output
+
+    def start_thread(self) -> threading.Thread:
+        buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
+        thread = threading.Thread(target=self.worker, kwargs={
+            'flask_app': current_app._get_current_object(),
+            'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
+        })
+
+        thread.start()
+
+        return thread
+
+    def stop_thread(self):
+        if self.thread and self.thread.is_alive():
+            self.thread_running = False
+
+    def worker(self, flask_app: Flask, buffer_size: int):
+        with flask_app.app_context():
+            current_length = 0
+            while self.thread_running:
+                moderation_buffer = self.buffer
+                buffer_length = len(moderation_buffer)
+                if not self.is_final_chunk:
+                    chunk_length = buffer_length - current_length
+                    if 0 <= chunk_length < buffer_size:
+                        time.sleep(1)
+                        continue
+
+                current_length = buffer_length
+
+                result = self.moderation(
+                    tenant_id=self.tenant_id,
+                    app_id=self.app_id,
+                    moderation_buffer=moderation_buffer
+                )
+
+                if not result or not result.flagged:
+                    continue
+
+                if result.action == ModerationAction.DIRECT_OUTPUT:
+                    final_output = result.preset_response
+                    self.final_output = final_output
+                else:
+                    final_output = result.text + self.buffer[len(moderation_buffer):]
+
+                # trigger replace event
+                if self.thread_running:
+                    self.on_message_replace_func(final_output)
+
+                if result.action == ModerationAction.DIRECT_OUTPUT:
+                    break
+
+    def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
+        try:
+            moderation_factory = ModerationFactory(
+                name=self.rule.type,
+                app_id=app_id,
+                tenant_id=tenant_id,
+                config=self.rule.config
+            )
+
+            result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
+            return result
+        except Exception as e:
+            logging.error("Moderation Output error: %s", e)
+
+        return None

+ 0 - 92
api/core/chain/sensitive_word_avoidance_chain.py

@@ -1,92 +0,0 @@
-import enum
-import logging
-from typing import List, Dict, Optional, Any
-
-from langchain.callbacks.manager import CallbackManagerForChainRun
-from langchain.chains.base import Chain
-from pydantic import BaseModel
-
-from core.model_providers.error import LLMBadRequestError
-from core.model_providers.model_factory import ModelFactory
-from core.model_providers.models.llm.base import BaseLLM
-from core.model_providers.models.moderation import openai_moderation
-
-
-class SensitiveWordAvoidanceRule(BaseModel):
-    class Type(enum.Enum):
-        MODERATION = "moderation"
-        KEYWORDS = "keywords"
-
-    type: Type
-    canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
-    extra_params: dict = {}
-
-
-class SensitiveWordAvoidanceChain(Chain):
-    input_key: str = "input"  #: :meta private:
-    output_key: str = "output"  #: :meta private:
-
-    model_instance: BaseLLM
-    sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
-
-    @property
-    def _chain_type(self) -> str:
-        return "sensitive_word_avoidance_chain"
-
-    @property
-    def input_keys(self) -> List[str]:
-        """Expect input key.
-
-        :meta private:
-        """
-        return [self.input_key]
-
-    @property
-    def output_keys(self) -> List[str]:
-        """Return output key.
-
-        :meta private:
-        """
-        return [self.output_key]
-
-    def _check_sensitive_word(self, text: str) -> bool:
-        for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
-            if word in text:
-                return False
-        return True
-
-    def _check_moderation(self, text: str) -> bool:
-        moderation_model_instance = ModelFactory.get_moderation_model(
-            tenant_id=self.model_instance.model_provider.provider.tenant_id,
-            model_provider_name='openai',
-            model_name=openai_moderation.DEFAULT_MODEL
-        )
-
-        try:
-            return moderation_model_instance.run(text=text)
-        except Exception as ex:
-            logging.exception(ex)
-            raise LLMBadRequestError('Rate limit exceeded, please try again later.')
-
-    def _call(
-            self,
-            inputs: Dict[str, Any],
-            run_manager: Optional[CallbackManagerForChainRun] = None,
-    ) -> Dict[str, Any]:
-        text = inputs[self.input_key]
-
-        if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
-            result = self._check_sensitive_word(text)
-        else:
-            result = self._check_moderation(text)
-
-        if not result:
-            raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response)
-
-        return {self.output_key: text}
-
-
-class SensitiveWordAvoidanceError(Exception):
-    def __init__(self, message):
-        super().__init__(message)
-        self.message = message

+ 132 - 25
api/core/completion.py

@@ -1,13 +1,18 @@
+import concurrent
+import json
 import logging
-from typing import Optional, List, Union
+from concurrent.futures import ThreadPoolExecutor
+from typing import Optional, List, Union, Tuple
 
+from flask import current_app, Flask
 from requests.exceptions import ChunkedEncodingError
 
 from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
 from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
 from core.callback_handler.llm_callback_handler import LLMCallbackHandler
-from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
-from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
+from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
+    ConversationTaskInterruptException
+from core.external_data_tool.factory import ExternalDataToolFactory
 from core.model_providers.error import LLMBadRequestError
 from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
     ReadOnlyConversationTokenDBBufferSharedMemory
@@ -18,6 +23,8 @@ from core.orchestrator_rule_parser import OrchestratorRuleParser
 from core.prompt.prompt_template import PromptTemplateParser
 from core.prompt.prompt_transform import PromptTransform
 from models.model import App, AppModelConfig, Account, Conversation, EndUser
+from core.moderation.base import ModerationException, ModerationAction
+from core.moderation.factory import ModerationFactory
 
 
 class Completion:
@@ -76,26 +83,35 @@ class Completion:
         )
 
         try:
-            # parse sensitive_word_avoidance_chain
             chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
-            sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
-                final_model_instance, [chain_callback])
-            if sensitive_word_avoidance_chain:
-                try:
-                    query = sensitive_word_avoidance_chain.run(query)
-                except SensitiveWordAvoidanceError as ex:
-                    cls.run_final_llm(
-                        model_instance=final_model_instance,
-                        mode=app.mode,
-                        app_model_config=app_model_config,
-                        query=query,
-                        inputs=inputs,
-                        agent_execute_result=None,
-                        conversation_message_task=conversation_message_task,
-                        memory=memory,
-                        fake_response=ex.message
-                    )
-                    return
+
+            try:
+                # process sensitive_word_avoidance
+                inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
+            except ModerationException as e:
+                cls.run_final_llm(
+                    model_instance=final_model_instance,
+                    mode=app.mode,
+                    app_model_config=app_model_config,
+                    query=query,
+                    inputs=inputs,
+                    agent_execute_result=None,
+                    conversation_message_task=conversation_message_task,
+                    memory=memory,
+                    fake_response=str(e)
+                )
+                return
+
+            # fill in variable inputs from external data tools if exists
+            external_data_tools = app_model_config.external_data_tools_list
+            if external_data_tools:
+                inputs = cls.fill_in_inputs_from_external_data_tools(
+                    tenant_id=app.tenant_id,
+                    app_id=app.id,
+                    external_data_tools=external_data_tools,
+                    inputs=inputs,
+                    query=query
+                )
 
             # get agent executor
             agent_executor = orchestrator_rule_parser.to_agent_executor(
@@ -135,19 +151,110 @@ class Completion:
                 memory=memory,
                 fake_response=fake_response
             )
-        except ConversationTaskStoppedException:
+        except (ConversationTaskInterruptException, ConversationTaskStoppedException):
             return
         except ChunkedEncodingError as e:
             # Interrupt by LLM (like OpenAI), handle it.
             logging.warning(f'ChunkedEncodingError: {e}')
             conversation_message_task.end()
             return
-        
+
+    @classmethod
+    def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
+        if not app_model_config.sensitive_word_avoidance_dict['enabled']:
+            return inputs, query
+
+        type = app_model_config.sensitive_word_avoidance_dict['type']
+
+        moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
+        moderation_result = moderation.moderation_for_inputs(inputs, query)
+
+        if not moderation_result.flagged:
+            return inputs, query
+
+        if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
+            raise ModerationException(moderation_result.preset_response)
+        elif moderation_result.action == ModerationAction.OVERRIDED:
+            inputs = moderation_result.inputs
+            query = moderation_result.query
+
+        return inputs, query
+
+    @classmethod
+    def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
+                                                inputs: dict, query: str) -> dict:
+        """
+        Fill in variable inputs from external data tools if exists.
+
+        :param tenant_id: workspace id
+        :param app_id: app id
+        :param external_data_tools: external data tools configs
+        :param inputs: the inputs
+        :param query: the query
+        :return: the filled inputs
+        """
+        # Group tools by type and config
+        grouped_tools = {}
+        for tool in external_data_tools:
+            if not tool.get("enabled"):
+                continue
+
+            tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
+            grouped_tools.setdefault(tool_key, []).append(tool)
+
+        results = {}
+        with ThreadPoolExecutor() as executor:
+            futures = {}
+            for tools in grouped_tools.values():
+                # Only query the first tool in each group
+                first_tool = tools[0]
+                future = executor.submit(
+                    cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, first_tool,
+                    inputs, query
+                )
+                for tool in tools:
+                    futures[future] = tool
+
+            for future in concurrent.futures.as_completed(futures):
+                tool_key, result = future.result()
+                if tool_key in grouped_tools:
+                    for tool in grouped_tools[tool_key]:
+                        results[tool['variable']] = result
+
+        inputs.update(results)
+        return inputs
+
+    @classmethod
+    def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
+                                 inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
+        with flask_app.app_context():
+            tool_variable = external_data_tool.get("variable")
+            tool_type = external_data_tool.get("type")
+            tool_config = external_data_tool.get("config")
+
+            external_data_tool_factory = ExternalDataToolFactory(
+                name=tool_type,
+                tenant_id=tenant_id,
+                app_id=app_id,
+                variable=tool_variable,
+                config=tool_config
+            )
+
+            # query external data tool
+            result = external_data_tool_factory.query(
+                inputs=inputs,
+                query=query
+            )
+
+            tool_key = (external_data_tool.get("type"), json.dumps(external_data_tool.get("config"), sort_keys=True))
+
+            return tool_key, result
+
     @classmethod
     def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
         if app.mode != 'completion':
             return query
-        
+
         return inputs.get(app_model_config.dataset_query_variable, "")
 
     @classmethod

+ 26 - 0
api/core/conversation_message_task.py

@@ -290,6 +290,10 @@ class ConversationMessageTask:
                 db.session.commit()
             self.retriever_resource = resource
 
+    def on_message_replace(self, text: str):
+        if text is not None:
+            self._pub_handler.pub_message_replace(text)
+
     def message_end(self):
         self._pub_handler.pub_message_end(self.retriever_resource)
 
@@ -342,6 +346,24 @@ class PubHandler:
             self.pub_end()
             raise ConversationTaskStoppedException()
 
+    def pub_message_replace(self, text: str):
+        content = {
+            'event': 'message_replace',
+            'data': {
+                'task_id': self._task_id,
+                'message_id': str(self._message.id),
+                'text': text,
+                'mode': self._conversation.mode,
+                'conversation_id': str(self._conversation.id)
+            }
+        }
+
+        redis_client.publish(self._channel, json.dumps(content))
+
+        if self._is_stopped():
+            self.pub_end()
+            raise ConversationTaskStoppedException()
+
     def pub_chain(self, message_chain: MessageChain):
         if self._chain_pub:
             content = {
@@ -443,3 +465,7 @@ class PubHandler:
 
 class ConversationTaskStoppedException(Exception):
     pass
+
+
+class ConversationTaskInterruptException(Exception):
+    pass

+ 0 - 0
api/core/extension/__init__.py


+ 62 - 0
api/core/extension/api_based_extension_requestor.py

@@ -0,0 +1,62 @@
+import os
+
+import requests
+
+from models.api_based_extension import APIBasedExtensionPoint
+
+
+class APIBasedExtensionRequestor:
+    timeout: (int, int) = (5, 60)
+    """timeout for request connect and read"""
+
+    def __init__(self, api_endpoint: str, api_key: str) -> None:
+        self.api_endpoint = api_endpoint
+        self.api_key = api_key
+
+    def request(self, point: APIBasedExtensionPoint, params: dict) -> dict:
+        """
+        Request the api.
+
+        :param point: the api point
+        :param params: the request params
+        :return: the response json
+        """
+        headers = {
+            "Content-Type": "application/json",
+            "Authorization": "Bearer {}".format(self.api_key)
+        }
+
+        url = self.api_endpoint
+
+        try:
+            # proxy support for security
+            proxies = None
+            if os.environ.get("API_BASED_EXTENSION_HTTP_PROXY") and os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"):
+                proxies = {
+                    'http': os.environ.get("API_BASED_EXTENSION_HTTP_PROXY"),
+                    'https': os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"),
+                }
+
+            response = requests.request(
+                method='POST',
+                url=url,
+                json={
+                    'point': point.value,
+                    'params': params
+                },
+                headers=headers,
+                timeout=self.timeout,
+                proxies=proxies
+            )
+        except requests.exceptions.Timeout:
+            raise ValueError("request timeout")
+        except requests.exceptions.ConnectionError:
+            raise ValueError("request connection error")
+
+        if response.status_code != 200:
+            raise ValueError("request error, status_code: {}, content: {}".format(
+                response.status_code,
+                response.text[:100]
+            ))
+
+        return response.json()

+ 111 - 0
api/core/extension/extensible.py

@@ -0,0 +1,111 @@
+import enum
+import importlib.util
+import json
+import logging
+import os
+from collections import OrderedDict
+from typing import Any, Optional
+
+from pydantic import BaseModel
+
+
+class ExtensionModule(enum.Enum):
+    MODERATION = 'moderation'
+    EXTERNAL_DATA_TOOL = 'external_data_tool'
+
+
+class ModuleExtension(BaseModel):
+    extension_class: Any
+    name: str
+    label: Optional[dict] = None
+    form_schema: Optional[list] = None
+    builtin: bool = True
+    position: Optional[int] = None
+
+
+class Extensible:
+    module: ExtensionModule
+
+    name: str
+    tenant_id: str
+    config: Optional[dict] = None
+
+    def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
+        self.tenant_id = tenant_id
+        self.config = config
+
+    @classmethod
+    def scan_extensions(cls):
+        extensions = {}
+
+        # get the path of the current class
+        current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
+        current_dir_path = os.path.dirname(current_path)
+
+        # traverse subdirectories
+        for subdir_name in os.listdir(current_dir_path):
+            if subdir_name.startswith('__'):
+                continue
+
+            subdir_path = os.path.join(current_dir_path, subdir_name)
+            extension_name = subdir_name
+            if os.path.isdir(subdir_path):
+                file_names = os.listdir(subdir_path)
+
+                # is builtin extension, builtin extension
+                # in the front-end page and business logic, there are special treatments.
+                builtin = False
+                position = None
+                if '__builtin__' in file_names:
+                    builtin = True
+
+                    builtin_file_path = os.path.join(subdir_path, '__builtin__')
+                    if os.path.exists(builtin_file_path):
+                        with open(builtin_file_path, 'r') as f:
+                            position = int(f.read().strip())
+
+                if (extension_name + '.py') not in file_names:
+                    logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
+                    continue
+
+                # Dynamic loading {subdir_name}.py file and find the subclass of Extensible
+                py_path = os.path.join(subdir_path, extension_name + '.py')
+                spec = importlib.util.spec_from_file_location(extension_name, py_path)
+                mod = importlib.util.module_from_spec(spec)
+                spec.loader.exec_module(mod)
+
+                extension_class = None
+                for name, obj in vars(mod).items():
+                    if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
+                        extension_class = obj
+                        break
+
+                if not extension_class:
+                    logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
+                    continue
+
+                json_data = {}
+                if not builtin:
+                    if 'schema.json' not in file_names:
+                        logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
+                        continue
+
+                    json_path = os.path.join(subdir_path, 'schema.json')
+                    json_data = {}
+                    if os.path.exists(json_path):
+                        with open(json_path, 'r') as f:
+                            json_data = json.load(f)
+
+                extensions[extension_name] = ModuleExtension(
+                    extension_class=extension_class,
+                    name=extension_name,
+                    label=json_data.get('label'),
+                    form_schema=json_data.get('form_schema'),
+                    builtin=builtin,
+                    position=position
+                )
+
+        sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position))
+        sorted_extensions = OrderedDict(sorted_items)
+
+        return sorted_extensions

+ 47 - 0
api/core/extension/extension.py

@@ -0,0 +1,47 @@
+from core.extension.extensible import ModuleExtension, ExtensionModule
+from core.external_data_tool.base import ExternalDataTool
+from core.moderation.base import Moderation
+
+
+class Extension:
+    __module_extensions: dict[str, dict[str, ModuleExtension]] = {}
+
+    module_classes = {
+        ExtensionModule.MODERATION: Moderation,
+        ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool
+    }
+
+    def init(self):
+        for module, module_class in self.module_classes.items():
+            self.__module_extensions[module.value] = module_class.scan_extensions()
+
+    def module_extensions(self, module: str) -> list[ModuleExtension]:
+        module_extensions = self.__module_extensions.get(module)
+
+        if not module_extensions:
+            raise ValueError(f"Extension Module {module} not found")
+
+        return list(module_extensions.values())
+
+    def module_extension(self, module: ExtensionModule, extension_name: str) -> ModuleExtension:
+        module_extensions = self.__module_extensions.get(module.value)
+
+        if not module_extensions:
+            raise ValueError(f"Extension Module {module} not found")
+
+        module_extension = module_extensions.get(extension_name)
+
+        if not module_extension:
+            raise ValueError(f"Extension {extension_name} not found")
+
+        return module_extension
+
+    def extension_class(self, module: ExtensionModule, extension_name: str) -> type:
+        module_extension = self.module_extension(module, extension_name)
+        return module_extension.extension_class
+
+    def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
+        module_extension = self.module_extension(module, extension_name)
+        form_schema = module_extension.form_schema
+
+        # TODO validate form_schema

+ 0 - 0
api/core/external_data_tool/__init__.py


+ 1 - 0
api/core/external_data_tool/api/__builtin__

@@ -0,0 +1 @@
+1

+ 0 - 0
api/core/external_data_tool/api/__init__.py


+ 92 - 0
api/core/external_data_tool/api/api.py

@@ -0,0 +1,92 @@
+from typing import Optional
+
+from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
+from core.external_data_tool.base import ExternalDataTool
+from core.helper import encrypter
+from extensions.ext_database import db
+from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
+
+
+class ApiExternalDataTool(ExternalDataTool):
+    """
+    The api external data tool.
+    """
+
+    name: str = "api"
+    """the unique name of external data tool"""
+
+    @classmethod
+    def validate_config(cls, tenant_id: str, config: dict) -> None:
+        """
+        Validate the incoming form config data.
+
+        :param tenant_id: the id of workspace
+        :param config: the form config data
+        :return:
+        """
+        # own validation logic
+        api_based_extension_id = config.get("api_based_extension_id")
+        if not api_based_extension_id:
+            raise ValueError("api_based_extension_id is required")
+
+        # get api_based_extension
+        api_based_extension = db.session.query(APIBasedExtension).filter(
+            APIBasedExtension.tenant_id == tenant_id,
+            APIBasedExtension.id == api_based_extension_id
+        ).first()
+
+        if not api_based_extension:
+            raise ValueError("api_based_extension_id is invalid")
+
+    def query(self, inputs: dict, query: Optional[str] = None) -> str:
+        """
+        Query the external data tool.
+
+        :param inputs: user inputs
+        :param query: the query of chat app
+        :return: the tool query result
+        """
+        # get params from config
+        api_based_extension_id = self.config.get("api_based_extension_id")
+
+        # get api_based_extension
+        api_based_extension = db.session.query(APIBasedExtension).filter(
+            APIBasedExtension.tenant_id == self.tenant_id,
+            APIBasedExtension.id == api_based_extension_id
+        ).first()
+
+        if not api_based_extension:
+            raise ValueError("[External data tool] API query failed, variable: {}, "
+                             "error: api_based_extension_id is invalid"
+                             .format(self.config.get('variable')))
+
+        # decrypt api_key
+        api_key = encrypter.decrypt_token(
+            tenant_id=self.tenant_id,
+            token=api_based_extension.api_key
+        )
+
+        try:
+            # request api
+            requestor = APIBasedExtensionRequestor(
+                api_endpoint=api_based_extension.api_endpoint,
+                api_key=api_key
+            )
+        except Exception as e:
+            raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(
+                self.config.get('variable'),
+                e
+            ))
+
+        response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={
+            'app_id': self.app_id,
+            'tool_variable': self.variable,
+            'inputs': inputs,
+            'query': query
+        })
+
+        if 'result' not in response_json:
+            raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response"
+                             .format(self.config.get('variable')))
+
+        return response_json['result']

+ 45 - 0
api/core/external_data_tool/base.py

@@ -0,0 +1,45 @@
+from abc import abstractmethod, ABC
+from typing import Optional
+
+from core.extension.extensible import Extensible, ExtensionModule
+
+
+class ExternalDataTool(Extensible, ABC):
+    """
+    The base class of external data tool.
+    """
+
+    module: ExtensionModule = ExtensionModule.EXTERNAL_DATA_TOOL
+
+    app_id: str
+    """the id of app"""
+    variable: str
+    """the tool variable name of app tool"""
+
+    def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None:
+        super().__init__(tenant_id, config)
+        self.app_id = app_id
+        self.variable = variable
+
+    @classmethod
+    @abstractmethod
+    def validate_config(cls, tenant_id: str, config: dict) -> None:
+        """
+        Validate the incoming form config data.
+
+        :param tenant_id: the id of workspace
+        :param config: the form config data
+        :return:
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def query(self, inputs: dict, query: Optional[str] = None) -> str:
+        """
+        Query the external data tool.
+
+        :param inputs: user inputs
+        :param query: the query of chat app
+        :return: the tool query result
+        """
+        raise NotImplementedError

+ 40 - 0
api/core/external_data_tool/factory.py

@@ -0,0 +1,40 @@
+from typing import Optional
+
+from core.extension.extensible import ExtensionModule
+from extensions.ext_code_based_extension import code_based_extension
+
+
+class ExternalDataToolFactory:
+
+    def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None:
+        extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
+        self.__extension_instance = extension_class(
+            tenant_id=tenant_id,
+            app_id=app_id,
+            variable=variable,
+            config=config
+        )
+
+    @classmethod
+    def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
+        """
+        Validate the incoming form config data.
+
+        :param name: the name of external data tool
+        :param tenant_id: the id of workspace
+        :param config: the form config data
+        :return:
+        """
+        code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
+        extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
+        extension_class.validate_config(tenant_id, config)
+
+    def query(self, inputs: dict, query: Optional[str] = None) -> str:
+        """
+        Query the external data tool.
+
+        :param inputs: user inputs
+        :param query: the query of chat app
+        :return: the tool query result
+        """
+        return self.__extension_instance.query(inputs, query)

+ 0 - 0
api/core/moderation/__init__.py


+ 1 - 0
api/core/moderation/api/__builtin__

@@ -0,0 +1 @@
+3

+ 0 - 0
api/core/moderation/api/__init__.py


+ 88 - 0
api/core/moderation/api/api.py

@@ -0,0 +1,88 @@
+from pydantic import BaseModel
+
+from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
+from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor, APIBasedExtensionPoint
+from core.helper.encrypter import decrypt_token
+from extensions.ext_database import db
+from models.api_based_extension import APIBasedExtension
+
+
+class ModerationInputParams(BaseModel):
+    app_id: str = ""
+    inputs: dict = {}
+    query: str = ""
+
+
+class ModerationOutputParams(BaseModel):
+    app_id: str = ""
+    text: str
+
+
+class ApiModeration(Moderation):
+    name: str = "api"
+
+    @classmethod
+    def validate_config(cls, tenant_id: str, config: dict) -> None:
+        """
+        Validate the incoming form config data.
+
+        :param tenant_id: the id of workspace
+        :param config: the form config data
+        :return:
+        """
+        cls._validate_inputs_and_outputs_config(config, False)
+
+        api_based_extension_id = config.get("api_based_extension_id")
+        if not api_based_extension_id:
+            raise ValueError("api_based_extension_id is required")
+
+        extension = cls._get_api_based_extension(tenant_id, api_based_extension_id)
+        if not extension:
+            raise ValueError("API-based Extension not found. Please check it again.")
+
+    def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
+        flagged = False
+        preset_response = ""
+
+        if self.config['inputs_config']['enabled']:
+            params = ModerationInputParams(
+                app_id=self.app_id,
+                inputs=inputs,
+                query=query
+            )
+
+            result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.dict())
+            return ModerationInputsResult(**result)
+
+        return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
+
+    def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
+        flagged = False
+        preset_response = ""
+
+        if self.config['outputs_config']['enabled']:
+            params = ModerationOutputParams(
+                app_id=self.app_id,
+                text=text
+            )
+
+            result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.dict())
+            return ModerationOutputsResult(**result)
+
+        return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
+
+    def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
+        extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id"))
+        requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key))
+
+        result = requestor.request(extension_point, params)
+        return result
+
+    @staticmethod
+    def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
+        extension = db.session.query(APIBasedExtension).filter(
+            APIBasedExtension.tenant_id == tenant_id,
+            APIBasedExtension.id == api_based_extension_id
+        ).first()
+
+        return extension

+ 113 - 0
api/core/moderation/base.py

@@ -0,0 +1,113 @@
+from abc import ABC, abstractmethod
+from typing import Optional
+from pydantic import BaseModel
+from enum import Enum
+
+from core.extension.extensible import Extensible, ExtensionModule
+
+
+class ModerationAction(Enum):
+    DIRECT_OUTPUT = 'direct_output'
+    OVERRIDED = 'overrided'
+
+
+class ModerationInputsResult(BaseModel):
+    flagged: bool = False
+    action: ModerationAction
+    preset_response: str = ""
+    inputs: dict = {}
+    query: str = ""
+
+
+class ModerationOutputsResult(BaseModel):
+    flagged: bool = False
+    action: ModerationAction
+    preset_response: str = ""
+    text: str = ""
+
+
+class Moderation(Extensible, ABC):
+    """
+    The base class of moderation.
+    """
+    module: ExtensionModule = ExtensionModule.MODERATION
+
+    def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None:
+        super().__init__(tenant_id, config)
+        self.app_id = app_id
+
+    @classmethod
+    @abstractmethod
+    def validate_config(cls, tenant_id: str, config: dict) -> None:
+        """
+        Validate the incoming form config data.
+
+        :param tenant_id: the id of workspace
+        :param config: the form config data
+        :return:
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
+        """
+        Moderation for inputs.
+        After the user inputs, this method will be called to perform sensitive content review
+        on the user inputs and return the processed results.
+
+        :param inputs: user inputs
+        :param query: query string (required in chat app)
+        :return:
+        """
+        raise NotImplementedError
+
+    @abstractmethod
+    def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
+        """
+        Moderation for outputs.
+        When LLM outputs content, the front end will pass the output content (may be segmented)
+        to this method for sensitive content review, and the output content will be shielded if the review fails.
+
+        :param text: LLM output content
+        :return:
+        """
+        raise NotImplementedError
+
+    @classmethod
+    def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None:
+        # inputs_config
+        inputs_config = config.get("inputs_config")
+        if not isinstance(inputs_config, dict):
+            raise ValueError("inputs_config must be a dict")
+
+        # outputs_config
+        outputs_config = config.get("outputs_config")
+        if not isinstance(outputs_config, dict):
+            raise ValueError("outputs_config must be a dict")
+
+        inputs_config_enabled = inputs_config.get("enabled")
+        outputs_config_enabled = outputs_config.get("enabled")
+        if not inputs_config_enabled and not outputs_config_enabled:
+            raise ValueError("At least one of inputs_config or outputs_config must be enabled")
+
+        # preset_response
+        if not is_preset_response_required:
+            return
+
+        if inputs_config_enabled:
+            if not inputs_config.get("preset_response"):
+                raise ValueError("inputs_config.preset_response is required")
+
+            if len(inputs_config.get("preset_response")) > 100:
+                raise ValueError("inputs_config.preset_response must be less than 100 characters")
+
+        if outputs_config_enabled:
+            if not outputs_config.get("preset_response"):
+                raise ValueError("outputs_config.preset_response is required")
+
+            if len(outputs_config.get("preset_response")) > 100:
+                raise ValueError("outputs_config.preset_response must be less than 100 characters")
+
+
+class ModerationException(Exception):
+    pass

+ 48 - 0
api/core/moderation/factory.py

@@ -0,0 +1,48 @@
+from core.extension.extensible import ExtensionModule
+from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
+from extensions.ext_code_based_extension import code_based_extension
+
+
+class ModerationFactory:
+    __extension_instance: Moderation
+
+    def __init__(self, name: str, app_id: str, tenant_id: str, config: dict) -> None:
+        extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
+        self.__extension_instance = extension_class(app_id, tenant_id, config)
+
+    @classmethod
+    def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
+        """
+        Validate the incoming form config data.
+
+        :param name: the name of extension
+        :param tenant_id: the id of workspace
+        :param config: the form config data
+        :return:
+        """
+        code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
+        extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
+        extension_class.validate_config(tenant_id, config)
+
+    def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
+        """
+        Moderation for inputs.
+        After the user inputs, this method will be called to perform sensitive content review
+        on the user inputs and return the processed results.
+
+        :param inputs: user inputs
+        :param query: query string (required in chat app)
+        :return:
+        """
+        return self.__extension_instance.moderation_for_inputs(inputs, query)
+
+    def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
+        """
+        Moderation for outputs.
+        When LLM outputs content, the front end will pass the output content (may be segmented)
+        to this method for sensitive content review, and the output content will be shielded if the review fails.
+
+        :param text: LLM output content
+        :return:
+        """
+        return self.__extension_instance.moderation_for_outputs(text)

+ 1 - 0
api/core/moderation/keywords/__builtin__

@@ -0,0 +1 @@
+2

+ 0 - 0
api/core/moderation/keywords/__init__.py


+ 60 - 0
api/core/moderation/keywords/keywords.py

@@ -0,0 +1,60 @@
+from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
+
+
+class KeywordsModeration(Moderation):
+    name: str = "keywords"
+
+    @classmethod
+    def validate_config(cls, tenant_id: str, config: dict) -> None:
+        """
+        Validate the incoming form config data.
+
+        :param tenant_id: the id of workspace
+        :param config: the form config data
+        :return:
+        """
+        cls._validate_inputs_and_outputs_config(config, True)
+
+        if not config.get("keywords"):
+            raise ValueError("keywords is required")
+
+        if len(config.get("keywords")) > 1000:
+            raise ValueError("keywords length must be less than 1000")
+
+    def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
+        flagged = False
+        preset_response = ""
+
+        if self.config['inputs_config']['enabled']:
+            preset_response = self.config['inputs_config']['preset_response']
+
+            if query:
+                inputs['query__'] = query
+            keywords_list = self.config['keywords'].split('\n')
+            flagged = self._is_violated(inputs, keywords_list)
+
+        return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
+
+    def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
+        flagged = False
+        preset_response = ""
+
+        if self.config['outputs_config']['enabled']:
+            keywords_list = self.config['keywords'].split('\n')
+            flagged = self._is_violated({'text': text}, keywords_list)
+            preset_response = self.config['outputs_config']['preset_response']
+
+        return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
+
+    def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
+        for value in inputs.values():
+            if self._check_keywords_in_value(keywords_list, value):
+                return True
+
+        return False
+
+    def _check_keywords_in_value(self, keywords_list, value):
+        for keyword in keywords_list:
+            if keyword.lower() in value.lower():
+                return True
+        return False

+ 1 - 0
api/core/moderation/openai_moderation/__builtin__

@@ -0,0 +1 @@
+1

+ 0 - 0
api/core/moderation/openai_moderation/__init__.py


+ 46 - 0
api/core/moderation/openai_moderation/openai_moderation.py

@@ -0,0 +1,46 @@
+from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
+from core.model_providers.model_factory import ModelFactory
+
+
+class OpenAIModeration(Moderation):
+    name: str = "openai_moderation"
+
+    @classmethod
+    def validate_config(cls, tenant_id: str, config: dict) -> None:
+        """
+        Validate the incoming form config data.
+
+        :param tenant_id: the id of workspace
+        :param config: the form config data
+        :return:
+        """
+        cls._validate_inputs_and_outputs_config(config, True)
+
+    def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
+        flagged = False
+        preset_response = ""
+
+        if self.config['inputs_config']['enabled']:
+            preset_response = self.config['inputs_config']['preset_response']
+
+            if query:
+                inputs['query__'] = query
+            flagged = self._is_violated(inputs)
+
+        return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
+
+    def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
+        flagged = False
+        preset_response = ""
+
+        if self.config['outputs_config']['enabled']:
+            flagged = self._is_violated({'text': text})
+            preset_response = self.config['outputs_config']['preset_response']
+
+        return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
+
+    def _is_violated(self, inputs: dict):
+        text = '\n'.join(inputs.values())
+        openai_moderation = ModelFactory.get_moderation_model(self.tenant_id, "openai", "moderation")
+        is_not_invalid = openai_moderation.run(text)
+        return not is_not_invalid

+ 0 - 47
api/core/orchestrator_rule_parser.py

@@ -11,7 +11,6 @@ from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGa
 from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
 from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
-from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain, SensitiveWordAvoidanceRule
 from core.conversation_message_task import ConversationMessageTask
 from core.model_providers.error import ProviderTokenNotInitError
 from core.model_providers.model_factory import ModelFactory
@@ -125,52 +124,6 @@ class OrchestratorRuleParser:
 
         return chain
 
-    def to_sensitive_word_avoidance_chain(self, model_instance: BaseLLM, callbacks: Callbacks = None, **kwargs) \
-            -> Optional[SensitiveWordAvoidanceChain]:
-        """
-        Convert app sensitive word avoidance config to chain
-
-        :param model_instance: model instance
-        :param callbacks: callbacks for the chain
-        :param kwargs:
-        :return:
-        """
-        sensitive_word_avoidance_rule = None
-
-        if self.app_model_config.sensitive_word_avoidance_dict:
-            sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
-            if sensitive_word_avoidance_config.get("enabled", False):
-                if sensitive_word_avoidance_config.get('type') == 'moderation':
-                    sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
-                        type=SensitiveWordAvoidanceRule.Type.MODERATION,
-                        canned_response=sensitive_word_avoidance_config.get("canned_response")
-                        if sensitive_word_avoidance_config.get("canned_response")
-                        else 'Your content violates our usage policy. Please revise and try again.',
-                    )
-                else:
-                    sensitive_words = sensitive_word_avoidance_config.get("words", "")
-                    if sensitive_words:
-                        sensitive_word_avoidance_rule = SensitiveWordAvoidanceRule(
-                            type=SensitiveWordAvoidanceRule.Type.KEYWORDS,
-                            canned_response=sensitive_word_avoidance_config.get("canned_response")
-                            if sensitive_word_avoidance_config.get("canned_response")
-                            else 'Your content violates our usage policy. Please revise and try again.',
-                            extra_params={
-                                'sensitive_words': sensitive_words.split(','),
-                            }
-                        )
-
-        if sensitive_word_avoidance_rule:
-            return SensitiveWordAvoidanceChain(
-                model_instance=model_instance,
-                sensitive_word_avoidance_rule=sensitive_word_avoidance_rule,
-                output_key="sensitive_word_avoidance_output",
-                callbacks=callbacks,
-                **kwargs
-            )
-
-        return None
-
     def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
         """
         Convert app agent tool configs to tools

+ 8 - 0
api/extensions/ext_code_based_extension.py

@@ -0,0 +1,8 @@
+from core.extension.extension import Extension
+
+
+def init():
+    code_based_extension.init()
+
+
+code_based_extension = Extension()

+ 17 - 0
api/fields/api_based_extension_fields.py

@@ -0,0 +1,17 @@
+from flask_restful import fields
+
+from libs.helper import TimestampField
+
+
+class HiddenAPIKey(fields.Raw):
+    def output(self, key, obj):
+        return obj.api_key[:3] + '***' + obj.api_key[-3:]
+
+
+api_based_extension_fields = {
+    'id': fields.String,
+    'name': fields.String,
+    'api_endpoint': fields.String,
+    'api_key': HiddenAPIKey,
+    'created_at': TimestampField
+}

+ 1 - 0
api/fields/app_fields.py

@@ -23,6 +23,7 @@ model_config_fields = {
     'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
     'more_like_this': fields.Raw(attribute='more_like_this_dict'),
     'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
+    'external_data_tools': fields.Raw(attribute='external_data_tools_list'),
     'model': fields.Raw(attribute='model_dict'),
     'user_input_form': fields.Raw(attribute='user_input_form_list'),
     'dataset_query_variable': fields.String,

+ 45 - 0
api/migrations/versions/968fff4c0ab9_add_api_based_extension.py

@@ -0,0 +1,45 @@
+"""add_api_based_extension
+
+Revision ID: 968fff4c0ab9
+Revises: b3a09c049e8e
+Create Date: 2023-10-27 13:05:58.901858
+
+"""
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '968fff4c0ab9'
+down_revision = 'b3a09c049e8e'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+
+    op.create_table('api_based_extensions',
+    sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+    sa.Column('tenant_id', postgresql.UUID(), nullable=False),
+    sa.Column('name', sa.String(length=255), nullable=False),
+    sa.Column('api_endpoint', sa.String(length=255), nullable=False),
+    sa.Column('api_key', sa.Text(), nullable=False),
+    sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.PrimaryKeyConstraint('id', name='api_based_extension_pkey')
+    )
+    with op.batch_alter_table('api_based_extensions', schema=None) as batch_op:
+        batch_op.create_index('api_based_extension_tenant_idx', ['tenant_id'], unique=False)
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+
+    with op.batch_alter_table('api_based_extensions', schema=None) as batch_op:
+        batch_op.drop_index('api_based_extension_tenant_idx')
+
+    op.drop_table('api_based_extensions')
+
+    # ### end Alembic commands ###

+ 32 - 0
api/migrations/versions/a9836e3baeee_add_external_data_tools_in_app_model_.py

@@ -0,0 +1,32 @@
+"""add external_data_tools in app model config
+
+Revision ID: a9836e3baeee
+Revises: 968fff4c0ab9
+Create Date: 2023-11-02 04:04:57.609485
+
+"""
+from alembic import op
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = 'a9836e3baeee'
+down_revision = '968fff4c0ab9'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+        batch_op.add_column(sa.Column('external_data_tools', sa.Text(), nullable=True))
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
+        batch_op.drop_column('external_data_tools')
+
+    # ### end Alembic commands ###

+ 27 - 0
api/models/api_based_extension.py

@@ -0,0 +1,27 @@
+import enum
+
+from sqlalchemy.dialects.postgresql import UUID
+
+from extensions.ext_database import db
+
+
+class APIBasedExtensionPoint(enum.Enum):
+    APP_EXTERNAL_DATA_TOOL_QUERY = 'app.external_data_tool.query'
+    PING = 'ping'
+    APP_MODERATION_INPUT = 'app.moderation.input'
+    APP_MODERATION_OUTPUT = 'app.moderation.output'
+
+
+class APIBasedExtension(db.Model):
+    __tablename__ = 'api_based_extensions'
+    __table_args__ = (
+        db.PrimaryKeyConstraint('id', name='api_based_extension_pkey'),
+        db.Index('api_based_extension_tenant_idx', 'tenant_id'),
+    )
+
+    id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
+    tenant_id = db.Column(UUID, nullable=False)
+    name = db.Column(db.String(255), nullable=False)
+    api_endpoint = db.Column(db.String(255), nullable=False)
+    api_key = db.Column(db.Text, nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))

+ 14 - 30
api/models/model.py

@@ -97,6 +97,7 @@ class AppModelConfig(db.Model):
     chat_prompt_config = db.Column(db.Text)
     completion_prompt_config = db.Column(db.Text)
     dataset_configs = db.Column(db.Text)
+    external_data_tools = db.Column(db.Text)
 
     @property
     def app(self):
@@ -133,7 +134,12 @@ class AppModelConfig(db.Model):
     @property
     def sensitive_word_avoidance_dict(self) -> dict:
         return json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance \
-            else {"enabled": False, "words": [], "canned_response": []}
+            else {"enabled": False, "type": "", "configs": []}
+
+    @property
+    def external_data_tools_list(self) -> list[dict]:
+        return json.loads(self.external_data_tools) if self.external_data_tools \
+            else []
 
     @property
     def user_input_form_list(self) -> dict:
@@ -167,6 +173,7 @@ class AppModelConfig(db.Model):
             "retriever_resource": self.retriever_resource_dict,
             "more_like_this": self.more_like_this_dict,
             "sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
+            "external_data_tools": self.external_data_tools_list,
             "model": self.model_dict,
             "user_input_form": self.user_input_form_list,
             "dataset_query_variable": self.dataset_query_variable,
@@ -190,6 +197,7 @@ class AppModelConfig(db.Model):
         self.more_like_this = json.dumps(model_config['more_like_this'])
         self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \
             if model_config.get('sensitive_word_avoidance') else None
+        self.external_data_tools = json.dumps(model_config['external_data_tools'])
         self.model = json.dumps(model_config['model'])
         self.user_input_form = json.dumps(model_config['user_input_form'])
         self.dataset_query_variable = model_config.get('dataset_query_variable')
@@ -219,6 +227,7 @@ class AppModelConfig(db.Model):
             speech_to_text=self.speech_to_text,
             more_like_this=self.more_like_this,
             sensitive_word_avoidance=self.sensitive_word_avoidance,
+            external_data_tools=self.external_data_tools,
             model=self.model,
             user_input_form=self.user_input_form,
             dataset_query_variable=self.dataset_query_variable,
@@ -332,41 +341,16 @@ class Conversation(db.Model):
             override_model_configs = json.loads(self.override_model_configs)
 
             if 'model' in override_model_configs:
-                model_config['model'] = override_model_configs['model']
-                model_config['pre_prompt'] = override_model_configs['pre_prompt']
-                model_config['agent_mode'] = override_model_configs['agent_mode']
-                model_config['opening_statement'] = override_model_configs['opening_statement']
-                model_config['suggested_questions'] = override_model_configs['suggested_questions']
-                model_config['suggested_questions_after_answer'] = override_model_configs[
-                    'suggested_questions_after_answer'] \
-                    if 'suggested_questions_after_answer' in override_model_configs else {"enabled": False}
-                model_config['speech_to_text'] = override_model_configs[
-                    'speech_to_text'] \
-                    if 'speech_to_text' in override_model_configs else {"enabled": False}
-                model_config['more_like_this'] = override_model_configs['more_like_this'] \
-                    if 'more_like_this' in override_model_configs else {"enabled": False}
-                model_config['sensitive_word_avoidance'] = override_model_configs['sensitive_word_avoidance'] \
-                    if 'sensitive_word_avoidance' in override_model_configs \
-                    else {"enabled": False, "words": [], "canned_response": []}
-                model_config['user_input_form'] = override_model_configs['user_input_form']
+                app_model_config = AppModelConfig()
+                app_model_config = app_model_config.from_model_config_dict(override_model_configs)
+                model_config = app_model_config.to_dict()
             else:
                 model_config['configs'] = override_model_configs
         else:
             app_model_config = db.session.query(AppModelConfig).filter(
                 AppModelConfig.id == self.app_model_config_id).first()
 
-            model_config['configs'] = app_model_config.configs
-            model_config['model'] = app_model_config.model_dict
-            model_config['pre_prompt'] = app_model_config.pre_prompt
-            model_config['agent_mode'] = app_model_config.agent_mode_dict
-            model_config['opening_statement'] = app_model_config.opening_statement
-            model_config['suggested_questions'] = app_model_config.suggested_questions_list
-            model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict
-            model_config['speech_to_text'] = app_model_config.speech_to_text_dict
-            model_config['retriever_resource'] = app_model_config.retriever_resource_dict
-            model_config['more_like_this'] = app_model_config.more_like_this_dict
-            model_config['sensitive_word_avoidance'] = app_model_config.sensitive_word_avoidance_dict
-            model_config['user_input_form'] = app_model_config.user_input_form_list
+            model_config = app_model_config.to_dict()
 
         model_config['model_id'] = self.model_id
         model_config['provider'] = self.model_provider

+ 98 - 0
api/services/api_based_extension_service.py

@@ -0,0 +1,98 @@
+from extensions.ext_database import db
+from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
+from core.helper.encrypter import encrypt_token, decrypt_token
+from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
+
+
+class APIBasedExtensionService:
+
+    @staticmethod
+    def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]:
+        extension_list = db.session.query(APIBasedExtension) \
+                    .filter_by(tenant_id=tenant_id) \
+                    .order_by(APIBasedExtension.created_at.desc()) \
+                    .all()
+
+        for extension in extension_list:
+            extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
+
+        return extension_list
+
+    @classmethod
+    def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension:
+        cls._validation(extension_data)
+
+        extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key)
+
+        db.session.add(extension_data)
+        db.session.commit()
+        return extension_data
+
+    @staticmethod
+    def delete(extension_data: APIBasedExtension) -> None:
+        db.session.delete(extension_data)
+        db.session.commit()
+
+    @staticmethod
+    def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
+        extension = db.session.query(APIBasedExtension) \
+            .filter_by(tenant_id=tenant_id) \
+            .filter_by(id=api_based_extension_id) \
+            .first()
+
+        if not extension:
+            raise ValueError("API based extension is not found")
+
+        extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
+
+        return extension
+
+    @classmethod
+    def _validation(cls, extension_data: APIBasedExtension) -> None:
+        # name
+        if not extension_data.name:
+            raise ValueError("name must not be empty")
+
+        if not extension_data.id:
+            # case one: check new data, name must be unique
+            is_name_existed = db.session.query(APIBasedExtension) \
+                .filter_by(tenant_id=extension_data.tenant_id) \
+                .filter_by(name=extension_data.name) \
+                .first()
+
+            if is_name_existed:
+                raise ValueError("name must be unique, it is already existed")
+        else:
+            # case two: check existing data, name must be unique
+            is_name_existed = db.session.query(APIBasedExtension) \
+                .filter_by(tenant_id=extension_data.tenant_id) \
+                .filter_by(name=extension_data.name) \
+                .filter(APIBasedExtension.id != extension_data.id) \
+                .first()
+
+            if is_name_existed:
+                raise ValueError("name must be unique, it is already existed")
+
+        # api_endpoint
+        if not extension_data.api_endpoint:
+            raise ValueError("api_endpoint must not be empty")
+
+        # api_key
+        if not extension_data.api_key:
+            raise ValueError("api_key must not be empty")
+
+        if len(extension_data.api_key) < 5:
+            raise ValueError("api_key must be at least 5 characters")
+
+        # check endpoint
+        cls._ping_connection(extension_data)
+
+    @staticmethod
+    def _ping_connection(extension_data: APIBasedExtension) -> None:
+        try:
+            client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
+            resp = client.request(point=APIBasedExtensionPoint.PING, params={})
+            if resp.get('result') != 'pong':
+                raise ValueError(resp)
+        except Exception as e:
+            raise ValueError("connection error: {}".format(e))

+ 89 - 54
api/services/app_model_config_service.py

@@ -1,6 +1,8 @@
 import re
 import uuid
 
+from core.external_data_tool.factory import ExternalDataToolFactory
+from core.moderation.factory import ModerationFactory
 from core.prompt.prompt_transform import AppMode
 from core.agent.agent_executor import PlanningStrategy
 from core.model_providers.model_provider_factory import ModelProviderFactory
@@ -13,8 +15,8 @@ SUPPORT_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current
 
 
 class AppModelConfigService:
-    @staticmethod
-    def is_dataset_exists(account: Account, dataset_id: str) -> bool:
+    @classmethod
+    def is_dataset_exists(cls, account: Account, dataset_id: str) -> bool:
         # verify if the dataset ID exists
         dataset = DatasetService.get_dataset(dataset_id)
 
@@ -26,8 +28,8 @@ class AppModelConfigService:
 
         return True
 
-    @staticmethod
-    def validate_model_completion_params(cp: dict, model_name: str) -> dict:
+    @classmethod
+    def validate_model_completion_params(cls, cp: dict, model_name: str) -> dict:
         # 6. model.completion_params
         if not isinstance(cp, dict):
             raise ValueError("model.completion_params must be of object type")
@@ -57,7 +59,7 @@ class AppModelConfigService:
             cp["stop"] = []
         elif not isinstance(cp["stop"], list):
             raise ValueError("stop in model.completion_params must be of list type")
-        
+
         if len(cp["stop"]) > 4:
             raise ValueError("stop sequences must be less than 4")
 
@@ -73,8 +75,8 @@ class AppModelConfigService:
 
         return filtered_cp
 
-    @staticmethod
-    def validate_configuration(tenant_id: str, account: Account, config: dict, mode: str) -> dict:
+    @classmethod
+    def validate_configuration(cls, tenant_id: str, account: Account, config: dict, mode: str) -> dict:
         # opening_statement
         if 'opening_statement' not in config or not config["opening_statement"]:
             config["opening_statement"] = ""
@@ -153,33 +155,6 @@ class AppModelConfigService:
         if not isinstance(config["more_like_this"]["enabled"], bool):
             raise ValueError("enabled in more_like_this must be of boolean type")
 
-        # sensitive_word_avoidance
-        if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]:
-            config["sensitive_word_avoidance"] = {
-                "enabled": False
-            }
-
-        if not isinstance(config["sensitive_word_avoidance"], dict):
-            raise ValueError("sensitive_word_avoidance must be of dict type")
-
-        if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
-            config["sensitive_word_avoidance"]["enabled"] = False
-
-        if not isinstance(config["sensitive_word_avoidance"]["enabled"], bool):
-            raise ValueError("enabled in sensitive_word_avoidance must be of boolean type")
-
-        if "words" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["words"]:
-            config["sensitive_word_avoidance"]["words"] = ""
-
-        if not isinstance(config["sensitive_word_avoidance"]["words"], str):
-            raise ValueError("words in sensitive_word_avoidance must be of string type")
-
-        if "canned_response" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["canned_response"]:
-            config["sensitive_word_avoidance"]["canned_response"] = ""
-
-        if not isinstance(config["sensitive_word_avoidance"]["canned_response"], str):
-            raise ValueError("canned_response in sensitive_word_avoidance must be of string type")
-
         # model
         if 'model' not in config:
             raise ValueError("model is required")
@@ -204,7 +179,7 @@ class AppModelConfigService:
         model_ids = [m['id'] for m in model_list]
         if config["model"]["name"] not in model_ids:
             raise ValueError("model.name must be in the specified model list")
-        
+
         # model.mode
         if 'mode' not in config['model'] or not config['model']["mode"]:
             config['model']["mode"] = ""
@@ -213,7 +188,7 @@ class AppModelConfigService:
         if 'completion_params' not in config["model"]:
             raise ValueError("model.completion_params is required")
 
-        config["model"]["completion_params"] = AppModelConfigService.validate_model_completion_params(
+        config["model"]["completion_params"] = cls.validate_model_completion_params(
             config["model"]["completion_params"],
             config["model"]["name"]
         )
@@ -330,14 +305,20 @@ class AppModelConfigService:
                 except ValueError:
                     raise ValueError("id in dataset must be of UUID type")
 
-                if not AppModelConfigService.is_dataset_exists(account, tool_item["id"]):
+                if not cls.is_dataset_exists(account, tool_item["id"]):
                     raise ValueError("Dataset ID does not exist, please check your permission.")
-        
+
         # dataset_query_variable
-        AppModelConfigService.is_dataset_query_variable_valid(config, mode)
+        cls.is_dataset_query_variable_valid(config, mode)
 
         # advanced prompt validation
-        AppModelConfigService.is_advanced_prompt_valid(config, mode)
+        cls.is_advanced_prompt_valid(config, mode)
+
+        # external data tools validation
+        cls.is_external_data_tools_valid(tenant_id, config)
+
+        # moderation validation
+        cls.is_moderation_valid(tenant_id, config)
 
         # Filter out extra parameters
         filtered_config = {
@@ -348,6 +329,7 @@ class AppModelConfigService:
             "retriever_resource": config["retriever_resource"],
             "more_like_this": config["more_like_this"],
             "sensitive_word_avoidance": config["sensitive_word_avoidance"],
+            "external_data_tools": config["external_data_tools"],
             "model": {
                 "provider": config["model"]["provider"],
                 "name": config["model"]["name"],
@@ -365,32 +347,86 @@ class AppModelConfigService:
         }
 
         return filtered_config
-    
-    @staticmethod
-    def is_dataset_query_variable_valid(config: dict, mode: str) -> None:
+
+    @classmethod
+    def is_moderation_valid(cls, tenant_id: str, config: dict):
+        if 'sensitive_word_avoidance' not in config or not config["sensitive_word_avoidance"]:
+            config["sensitive_word_avoidance"] = {
+                "enabled": False
+            }
+
+        if not isinstance(config["sensitive_word_avoidance"], dict):
+            raise ValueError("sensitive_word_avoidance must be of dict type")
+
+        if "enabled" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["enabled"]:
+            config["sensitive_word_avoidance"]["enabled"] = False
+
+        if not config["sensitive_word_avoidance"]["enabled"]:
+            return
+
+        if "type" not in config["sensitive_word_avoidance"] or not config["sensitive_word_avoidance"]["type"]:
+            raise ValueError("sensitive_word_avoidance.type is required")
+
+        type = config["sensitive_word_avoidance"]["type"]
+        config = config["sensitive_word_avoidance"]["config"]
+
+        ModerationFactory.validate_config(
+            name=type,
+            tenant_id=tenant_id,
+            config=config
+        )
+
+    @classmethod
+    def is_external_data_tools_valid(cls, tenant_id: str, config: dict):
+        if 'external_data_tools' not in config or not config["external_data_tools"]:
+            config["external_data_tools"] = []
+
+        if not isinstance(config["external_data_tools"], list):
+            raise ValueError("external_data_tools must be of list type")
+
+        for tool in config["external_data_tools"]:
+            if "enabled" not in tool or not tool["enabled"]:
+                tool["enabled"] = False
+
+            if not tool["enabled"]:
+                continue
+
+            if "type" not in tool or not tool["type"]:
+                raise ValueError("external_data_tools[].type is required")
+
+            type = tool["type"]
+            config = tool["config"]
+
+            ExternalDataToolFactory.validate_config(
+                name=type,
+                tenant_id=tenant_id,
+                config=config
+            )
+
+    @classmethod
+    def is_dataset_query_variable_valid(cls, config: dict, mode: str) -> None:
         # Only check when mode is completion
         if mode != 'completion':
             return
-        
+
         agent_mode = config.get("agent_mode", {})
         tools = agent_mode.get("tools", [])
         dataset_exists = "dataset" in str(tools)
-        
+
         dataset_query_variable = config.get("dataset_query_variable")
 
         if dataset_exists and not dataset_query_variable:
             raise ValueError("Dataset query variable is required when dataset is exist")
-        
 
-    @staticmethod
-    def is_advanced_prompt_valid(config: dict, app_mode: str) -> None:
+    @classmethod
+    def is_advanced_prompt_valid(cls, config: dict, app_mode: str) -> None:
         # prompt_type
         if 'prompt_type' not in config or not config["prompt_type"]:
             config["prompt_type"] = "simple"
 
         if config['prompt_type'] not in ['simple', 'advanced']:
             raise ValueError("prompt_type must be in ['simple', 'advanced']")
-        
+
         # chat_prompt_config
         if 'chat_prompt_config' not in config or not config["chat_prompt_config"]:
             config["chat_prompt_config"] = {}
@@ -404,7 +440,7 @@ class AppModelConfigService:
 
         if not isinstance(config["completion_prompt_config"], dict):
             raise ValueError("completion_prompt_config must be of object type")
-        
+
         # dataset_configs
         if 'dataset_configs' not in config or not config["dataset_configs"]:
             config["dataset_configs"] = {"top_k": 2, "score_threshold": {"enable": False}}
@@ -415,10 +451,10 @@ class AppModelConfigService:
         if config['prompt_type'] == 'advanced':
             if not config['chat_prompt_config'] and not config['completion_prompt_config']:
                 raise ValueError("chat_prompt_config or completion_prompt_config is required when prompt_type is advanced")
-            
+
             if config['model']["mode"] not in ['chat', 'completion']:
                 raise ValueError("model.mode must be in ['chat', 'completion'] when prompt_type is advanced")
-            
+
             if app_mode == AppMode.CHAT.value and config['model']["mode"] == ModelMode.COMPLETION.value:
                 user_prefix = config['completion_prompt_config']['conversation_histories_role']['user_prefix']
                 assistant_prefix = config['completion_prompt_config']['conversation_histories_role']['assistant_prefix']
@@ -429,9 +465,8 @@ class AppModelConfigService:
                 if not assistant_prefix:
                     config['completion_prompt_config']['conversation_histories_role']['assistant_prefix'] = 'Assistant'
 
-
             if config['model']["mode"] == ModelMode.CHAT.value:
                 prompt_list = config['chat_prompt_config']['prompt']
 
                 if len(prompt_list) > 10:
-                    raise ValueError("prompt messages must be less than 10")
+                    raise ValueError("prompt messages must be less than 10")

+ 13 - 0
api/services/code_based_extension_service.py

@@ -0,0 +1,13 @@
+from extensions.ext_code_based_extension import code_based_extension
+
+
+class CodeBasedExtensionService:
+
+    @staticmethod
+    def get_code_based_extension(module: str) -> list[dict]:
+        module_extensions = code_based_extension.module_extensions(module)
+        return [{
+            'name': module_extension.name,
+            'label': module_extension.label,
+            'form_schema': module_extension.form_schema
+        } for module_extension in module_extensions if not module_extension.builtin]

+ 28 - 9
api/services/completion_service.py

@@ -10,7 +10,8 @@ from redis.client import PubSub
 from sqlalchemy import and_
 
 from core.completion import Completion
-from core.conversation_message_task import PubHandler, ConversationTaskStoppedException
+from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
+    ConversationTaskInterruptException
 from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
     LLMRateLimitError, \
     LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
@@ -28,9 +29,9 @@ from services.errors.message import MessageNotExistsError
 class CompletionService:
 
     @classmethod
-    def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any,
+    def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any,
                    from_source: str, streaming: bool = True,
-                   is_model_config_override: bool = False) -> Union[dict | Generator]:
+                   is_model_config_override: bool = False) -> Union[dict, Generator]:
         # is streaming mode
         inputs = args['inputs']
         query = args['query']
@@ -199,9 +200,9 @@ class CompletionService:
                     is_override=is_model_config_override,
                     retriever_from=retriever_from
                 )
-            except ConversationTaskStoppedException:
+            except (ConversationTaskInterruptException, ConversationTaskStoppedException):
                 pass
-            except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
+            except (ValueError, LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
                     LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
                     ModelCurrentlyNotSupportError) as e:
                 PubHandler.pub_error(user, generate_task_id, e)
@@ -234,7 +235,7 @@ class CompletionService:
                     PubHandler.stop(user, generate_task_id)
                     try:
                         pubsub.close()
-                    except:
+                    except Exception:
                         pass
 
         countdown_thread = threading.Thread(target=close_pubsub)
@@ -243,9 +244,9 @@ class CompletionService:
         return countdown_thread
 
     @classmethod
-    def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
+    def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
                                 message_id: str, streaming: bool = True,
-                                retriever_from: str = 'dev') -> Union[dict | Generator]:
+                                retriever_from: str = 'dev') -> Union[dict, Generator]:
         if not user:
             raise ValueError('user cannot be None')
 
@@ -341,7 +342,7 @@ class CompletionService:
         return filtered_inputs
 
     @classmethod
-    def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict | Generator]:
+    def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict, Generator]:
         generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
         if not streaming:
             try:
@@ -386,6 +387,8 @@ class CompletionService:
                                 break
                             if event == 'message':
                                 yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
+                            elif event == 'message_replace':
+                                yield "data: " + json.dumps(cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
                             elif event == 'chain':
                                 yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
                             elif event == 'agent_thought':
@@ -427,6 +430,21 @@ class CompletionService:
 
         return response_data
 
+    @classmethod
+    def get_message_replace_response_data(cls, data: dict):
+        response_data = {
+            'event': 'message_replace',
+            'task_id': data.get('task_id'),
+            'id': data.get('message_id'),
+            'answer': data.get('text'),
+            'created_at': int(time.time())
+        }
+
+        if data.get('mode') == 'chat':
+            response_data['conversation_id'] = data.get('conversation_id')
+
+        return response_data
+
     @classmethod
     def get_blocking_message_response_data(cls, data: dict):
         message = data.get('message')
@@ -508,6 +526,7 @@ class CompletionService:
 
         # handle errors
         llm_errors = {
+            'ValueError': LLMBadRequestError,
             'LLMBadRequestError': LLMBadRequestError,
             'LLMAPIConnectionError': LLMAPIConnectionError,
             'LLMAPIUnavailableError': LLMAPIUnavailableError,

+ 20 - 0
api/services/moderation_service.py

@@ -0,0 +1,20 @@
+from models.model import AppModelConfig, App
+from core.moderation.factory import ModerationFactory, ModerationOutputsResult
+from extensions.ext_database import db
+
+
+class ModerationService:
+
+    def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
+        app_model_config: AppModelConfig = None
+
+        app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
+
+        if not app_model_config:
+            raise ValueError("app model config not found")
+
+        name = app_model_config.sensitive_word_avoidance_dict['type']
+        config = app_model_config.sensitive_word_avoidance_dict['config']
+
+        moderation = ModerationFactory(name, app_id, app_model.tenant_id, config)
+        return moderation.moderation_for_outputs(text)