_positions.py 1.0 KB

1234567891011121314151617181920212223242526272829
  1. from core.tools.entities.user_entities import UserToolProvider
  2. from core.tools.entities.tool_entities import ToolProviderType
  3. from typing import List
  4. from yaml import load, FullLoader
  5. import os.path
  6. position = {}
  7. class BuiltinToolProviderSort:
  8. @staticmethod
  9. def sort(providers: List[UserToolProvider]) -> List[UserToolProvider]:
  10. global position
  11. if not position:
  12. tmp_position = {}
  13. file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
  14. with open(file_path, 'r') as f:
  15. for pos, val in enumerate(load(f, Loader=FullLoader)):
  16. tmp_position[val] = pos
  17. position = tmp_position
  18. def sort_compare(provider: UserToolProvider) -> int:
  19. # if provider.type == UserToolProvider.ProviderType.MODEL:
  20. # return position.get(f'model_provider.{provider.name}', 10000)
  21. return position.get(provider.name, 10000)
  22. sorted_providers = sorted(providers, key=sort_compare)
  23. return sorted_providers