Procházet zdrojové kódy

Feat/fix ops trace (#5672)

Co-authored-by: takatost <takatost@gmail.com>
Joe před 9 měsíci
rodič
revize
e8b8f6c6dd

+ 1 - 1
.devcontainer/post_create_command.sh

@@ -3,7 +3,7 @@
 cd web && npm install
 
 echo 'alias start-api="cd /workspaces/dify/api && flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc
-echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail"' >> ~/.bashrc
+echo 'alias start-worker="cd /workspaces/dify/api && celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace"' >> ~/.bashrc
 echo 'alias start-web="cd /workspaces/dify/web && npm run dev"' >> ~/.bashrc
 echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify up -d"' >> ~/.bashrc
 

+ 13 - 1
.vscode/launch.json

@@ -37,7 +37,19 @@
                 "FLASK_DEBUG": "1",
                 "GEVENT_SUPPORT": "True"
             },
-            "args": ["-A", "app.celery", "worker", "-P", "gevent", "-c", "1", "--loglevel", "info", "-Q", "dataset,generation,mail"],
+            "args": [
+                "-A",
+                "app.celery",
+                "worker",
+                "-P",
+                "gevent",
+                "-c",
+                "1",
+                "--loglevel",
+                "info",
+                "-Q",
+                "dataset,generation,mail,ops_trace"
+            ]
         },
     ]
 }

+ 1 - 1
api/README.md

@@ -66,7 +66,7 @@
 10. If you need to debug local async processing, please start the worker service.
 
    ```bash
-   poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail
+   poetry run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace
    ```
 
    The started celery app handles the async tasks, e.g. dataset importing and documents indexing.

+ 0 - 2
api/app.py

@@ -26,7 +26,6 @@ from werkzeug.exceptions import Unauthorized
 from commands import register_commands
 
 # DO NOT REMOVE BELOW
-from events import event_handlers
 from extensions import (
     ext_celery,
     ext_code_based_extension,
@@ -43,7 +42,6 @@ from extensions import (
 from extensions.ext_database import db
 from extensions.ext_login import login_manager
 from libs.passport import PassportService
-from models import account, dataset, model, source, task, tool, tools, web
 from services.account_service import AccountService
 
 # DO NOT REMOVE ABOVE

+ 1 - 1
api/core/moderation/input_moderation.py

@@ -57,7 +57,7 @@ class InputModeration:
                     timer=timer
                 )
             )
-        
+
         if not moderation_result.flagged:
             return False, inputs, query
 

+ 11 - 1
api/core/ops/entities/trace_entity.py

@@ -94,5 +94,15 @@ class ToolTraceInfo(BaseTraceInfo):
 
 
 class GenerateNameTraceInfo(BaseTraceInfo):
-    conversation_id: str
+    conversation_id: Optional[str] = None
     tenant_id: str
+
+trace_info_info_map = {
+    'WorkflowTraceInfo': WorkflowTraceInfo,
+    'MessageTraceInfo': MessageTraceInfo,
+    'ModerationTraceInfo': ModerationTraceInfo,
+    'SuggestedQuestionTraceInfo': SuggestedQuestionTraceInfo,
+    'DatasetRetrievalTraceInfo': DatasetRetrievalTraceInfo,
+    'ToolTraceInfo': ToolTraceInfo,
+    'GenerateNameTraceInfo': GenerateNameTraceInfo,
+}

+ 29 - 2
api/core/ops/langfuse_trace/langfuse_trace.py

