_positions.py 982 B

12345678910111213141516171819202122232425262728
  1. import os.path
  2. from yaml import FullLoader, load
  3. from core.tools.entities.user_entities import UserToolProvider
  4. position = {}
  5. class BuiltinToolProviderSort:
  6. @staticmethod
  7. def sort(providers: list[UserToolProvider]) -> list[UserToolProvider]:
  8. global position
  9. if not position:
  10. tmp_position = {}
  11. file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
  12. with open(file_path) as f:
  13. for pos, val in enumerate(load(f, Loader=FullLoader)):
  14. tmp_position[val] = pos
  15. position = tmp_position
  16. def sort_compare(provider: UserToolProvider) -> int:
  17. # if provider.type == UserToolProvider.ProviderType.MODEL:
  18. # return position.get(f'model_provider.{provider.name}', 10000)
  19. return position.get(provider.name, 10000)
  20. sorted_providers = sorted(providers, key=sort_compare)
  21. return sorted_providers