123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- from core.tools.entities.values import default_tool_label_name_list
- from core.tools.provider.api_tool_provider import ApiToolProviderController
- from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
- from core.tools.provider.tool_provider import ToolProviderController
- from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
- from extensions.ext_database import db
- from models.tools import ToolLabelBinding
- class ToolLabelManager:
- @classmethod
- def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]:
- """
- Filter tool labels
- """
- tool_labels = [label for label in tool_labels if label in default_tool_label_name_list]
- return list(set(tool_labels))
-
- @classmethod
- def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
- """
- Update tool labels
- """
- labels = cls.filter_tool_labels(labels)
- if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
- provider_id = controller.provider_id
- else:
- raise ValueError('Unsupported tool type')
- # delete old labels
- db.session.query(ToolLabelBinding).filter(
- ToolLabelBinding.tool_id == provider_id
- ).delete()
- # insert new labels
- for label in labels:
- db.session.add(ToolLabelBinding(
- tool_id=provider_id,
- tool_type=controller.provider_type.value,
- label_name=label,
- ))
- db.session.commit()
- @classmethod
- def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
- """
- Get tool labels
- """
- if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
- provider_id = controller.provider_id
- elif isinstance(controller, BuiltinToolProviderController):
- return controller.tool_labels
- else:
- raise ValueError('Unsupported tool type')
- labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding.label_name).filter(
- ToolLabelBinding.tool_id == provider_id,
- ToolLabelBinding.tool_type == controller.provider_type.value,
- ).all()
- return [label.label_name for label in labels]
- @classmethod
- def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
- """
- Get tools labels
- :param tool_providers: list of tool providers
- :return: dict of tool labels
- :key: tool id
- :value: list of tool labels
- """
- if not tool_providers:
- return {}
-
- for controller in tool_providers:
- if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
- raise ValueError('Unsupported tool type')
-
- provider_ids = [controller.provider_id for controller in tool_providers]
- labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding).filter(
- ToolLabelBinding.tool_id.in_(provider_ids)
- ).all()
- tool_labels = {
- label.tool_id: [] for label in labels
- }
- for label in labels:
- tool_labels[label.tool_id].append(label.label_name)
- return tool_labels
|