builtin_tool.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. from core.model_runtime.entities.llm_entities import LLMResult
  2. from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
  3. from core.tools.entities.tool_entities import ToolProviderType
  4. from core.tools.entities.user_entities import UserToolProvider
  5. from core.tools.model.tool_model_manager import ToolModelManager
  6. from core.tools.tool.tool import Tool
  7. from core.tools.utils.web_reader_tool import get_url
  8. _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
  9. and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
  10. retain the original meaning and keep the key points.
  11. however, the text you got is too long, what you got is possible a part of the text.
  12. Please summarize the text you got.
  13. """
  14. class BuiltinTool(Tool):
  15. """
  16. Builtin tool
  17. :param meta: the meta data of a tool call processing
  18. """
  19. def invoke_model(
  20. self, user_id: str, prompt_messages: list[PromptMessage], stop: list[str]
  21. ) -> LLMResult:
  22. """
  23. invoke model
  24. :param model_config: the model config
  25. :param prompt_messages: the prompt messages
  26. :param stop: the stop words
  27. :return: the model result
  28. """
  29. # invoke model
  30. return ToolModelManager.invoke(
  31. user_id=user_id,
  32. tenant_id=self.runtime.tenant_id,
  33. tool_type='builtin',
  34. tool_name=self.identity.name,
  35. prompt_messages=prompt_messages,
  36. )
  37. def tool_provider_type(self) -> ToolProviderType:
  38. return UserToolProvider.ProviderType.BUILTIN
  39. def get_max_tokens(self) -> int:
  40. """
  41. get max tokens
  42. :param model_config: the model config
  43. :return: the max tokens
  44. """
  45. return ToolModelManager.get_max_llm_context_tokens(
  46. tenant_id=self.runtime.tenant_id,
  47. )
  48. def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
  49. """
  50. get prompt tokens
  51. :param prompt_messages: the prompt messages
  52. :return: the tokens
  53. """
  54. return ToolModelManager.calculate_tokens(
  55. tenant_id=self.runtime.tenant_id,
  56. prompt_messages=prompt_messages
  57. )
  58. def summary(self, user_id: str, content: str) -> str:
  59. max_tokens = self.get_max_tokens()
  60. if self.get_prompt_tokens(prompt_messages=[
  61. UserPromptMessage(content=content)
  62. ]) < max_tokens * 0.6:
  63. return content
  64. def get_prompt_tokens(content: str) -> int:
  65. return self.get_prompt_tokens(prompt_messages=[
  66. SystemPromptMessage(content=_SUMMARY_PROMPT),
  67. UserPromptMessage(content=content)
  68. ])
  69. def summarize(content: str) -> str:
  70. summary = self.invoke_model(user_id=user_id, prompt_messages=[
  71. SystemPromptMessage(content=_SUMMARY_PROMPT),
  72. UserPromptMessage(content=content)
  73. ], stop=[])
  74. return summary.message.content
  75. lines = content.split('\n')
  76. new_lines = []
  77. # split long line into multiple lines
  78. for i in range(len(lines)):
  79. line = lines[i]
  80. if not line.strip():
  81. continue
  82. if len(line) < max_tokens * 0.5:
  83. new_lines.append(line)
  84. elif get_prompt_tokens(line) > max_tokens * 0.7:
  85. while get_prompt_tokens(line) > max_tokens * 0.7:
  86. new_lines.append(line[:int(max_tokens * 0.5)])
  87. line = line[int(max_tokens * 0.5):]
  88. new_lines.append(line)
  89. else:
  90. new_lines.append(line)
  91. # merge lines into messages with max tokens
  92. messages: list[str] = []
  93. for i in new_lines:
  94. if len(messages) == 0:
  95. messages.append(i)
  96. else:
  97. if len(messages[-1]) + len(i) < max_tokens * 0.5:
  98. messages[-1] += i
  99. if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
  100. messages.append(i)
  101. else:
  102. messages[-1] += i
  103. summaries = []
  104. for i in range(len(messages)):
  105. message = messages[i]
  106. summary = summarize(message)
  107. summaries.append(summary)
  108. result = '\n'.join(summaries)
  109. if self.get_prompt_tokens(prompt_messages=[
  110. UserPromptMessage(content=result)
  111. ]) > max_tokens * 0.7:
  112. return self.summary(user_id=user_id, content=result)
  113. return result
  114. def get_url(self, url: str, user_agent: str = None) -> str:
  115. """
  116. get url
  117. """
  118. return get_url(url, user_agent=user_agent)