fake.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import time
  2. from typing import List, Optional, Any, Mapping
  3. from langchain.callbacks.manager import CallbackManagerForLLMRun
  4. from langchain.chat_models.base import SimpleChatModel
  5. from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration
  6. class FakeLLM(SimpleChatModel):
  7. """Fake ChatModel for testing purposes."""
  8. streaming: bool = False
  9. """Whether to stream the results or not."""
  10. response: str
  11. @property
  12. def _llm_type(self) -> str:
  13. return "fake-chat-model"
  14. def _call(
  15. self,
  16. messages: List[BaseMessage],
  17. stop: Optional[List[str]] = None,
  18. run_manager: Optional[CallbackManagerForLLMRun] = None,
  19. **kwargs: Any,
  20. ) -> str:
  21. """First try to lookup in queries, else return 'foo' or 'bar'."""
  22. return self.response
  23. @property
  24. def _identifying_params(self) -> Mapping[str, Any]:
  25. return {"response": self.response}
  26. def get_num_tokens(self, text: str) -> int:
  27. return 0
  28. def _generate(
  29. self,
  30. messages: List[BaseMessage],
  31. stop: Optional[List[str]] = None,
  32. run_manager: Optional[CallbackManagerForLLMRun] = None,
  33. **kwargs: Any,
  34. ) -> ChatResult:
  35. output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
  36. if self.streaming:
  37. for token in output_str:
  38. if run_manager:
  39. run_manager.on_llm_new_token(token)
  40. time.sleep(0.01)
  41. message = AIMessage(content=output_str)
  42. generation = ChatGeneration(message=message)
  43. llm_output = {"token_usage": {
  44. 'prompt_tokens': 0,
  45. 'completion_tokens': 0,
  46. 'total_tokens': 0,
  47. }}
  48. return ChatResult(generations=[generation], llm_output=llm_output)