huggingface_chat.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import re
  2. from typing import Any, Generator, List, Literal, Optional, Union
  3. from _pytest.monkeypatch import MonkeyPatch
  4. from huggingface_hub import InferenceClient
  5. from huggingface_hub.inference._text_generation import (Details, StreamDetails, TextGenerationResponse,
  6. TextGenerationStreamResponse, Token)
  7. from huggingface_hub.utils import BadRequestError
  8. class MockHuggingfaceChatClass(object):
  9. @staticmethod
  10. def generate_create_sync(model: str) -> TextGenerationResponse:
  11. response = TextGenerationResponse(
  12. generated_text="You can call me Miku Miku o~e~o~",
  13. details=Details(
  14. finish_reason="length",
  15. generated_tokens=6,
  16. tokens=[
  17. Token(id=0, text="You", logprob=0.0, special=False) for i in range(0, 6)
  18. ]
  19. )
  20. )
  21. return response
  22. @staticmethod
  23. def generate_create_stream(model: str) -> Generator[TextGenerationStreamResponse, None, None]:
  24. full_text = "You can call me Miku Miku o~e~o~"
  25. for i in range(0, len(full_text)):
  26. response = TextGenerationStreamResponse(
  27. token = Token(id=i, text=full_text[i], logprob=0.0, special=False),
  28. )
  29. response.generated_text = full_text[i]
  30. response.details = StreamDetails(finish_reason='stop_sequence', generated_tokens=1)
  31. yield response
  32. def text_generation(self: InferenceClient, prompt: str, *,
  33. stream: Literal[False] = ...,
  34. model: Optional[str] = None,
  35. **kwargs: Any
  36. ) -> Union[TextGenerationResponse, Generator[TextGenerationStreamResponse, None, None]]:
  37. # check if key is valid
  38. if not re.match(r'Bearer\shf\-[a-zA-Z0-9]{16,}', self.headers['authorization']):
  39. raise BadRequestError('Invalid API key')
  40. if model is None:
  41. raise BadRequestError('Invalid model')
  42. if stream:
  43. return MockHuggingfaceChatClass.generate_create_stream(model)
  44. return MockHuggingfaceChatClass.generate_create_sync(model)