llm_chain.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. from typing import Any, Dict, List, Optional
  2. from langchain import LLMChain as LCLLMChain
  3. from langchain.callbacks.manager import CallbackManagerForChainRun
  4. from langchain.schema import Generation, LLMResult
  5. from langchain.schema.language_model import BaseLanguageModel
  6. from core.agent.agent.agent_llm_callback import AgentLLMCallback
  7. from core.entities.application_entities import ModelConfigEntity
  8. from core.entities.message_entities import lc_messages_to_prompt_messages
  9. from core.model_manager import ModelInstance
  10. from core.third_party.langchain.llms.fake import FakeLLM
  11. class LLMChain(LCLLMChain):
  12. model_config: ModelConfigEntity
  13. """The language model instance to use."""
  14. llm: BaseLanguageModel = FakeLLM(response="")
  15. parameters: Dict[str, Any] = {}
  16. agent_llm_callback: Optional[AgentLLMCallback] = None
  17. def generate(
  18. self,
  19. input_list: List[Dict[str, Any]],
  20. run_manager: Optional[CallbackManagerForChainRun] = None,
  21. ) -> LLMResult:
  22. """Generate LLM result from inputs."""
  23. prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
  24. messages = prompts[0].to_messages()
  25. prompt_messages = lc_messages_to_prompt_messages(messages)
  26. model_instance = ModelInstance(
  27. provider_model_bundle=self.model_config.provider_model_bundle,
  28. model=self.model_config.model,
  29. )
  30. result = model_instance.invoke_llm(
  31. prompt_messages=prompt_messages,
  32. stream=False,
  33. stop=stop,
  34. callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None,
  35. model_parameters=self.parameters
  36. )
  37. generations = [
  38. [Generation(text=result.message.content)]
  39. ]
  40. return LLMResult(generations=generations)