langfuse_trace.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. import json
  2. import logging
  3. import os
  4. from datetime import datetime, timedelta
  5. from typing import Optional
  6. from langfuse import Langfuse
  7. from core.ops.base_trace_instance import BaseTraceInstance
  8. from core.ops.entities.config_entity import LangfuseConfig
  9. from core.ops.entities.trace_entity import (
  10. BaseTraceInfo,
  11. DatasetRetrievalTraceInfo,
  12. GenerateNameTraceInfo,
  13. MessageTraceInfo,
  14. ModerationTraceInfo,
  15. SuggestedQuestionTraceInfo,
  16. ToolTraceInfo,
  17. WorkflowTraceInfo,
  18. )
  19. from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
  20. GenerationUsage,
  21. LangfuseGeneration,
  22. LangfuseSpan,
  23. LangfuseTrace,
  24. LevelEnum,
  25. UnitEnum,
  26. )
  27. from core.ops.utils import filter_none_values
  28. from extensions.ext_database import db
  29. from models.model import EndUser
  30. from models.workflow import WorkflowNodeExecution
  31. logger = logging.getLogger(__name__)
  32. class LangFuseDataTrace(BaseTraceInstance):
  33. def __init__(
  34. self,
  35. langfuse_config: LangfuseConfig,
  36. ):
  37. super().__init__(langfuse_config)
  38. self.langfuse_client = Langfuse(
  39. public_key=langfuse_config.public_key,
  40. secret_key=langfuse_config.secret_key,
  41. host=langfuse_config.host,
  42. )
  43. self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
  44. def trace(self, trace_info: BaseTraceInfo):
  45. if isinstance(trace_info, WorkflowTraceInfo):
  46. self.workflow_trace(trace_info)
  47. if isinstance(trace_info, MessageTraceInfo):
  48. self.message_trace(trace_info)
  49. if isinstance(trace_info, ModerationTraceInfo):
  50. self.moderation_trace(trace_info)
  51. if isinstance(trace_info, SuggestedQuestionTraceInfo):
  52. self.suggested_question_trace(trace_info)
  53. if isinstance(trace_info, DatasetRetrievalTraceInfo):
  54. self.dataset_retrieval_trace(trace_info)
  55. if isinstance(trace_info, ToolTraceInfo):
  56. self.tool_trace(trace_info)
  57. if isinstance(trace_info, GenerateNameTraceInfo):
  58. self.generate_name_trace(trace_info)
  59. def workflow_trace(self, trace_info: WorkflowTraceInfo):
  60. trace_id = trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id
  61. if trace_info.message_id:
  62. trace_id = trace_info.message_id
  63. name = f"message_{trace_info.message_id}"
  64. trace_data = LangfuseTrace(
  65. id=trace_info.message_id,
  66. user_id=trace_info.tenant_id,
  67. name=name,
  68. input=trace_info.workflow_run_inputs,
  69. output=trace_info.workflow_run_outputs,
  70. metadata=trace_info.metadata,
  71. session_id=trace_info.conversation_id,
  72. tags=["message", "workflow"],
  73. )
  74. self.add_trace(langfuse_trace_data=trace_data)
  75. workflow_span_data = LangfuseSpan(
  76. id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
  77. name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}",
  78. input=trace_info.workflow_run_inputs,
  79. output=trace_info.workflow_run_outputs,
  80. trace_id=trace_id,
  81. start_time=trace_info.start_time,
  82. end_time=trace_info.end_time,
  83. metadata=trace_info.metadata,
  84. level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
  85. status_message=trace_info.error if trace_info.error else "",
  86. )
  87. self.add_span(langfuse_span_data=workflow_span_data)
  88. else:
  89. trace_data = LangfuseTrace(
  90. id=trace_id,
  91. user_id=trace_info.tenant_id,
  92. name=f"workflow_{trace_info.workflow_app_log_id}" if trace_info.workflow_app_log_id else f"workflow_{trace_info.workflow_run_id}",
  93. input=trace_info.workflow_run_inputs,
  94. output=trace_info.workflow_run_outputs,
  95. metadata=trace_info.metadata,
  96. session_id=trace_info.conversation_id,
  97. tags=["workflow"],
  98. )
  99. self.add_trace(langfuse_trace_data=trace_data)
  100. # through workflow_run_id get all_nodes_execution
  101. workflow_nodes_executions = (
  102. db.session.query(WorkflowNodeExecution)
  103. .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
  104. .order_by(WorkflowNodeExecution.index.desc())
  105. .all()
  106. )
  107. for node_execution in workflow_nodes_executions:
  108. node_execution_id = node_execution.id
  109. tenant_id = node_execution.tenant_id
  110. app_id = node_execution.app_id
  111. node_name = node_execution.title
  112. node_type = node_execution.node_type
  113. status = node_execution.status
  114. if node_type == "llm":
  115. inputs = json.loads(node_execution.process_data).get(
  116. "prompts", {}
  117. ) if node_execution.process_data else {}
  118. else:
  119. inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
  120. outputs = (
  121. json.loads(node_execution.outputs) if node_execution.outputs else {}
  122. )
  123. created_at = node_execution.created_at if node_execution.created_at else datetime.now()
  124. elapsed_time = node_execution.elapsed_time
  125. finished_at = created_at + timedelta(seconds=elapsed_time)
  126. metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
  127. metadata.update(
  128. {
  129. "workflow_run_id": trace_info.workflow_run_id,
  130. "node_execution_id": node_execution_id,
  131. "tenant_id": tenant_id,
  132. "app_id": app_id,
  133. "node_name": node_name,
  134. "node_type": node_type,
  135. "status": status,
  136. }
  137. )
  138. # add span
  139. if trace_info.message_id:
  140. span_data = LangfuseSpan(
  141. id=node_execution_id,
  142. name=f"{node_name}_{node_execution_id}",
  143. input=inputs,
  144. output=outputs,
  145. trace_id=trace_id,
  146. start_time=created_at,
  147. end_time=finished_at,
  148. metadata=metadata,
  149. level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR,
  150. status_message=trace_info.error if trace_info.error else "",
  151. parent_observation_id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
  152. )
  153. else:
  154. span_data = LangfuseSpan(
  155. id=node_execution_id,
  156. name=f"{node_name}_{node_execution_id}",
  157. input=inputs,
  158. output=outputs,
  159. trace_id=trace_id,
  160. start_time=created_at,
  161. end_time=finished_at,
  162. metadata=metadata,
  163. level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR,
  164. status_message=trace_info.error if trace_info.error else "",
  165. )
  166. self.add_span(langfuse_span_data=span_data)
  167. process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
  168. if process_data and process_data.get("model_mode") == "chat":
  169. total_token = metadata.get("total_tokens", 0)
  170. # add generation
  171. generation_usage = GenerationUsage(
  172. totalTokens=total_token,
  173. )
  174. node_generation_data = LangfuseGeneration(
  175. name=f"generation_{node_execution_id}",
  176. trace_id=trace_id,
  177. parent_observation_id=node_execution_id,
  178. start_time=created_at,
  179. end_time=finished_at,
  180. input=inputs,
  181. output=outputs,
  182. metadata=metadata,
  183. level=LevelEnum.DEFAULT if status == 'succeeded' else LevelEnum.ERROR,
  184. status_message=trace_info.error if trace_info.error else "",
  185. usage=generation_usage,
  186. )
  187. self.add_generation(langfuse_generation_data=node_generation_data)
  188. def message_trace(
  189. self, trace_info: MessageTraceInfo, **kwargs
  190. ):
  191. # get message file data
  192. file_list = trace_info.file_list
  193. metadata = trace_info.metadata
  194. message_data = trace_info.message_data
  195. message_id = message_data.id
  196. user_id = message_data.from_account_id
  197. if message_data.from_end_user_id:
  198. end_user_data: EndUser = db.session.query(EndUser).filter(
  199. EndUser.id == message_data.from_end_user_id
  200. ).first()
  201. if end_user_data is not None:
  202. user_id = end_user_data.session_id
  203. metadata["user_id"] = user_id
  204. trace_data = LangfuseTrace(
  205. id=message_id,
  206. user_id=user_id,
  207. name=f"message_{message_id}",
  208. input={
  209. "message": trace_info.inputs,
  210. "files": file_list,
  211. "message_tokens": trace_info.message_tokens,
  212. "answer_tokens": trace_info.answer_tokens,
  213. "total_tokens": trace_info.total_tokens,
  214. "error": trace_info.error,
  215. "provider_response_latency": message_data.provider_response_latency,
  216. "created_at": trace_info.start_time,
  217. },
  218. output=trace_info.outputs,
  219. metadata=metadata,
  220. session_id=message_data.conversation_id,
  221. tags=["message", str(trace_info.conversation_mode)],
  222. version=None,
  223. release=None,
  224. public=None,
  225. )
  226. self.add_trace(langfuse_trace_data=trace_data)
  227. # start add span
  228. generation_usage = GenerationUsage(
  229. totalTokens=trace_info.total_tokens,
  230. input=trace_info.message_tokens,
  231. output=trace_info.answer_tokens,
  232. total=trace_info.total_tokens,
  233. unit=UnitEnum.TOKENS,
  234. totalCost=message_data.total_price,
  235. )
  236. langfuse_generation_data = LangfuseGeneration(
  237. name=f"generation_{message_id}",
  238. trace_id=message_id,
  239. start_time=trace_info.start_time,
  240. end_time=trace_info.end_time,
  241. model=message_data.model_id,
  242. input=trace_info.inputs,
  243. output=message_data.answer,
  244. metadata=metadata,
  245. level=LevelEnum.DEFAULT if message_data.status != 'error' else LevelEnum.ERROR,
  246. status_message=message_data.error if message_data.error else "",
  247. usage=generation_usage,
  248. )
  249. self.add_generation(langfuse_generation_data)
  250. def moderation_trace(self, trace_info: ModerationTraceInfo):
  251. span_data = LangfuseSpan(
  252. name="moderation",
  253. input=trace_info.inputs,
  254. output={
  255. "action": trace_info.action,
  256. "flagged": trace_info.flagged,
  257. "preset_response": trace_info.preset_response,
  258. "inputs": trace_info.inputs,
  259. },
  260. trace_id=trace_info.message_id,
  261. start_time=trace_info.start_time or trace_info.message_data.created_at,
  262. end_time=trace_info.end_time or trace_info.message_data.created_at,
  263. metadata=trace_info.metadata,
  264. )
  265. self.add_span(langfuse_span_data=span_data)
  266. def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
  267. message_data = trace_info.message_data
  268. generation_usage = GenerationUsage(
  269. totalTokens=len(str(trace_info.suggested_question)),
  270. input=len(trace_info.inputs),
  271. output=len(trace_info.suggested_question),
  272. total=len(trace_info.suggested_question),
  273. unit=UnitEnum.CHARACTERS,
  274. )
  275. generation_data = LangfuseGeneration(
  276. name="suggested_question",
  277. input=trace_info.inputs,
  278. output=str(trace_info.suggested_question),
  279. trace_id=trace_info.message_id,
  280. start_time=trace_info.start_time,
  281. end_time=trace_info.end_time,
  282. metadata=trace_info.metadata,
  283. level=LevelEnum.DEFAULT if message_data.status != 'error' else LevelEnum.ERROR,
  284. status_message=message_data.error if message_data.error else "",
  285. usage=generation_usage,
  286. )
  287. self.add_generation(langfuse_generation_data=generation_data)
  288. def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
  289. dataset_retrieval_span_data = LangfuseSpan(
  290. name="dataset_retrieval",
  291. input=trace_info.inputs,
  292. output={"documents": trace_info.documents},
  293. trace_id=trace_info.message_id,
  294. start_time=trace_info.start_time or trace_info.message_data.created_at,
  295. end_time=trace_info.end_time or trace_info.message_data.updated_at,
  296. metadata=trace_info.metadata,
  297. )
  298. self.add_span(langfuse_span_data=dataset_retrieval_span_data)
  299. def tool_trace(self, trace_info: ToolTraceInfo):
  300. tool_span_data = LangfuseSpan(
  301. name=trace_info.tool_name,
  302. input=trace_info.tool_inputs,
  303. output=trace_info.tool_outputs,
  304. trace_id=trace_info.message_id,
  305. start_time=trace_info.start_time,
  306. end_time=trace_info.end_time,
  307. metadata=trace_info.metadata,
  308. level=LevelEnum.DEFAULT if trace_info.error == "" or trace_info.error is None else LevelEnum.ERROR,
  309. status_message=trace_info.error,
  310. )
  311. self.add_span(langfuse_span_data=tool_span_data)
  312. def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
  313. name_generation_trace_data = LangfuseTrace(
  314. name="generate_name",
  315. input=trace_info.inputs,
  316. output=trace_info.outputs,
  317. user_id=trace_info.tenant_id,
  318. metadata=trace_info.metadata,
  319. session_id=trace_info.conversation_id,
  320. )
  321. self.add_trace(langfuse_trace_data=name_generation_trace_data)
  322. name_generation_span_data = LangfuseSpan(
  323. name="generate_name",
  324. input=trace_info.inputs,
  325. output=trace_info.outputs,
  326. trace_id=trace_info.conversation_id,
  327. start_time=trace_info.start_time,
  328. end_time=trace_info.end_time,
  329. metadata=trace_info.metadata,
  330. )
  331. self.add_span(langfuse_span_data=name_generation_span_data)
  332. def add_trace(self, langfuse_trace_data: Optional[LangfuseTrace] = None):
  333. format_trace_data = (
  334. filter_none_values(langfuse_trace_data.model_dump()) if langfuse_trace_data else {}
  335. )
  336. try:
  337. self.langfuse_client.trace(**format_trace_data)
  338. logger.debug("LangFuse Trace created successfully")
  339. except Exception as e:
  340. raise ValueError(f"LangFuse Failed to create trace: {str(e)}")
  341. def add_span(self, langfuse_span_data: Optional[LangfuseSpan] = None):
  342. format_span_data = (
  343. filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
  344. )
  345. try:
  346. self.langfuse_client.span(**format_span_data)
  347. logger.debug("LangFuse Span created successfully")
  348. except Exception as e:
  349. raise ValueError(f"LangFuse Failed to create span: {str(e)}")
  350. def update_span(self, span, langfuse_span_data: Optional[LangfuseSpan] = None):
  351. format_span_data = (
  352. filter_none_values(langfuse_span_data.model_dump()) if langfuse_span_data else {}
  353. )
  354. span.end(**format_span_data)
  355. def add_generation(
  356. self, langfuse_generation_data: Optional[LangfuseGeneration] = None
  357. ):
  358. format_generation_data = (
  359. filter_none_values(langfuse_generation_data.model_dump())
  360. if langfuse_generation_data
  361. else {}
  362. )
  363. try:
  364. self.langfuse_client.generation(**format_generation_data)
  365. logger.debug("LangFuse Generation created successfully")
  366. except Exception as e:
  367. raise ValueError(f"LangFuse Failed to create generation: {str(e)}")
  368. def update_generation(
  369. self, generation, langfuse_generation_data: Optional[LangfuseGeneration] = None
  370. ):
  371. format_generation_data = (
  372. filter_none_values(langfuse_generation_data.model_dump())
  373. if langfuse_generation_data
  374. else {}
  375. )
  376. generation.end(**format_generation_data)
  377. def api_check(self):
  378. try:
  379. return self.langfuse_client.auth_check()
  380. except Exception as e:
  381. logger.debug(f"LangFuse API check failed: {str(e)}")
  382. raise ValueError(f"LangFuse API check failed: {str(e)}")