@@ -147,6 +147,7 @@ class LangFuseDataTrace(BaseTraceInstance):
             # add span
             if trace_info.message_id:
                 span_data = LangfuseSpan(
+                    id=node_execution_id,
                     name=f"{node_name}_{node_execution_id}",
                     input=inputs,
                     output=outputs,
@@ -160,6 +161,7 @@ class LangFuseDataTrace(BaseTraceInstance):
                 )
             else:
                 span_data = LangfuseSpan(
+                    id=node_execution_id,
                     name=f"{node_name}_{node_execution_id}",
                     input=inputs,
                     output=outputs,
@@ -173,6 +175,30 @@ class LangFuseDataTrace(BaseTraceInstance):
 
             self.add_span(langfuse_span_data=span_data)
 
+            process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
+            if process_data and process_data.get("model_mode") == "chat":
+                total_token = metadata.get("total_tokens", 0)
+                # add generation
+                generation_usage = GenerationUsage(
+                    totalTokens=total_token,
+                )
+
+                node_generation_data = LangfuseGeneration(
+                    name=f"generation_{node_execution_id}",
+                    trace_id=trace_id,
+                    parent_observation_id=node_execution_id,
+                    start_time=created_at,
+                    end_time=finished_at,
+                    input=inputs,
+                    output=outputs,
+                    metadata=metadata,
+                    level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR,
+                    status_message=trace_info.error if trace_info.error else "",
+                    usage=generation_usage,
+                )
+
+                self.add_generation(langfuse_generation_data=node_generation_data)
+
     def message_trace(
         self, trace_info: MessageTraceInfo, **kwargs
     ):
@@ -186,7 +212,7 @@ class LangFuseDataTrace(BaseTraceInstance):
         if message_data.from_end_user_id:
             end_user_data: EndUser = db.session.query(EndUser).filter(
                 EndUser.id == message_data.from_end_user_id
-            ).first().session_id
+            ).first()
             user_id = end_user_data.session_id
 
         trace_data = LangfuseTrace(
@@ -220,6 +246,7 @@ class LangFuseDataTrace(BaseTraceInstance):
             output=trace_info.answer_tokens,
             total=trace_info.total_tokens,
             unit=UnitEnum.TOKENS,
+            totalCost=message_data.total_price,
         )
 
         langfuse_generation_data = LangfuseGeneration(
@@ -303,7 +330,7 @@ class LangFuseDataTrace(BaseTraceInstance):
             start_time=trace_info.start_time,
             end_time=trace_info.end_time,
             metadata=trace_info.metadata,
-            level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
+            level=LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR,
             status_message=trace_info.error,
         )
 

+ 78 - 37
api/core/ops/ops_trace_manager.py

@@ -1,16 +1,17 @@
 import json
+import logging
 import os
 import queue
 import threading
+import time
 from datetime import timedelta
 from enum import Enum
 from typing import Any, Optional, Union
 from uuid import UUID
 
-from flask import Flask, current_app
+from flask import current_app
 
 from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
