fake.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. import time
  2. from typing import List, Optional, Any, Mapping, Callable
  3. from langchain.callbacks.manager import CallbackManagerForLLMRun
  4. from langchain.chat_models.base import SimpleChatModel
  5. from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration
  6. from core.model_providers.models.entity.message import str_to_prompt_messages
  7. class FakeLLM(SimpleChatModel):
  8. """Fake ChatModel for testing purposes."""
  9. streaming: bool = False
  10. """Whether to stream the results or not."""
  11. response: str
  12. num_token_func: Optional[Callable] = None
  13. @property
  14. def _llm_type(self) -> str:
  15. return "fake-chat-model"
  16. def _call(
  17. self,
  18. messages: List[BaseMessage],
  19. stop: Optional[List[str]] = None,
  20. run_manager: Optional[CallbackManagerForLLMRun] = None,
  21. **kwargs: Any,
  22. ) -> str:
  23. """First try to lookup in queries, else return 'foo' or 'bar'."""
  24. return self.response
  25. @property
  26. def _identifying_params(self) -> Mapping[str, Any]:
  27. return {"response": self.response}
  28. def get_num_tokens(self, text: str) -> int:
  29. return self.num_token_func(str_to_prompt_messages([text])) if self.num_token_func else 0
  30. def _generate(
  31. self,
  32. messages: List[BaseMessage],
  33. stop: Optional[List[str]] = None,
  34. run_manager: Optional[CallbackManagerForLLMRun] = None,
  35. **kwargs: Any,
  36. ) -> ChatResult:
  37. output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
  38. if self.streaming:
  39. for token in output_str:
  40. if run_manager:
  41. run_manager.on_llm_new_token(token)
  42. time.sleep(0.01)
  43. message = AIMessage(content=output_str)
  44. generation = ChatGeneration(message=message)
  45. llm_output = {"token_usage": {
  46. 'prompt_tokens': 0,
  47. 'completion_tokens': 0,
  48. 'total_tokens': 0,
  49. }}
  50. return ChatResult(generations=[generation], llm_output=llm_output)