_positions.py 1011 B

1234567891011121314151617181920212223242526272829
  1. import os.path
  2. from typing import List
  3. from yaml import FullLoader, load
  4. from core.tools.entities.user_entities import UserToolProvider
  5. position = {}
  6. class BuiltinToolProviderSort:
  7. @staticmethod
  8. def sort(providers: List[UserToolProvider]) -> List[UserToolProvider]:
  9. global position
  10. if not position:
  11. tmp_position = {}
  12. file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
  13. with open(file_path, 'r') as f:
  14. for pos, val in enumerate(load(f, Loader=FullLoader)):
  15. tmp_position[val] = pos
  16. position = tmp_position
  17. def sort_compare(provider: UserToolProvider) -> int:
  18. # if provider.type == UserToolProvider.ProviderType.MODEL:
  19. # return position.get(f'model_provider.{provider.name}', 10000)
  20. return position.get(provider.name, 10000)
  21. sorted_providers = sorted(providers, key=sort_compare)
  22. return sorted_providers