fake_llm.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. import time
  2. from collections.abc import Mapping
  3. from typing import Any, Optional
  4. from langchain.callbacks.manager import CallbackManagerForLLMRun
  5. from langchain.chat_models.base import SimpleChatModel
  6. from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
  7. class FakeLLM(SimpleChatModel):
  8. """Fake ChatModel for testing purposes."""
  9. streaming: bool = False
  10. """Whether to stream the results or not."""
  11. response: str
  12. @property
  13. def _llm_type(self) -> str:
  14. return "fake-chat-model"
  15. def _call(
  16. self,
  17. messages: list[BaseMessage],
  18. stop: Optional[list[str]] = None,
  19. run_manager: Optional[CallbackManagerForLLMRun] = None,
  20. **kwargs: Any,
  21. ) -> str:
  22. """First try to lookup in queries, else return 'foo' or 'bar'."""
  23. return self.response
  24. @property
  25. def _identifying_params(self) -> Mapping[str, Any]:
  26. return {"response": self.response}
  27. def get_num_tokens(self, text: str) -> int:
  28. return 0
  29. def _generate(
  30. self,
  31. messages: list[BaseMessage],
  32. stop: Optional[list[str]] = None,
  33. run_manager: Optional[CallbackManagerForLLMRun] = None,
  34. **kwargs: Any,
  35. ) -> ChatResult:
  36. output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
  37. if self.streaming:
  38. for token in output_str:
  39. if run_manager:
  40. run_manager.on_llm_new_token(token)
  41. time.sleep(0.01)
  42. message = AIMessage(content=output_str)
  43. generation = ChatGeneration(message=message)
  44. llm_output = {"token_usage": {
  45. 'prompt_tokens': 0,
  46. 'completion_tokens': 0,
  47. 'total_tokens': 0,
  48. }}
  49. return ChatResult(generations=[generation], llm_output=llm_output)