model_tool_provider.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. from typing import Any
  2. from core.entities.model_entities import ModelStatus
  3. from core.errors.error import ProviderTokenNotInitError
  4. from core.model_manager import ModelInstance
  5. from core.model_runtime.entities.model_entities import ModelFeature, ModelType
  6. from core.provider_manager import ProviderConfiguration, ProviderManager, ProviderModelBundle
  7. from core.tools.entities.common_entities import I18nObject
  8. from core.tools.entities.tool_entities import (
  9. ModelToolPropertyKey,
  10. ToolDescription,
  11. ToolIdentity,
  12. ToolParameter,
  13. ToolProviderCredentials,
  14. ToolProviderIdentity,
  15. ToolProviderType,
  16. )
  17. from core.tools.errors import ToolNotFoundError
  18. from core.tools.provider.tool_provider import ToolProviderController
  19. from core.tools.tool.model_tool import ModelTool
  20. from core.tools.tool.tool import Tool
  21. from core.tools.utils.configuration import ModelToolConfigurationManager
  22. class ModelToolProviderController(ToolProviderController):
  23. configuration: ProviderConfiguration = None
  24. is_active: bool = False
  25. def __init__(self, configuration: ProviderConfiguration = None, **kwargs):
  26. """
  27. init the provider
  28. :param data: the data of the provider
  29. """
  30. super().__init__(**kwargs)
  31. self.configuration = configuration
  32. @staticmethod
  33. def from_db(configuration: ProviderConfiguration = None) -> 'ModelToolProviderController':
  34. """
  35. init the provider from db
  36. :param configuration: the configuration of the provider
  37. """
  38. # check if all models are active
  39. if configuration is None:
  40. return None
  41. is_active = True
  42. models = configuration.get_provider_models()
  43. for model in models:
  44. if model.status != ModelStatus.ACTIVE:
  45. is_active = False
  46. break
  47. # get the provider configuration
  48. model_tool_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider)
  49. if model_tool_configuration is None:
  50. raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}')
  51. # override the configuration
  52. if model_tool_configuration.label:
  53. if model_tool_configuration.label.en_US:
  54. configuration.provider.label.en_US = model_tool_configuration.label.en_US
  55. if model_tool_configuration.label.zh_Hans:
  56. configuration.provider.label.zh_Hans = model_tool_configuration.label.zh_Hans
  57. return ModelToolProviderController(
  58. is_active=is_active,
  59. identity=ToolProviderIdentity(
  60. author='Dify',
  61. name=configuration.provider.provider,
  62. description=I18nObject(
  63. zh_Hans=f'{configuration.provider.label.zh_Hans} 模型能力提供商',
  64. en_US=f'{configuration.provider.label.en_US} model capability provider'
  65. ),
  66. label=I18nObject(
  67. zh_Hans=configuration.provider.label.zh_Hans,
  68. en_US=configuration.provider.label.en_US
  69. ),
  70. icon=configuration.provider.icon_small.en_US,
  71. ),
  72. configuration=configuration,
  73. credentials_schema={},
  74. )
  75. @staticmethod
  76. def is_configuration_valid(configuration: ProviderConfiguration) -> bool:
  77. """
  78. check if the configuration has a model can be used as a tool
  79. """
  80. models = configuration.get_provider_models()
  81. for model in models:
  82. if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
  83. return True
  84. return False
  85. def _get_model_tools(self, tenant_id: str = None) -> list[ModelTool]:
  86. """
  87. returns a list of tools that the provider can provide
  88. :return: list of tools
  89. """
  90. tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
  91. provider_manager = ProviderManager()
  92. if self.configuration is None:
  93. configurations = provider_manager.get_configurations(tenant_id=tenant_id).values()
  94. self.configuration = next(filter(lambda x: x.provider == self.identity.name, configurations), None)
  95. # get all tools
  96. tools: list[ModelTool] = []
  97. # get all models
  98. if not self.configuration:
  99. return tools
  100. configuration = self.configuration
  101. provider_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider)
  102. if provider_configuration is None:
  103. raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}')
  104. for model in configuration.get_provider_models():
  105. model_configuration = ModelToolConfigurationManager.get_model_configuration(self.configuration.provider.provider, model.model)
  106. if model_configuration is None:
  107. continue
  108. if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []):
  109. provider_instance = configuration.get_provider_instance()
  110. model_type_instance = provider_instance.get_model_instance(model.model_type)
  111. provider_model_bundle = ProviderModelBundle(
  112. configuration=configuration,
  113. provider_instance=provider_instance,
  114. model_type_instance=model_type_instance
  115. )
  116. try:
  117. model_instance = ModelInstance(provider_model_bundle, model.model)
  118. except ProviderTokenNotInitError:
  119. model_instance = None
  120. tools.append(ModelTool(
  121. identity=ToolIdentity(
  122. author='Dify',
  123. name=model.model,
  124. label=model_configuration.label,
  125. ),
  126. parameters=[
  127. ToolParameter(
  128. name=ModelToolPropertyKey.IMAGE_PARAMETER_NAME.value,
  129. label=I18nObject(zh_Hans='图片ID', en_US='Image ID'),
  130. human_description=I18nObject(zh_Hans='图片ID', en_US='Image ID'),
  131. type=ToolParameter.ToolParameterType.STRING,
  132. form=ToolParameter.ToolParameterForm.LLM,
  133. required=True,
  134. default=Tool.VARIABLE_KEY.IMAGE.value
  135. )
  136. ],
  137. description=ToolDescription(
  138. human=I18nObject(zh_Hans='图生文工具', en_US='Convert image to text'),
  139. llm='Vision tool used to extract text and other visual information from images, can be used for OCR, image captioning, etc.',
  140. ),
  141. is_team_authorization=model.status == ModelStatus.ACTIVE,
  142. tool_type=ModelTool.ModelToolType.VISION,
  143. model_instance=model_instance,
  144. model=model.model,
  145. ))
  146. self.tools = tools
  147. return tools
  148. def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
  149. """
  150. returns the credentials schema of the provider
  151. :return: the credentials schema
  152. """
  153. return {}
  154. def get_tools(self, user_id: str, tenant_id: str) -> list[ModelTool]:
  155. """
  156. returns a list of tools that the provider can provide
  157. :return: list of tools
  158. """
  159. return self._get_model_tools(tenant_id=tenant_id)
  160. def get_tool(self, tool_name: str) -> ModelTool:
  161. """
  162. get tool by name
  163. :param tool_name: the name of the tool
  164. :return: the tool
  165. """
  166. if self.tools is None:
  167. self.get_tools(user_id='', tenant_id=self.configuration.tenant_id)
  168. for tool in self.tools:
  169. if tool.identity.name == tool_name:
  170. return tool
  171. raise ValueError(f'tool {tool_name} not found')
  172. def get_parameters(self, tool_name: str) -> list[ToolParameter]:
  173. """
  174. returns the parameters of the tool
  175. :param tool_name: the name of the tool, defined in `get_tools`
  176. :return: list of parameters
  177. """
  178. tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
  179. if tool is None:
  180. raise ToolNotFoundError(f'tool {tool_name} not found')
  181. return tool.parameters
  182. @property
  183. def app_type(self) -> ToolProviderType:
  184. """
  185. returns the type of the provider
  186. :return: type of the provider
  187. """
  188. return ToolProviderType.MODEL
  189. def validate_credentials(self, credentials: dict[str, Any]) -> None:
  190. """
  191. validate the credentials of the provider
  192. :param tool_name: the name of the tool, defined in `get_tools`
  193. :param credentials: the credentials of the tool
  194. """
  195. pass
  196. def _validate_credentials(self, credentials: dict[str, Any]) -> None:
  197. """
  198. validate the credentials of the provider
  199. :param tool_name: the name of the tool, defined in `get_tools`
  200. :param credentials: the credentials of the tool
  201. """
  202. pass