_positions.py 975 B

12345678910111213141516171819202122232425262728
  1. import os.path
  2. from yaml import FullLoader, load
  3. from core.tools.entities.user_entities import UserToolProvider
  4. class BuiltinToolProviderSort:
  5. _position = {}
  6. @classmethod
  7. def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
  8. if not cls._position:
  9. tmp_position = {}
  10. file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
  11. with open(file_path) as f:
  12. for pos, val in enumerate(load(f, Loader=FullLoader)):
  13. tmp_position[val] = pos
  14. cls._position = tmp_position
  15. def sort_compare(provider: UserToolProvider) -> int:
  16. if provider.type == UserToolProvider.ProviderType.MODEL:
  17. return cls._position.get(f'model.{provider.name}', 10000)
  18. return cls._position.get(provider.name, 10000)
  19. sorted_providers = sorted(providers, key=sort_compare)
  20. return sorted_providers