Browse Source

Feat/fix ops trace (#5672)

Co-authored-by: takatost <takatost@gmail.com>
Joe 9 months ago
parent
commit
e8b8f6c6dd

+ 1 - 1
.devcontainer/post_create_command.sh

@@ -3,7 +3,7 @@
 cd web && npm install
 cd web && npm install
 
 
 echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
 echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
-echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail"' >> ~/.bashrc
+echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace"' >> ~/.bashrc
 echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
 echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
 echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
 echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
 
 

+ 13 - 1
.vscode/launch.json

@@ -37,7 +37,19 @@
                 "FLASK_DEBUG": "1",
                 "FLASK_DEBUG": "1",
                 "GEVENT_SUPPORT": "True"
                 "GEVENT_SUPPORT": "True"
             },
             },
-            "args": ["-A", "app.celery", "worker", "-P", "gevent", "-c", "1", "--loglevel", "info", "-Q", "dataset,generation,mail"],
+            "args": [
+                "-A",
+                "app.celery",
+                "worker",
+                "-P",
+                "gevent",
+                "-c",
+                "1",
+                "--loglevel",
+                "info",
+                "-Q",
+                "dataset,generation,mail,ops_trace"
+            ]
         },
         },
     ]
     ]
 }
 }

+ 1 - 1
api/README.md

@@ -66,7 +66,7 @@
 10. If you need to debug local async processing, please start the worker service.
 10. If you need to debug local async processing, please start the worker service.
 
 
    ```bash
    ```bash
-   poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail
+   poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace
    ```
    ```
 
 
    The started celery app handles the async tasks, e.g. dataset importing and documents indexing.
    The started celery app handles the async tasks, e.g. dataset importing and documents indexing.

+ 0 - 2
api/app.py

@@ -26,7 +26,6 @@ from werkzeug.exceptions import Unauthorized
 from commands import register_commands
 from commands import register_commands
 
 
 # DO NOT REMOVE BELOW
 # DO NOT REMOVE BELOW
-from events import event_handlers
 from extensions import (
 from extensions import (
     ext_celery,
     ext_celery,
     ext_code_based_extension,
     ext_code_based_extension,
@@ -43,7 +42,6 @@ from extensions import (
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_login import login_manager
 from extensions.ext_login import login_manager
 from libs.passport import PassportService
 from libs.passport import PassportService
-from models import account, dataset, model, source, task, tool, tools, web
 from services.account_service import AccountService
 from services.account_service import AccountService
 
 
 # DO NOT REMOVE ABOVE
 # DO NOT REMOVE ABOVE

+ 1 - 1
api/core/moderation/input_moderation.py

@@ -57,7 +57,7 @@ class InputModeration:
                     timer=timer
                     timer=timer
                 )
                 )
             )
             )
-        
+
         if not moderation_result.flagged:
         if not moderation_result.flagged:
             return False, inputs, query
             return False, inputs, query
 
 

+ 11 - 1
api/core/ops/entities/trace_entity.py

@@ -94,5 +94,15 @@ class ToolTraceInfo(BaseTraceInfo):
 
 
 
 
 class GenerateNameTraceInfo(BaseTraceInfo):
 class GenerateNameTraceInfo(BaseTraceInfo):
-    conversation_id: str
+    conversation_id: Optional[str] = None
     tenant_id: str
     tenant_id: str
+
+trace_info_info_map = {
+    'WorkflowTraceInfo': WorkflowTraceInfo,
+    'MessageTraceInfo': MessageTraceInfo,
+    'ModerationTraceInfo': ModerationTraceInfo,
+    'SuggestedQuestionTraceInfo': SuggestedQuestionTraceInfo,
+    'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo,
+    'ToolTraceInfo': ToolTraceInfo,
+    'GenerateNameTraceInfo': GenerateNameTraceInfo,
+}

+ 29 - 2
api/core/ops/langfuse_trace/langfuse_trace.py

