langsmith_trace.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. import json
  2. import logging
  3. import os
  4. from datetime import datetime, timedelta
  5. from langsmith import Client
  6. from core.ops.base_trace_instance import BaseTraceInstance
  7. from core.ops.entities.config_entity import LangSmithConfig
  8. from core.ops.entities.trace_entity import (
  9. BaseTraceInfo,
  10. DatasetRetrievalTraceInfo,
  11. GenerateNameTraceInfo,
  12. MessageTraceInfo,
  13. ModerationTraceInfo,
  14. SuggestedQuestionTraceInfo,
  15. ToolTraceInfo,
  16. TraceTaskName,
  17. WorkflowTraceInfo,
  18. )
  19. from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
  20. LangSmithRunModel,
  21. LangSmithRunType,
  22. LangSmithRunUpdateModel,
  23. )
  24. from core.ops.utils import filter_none_values
  25. from extensions.ext_database import db
  26. from models.model import EndUser, MessageFile
  27. from models.workflow import WorkflowNodeExecution
  28. logger = logging.getLogger(__name__)
  29. class LangSmithDataTrace(BaseTraceInstance):
  30. def __init__(
  31. self,
  32. langsmith_config: LangSmithConfig,
  33. ):
  34. super().__init__(langsmith_config)
  35. self.langsmith_key = langsmith_config.api_key
  36. self.project_name = langsmith_config.project
  37. self.project_id = None
  38. self.langsmith_client = Client(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint)
  39. self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
  40. def trace(self, trace_info: BaseTraceInfo):
  41. if isinstance(trace_info, WorkflowTraceInfo):
  42. self.workflow_trace(trace_info)
  43. if isinstance(trace_info, MessageTraceInfo):
  44. self.message_trace(trace_info)
  45. if isinstance(trace_info, ModerationTraceInfo):
  46. self.moderation_trace(trace_info)
  47. if isinstance(trace_info, SuggestedQuestionTraceInfo):
  48. self.suggested_question_trace(trace_info)
  49. if isinstance(trace_info, DatasetRetrievalTraceInfo):
  50. self.dataset_retrieval_trace(trace_info)
  51. if isinstance(trace_info, ToolTraceInfo):
  52. self.tool_trace(trace_info)
  53. if isinstance(trace_info, GenerateNameTraceInfo):
  54. self.generate_name_trace(trace_info)
  55. def workflow_trace(self, trace_info: WorkflowTraceInfo):
  56. if trace_info.message_id:
  57. message_run = LangSmithRunModel(
  58. id=trace_info.message_id,
  59. name=TraceTaskName.MESSAGE_TRACE.value,
  60. inputs=trace_info.workflow_run_inputs,
  61. outputs=trace_info.workflow_run_outputs,
  62. run_type=LangSmithRunType.chain,
  63. start_time=trace_info.start_time,
  64. end_time=trace_info.end_time,
  65. extra={
  66. "metadata": trace_info.metadata,
  67. },
  68. tags=["message", "workflow"],
  69. error=trace_info.error,
  70. )
  71. self.add_run(message_run)
  72. langsmith_run = LangSmithRunModel(
  73. file_list=trace_info.file_list,
  74. total_tokens=trace_info.total_tokens,
  75. id=trace_info.workflow_app_log_id if trace_info.workflow_app_log_id else trace_info.workflow_run_id,
  76. name=TraceTaskName.WORKFLOW_TRACE.value,
  77. inputs=trace_info.workflow_run_inputs,
  78. run_type=LangSmithRunType.tool,
  79. start_time=trace_info.workflow_data.created_at,
  80. end_time=trace_info.workflow_data.finished_at,
  81. outputs=trace_info.workflow_run_outputs,
  82. extra={
  83. "metadata": trace_info.metadata,
  84. },
  85. error=trace_info.error,
  86. tags=["workflow"],
  87. parent_run_id=trace_info.message_id if trace_info.message_id else None,
  88. )
  89. self.add_run(langsmith_run)
  90. # through workflow_run_id get all_nodes_execution
  91. workflow_nodes_executions = (
  92. db.session.query(
  93. WorkflowNodeExecution.id,
  94. WorkflowNodeExecution.tenant_id,
  95. WorkflowNodeExecution.app_id,
  96. WorkflowNodeExecution.title,
  97. WorkflowNodeExecution.node_type,
  98. WorkflowNodeExecution.status,
  99. WorkflowNodeExecution.inputs,
  100. WorkflowNodeExecution.outputs,
  101. WorkflowNodeExecution.created_at,
  102. WorkflowNodeExecution.elapsed_time,
  103. WorkflowNodeExecution.process_data,
  104. WorkflowNodeExecution.execution_metadata,
  105. )
  106. .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
  107. .all()
  108. )
  109. for node_execution in workflow_nodes_executions:
  110. node_execution_id = node_execution.id
  111. tenant_id = node_execution.tenant_id
  112. app_id = node_execution.app_id
  113. node_name = node_execution.title
  114. node_type = node_execution.node_type
  115. status = node_execution.status
  116. if node_type == "llm":
  117. inputs = (
  118. json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
  119. )
  120. else:
  121. inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
  122. outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
  123. created_at = node_execution.created_at if node_execution.created_at else datetime.now()
  124. elapsed_time = node_execution.elapsed_time
  125. finished_at = created_at + timedelta(seconds=elapsed_time)
  126. execution_metadata = (
  127. json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
  128. )
  129. node_total_tokens = execution_metadata.get("total_tokens", 0)
  130. metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
  131. metadata.update(
  132. {
  133. "workflow_run_id": trace_info.workflow_run_id,
  134. "node_execution_id": node_execution_id,
  135. "tenant_id": tenant_id,
  136. "app_id": app_id,
  137. "app_name": node_name,
  138. "node_type": node_type,
  139. "status": status,
  140. }
  141. )
  142. process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
  143. if process_data and process_data.get("model_mode") == "chat":
  144. run_type = LangSmithRunType.llm
  145. elif node_type == "knowledge-retrieval":
  146. run_type = LangSmithRunType.retriever
  147. else:
  148. run_type = LangSmithRunType.tool
  149. langsmith_run = LangSmithRunModel(
  150. total_tokens=node_total_tokens,
  151. name=node_type,
  152. inputs=inputs,
  153. run_type=run_type,
  154. start_time=created_at,
  155. end_time=finished_at,
  156. outputs=outputs,
  157. file_list=trace_info.file_list,
  158. extra={
  159. "metadata": metadata,
  160. },
  161. parent_run_id=trace_info.workflow_app_log_id
  162. if trace_info.workflow_app_log_id
  163. else trace_info.workflow_run_id,
  164. tags=["node_execution"],
  165. )
  166. self.add_run(langsmith_run)
  167. def message_trace(self, trace_info: MessageTraceInfo):
  168. # get message file data
  169. file_list = trace_info.file_list
  170. message_file_data: MessageFile = trace_info.message_file_data
  171. file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
  172. file_list.append(file_url)
  173. metadata = trace_info.metadata
  174. message_data = trace_info.message_data
  175. message_id = message_data.id
  176. user_id = message_data.from_account_id
  177. metadata["user_id"] = user_id
  178. if message_data.from_end_user_id:
  179. end_user_data: EndUser = (
  180. db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
  181. )
  182. if end_user_data is not None:
  183. end_user_id = end_user_data.session_id
  184. metadata["end_user_id"] = end_user_id
  185. message_run = LangSmithRunModel(
  186. input_tokens=trace_info.message_tokens,
  187. output_tokens=trace_info.answer_tokens,
  188. total_tokens=trace_info.total_tokens,
  189. id=message_id,
  190. name=TraceTaskName.MESSAGE_TRACE.value,
  191. inputs=trace_info.inputs,
  192. run_type=LangSmithRunType.chain,
  193. start_time=trace_info.start_time,
  194. end_time=trace_info.end_time,
  195. outputs=message_data.answer,
  196. extra={
  197. "metadata": metadata,
  198. },
  199. tags=["message", str(trace_info.conversation_mode)],
  200. error=trace_info.error,
  201. file_list=file_list,
  202. )
  203. self.add_run(message_run)
  204. # create llm run parented to message run
  205. llm_run = LangSmithRunModel(
  206. input_tokens=trace_info.message_tokens,
  207. output_tokens=trace_info.answer_tokens,
  208. total_tokens=trace_info.total_tokens,
  209. name="llm",
  210. inputs=trace_info.inputs,
  211. run_type=LangSmithRunType.llm,
  212. start_time=trace_info.start_time,
  213. end_time=trace_info.end_time,
  214. outputs=message_data.answer,
  215. extra={
  216. "metadata": metadata,
  217. },
  218. parent_run_id=message_id,
  219. tags=["llm", str(trace_info.conversation_mode)],
  220. error=trace_info.error,
  221. file_list=file_list,
  222. )
  223. self.add_run(llm_run)
  224. def moderation_trace(self, trace_info: ModerationTraceInfo):
  225. langsmith_run = LangSmithRunModel(
  226. name=TraceTaskName.MODERATION_TRACE.value,
  227. inputs=trace_info.inputs,
  228. outputs={
  229. "action": trace_info.action,
  230. "flagged": trace_info.flagged,
  231. "preset_response": trace_info.preset_response,
  232. "inputs": trace_info.inputs,
  233. },
  234. run_type=LangSmithRunType.tool,
  235. extra={
  236. "metadata": trace_info.metadata,
  237. },
  238. tags=["moderation"],
  239. parent_run_id=trace_info.message_id,
  240. start_time=trace_info.start_time or trace_info.message_data.created_at,
  241. end_time=trace_info.end_time or trace_info.message_data.updated_at,
  242. )
  243. self.add_run(langsmith_run)
  244. def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
  245. message_data = trace_info.message_data
  246. suggested_question_run = LangSmithRunModel(
  247. name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
  248. inputs=trace_info.inputs,
  249. outputs=trace_info.suggested_question,
  250. run_type=LangSmithRunType.tool,
  251. extra={
  252. "metadata": trace_info.metadata,
  253. },
  254. tags=["suggested_question"],
  255. parent_run_id=trace_info.message_id,
  256. start_time=trace_info.start_time or message_data.created_at,
  257. end_time=trace_info.end_time or message_data.updated_at,
  258. )
  259. self.add_run(suggested_question_run)
  260. def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
  261. dataset_retrieval_run = LangSmithRunModel(
  262. name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
  263. inputs=trace_info.inputs,
  264. outputs={"documents": trace_info.documents},
  265. run_type=LangSmithRunType.retriever,
  266. extra={
  267. "metadata": trace_info.metadata,
  268. },
  269. tags=["dataset_retrieval"],
  270. parent_run_id=trace_info.message_id,
  271. start_time=trace_info.start_time or trace_info.message_data.created_at,
  272. end_time=trace_info.end_time or trace_info.message_data.updated_at,
  273. )
  274. self.add_run(dataset_retrieval_run)
  275. def tool_trace(self, trace_info: ToolTraceInfo):
  276. tool_run = LangSmithRunModel(
  277. name=trace_info.tool_name,
  278. inputs=trace_info.tool_inputs,
  279. outputs=trace_info.tool_outputs,
  280. run_type=LangSmithRunType.tool,
  281. extra={
  282. "metadata": trace_info.metadata,
  283. },
  284. tags=["tool", trace_info.tool_name],
  285. parent_run_id=trace_info.message_id,
  286. start_time=trace_info.start_time,
  287. end_time=trace_info.end_time,
  288. file_list=[trace_info.file_url],
  289. )
  290. self.add_run(tool_run)
  291. def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
  292. name_run = LangSmithRunModel(
  293. name=TraceTaskName.GENERATE_NAME_TRACE.value,
  294. inputs=trace_info.inputs,
  295. outputs=trace_info.outputs,
  296. run_type=LangSmithRunType.tool,
  297. extra={
  298. "metadata": trace_info.metadata,
  299. },
  300. tags=["generate_name"],
  301. start_time=trace_info.start_time or datetime.now(),
  302. end_time=trace_info.end_time or datetime.now(),
  303. )
  304. self.add_run(name_run)
  305. def add_run(self, run_data: LangSmithRunModel):
  306. data = run_data.model_dump()
  307. if self.project_id:
  308. data["session_id"] = self.project_id
  309. elif self.project_name:
  310. data["session_name"] = self.project_name
  311. data = filter_none_values(data)
  312. try:
  313. self.langsmith_client.create_run(**data)
  314. logger.debug("LangSmith Run created successfully.")
  315. except Exception as e:
  316. raise ValueError(f"LangSmith Failed to create run: {str(e)}")
  317. def update_run(self, update_run_data: LangSmithRunUpdateModel):
  318. data = update_run_data.model_dump()
  319. data = filter_none_values(data)
  320. try:
  321. self.langsmith_client.update_run(**data)
  322. logger.debug("LangSmith Run updated successfully.")
  323. except Exception as e:
  324. raise ValueError(f"LangSmith Failed to update run: {str(e)}")
  325. def api_check(self):
  326. try:
  327. random_project_name = f"test_project_{datetime.now().strftime('%Y%m%d%H%M%S')}"
  328. self.langsmith_client.create_project(project_name=random_project_name)
  329. self.langsmith_client.delete_project(project_name=random_project_name)
  330. return True
  331. except Exception as e:
  332. logger.debug(f"LangSmith API check failed: {str(e)}")
  333. raise ValueError(f"LangSmith API check failed: {str(e)}")