123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284 |
- import logging
- import threading
- import time
- from typing import Any, Dict, List, Union, Optional
- from flask import Flask, current_app
- from langchain.callbacks.base import BaseCallbackHandler
- from langchain.schema import LLMResult, BaseMessage
- from pydantic import BaseModel
- from core.callback_handler.entity.llm_message import LLMMessage
- from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
- ConversationTaskInterruptException
- from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
- ImagePromptMessageFile
- from core.model_providers.models.llm.base import BaseLLM
- from core.moderation.base import ModerationOutputsResult, ModerationAction
- from core.moderation.factory import ModerationFactory
- class ModerationRule(BaseModel):
- type: str
- config: Dict[str, Any]
- class LLMCallbackHandler(BaseCallbackHandler):
- raise_error: bool = True
- def __init__(self, model_instance: BaseLLM,
- conversation_message_task: ConversationMessageTask):
- self.model_instance = model_instance
- self.llm_message = LLMMessage()
- self.start_at = None
- self.conversation_message_task = conversation_message_task
- self.output_moderation_handler = None
- self.init_output_moderation()
- def init_output_moderation(self):
- app_model_config = self.conversation_message_task.app_model_config
- sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
- if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
- self.output_moderation_handler = OutputModerationHandler(
- tenant_id=self.conversation_message_task.tenant_id,
- app_id=self.conversation_message_task.app.id,
- rule=ModerationRule(
- type=sensitive_word_avoidance_dict.get("type"),
- config=sensitive_word_avoidance_dict.get("config")
- ),
- on_message_replace_func=self.conversation_message_task.on_message_replace
- )
- @property
- def always_verbose(self) -> bool:
- """Whether to call verbose callbacks even if verbose is False."""
- return True
- def on_chat_model_start(
- self,
- serialized: Dict[str, Any],
- messages: List[List[BaseMessage]],
- **kwargs: Any
- ) -> Any:
- real_prompts = []
- for message in messages[0]:
- if message.type == 'human':
- role = 'user'
- elif message.type == 'ai':
- role = 'assistant'
- else:
- role = 'system'
- real_prompts.append({
- "role": role,
- "text": message.content,
- "files": [{
- "type": file.type.value,
- "data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
- "detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
- } for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
- })
- self.llm_message.prompt = real_prompts
- self.llm_message.prompt_tokens = self.model_instance.get_num_tokens(to_prompt_messages(messages[0]))
- def on_llm_start(
- self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
- ) -> None:
- self.llm_message.prompt = [{
- "role": 'user',
- "text": prompts[0]
- }]
- self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
- def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
- if self.output_moderation_handler:
- self.output_moderation_handler.stop_thread()
- self.llm_message.completion = self.output_moderation_handler.moderation_completion(
- completion=response.generations[0][0].text,
- public_event=True if self.conversation_message_task.streaming else False
- )
- else:
- self.llm_message.completion = response.generations[0][0].text
- if not self.conversation_message_task.streaming:
- self.conversation_message_task.append_message_text(self.llm_message.completion)
- if response.llm_output and 'token_usage' in response.llm_output:
- if 'prompt_tokens' in response.llm_output['token_usage']:
- self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
- if 'completion_tokens' in response.llm_output['token_usage']:
- self.llm_message.completion_tokens = response.llm_output['token_usage']['completion_tokens']
- else:
- self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
- [PromptMessage(content=self.llm_message.completion)])
- else:
- self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
- [PromptMessage(content=self.llm_message.completion)])
- self.conversation_message_task.save_message(self.llm_message)
- def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
- if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
- # stop subscribe new token when output moderation should direct output
- ex = ConversationTaskInterruptException()
- self.on_llm_error(error=ex)
- raise ex
- try:
- self.conversation_message_task.append_message_text(token)
- self.llm_message.completion += token
- if self.output_moderation_handler:
- self.output_moderation_handler.append_new_token(token)
- except ConversationTaskStoppedException as ex:
- self.on_llm_error(error=ex)
- raise ex
- def on_llm_error(
- self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
- ) -> None:
- """Do nothing."""
- if self.output_moderation_handler:
- self.output_moderation_handler.stop_thread()
- if isinstance(error, ConversationTaskStoppedException):
- if self.conversation_message_task.streaming:
- self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
- [PromptMessage(content=self.llm_message.completion)]
- )
- self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
- if isinstance(error, ConversationTaskInterruptException):
- self.llm_message.completion = self.output_moderation_handler.get_final_output()
- self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
- [PromptMessage(content=self.llm_message.completion)]
- )
- self.conversation_message_task.save_message(llm_message=self.llm_message)
- else:
- logging.debug("on_llm_error: %s", error)
- class OutputModerationHandler(BaseModel):
- DEFAULT_BUFFER_SIZE: int = 300
- tenant_id: str
- app_id: str
- rule: ModerationRule
- on_message_replace_func: Any
- thread: Optional[threading.Thread] = None
- thread_running: bool = True
- buffer: str = ''
- is_final_chunk: bool = False
- final_output: Optional[str] = None
- class Config:
- arbitrary_types_allowed = True
- def should_direct_output(self):
- return self.final_output is not None
- def get_final_output(self):
- return self.final_output
- def append_new_token(self, token: str):
- self.buffer += token
- if not self.thread:
- self.thread = self.start_thread()
- def moderation_completion(self, completion: str, public_event: bool = False) -> str:
- self.buffer = completion
- self.is_final_chunk = True
- result = self.moderation(
- tenant_id=self.tenant_id,
- app_id=self.app_id,
- moderation_buffer=completion
- )
- if not result or not result.flagged:
- return completion
- if result.action == ModerationAction.DIRECT_OUTPUT:
- final_output = result.preset_response
- else:
- final_output = result.text
- if public_event:
- self.on_message_replace_func(final_output)
- return final_output
- def start_thread(self) -> threading.Thread:
- buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
- thread = threading.Thread(target=self.worker, kwargs={
- 'flask_app': current_app._get_current_object(),
- 'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
- })
- thread.start()
- return thread
- def stop_thread(self):
- if self.thread and self.thread.is_alive():
- self.thread_running = False
- def worker(self, flask_app: Flask, buffer_size: int):
- with flask_app.app_context():
- current_length = 0
- while self.thread_running:
- moderation_buffer = self.buffer
- buffer_length = len(moderation_buffer)
- if not self.is_final_chunk:
- chunk_length = buffer_length - current_length
- if 0 <= chunk_length < buffer_size:
- time.sleep(1)
- continue
- current_length = buffer_length
- result = self.moderation(
- tenant_id=self.tenant_id,
- app_id=self.app_id,
- moderation_buffer=moderation_buffer
- )
- if not result or not result.flagged:
- continue
- if result.action == ModerationAction.DIRECT_OUTPUT:
- final_output = result.preset_response
- self.final_output = final_output
- else:
- final_output = result.text + self.buffer[len(moderation_buffer):]
- # trigger replace event
- if self.thread_running:
- self.on_message_replace_func(final_output)
- if result.action == ModerationAction.DIRECT_OUTPUT:
- break
- def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
- try:
- moderation_factory = ModerationFactory(
- name=self.rule.type,
- app_id=app_id,
- tenant_id=tenant_id,
- config=self.rule.config
- )
- result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
- return result
- except Exception as e:
- logging.error("Moderation Output error: %s", e)
- return None
|