from copy import deepcopy from typing import Any from core.entities.model_entities import ModelStatus from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.provider_manager import ProviderConfiguration, ProviderManager, ProviderModelBundle from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( ModelToolPropertyKey, ToolDescription, ToolIdentity, ToolParameter, ToolProviderCredentials, ToolProviderIdentity, ToolProviderType, ) from core.tools.errors import ToolNotFoundError from core.tools.provider.tool_provider import ToolProviderController from core.tools.tool.model_tool import ModelTool from core.tools.tool.tool import Tool from core.tools.utils.configuration import ModelToolConfigurationManager class ModelToolProviderController(ToolProviderController): configuration: ProviderConfiguration = None is_active: bool = False def __init__(self, configuration: ProviderConfiguration = None, **kwargs): """ init the provider :param data: the data of the provider """ super().__init__(**kwargs) self.configuration = configuration @staticmethod def from_db(configuration: ProviderConfiguration = None) -> 'ModelToolProviderController': """ init the provider from db :param configuration: the configuration of the provider """ # check if all models are active if configuration is None: return None is_active = True models = configuration.get_provider_models() for model in models: if model.status != ModelStatus.ACTIVE: is_active = False break # get the provider configuration model_tool_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider) if model_tool_configuration is None: raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}') # override the configuration if model_tool_configuration.label: label = deepcopy(model_tool_configuration.label) if label.en_US: label.en_US = model_tool_configuration.label.en_US if label.zh_Hans: label.zh_Hans = model_tool_configuration.label.zh_Hans else: label = I18nObject( en_US=configuration.provider.label.en_US, zh_Hans=configuration.provider.label.zh_Hans ) return ModelToolProviderController( is_active=is_active, identity=ToolProviderIdentity( author='Dify', name=configuration.provider.provider, description=I18nObject( zh_Hans=f'{label.zh_Hans} 模型能力提供商', en_US=f'{label.en_US} model capability provider' ), label=I18nObject( zh_Hans=label.zh_Hans, en_US=label.en_US ), icon=configuration.provider.icon_small.en_US, ), configuration=configuration, credentials_schema={}, ) @staticmethod def is_configuration_valid(configuration: ProviderConfiguration) -> bool: """ check if the configuration has a model can be used as a tool """ models = configuration.get_provider_models() for model in models: if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []): return True return False def _get_model_tools(self, tenant_id: str = None) -> list[ModelTool]: """ returns a list of tools that the provider can provide :return: list of tools """ tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff' provider_manager = ProviderManager() if self.configuration is None: configurations = provider_manager.get_configurations(tenant_id=tenant_id).values() self.configuration = next(filter(lambda x: x.provider == self.identity.name, configurations), None) # get all tools tools: list[ModelTool] = [] # get all models if not self.configuration: return tools configuration = self.configuration provider_configuration = ModelToolConfigurationManager.get_configuration(configuration.provider.provider) if provider_configuration is None: raise RuntimeError(f'no configuration found for provider {configuration.provider.provider}') for model in configuration.get_provider_models(): model_configuration = ModelToolConfigurationManager.get_model_configuration(self.configuration.provider.provider, model.model) if model_configuration is None: continue if model.model_type == ModelType.LLM and ModelFeature.VISION in (model.features or []): provider_instance = configuration.get_provider_instance() model_type_instance = provider_instance.get_model_instance(model.model_type) provider_model_bundle = ProviderModelBundle( configuration=configuration, provider_instance=provider_instance, model_type_instance=model_type_instance ) try: model_instance = ModelInstance(provider_model_bundle, model.model) except ProviderTokenNotInitError: model_instance = None tools.append(ModelTool( identity=ToolIdentity( author='Dify', name=model.model, label=model_configuration.label, ), parameters=[ ToolParameter( name=ModelToolPropertyKey.IMAGE_PARAMETER_NAME.value, label=I18nObject(zh_Hans='图片ID', en_US='Image ID'), human_description=I18nObject(zh_Hans='图片ID', en_US='Image ID'), type=ToolParameter.ToolParameterType.STRING, form=ToolParameter.ToolParameterForm.LLM, required=True, default=Tool.VARIABLE_KEY.IMAGE.value ) ], description=ToolDescription( human=I18nObject(zh_Hans='图生文工具', en_US='Convert image to text'), llm='Vision tool used to extract text and other visual information from images, can be used for OCR, image captioning, etc.', ), is_team_authorization=model.status == ModelStatus.ACTIVE, tool_type=ModelTool.ModelToolType.VISION, model_instance=model_instance, model=model.model, )) self.tools = tools return tools def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]: """ returns the credentials schema of the provider :return: the credentials schema """ return {} def get_tools(self, user_id: str, tenant_id: str) -> list[ModelTool]: """ returns a list of tools that the provider can provide :return: list of tools """ return self._get_model_tools(tenant_id=tenant_id) def get_tool(self, tool_name: str) -> ModelTool: """ get tool by name :param tool_name: the name of the tool :return: the tool """ if self.tools is None: self.get_tools(user_id='', tenant_id=self.configuration.tenant_id) for tool in self.tools: if tool.identity.name == tool_name: return tool raise ValueError(f'tool {tool_name} not found') def get_parameters(self, tool_name: str) -> list[ToolParameter]: """ returns the parameters of the tool :param tool_name: the name of the tool, defined in `get_tools` :return: list of parameters """ tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None) if tool is None: raise ToolNotFoundError(f'tool {tool_name} not found') return tool.parameters @property def app_type(self) -> ToolProviderType: """ returns the type of the provider :return: type of the provider """ return ToolProviderType.MODEL def validate_credentials(self, credentials: dict[str, Any]) -> None: """ validate the credentials of the provider :param tool_name: the name of the tool, defined in `get_tools` :param credentials: the credentials of the tool """ pass def _validate_credentials(self, credentials: dict[str, Any]) -> None: """ validate the credentials of the provider :param tool_name: the name of the tool, defined in `get_tools` :param credentials: the credentials of the tool """ pass