1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859 |
- import time
- from typing import List, Optional, Any, Mapping
- from langchain.callbacks.manager import CallbackManagerForLLMRun
- from langchain.chat_models.base import SimpleChatModel
- from langchain.schema import BaseMessage, ChatResult, AIMessage, ChatGeneration, BaseLanguageModel
- class FakeLLM(SimpleChatModel):
- """Fake ChatModel for testing purposes."""
- streaming: bool = False
- """Whether to stream the results or not."""
- response: str
- origin_llm: Optional[BaseLanguageModel] = None
- @property
- def _llm_type(self) -> str:
- return "fake-chat-model"
- def _call(
- self,
- messages: List[BaseMessage],
- stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> str:
- """First try to lookup in queries, else return 'foo' or 'bar'."""
- return self.response
- @property
- def _identifying_params(self) -> Mapping[str, Any]:
- return {"response": self.response}
- def get_num_tokens(self, text: str) -> int:
- return self.origin_llm.get_num_tokens(text) if self.origin_llm else 0
- def _generate(
- self,
- messages: List[BaseMessage],
- stop: Optional[List[str]] = None,
- run_manager: Optional[CallbackManagerForLLMRun] = None,
- **kwargs: Any,
- ) -> ChatResult:
- output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
- if self.streaming:
- for token in output_str:
- if run_manager:
- run_manager.on_llm_new_token(token)
- time.sleep(0.01)
- message = AIMessage(content=output_str)
- generation = ChatGeneration(message=message)
- llm_output = {"token_usage": {
- 'prompt_tokens': 0,
- 'completion_tokens': 0,
- 'total_tokens': 0,
- }}
- return ChatResult(generations=[generation], llm_output=llm_output)
|