-from core.ops.base_trace_instance import BaseTraceInstance
 from core.ops.entities.config_entity import (
     LangfuseConfig,
     LangSmithConfig,
@@ -31,6 +32,7 @@ from core.ops.utils import get_message_data
 from extensions.ext_database import db
 from models.model import App, AppModelConfig, Conversation, Message, MessageAgentThought, MessageFile, TraceAppConfig
 from models.workflow import WorkflowAppLog, WorkflowRun
+from tasks.ops_trace_task import process_trace_tasks
 
 provider_config_map = {
     TracingProviderEnum.LANGFUSE.value: {
@@ -105,7 +107,7 @@ class OpsTraceManager:
         return config_class(**new_config).model_dump()
 
     @classmethod
-    def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config:dict):
+    def obfuscated_decrypt_token(cls, tracing_provider: str, decrypt_tracing_config: dict):
         """
         Decrypt tracing config
         :param tracing_provider: tracing provider
@@ -295,11 +297,9 @@ class TraceTask:
         self.kwargs = kwargs
         self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
 
-    def execute(self, trace_instance: BaseTraceInstance):
+    def execute(self):
         method_name, trace_info = self.preprocess()
-        if trace_instance:
-            method = trace_instance.trace
-            method(trace_info)
+        return trace_info
 
     def preprocess(self):
         if self.trace_type == TraceTaskName.CONVERSATION_TRACE:
@@ -372,7 +372,7 @@ class TraceTask:
         }
 
         workflow_trace_info = WorkflowTraceInfo(
-            workflow_data=workflow_run,
+            workflow_data=workflow_run.to_dict(),
             conversation_id=conversation_id,
             workflow_id=workflow_id,
             tenant_id=tenant_id,
@@ -427,7 +427,8 @@ class TraceTask:
         message_tokens = message_data.message_tokens
 
         message_trace_info = MessageTraceInfo(
-            message_data=message_data,
+            message_id=message_id,
+            message_data=message_data.to_dict(),
             conversation_model=conversation_mode,
             message_tokens=message_tokens,
             answer_tokens=message_data.answer_tokens,
@@ -469,7 +470,7 @@ class TraceTask:
         moderation_trace_info = ModerationTraceInfo(
             message_id=workflow_app_log_id if workflow_app_log_id else message_id,
             inputs=inputs,
-            message_data=message_data,
+            message_data=message_data.to_dict(),
             flagged=moderation_result.flagged,
             action=moderation_result.action,
             preset_response=moderation_result.preset_response,
@@ -508,7 +509,7 @@ class TraceTask:
 
         suggested_question_trace_info = SuggestedQuestionTraceInfo(
             message_id=workflow_app_log_id if workflow_app_log_id else message_id,
-            message_data=message_data,
+            message_data=message_data.to_dict(),
             inputs=message_data.message,
             outputs=message_data.answer,
             start_time=timer.get("start"),
@@ -550,11 +551,11 @@ class TraceTask:
         dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
             message_id=message_id,
             inputs=message_data.query if message_data.query else message_data.inputs,
-            documents=documents,
+            documents=[doc.model_dump() for doc in documents],
             start_time=timer.get("start"),
             end_time=timer.get("end"),
             metadata=metadata,
-            message_data=message_data,
+            message_data=message_data.to_dict(),
         )
 
         return dataset_retrieval_trace_info
@@ -613,7 +614,7 @@ class TraceTask:
 
         tool_trace_info = ToolTraceInfo(
             message_id=message_id,
-            message_data=message_data,
+            message_data=message_data.to_dict(),
             tool_name=tool_name,
             start_time=timer.get("start") if timer else created_time,
             end_time=timer.get("end") if timer else end_time,
@@ -657,31 +658,71 @@ class TraceTask:
         return generate_name_trace_info
 
 
+trace_manager_timer = None
+trace_manager_queue = queue.Queue()
+trace_manager_interval = int(os.getenv("TRACE_QUEUE_MANAGER_INTERVAL", 1))
+trace_manager_batch_size = int(os.getenv("TRACE_QUEUE_MANAGER_BATCH_SIZE", 100))
+
+
 class TraceQueueManager:
     def __init__(self, app_id=None, conversation_id=None, message_id=None):
-        tracing_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id)
-        self.queue = queue.Queue()
-        self.is_running = True
-        self.thread = threading.Thread(
-            target=self.process_queue, kwargs={
-                'flask_app': current_app._get_current_object(),
-                'trace_instance': tracing_instance
-            }
-        )
-        self.thread.start()
+        global trace_manager_timer
 
-    def stop(self):
-        self.is_running = False
-
-    def process_queue(self, flask_app: Flask, trace_instance: BaseTraceInstance):
-        with flask_app.app_context():
-            while self.is_running:
-                try:
-                    task = self.queue.get(timeout=60)
-                    task.execute(trace_instance)
-                    self.queue.task_done()
-                except queue.Empty:
-                    self.stop()
+        self.app_id = app_id
+        self.conversation_id = conversation_id
+        self.message_id = message_id
+        self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id)
+        self.flask_app = current_app._get_current_object()
+        if trace_manager_timer is None:
+            self.start_timer()
 
     def add_trace_task(self, trace_task: TraceTask):
-        self.queue.put(trace_task)
+        global trace_manager_timer
+        global trace_manager_queue
+        try:
+            if self.trace_instance:
+                trace_manager_queue.put(trace_task)
+        except Exception as e:
+            logging.debug(f"Error adding trace task: {e}")
+        finally:
+            self.start_timer()
+
+    def collect_tasks(self):
+        global trace_manager_queue
+        tasks = []
+        while len(tasks) < trace_manager_batch_size and not trace_manager_queue.empty():
+            task = trace_manager_queue.get_nowait()
+            tasks.append(task)
+            trace_manager_queue.task_done()
+        return tasks
+
+    def run(self):
+        try:
+            tasks = self.collect_tasks()
+            if tasks:
+                self.send_to_celery(tasks)
+        except Exception as e:
+            logging.debug(f"Error processing trace tasks: {e}")
+
+    def start_timer(self):
+        global trace_manager_timer
+        if trace_manager_timer is None or not trace_manager_timer.is_alive():
+            trace_manager_timer = threading.Timer(
+                trace_manager_interval, self.run
+            )
+            trace_manager_timer.name = f"trace_manager_timer_{time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}"
+            trace_manager_timer.daemon = False
+            trace_manager_timer.start()
+
+    def send_to_celery(self, tasks: list[TraceTask]):
+        with self.flask_app.app_context():
+            for task in tasks:
+                trace_info = task.execute()
+                task_data = {
+                    "app_id": self.app_id,
+                    "conversation_id": self.conversation_id,
+                    "message_id": self.message_id,
+                    "trace_info_type": type(trace_info).__name__,
+                    "trace_info": trace_info.model_dump() if trace_info else {},
+                }
+                process_trace_tasks.delay(task_data)

+ 2 - 2
api/core/rag/retrieval/dataset_retrieval.py

@@ -12,7 +12,7 @@ from core.model_manager import ModelInstance, ModelManager
 from core.model_runtime.entities.message_entities import PromptMessageTool
 from core.model_runtime.entities.model_entities import ModelFeature, ModelType
 from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
-from core.ops.ops_trace_manager import TraceTask, TraceTaskName
+from core.ops.ops_trace_manager import TraceQueueManager, TraceTask, TraceTaskName
 from core.ops.utils import measure_time
 from core.rag.datasource.retrieval_service import RetrievalService
 from core.rag.models.document import Document
@@ -357,7 +357,7 @@ class DatasetRetrieval:
             db.session.commit()
 
         # get tracing instance
-        trace_manager = self.application_generate_entity.trace_manager if self.application_generate_entity else None
+        trace_manager: TraceQueueManager = self.application_generate_entity.trace_manager if self.application_generate_entity else None
         if trace_manager:
             trace_manager.add_trace_task(
                 TraceTask(

+ 1 - 1
api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py

@@ -94,7 +94,7 @@ class ParameterExtractorNode(LLMNode):
         memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
 
         if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} \
-                and node_data.reasoning_mode == 'function_call':
+            and node_data.reasoning_mode == 'function_call':
             # use function call 
             prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
                 node_data, query, variable_pool, model_config, memory

+ 1 - 1
api/docker/entrypoint.sh

@@ -9,7 +9,7 @@ fi
 
 if [[ "${MODE}" == "worker" ]]; then
   celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} -c ${CELERY_WORKER_AMOUNT:-1} --loglevel INFO \
-    -Q ${CELERY_QUEUES:-dataset,generation,mail}
+    -Q ${CELERY_QUEUES:-dataset,generation,mail,ops_trace}
 elif [[ "${MODE}" == "beat" ]]; then
   celery -A app.celery beat --loglevel INFO
 else

+ 1 - 7
api/migrations/versions/04c602f5dc9b_update_appmodelconfig_and_add_table_.py

@@ -31,17 +31,11 @@ def upgrade():
     with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op:
         batch_op.create_index('tracing_app_config_app_id_idx', ['app_id'], unique=False)
 
-    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
-        batch_op.add_column(sa.Column('trace_config', sa.Text(), nullable=True))
-
     # ### end Alembic commands ###
 
 
 def downgrade():
-    # ### commands auto generated by Alembic - please adjust! ###
-    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
-        batch_op.drop_column('trace_config')
-
+    # ### commands auto generated by Alembic - please adjust! ##
     with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op:
         batch_op.drop_index('tracing_app_config_app_id_idx')
 

+ 0 - 7
api/migrations/versions/c031d46af369_remove_app_model_config_trace_config_.py

@@ -35,18 +35,11 @@ def upgrade():
 
     with op.batch_alter_table('tracing_app_configs', schema=None) as batch_op:
         batch_op.drop_index('tracing_app_config_app_id_idx')
-
-    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
-        batch_op.drop_column('trace_config')
-
     # ### end Alembic commands ###
 
 
 def downgrade():
     # ### commands auto generated by Alembic - please adjust! ###
-    with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
-        batch_op.add_column(sa.Column('trace_config', sa.TEXT(), autoincrement=False, nullable=True))
-
     op.create_table('tracing_app_configs',
     sa.Column('id', sa.UUID(), server_default=sa.text('uuid_generate_v4()'), autoincrement=False, nullable=False),
     sa.Column('app_id', sa.UUID(), autoincrement=False, nullable=False),

+ 95 - 0
api/models/dataset.py

@@ -352,6 +352,101 @@ class Document(db.Model):
         return DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) \
             .filter(DocumentSegment.document_id == self.id).scalar()
 
+    def to_dict(self):
+        return {
+            'id': self.id,
+            'tenant_id': self.tenant_id,
+            'dataset_id': self.dataset_id,
+            'position': self.position,
+            'data_source_type': self.data_source_type,
+            'data_source_info': self.data_source_info,
+            'dataset_process_rule_id': self.dataset_process_rule_id,
+            'batch': self.batch,
+            'name': self.name,
+            'created_from': self.created_from,
+            'created_by': self.created_by,
+            'created_api_request_id': self.created_api_request_id,
+            'created_at': self.created_at,
+            'processing_started_at': self.processing_started_at,
+            'file_id': self.file_id,
+            'word_count': self.word_count,
+            'parsing_completed_at': self.parsing_completed_at,
+            'cleaning_completed_at': self.cleaning_completed_at,
+            'splitting_completed_at': self.splitting_completed_at,
+            'tokens': self.tokens,
+            'indexing_latency': self.indexing_latency,
+            'completed_at': self.completed_at,
+            'is_paused': self.is_paused,
+            'paused_by': self.paused_by,
+            'paused_at': self.paused_at,
+            'error': self.error,
+            'stopped_at': self.stopped_at,
+            'indexing_status': self.indexing_status,
+            'enabled': self.enabled,
+            'disabled_at': self.disabled_at,
+            'disabled_by': self.disabled_by,
+            'archived': self.archived,
+            'archived_reason': self.archived_reason,
+            'archived_by': self.archived_by,
+            'archived_at': self.archived_at,
+            'updated_at': self.updated_at,
+            'doc_type': self.doc_type,
+            'doc_metadata': self.doc_metadata,
+            'doc_form': self.doc_form,
+            'doc_language': self.doc_language,
+            'display_status': self.display_status,
+            'data_source_info_dict': self.data_source_info_dict,
+            'average_segment_length': self.average_segment_length,
+            'dataset_process_rule': self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
+            'dataset': self.dataset.to_dict() if self.dataset else None,
+            'segment_count': self.segment_count,
+            'hit_count': self.hit_count
+        }
+
+    @classmethod
+    def from_dict(cls, data: dict):
+        return cls(
+            id=data.get('id'),
+            tenant_id=data.get('tenant_id'),
+            dataset_id=data.get('dataset_id'),
+            position=data.get('position'),
+            data_source_type=data.get('data_source_type'),
+            data_source_info=data.get('data_source_info'),
+            dataset_process_rule_id=data.get('dataset_process_rule_id'),
+            batch=data.get('batch'),
+            name=data.get('name'),
+            created_from=data.get('created_from'),
+            created_by=data.get('created_by'),
+            created_api_request_id=data.get('created_api_request_id'),
+            created_at=data.get('created_at'),
+            processing_started_at=data.get('processing_started_at'),
+            file_id=data.get('file_id'),
+            word_count=data.get('word_count'),
+            parsing_completed_at=data.get('parsing_completed_at'),
+            cleaning_completed_at=data.get('cleaning_completed_at'),
+            splitting_completed_at=data.get('splitting_completed_at'),
+            tokens=data.get('tokens'),
+            indexing_latency=data.get('indexing_latency'),
+            completed_at=data.get('completed_at'),
+            is_paused=data.get('is_paused'),
+            paused_by=data.get('paused_by'),
+            paused_at=data.get('paused_at'),
+            error=data.get('error'),
+            stopped_at=data.get('stopped_at'),
+            indexing_status=data.get('indexing_status'),
+            enabled=data.get('enabled'),
+            disabled_at=data.get('disabled_at'),
+            disabled_by=data.get('disabled_by'),
+            archived=data.get('archived'),
+            archived_reason=data.get('archived_reason'),
+            archived_by=data.get('archived_by'),
+            archived_at=data.get('archived_at'),
+            updated_at=data.get('updated_at'),
+            doc_type=data.get('doc_type'),
+            doc_metadata=data.get('doc_metadata'),
+            doc_form=data.get('doc_form'),
+            doc_language=data.get('doc_language')
+        )
 
 class DocumentSegment(db.Model):
     __tablename__ = 'document_segments'

+ 43 - 0
api/models/model.py

@@ -838,6 +838,49 @@ class Message(db.Model):
 
         return None
 
+    def to_dict(self) -> dict:
+        return {
+            'id': self.id,
+            'app_id': self.app_id,
+            'conversation_id': self.conversation_id,
+            'inputs': self.inputs,
+            'query': self.query,
+            'message': self.message,
+            'answer': self.answer,
+            'status': self.status,
+            'error': self.error,
+            'message_metadata': self.message_metadata_dict,
+            'from_source': self.from_source,
+            'from_end_user_id': self.from_end_user_id,
+            'from_account_id': self.from_account_id,
+            'created_at': self.created_at.isoformat(),
+            'updated_at': self.updated_at.isoformat(),
+            'agent_based': self.agent_based,
+            'workflow_run_id': self.workflow_run_id
+        }
+
+    @classmethod
+    def from_dict(cls, data: dict):
+        return cls(
+            id=data['id'],
+            app_id=data['app_id'],
+            conversation_id=data['conversation_id'],
+            inputs=data['inputs'],
+            query=data['query'],
+            message=data['message'],
+            answer=data['answer'],
+            status=data['status'],
+            error=data['error'],
+            message_metadata=json.dumps(data['message_metadata']),
+            from_source=data['from_source'],
+            from_end_user_id=data['from_end_user_id'],
+            from_account_id=data['from_account_id'],
+            created_at=data['created_at'],
+            updated_at=data['updated_at'],
+            agent_based=data['agent_based'],
+            workflow_run_id=data['workflow_run_id']
+        )
+
 
 class MessageFeedback(db.Model):
     __tablename__ = 'message_feedbacks'

+ 49 - 0
api/models/workflow.py

@@ -324,6 +324,55 @@ class WorkflowRun(db.Model):
     def workflow(self):
         return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first()
 
+    def to_dict(self):
+        return {
+            'id': self.id,
+            'tenant_id': self.tenant_id,
+            'app_id': self.app_id,
+            'sequence_number': self.sequence_number,
+            'workflow_id': self.workflow_id,
+            'type': self.type,
+            'triggered_from': self.triggered_from,
+            'version': self.version,
+            'graph': self.graph_dict,
+            'inputs': self.inputs_dict,
+            'status': self.status,
+            'outputs': self.outputs_dict,
+            'error': self.error,
+            'elapsed_time': self.elapsed_time,
+            'total_tokens': self.total_tokens,
+            'total_steps': self.total_steps,
+            'created_by_role': self.created_by_role,
+            'created_by': self.created_by,
+            'created_at': self.created_at,
+            'finished_at': self.finished_at,
+        }
+
+    @classmethod
+    def from_dict(cls, data: dict) -> 'WorkflowRun':
+        return cls(
+            id=data.get('id'),
+            tenant_id=data.get('tenant_id'),
+            app_id=data.get('app_id'),
+            sequence_number=data.get('sequence_number'),
+            workflow_id=data.get('workflow_id'),
+            type=data.get('type'),
+            triggered_from=data.get('triggered_from'),
+            version=data.get('version'),
+            graph=json.dumps(data.get('graph')),
+            inputs=json.dumps(data.get('inputs')),
+            status=data.get('status'),
+            outputs=json.dumps(data.get('outputs')),
+            error=data.get('error'),
+            elapsed_time=data.get('elapsed_time'),
+            total_tokens=data.get('total_tokens'),
+            total_steps=data.get('total_steps'),
+            created_by_role=data.get('created_by_role'),
+            created_by=data.get('created_by'),
+            created_at=data.get('created_at'),
+            finished_at=data.get('finished_at'),
+        )
+
 
 class WorkflowNodeExecutionTriggeredFrom(Enum):
     """

+ 46 - 0
api/tasks/ops_trace_task.py

@@ -0,0 +1,46 @@
+import logging
+import time
+
+from celery import shared_task
+from flask import current_app
+
+from core.ops.entities.trace_entity import trace_info_info_map
+from core.rag.models.document import Document
+from models.model import Message
+from models.workflow import WorkflowRun
+
+
+@shared_task(queue='ops_trace')
+def process_trace_tasks(tasks_data):
+    """
+    Async process trace tasks
+    :param tasks_data: List of dictionaries containing task data
+
+    Usage: process_trace_tasks.delay(tasks_data)
+    """
+    from core.ops.ops_trace_manager import OpsTraceManager
+
+    trace_info = tasks_data.get('trace_info')
+    app_id = tasks_data.get('app_id')
+    conversation_id = tasks_data.get('conversation_id')
+    message_id = tasks_data.get('message_id')
+    trace_info_type = tasks_data.get('trace_info_type')
+    trace_instance = OpsTraceManager.get_ops_trace_instance(app_id, conversation_id, message_id)
+
+    if trace_info.get('message_data'):
+        trace_info['message_data'] = Message.from_dict(data=trace_info['message_data'])
+    if trace_info.get('workflow_data'):
+        trace_info['workflow_data'] = WorkflowRun.from_dict(data=trace_info['workflow_data'])
+    if trace_info.get('documents'):
+        trace_info['documents'] = [Document(**doc) for doc in trace_info['documents']]
+
+    try:
+        if trace_instance:
+            with current_app.app_context():
+                trace_type = trace_info_info_map.get(trace_info_type)
+                if trace_type:
+                    trace_info = trace_type(**trace_info)
+                trace_instance.trace(trace_info)
+            end_at = time.perf_counter()
+    except Exception:
+        logging.exception("Processing trace tasks failed")