tool_label_manager.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from core.tools.entities.values import default_tool_label_name_list
  2. from core.tools.provider.api_tool_provider import ApiToolProviderController
  3. from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
  4. from core.tools.provider.tool_provider import ToolProviderController
  5. from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
  6. from extensions.ext_database import db
  7. from models.tools import ToolLabelBinding
  8. class ToolLabelManager:
  9. @classmethod
  10. def filter_tool_labels(cls, tool_labels: list[str]) -> list[str]:
  11. """
  12. Filter tool labels
  13. """
  14. tool_labels = [label for label in tool_labels if label in default_tool_label_name_list]
  15. return list(set(tool_labels))
  16. @classmethod
  17. def update_tool_labels(cls, controller: ToolProviderController, labels: list[str]):
  18. """
  19. Update tool labels
  20. """
  21. labels = cls.filter_tool_labels(labels)
  22. if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
  23. provider_id = controller.provider_id
  24. else:
  25. raise ValueError('Unsupported tool type')
  26. # delete old labels
  27. db.session.query(ToolLabelBinding).filter(
  28. ToolLabelBinding.tool_id == provider_id
  29. ).delete()
  30. # insert new labels
  31. for label in labels:
  32. db.session.add(ToolLabelBinding(
  33. tool_id=provider_id,
  34. tool_type=controller.provider_type.value,
  35. label_name=label,
  36. ))
  37. db.session.commit()
  38. @classmethod
  39. def get_tool_labels(cls, controller: ToolProviderController) -> list[str]:
  40. """
  41. Get tool labels
  42. """
  43. if isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
  44. provider_id = controller.provider_id
  45. elif isinstance(controller, BuiltinToolProviderController):
  46. return controller.tool_labels
  47. else:
  48. raise ValueError('Unsupported tool type')
  49. labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding.label_name).filter(
  50. ToolLabelBinding.tool_id == provider_id,
  51. ToolLabelBinding.tool_type == controller.provider_type.value,
  52. ).all()
  53. return [label.label_name for label in labels]
  54. @classmethod
  55. def get_tools_labels(cls, tool_providers: list[ToolProviderController]) -> dict[str, list[str]]:
  56. """
  57. Get tools labels
  58. :param tool_providers: list of tool providers
  59. :return: dict of tool labels
  60. :key: tool id
  61. :value: list of tool labels
  62. """
  63. if not tool_providers:
  64. return {}
  65. for controller in tool_providers:
  66. if not isinstance(controller, ApiToolProviderController | WorkflowToolProviderController):
  67. raise ValueError('Unsupported tool type')
  68. provider_ids = [controller.provider_id for controller in tool_providers]
  69. labels: list[ToolLabelBinding] = db.session.query(ToolLabelBinding).filter(
  70. ToolLabelBinding.tool_id.in_(provider_ids)
  71. ).all()
  72. tool_labels = {
  73. label.tool_id: [] for label in labels
  74. }
  75. for label in labels:
  76. tool_labels[label.tool_id].append(label.label_name)
  77. return tool_labels