@@ -147,6 +147,7 @@ class LangFuseDataTrace(BaseTraceInstance):
             # add span
             # add span
             if trace_info.message_id:
             if trace_info.message_id:
                 span_data = LangfuseSpan(
                 span_data = LangfuseSpan(
+                    id=node_execution_id,
                     name=f"{node_name}_{node_execution_id}",
                     name=f"{node_name}_{node_execution_id}",
                     input=inputs,
                     input=inputs,
                     output=outputs,
                     output=outputs,
@@ -160,6 +161,7 @@ class LangFuseDataTrace(BaseTraceInstance):
                 )
                 )
             else:
             else:
                 span_data = LangfuseSpan(
                 span_data = LangfuseSpan(
+                    id=node_execution_id,
                     name=f"{node_name}_{node_execution_id}",
                     name=f"{node_name}_{node_execution_id}",
                     input=inputs,
                     input=inputs,
                     output=outputs,
                     output=outputs,
@@ -173,6 +175,30 @@ class LangFuseDataTrace(BaseTraceInstance):
 
 
             self.add_span(langfuse_span_data=span_data)
             self.add_span(langfuse_span_data=span_data)
 
 
+            process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
+            if process_data and process_data.get("model_mode") == "chat":
+                total_token = metadata.get("total_tokens", 0)
+                # add generation
+                generation_usage = GenerationUsage(
+                    totalTokens=total_token,
+                )
+
+                node_generation_data = LangfuseGeneration(
+                    name=f"generation_{node_execution_id}",
+                    trace_id=trace_id,
+                    parent_observation_id=node_execution_id,
+                    start_time=created_at,
+                    end_time=finished_at,
+                    input=inputs,
+                    output=outputs,
+                    metadata=metadata,
+                    level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR,
+                    status_message=trace_info.error if trace_info.error else "",
+                    usage=generation_usage,
+                )
+
+                self.add_generation(langfuse_generation_data=node_generation_data)
+
     def message_trace(
     def message_trace(
         self, trace_info: MessageTraceInfo, **kwargs
         self, trace_info: MessageTraceInfo, **kwargs
     ):
     ):
@@ -186,7 +212,7 @@ class LangFuseDataTrace(BaseTraceInstance):
         if message_data.from_end_user_id:
         if message_data.from_end_user_id:
             end_user_data: EndUser = db.session.query(EndUser).filter(
             end_user_data: EndUser = db.session.query(EndUser).filter(
                 EndUser.id == message_data.from_end_user_id
                 EndUser.id == message_data.from_end_user_id
-            ).first().session_id
+            ).first()
             user_id = end_user_data.session_id
             user_id = end_user_data.session_id
 
 
         trace_data = LangfuseTrace(
         trace_data = LangfuseTrace(
@@ -220,6 +246,7 @@ class LangFuseDataTrace(BaseTraceInstance):
             output=trace_info.answer_tokens,
             output=trace_info.answer_tokens,
             total=trace_info.total_tokens,
             total=trace_info.total_tokens,
             unit=UnitEnum.TOKENS,
             unit=UnitEnum.TOKENS,
+            totalCost=message_data.total_price,
         )
         )
 
 
         langfuse_generation_data = LangfuseGeneration(
         langfuse_generation_data = LangfuseGeneration(
@@ -303,7 +330,7 @@ class LangFuseDataTrace(BaseTraceInstance):
             start_time=trace_info.start_time,
             start_time=trace_info.start_time,
             end_time=trace_info.end_time,
             end_time=trace_info.end_time,
             metadata=trace_info.metadata,
             metadata=trace_info.metadata,
-            level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
+            level=LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR,
             status_message=trace_info.error,
             status_message=trace_info.error,
         )
         )
 
 

+ 78 - 37
api/core/ops/ops_trace_manager.py

@@ -1,16 +1,17 @@
 import json
 import json
+import logging
 import os
 import os
 import queue
 import queue
 import threading
 import threading
+import time
 from datetime import timedelta
 from datetime import timedelta
 from enum import Enum
 from enum import Enum
 from typing import Any, Optional, Union
 from typing import Any, Optional, Union
 from uuid import UUID
 from uuid import UUID
 
 
-from flask import Flask, current_app
+from flask import current_app
 
 
 from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
 from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
-from core.ops.base_trace_instance import BaseTraceInstance
 from core.ops.entities.config_entity import (
 from core.ops.entities.config_entity import (
     LangfuseConfig,
     LangfuseConfig,
     LangSmithConfig,
     LangSmithConfig,
@@ -31,6 +32,7 @@ from core.ops.utils import get_message_data
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig
 from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig
 from models.workflow import WorkflowAppLog, WorkflowRun
 from models.workflow import WorkflowAppLog, WorkflowRun
+from tasks.ops_trace_task import process_trace_tasks
 
 
 provider_config_map = {
 provider_config_map = {
     TracingProviderEnum.LANGFUSE.value: {
     TracingProviderEnum.LANGFUSE.value: {
@@ -105,7 +107,7 @@ class OpsTraceManager:
         return config_class(**new_config).model_dump()
         return config_class(**new_config).model_dump()
 
 
     @classmethod
     @classmethod
-    def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config:dict):
+    def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
         """
         """
         Decrypt tracing config
         Decrypt tracing config
         :param tracing_provider: tracing provider
         :param tracing_provider: tracing provider
@@ -295,11 +297,9 @@ class TraceTask:
         self.kwargs = kwargs
         self.kwargs = kwargs
         self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
         self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
 
 
-    def execute(self, trace_instance: BaseTraceInstance):
+    def execute(self):
         method_name, trace_info = self.preprocess()
         method_name, trace_info = self.preprocess()
-        if trace_instance:
-            method = trace_instance.trace
-            method(trace_info)
+        return trace_info
 
 
     def preprocess(self):
     def preprocess(self):
         if self.trace_type == TraceTaskName.CONVERSATION_TRACE:
         if self.trace_type == TraceTaskName.CONVERSATION_TRACE:
@@ -372,7 +372,7 @@ class TraceTask:
         }
         }
 
 
         workflow_trace_info = WorkflowTraceInfo(
         workflow_trace_info = WorkflowTraceInfo(
-            workflow_data=workflow_run,
+            workflow_data=workflow_run.to_dict(),
             conversation_id=conversation_id,
             conversation_id=conversation_id,
             workflow_id=workflow_id,
             workflow_id=workflow_id,
             tenant_id=tenant_id,
             tenant_id=tenant_id,
@@ -427,7 +427,8 @@ class TraceTask:
         message_tokens = message_data.message_tokens
         message_tokens = message_data.message_tokens
 
 
         message_trace_info = MessageTraceInfo(
         message_trace_info = MessageTraceInfo(
-            message_data=message_data,
+            message_id=message_id,
+            message_data=message_data.to_dict(),
             conversation_model=conversation_mode,
             conversation_model=conversation_mode,
             message_tokens=message_tokens,
             message_tokens=message_tokens,
             answer_tokens=message_data.answer_tokens,
             answer_tokens=message_data.answer_tokens,
@@ -469,7 +470,7 @@ class TraceTask:
         moderation_trace_info = ModerationTraceInfo(
         moderation_trace_info = ModerationTraceInfo(
             message_id=workflow_app_log_id if workflow_app_log_id else message_id,
             message_id=workflow_app_log_id if workflow_app_log_id else message_id,
             inputs=inputs,
             inputs=inputs,
-            message_data=message_data,
+            message_data=message_data.to_dict(),
             flagged=moderation_result.flagged,
             flagged=moderation_result.flagged,
             action=moderation_result.action,
             action=moderation_result.action,
             preset_response=moderation_result.preset_response,
             preset_response=moderation_result.preset_response,
@@ -508,7 +509,7 @@ class TraceTask:
 
 
         suggested_question_trace_info = SuggestedQuestionTraceInfo(
         suggested_question_trace_info = SuggestedQuestionTraceInfo(
             message_id=workflow_app_log_id if workflow_app_log_id else message_id,
             message_id=workflow_app_log_id if workflow_app_log_id else message_id,
-            message_data=message_data,
+            message_data=message_data.to_dict(),
             inputs=message_data.message,
             inputs=message_data.message,
             outputs=message_data.answer,
             outputs=message_data.answer,
             start_time=timer.get("start"),
             start_time=timer.get("start"),
@@ -550,11 +551,11 @@ class TraceTask:
         dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
         dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
             message_id=message_id,
             message_id=message_id,
             inputs=message_data.query if message_data.query else message_data.inputs,
             inputs=message_data.query if message_data.query else message_data.inputs,
-            documents=documents,
+            documents=[doc.model_dump() for doc in documents],
             start_time=timer.get("start"),
             start_time=timer.get("start"),
             end_time=timer.get("end"),
             end_time=timer.get("end"),
             metadata=metadata,
             metadata=metadata,
-            message_data=message_data,
+            message_data=message_data.to_dict(),
         )
         )
 
 
         return dataset_retrieval_trace_info
         return dataset_retrieval_trace_info
@@ -613,7 +614,7 @@ class TraceTask:
 
 
         tool_trace_info = ToolTraceInfo(
         tool_trace_info = ToolTraceInfo(
             message_id=message_id,
             message_id=message_id,
-            message_data=message_data,
+            message_data=message_data.to_dict(),
             tool_name=tool_name,
             tool_name=tool_name,
             start_time=timer.get("start") if timer else created_time,
             start_time=timer.get("start") if timer else created_time,
             end_time=timer.get("end") if timer else end_time,
             end_time=timer.get("end") if timer else end_time,
@@ -657,31 +658,71 @@ class TraceTask:
         return generate_name_trace_info
         return generate_name_trace_info
 
 
 
 
+trace_manager_timer = None
+trace_manager_queue = queue.Queue()
+trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 1))
+trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
+
+
 class TraceQueueManager:
 class TraceQueueManager:
     def __init__(self, app_id=None, conversation_id=None, message_id=None):
     def __init__(self, app_id=None, conversation_id=None, message_id=None):
-        tracing_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id)
-        self.queue = queue.Queue()
-        self.is_running = True
-        self.thread = threading.Thread(
-            target=self.process_queue, kwargs={
-                'flask_app': current_app._get_current_object(),
-                'trace_instance': tracing_instance
-            }
-        )
-        self.thread.start()
+        global trace_manager_timer
 
 
-    def stop(self):
-        self.is_running = False
-
-    def process_queue(self, flask_app: Flask, trace_instance: BaseTraceInstance):
-        with flask_app.app_context():
-            while self.is_running:
-                try:
-                    task = self.queue.get(timeout=60)
-                    task.execute(trace_instance)
-                    self.queue.task_done()
-                except queue.Empty:
-                    self.stop()
+        self.app_id = app_id
+        self.conversation_id = conversation_id
+        self.message_id = message_id
+        self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id)
+        self.flask_app = current_app._get_current_object()
+        if trace_manager_timer is None:
+            self.start_timer()
 
 
     def add_trace_task(self, trace_task: TraceTask):
     def add_trace_task(self, trace_task: TraceTask):
-        self.queue.put(trace_task)
+        global trace_manager_timer
+        global trace_manager_queue
+        try:
+            if self.trace_instance:
+                trace_manager_queue.put(trace_task)
+        except Exception as e:
+            logging.debug(f"Error adding trace task: {e}")
+        finally:
+            self.start_timer()
+
+    def collect_tasks(self):
+        global trace_manager_queue
+        tasks = []
+        while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
+            task = trace_manager_queue.get_nowait()
+            tasks.append(task)
+            trace_manager_queue.task_done()
+        return tasks
+
+    def run(self):
+        try:
+            tasks = self.collect_tasks()
+            if tasks:
+                self.send_to_celery(tasks)
+        except Exception as e:
+            logging.debug(f"Error processing trace tasks: {e}")
+
+    def start_timer(self):
+        global trace_manager_timer
+        if trace_manager_timer is None or not trace_manager_timer.is_alive():
+            trace_manager_timer = threading.Timer(
+                trace_manager_interval, self.run
+            )
+            trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
+            trace_manager_timer.daemon = False
+            trace_manager_timer.start()
+
+    def send_to_celery(self, tasks: list[TraceTask]):
+        with self.flask_app.app_context():
+            for task in tasks:
+                trace_info = task.execute()
+                task_data = {
+                    "app_id": self.app_id,
+                    "conversation_id": self.conversation_id,
+                    "message_id": self.message_id,
+                    "trace_info_type": type(trace_info).__name__,
+                    "trace_info": trace_info.model_dump() if trace_info else {},
+                }
+                process_trace_tasks.delay(task_data)

+ 2 - 2
api/core/rag/retrieval/dataset_retrieval.py

@@ -12,7 +12,7 @@ from core.model_manager import ModelInstance, ModelManager
 from core.model_runtime.entities.message_entities import PromptMessageTool
 from core.model_runtime.entities.message_entities import PromptMessageTool
 from core.model_runtime.entities.model_entities import ModelFeature, ModelType
 from core.model_runtime.entities.model_entities import ModelFeature, ModelType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from core.ops.ops_trace_manager import TraceTask, TraceTaskName
+from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
 from core.ops.utils import measure_time
 from core.ops.utils import measure_time
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rag.models.document import Document
 from core.rag.models.document import Document
@@ -357,7 +357,7 @@ class DatasetRetrieval:
             db.session.commit()
             db.session.commit()
 
 
         # get tracing instance
         # get tracing instance
-        trace_manager = self.application_generate_entity.trace_manager if self.application_generate_entity else None
+        trace_manager: TraceQueueManager = self.application_generate_entity.trace_manager if self.application_generate_entity else None
         if trace_manager:
         if trace_manager:
             trace_manager.add_trace_task(
             trace_manager.add_trace_task(
                 TraceTask(
                 TraceTask(

+ 1 - 1
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py

@@ -94,7 +94,7 @@ class ParameterExtractorNode(LLMNode):
         memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
         memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
 
 
         if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \
         if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \
-                and node_data.reasoning_mode == 'function_call':
+            and node_data.reasoning_mode == 'function_call':
             # use function call 
             # use function call 
             prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
             prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
                 node_data, query, variable_pool, model_config, memory
                 node_data, query, variable_pool, model_config, memory

+ 1 - 1
api/docker/entrypoint.sh

@@ -9,7 +9,7 @@ fi
 
 
 if [[ "${MODE}" == "worker" ]]; then
 if [[ "${MODE}" == "worker" ]]; then
   celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} -c ${CELERY_WORKER_AMOUNT:-1} --loglevel INFO \
   celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} -c ${CELERY_WORKER_AMOUNT:-1} --loglevel INFO \
-    -Q ${CELERY_QUEUES:-dataset,generation,mail}
+    -Q ${CELERY_QUEUES:-dataset,generation,mail,ops_trace}
 elif [[ "${MODE}" == "beat" ]]; then
 elif [[ "${MODE}" == "beat" ]]; then
   celery -A app.celery beat --loglevel INFO
   celery -A app.celery beat --loglevel INFO
 else
 else

+ 1 - 7
api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py

@@ -31,17 +31,11 @@ def upgrade():
     with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op:
     with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op:
         batch_op.create_index('tracing_app_config_app_id_idx', ['app_id'], unique=False)
         batch_op.create_index('tracing_app_config_app_id_idx', ['app_id'], unique=False)
 
 
-    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
-        batch_op.add_column(sa.Column('trace_config', sa.Text(), nullable=True))
-
     # ### end Alembic commands ###
     # ### end Alembic commands ###
 
 
 
 
 def downgrade():
 def downgrade():
-    # ### commands auto generated by Alembic - please adjust! ###
-    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
-        batch_op.drop_column('trace_config')
-
+    # ### commands auto generated by Alembic - please adjust! ##
     with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op:
     with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op:
         batch_op.drop_index('tracing_app_config_app_id_idx')
         batch_op.drop_index('tracing_app_config_app_id_idx')
 
 

+ 0 - 7
api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py

@@ -35,18 +35,11 @@ def upgrade():
 
 
     with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op:
     with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op:
         batch_op.drop_index('tracing_app_config_app_id_idx')
         batch_op.drop_index('tracing_app_config_app_id_idx')
-
-    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
-        batch_op.drop_column('trace_config')
-
     # ### end Alembic commands ###
     # ### end Alembic commands ###
 
 
 
 
 def downgrade():
 def downgrade():
     # ### commands auto generated by Alembic - please adjust! ###
     # ### commands auto generated by Alembic - please adjust! ###
-    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
-        batch_op.add_column(sa.Column('trace_config', sa.TEXT(), autoincrement=False, nullable=True))
-
     op.create_table('tracing_app_configs',
     op.create_table('tracing_app_configs',
     sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
     sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
     sa.Column('app_id', sa.UUID(), autoincrement=False, nullable=False),
     sa.Column('app_id', sa.UUID(), autoincrement=False, nullable=False),

+ 95 - 0
api/models/dataset.py

@@ -352,6 +352,101 @@ class Document(db.Model):
         return DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) \
         return DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) \
             .filter(DocumentSegment.document_id == self.id).scalar()
             .filter(DocumentSegment.document_id == self.id).scalar()
 
 
+    def to_dict(self):
+        return {
+            'id': self.id,
+            'tenant_id': self.tenant_id,
+            'dataset_id': self.dataset_id,
+            'position': self.position,
+            'data_source_type': self.data_source_type,
+            'data_source_info': self.data_source_info,
+            'dataset_process_rule_id': self.dataset_process_rule_id,
+            'batch': self.batch,
+            'name': self.name,
+            'created_from': self.created_from,
+            'created_by': self.created_by,
+            'created_api_request_id': self.created_api_request_id,
+            'created_at': self.created_at,
+            'processing_started_at': self.processing_started_at,
+            'file_id': self.file_id,
+            'word_count': self.word_count,
+            'parsing_completed_at': self.parsing_completed_at,
+            'cleaning_completed_at': self.cleaning_completed_at,
+            'splitting_completed_at': self.splitting_completed_at,
+            'tokens': self.tokens,
+            'indexing_latency': self.indexing_latency,
+            'completed_at': self.completed_at,
+            'is_paused': self.is_paused,
+            'paused_by': self.paused_by,
+            'paused_at': self.paused_at,
+            'error': self.error,
+            'stopped_at': self.stopped_at,
+            'indexing_status': self.indexing_status,
+            'enabled': self.enabled,
+            'disabled_at': self.disabled_at,
+            'disabled_by': self.disabled_by,
+            'archived': self.archived,
+            'archived_reason': self.archived_reason,
+            'archived_by': self.archived_by,
+            'archived_at': self.archived_at,
+            'updated_at': self.updated_at,
+            'doc_type': self.doc_type,
+            'doc_metadata': self.doc_metadata,
+            'doc_form': self.doc_form,
+            'doc_language': self.doc_language,
+            'display_status': self.display_status,
+            'data_source_info_dict': self.data_source_info_dict,
+            'average_segment_length': self.average_segment_length,
+            'dataset_process_rule': self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
+            'dataset': self.dataset.to_dict() if self.dataset else None,
+            'segment_count': self.segment_count,
+            'hit_count': self.hit_count
+        }
+
+    @classmethod
+    def from_dict(cls, data: dict):
+        return cls(
+            id=data.get('id'),
+            tenant_id=data.get('tenant_id'),
+            dataset_id=data.get('dataset_id'),
+            position=data.get('position'),
+            data_source_type=data.get('data_source_type'),
+            data_source_info=data.get('data_source_info'),
+            dataset_process_rule_id=data.get('dataset_process_rule_id'),
+            batch=data.get('batch'),
+            name=data.get('name'),
+            created_from=data.get('created_from'),
+            created_by=data.get('created_by'),
+            created_api_request_id=data.get('created_api_request_id'),
+            created_at=data.get('created_at'),
+            processing_started_at=data.get('processing_started_at'),
+            file_id=data.get('file_id'),
+            word_count=data.get('word_count'),
+            parsing_completed_at=data.get('parsing_completed_at'),
+            cleaning_completed_at=data.get('cleaning_completed_at'),
+            splitting_completed_at=data.get('splitting_completed_at'),
+            tokens=data.get('tokens'),
+            indexing_latency=data.get('indexing_latency'),
+            completed_at=data.get('completed_at'),
+            is_paused=data.get('is_paused'),
+            paused_by=data.get('paused_by'),
+            paused_at=data.get('paused_at'),
+            error=data.get('error'),
+            stopped_at=data.get('stopped_at'),
+            indexing_status=data.get('indexing_status'),
+            enabled=data.get('enabled'),
+            disabled_at=data.get('disabled_at'),
+            disabled_by=data.get('disabled_by'),
+            archived=data.get('archived'),
+            archived_reason=data.get('archived_reason'),
+            archived_by=data.get('archived_by'),
+            archived_at=data.get('archived_at'),
+            updated_at=data.get('updated_at'),
+            doc_type=data.get('doc_type'),
+            doc_metadata=data.get('doc_metadata'),
+            doc_form=data.get('doc_form'),
+            doc_language=data.get('doc_language')
+        )
 
 
 class DocumentSegment(db.Model):
 class DocumentSegment(db.Model):
     __tablename__ = 'document_segments'
     __tablename__ = 'document_segments'

+ 43 - 0
api/models/model.py

@@ -838,6 +838,49 @@ class Message(db.Model):
 
 
         return None
         return None
 
 
+    def to_dict(self) -> dict:
+        return {
+            'id': self.id,
+            'app_id': self.app_id,
+            'conversation_id': self.conversation_id,
+            'inputs': self.inputs,
+            'query': self.query,
+            'message': self.message,
+            'answer': self.answer,
+            'status': self.status,
+            'error': self.error,
+            'message_metadata': self.message_metadata_dict,
+            'from_source': self.from_source,
+            'from_end_user_id': self.from_end_user_id,
+            'from_account_id': self.from_account_id,
+            'created_at': self.created_at.isoformat(),
+            'updated_at': self.updated_at.isoformat(),
+            'agent_based': self.agent_based,
+            'workflow_run_id': self.workflow_run_id
+        }
+
+    @classmethod
+    def from_dict(cls, data: dict):
+        return cls(
+            id=data['id'],
+            app_id=data['app_id'],
+            conversation_id=data['conversation_id'],
+            inputs=data['inputs'],
+            query=data['query'],
+            message=data['message'],
+            answer=data['answer'],
+            status=data['status'],
+            error=data['error'],
+            message_metadata=json.dumps(data['message_metadata']),
+            from_source=data['from_source'],
+            from_end_user_id=data['from_end_user_id'],
+            from_account_id=data['from_account_id'],
+            created_at=data['created_at'],
+            updated_at=data['updated_at'],
+            agent_based=data['agent_based'],
+            workflow_run_id=data['workflow_run_id']
+        )
+
 
 
 class MessageFeedback(db.Model):
 class MessageFeedback(db.Model):
     __tablename__ = 'message_feedbacks'
     __tablename__ = 'message_feedbacks'

+ 49 - 0
api/models/workflow.py

@@ -324,6 +324,55 @@ class WorkflowRun(db.Model):
     def workflow(self):
     def workflow(self):
         return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first()
         return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first()
 
 
+    def to_dict(self):
+        return {
+            'id': self.id,
+            'tenant_id': self.tenant_id,
+            'app_id': self.app_id,
+            'sequence_number': self.sequence_number,
+            'workflow_id': self.workflow_id,
+            'type': self.type,
+            'triggered_from': self.triggered_from,
+            'version': self.version,
+            'graph': self.graph_dict,
+            'inputs': self.inputs_dict,
+            'status': self.status,
+            'outputs': self.outputs_dict,
+            'error': self.error,
+            'elapsed_time': self.elapsed_time,
+            'total_tokens': self.total_tokens,
+            'total_steps': self.total_steps,
+            'created_by_role': self.created_by_role,
+            'created_by': self.created_by,
+            'created_at': self.created_at,
+            'finished_at': self.finished_at,
+        }
+
+    @classmethod
+    def from_dict(cls, data: dict) -> 'WorkflowRun':
+        return cls(
+            id=data.get('id'),
+            tenant_id=data.get('tenant_id'),
+            app_id=data.get('app_id'),
+            sequence_number=data.get('sequence_number'),
+            workflow_id=data.get('workflow_id'),
+            type=data.get('type'),
+            triggered_from=data.get('triggered_from'),
+            version=data.get('version'),
+            graph=json.dumps(data.get('graph')),
+            inputs=json.dumps(data.get('inputs')),
+            status=data.get('status'),
+            outputs=json.dumps(data.get('outputs')),
+            error=data.get('error'),
+            elapsed_time=data.get('elapsed_time'),
+            total_tokens=data.get('total_tokens'),
+            total_steps=data.get('total_steps'),
+            created_by_role=data.get('created_by_role'),
+            created_by=data.get('created_by'),
+            created_at=data.get('created_at'),
+            finished_at=data.get('finished_at'),
+        )
+
 
 
 class WorkflowNodeExecutionTriggeredFrom(Enum):
 class WorkflowNodeExecutionTriggeredFrom(Enum):
     """
     """

+ 46 - 0
api/tasks/ops_trace_task.py

@@ -0,0 +1,46 @@
+import logging
+import time
+
+from celery import shared_task
+from flask import current_app
+
+from core.ops.entities.trace_entity import trace_info_info_map
+from core.rag.models.document import Document
+from models.model import Message
+from models.workflow import WorkflowRun
+
+
+@shared_task(queue='ops_trace')
+def process_trace_tasks(tasks_data):
+    """
+    Async process trace tasks
+    :param tasks_data: List of dictionaries containing task data
+
+    Usage: process_trace_tasks.delay(tasks_data)
+    """
+    from core.ops.ops_trace_manager import OpsTraceManager
+
+    trace_info = tasks_data.get('trace_info')
+    app_id = tasks_data.get('app_id')
+    conversation_id = tasks_data.get('conversation_id')
+    message_id = tasks_data.get('message_id')
+    trace_info_type = tasks_data.get('trace_info_type')
+    trace_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id)
+
+    if trace_info.get('message_data'):
+        trace_info['message_data'] = Message.from_dict(data=trace_info['message_data'])
+    if trace_info.get('workflow_data'):
+        trace_info['workflow_data'] = WorkflowRun.from_dict(data=trace_info['workflow_data'])
+    if trace_info.get('documents'):
+        trace_info['documents'] = [Document(**doc) for doc in trace_info['documents']]
+
+    try:
+        if trace_instance:
+            with current_app.app_context():
+                trace_type = trace_info_info_map.get(trace_info_type)
+                if trace_type:
+                    trace_info = trace_type(**trace_info)
+                trace_instance.trace(trace_info)
+            end_at = time.perf_counter()
+    except Exception:
+        logging.exception("Processing trace tasks failed")