12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849 |
- from typing import Any, Dict, List, Optional
- from langchain import LLMChain as LCLLMChain
- from langchain.callbacks.manager import CallbackManagerForChainRun
- from langchain.schema import Generation, LLMResult
- from langchain.schema.language_model import BaseLanguageModel
- from core.agent.agent.agent_llm_callback import AgentLLMCallback
- from core.entities.application_entities import ModelConfigEntity
- from core.entities.message_entities import lc_messages_to_prompt_messages
- from core.model_manager import ModelInstance
- from core.third_party.langchain.llms.fake import FakeLLM
- class LLMChain(LCLLMChain):
- model_config: ModelConfigEntity
- """The language model instance to use."""
- llm: BaseLanguageModel = FakeLLM(response="")
- parameters: Dict[str, Any] = {}
- agent_llm_callback: Optional[AgentLLMCallback] = None
- def generate(
- self,
- input_list: List[Dict[str, Any]],
- run_manager: Optional[CallbackManagerForChainRun] = None,
- ) -> LLMResult:
- """Generate LLM result from inputs."""
- prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
- messages = prompts[0].to_messages()
- prompt_messages = lc_messages_to_prompt_messages(messages)
- model_instance = ModelInstance(
- provider_model_bundle=self.model_config.provider_model_bundle,
- model=self.model_config.model,
- )
- result = model_instance.invoke_llm(
- prompt_messages=prompt_messages,
- stream=False,
- stop=stop,
- callbacks=[self.agent_llm_callback] if self.agent_llm_callback else None,
- model_parameters=self.parameters
- )
- generations = [
- [Generation(text=result.message.content)]
- ]
- return LLMResult(generations=generations)
|