completion_service.py 21 KB


  1. import json
  2. import logging
  3. import threading
  4. import time
  5. import uuid
  6. from typing import Generator, Union, Any
  7. from flask import current_app, Flask
  8. from redis.client import PubSub
  9. from sqlalchemy import and_
  10. from core.completion import Completion
  11. from core.conversation_message_task import PubHandler, ConversationTaskStoppedException
  12. from core.llm.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \
  13. LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
  14. from extensions.ext_database import db
  15. from extensions.ext_redis import redis_client
  16. from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
  17. from services.app_model_config_service import AppModelConfigService
  18. from services.errors.app import MoreLikeThisDisabledError
  19. from services.errors.app_model_config import AppModelConfigBrokenError
  20. from services.errors.completion import CompletionStoppedError
  21. from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
  22. from services.errors.message import MessageNotExistsError
  23. class CompletionService:
  24. @classmethod
  25. def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any,
  26. from_source: str, streaming: bool = True,
  27. is_model_config_override: bool = False) -> Union[dict | Generator]:
  28. # is streaming mode
  29. inputs = args['inputs']
  30. query = args['query']
  31. if not query:
  32. raise ValueError('query is required')
  33. conversation_id = args['conversation_id'] if 'conversation_id' in args else None
  34. conversation = None
  35. if conversation_id:
  36. conversation_filter = [
  37. Conversation.id == args['conversation_id'],
  38. Conversation.app_id == app_model.id,
  39. Conversation.status == 'normal'
  40. ]
  41. if from_source == 'console':
  42. conversation_filter.append(Conversation.from_account_id == user.id)
  43. else:
  44. conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
  45. conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first()
  46. if not conversation:
  47. raise ConversationNotExistsError()
  48. if conversation.status != 'normal':
  49. raise ConversationCompletedError()
  50. if not conversation.override_model_configs:
  51. app_model_config = db.session.query(AppModelConfig).get(conversation.app_model_config_id)
  52. if not app_model_config:
  53. raise AppModelConfigBrokenError()
  54. else:
  55. conversation_override_model_configs = json.loads(conversation.override_model_configs)
  56. app_model_config = AppModelConfig(
  57. id=conversation.app_model_config_id,
  58. app_id=app_model.id,
  59. provider="",
  60. model_id="",
  61. configs="",
  62. opening_statement=conversation_override_model_configs['opening_statement'],
  63. suggested_questions=json.dumps(conversation_override_model_configs['suggested_questions']),
  64. model=json.dumps(conversation_override_model_configs['model']),
  65. user_input_form=json.dumps(conversation_override_model_configs['user_input_form']),
  66. pre_prompt=conversation_override_model_configs['pre_prompt'],
  67. agent_mode=json.dumps(conversation_override_model_configs['agent_mode']),
  68. )
  69. if is_model_config_override:
  70. # build new app model config
  71. if 'model' not in args['model_config']:
  72. raise ValueError('model_config.model is required')
  73. if 'completion_params' not in args['model_config']['model']:
  74. raise ValueError('model_config.model.completion_params is required')
  75. completion_params = AppModelConfigService.validate_model_completion_params(
  76. cp=args['model_config']['model']['completion_params'],
  77. model_name=app_model_config.model_dict["name"]
  78. )
  79. app_model_config_model = app_model_config.model_dict
  80. app_model_config_model['completion_params'] = completion_params
  81. app_model_config = AppModelConfig(
  82. id=app_model_config.id,
  83. app_id=app_model.id,
  84. provider="",
  85. model_id="",
  86. configs="",
  87. opening_statement=app_model_config.opening_statement,
  88. suggested_questions=app_model_config.suggested_questions,
  89. model=json.dumps(app_model_config_model),
  90. user_input_form=app_model_config.user_input_form,
  91. pre_prompt=app_model_config.pre_prompt,
  92. agent_mode=app_model_config.agent_mode,
  93. )
  94. else:
  95. if app_model.app_model_config_id is None:
  96. raise AppModelConfigBrokenError()
  97. app_model_config = app_model.app_model_config
  98. if not app_model_config:
  99. raise AppModelConfigBrokenError()
  100. if is_model_config_override:
  101. if not isinstance(user, Account):
  102. raise Exception("Only account can override model config")
  103. # validate config
  104. model_config = AppModelConfigService.validate_configuration(
  105. account=user,
  106. config=args['model_config'],
  107. mode=app_model.mode
  108. )
  109. app_model_config = AppModelConfig(
  110. id=app_model_config.id,
  111. app_id=app_model.id,
  112. provider="",
  113. model_id="",
  114. configs="",
  115. opening_statement=model_config['opening_statement'],
  116. suggested_questions=json.dumps(model_config['suggested_questions']),
  117. suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']),
  118. more_like_this=json.dumps(model_config['more_like_this']),
  119. model=json.dumps(model_config['model']),
  120. user_input_form=json.dumps(model_config['user_input_form']),
  121. pre_prompt=model_config['pre_prompt'],
  122. agent_mode=json.dumps(model_config['agent_mode']),
  123. )
  124. # clean input by app_model_config form rules
  125. inputs = cls.get_cleaned_inputs(inputs, app_model_config)
  126. generate_task_id = str(uuid.uuid4())
  127. pubsub = redis_client.pubsub()
  128. pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
  129. user = cls.get_real_user_instead_of_proxy_obj(user)
  130. generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
  131. 'flask_app': current_app._get_current_object(),
  132. 'generate_task_id': generate_task_id,
  133. 'app_model': app_model,
  134. 'app_model_config': app_model_config,
  135. 'query': query,
  136. 'inputs': inputs,
  137. 'user': user,
  138. 'conversation': conversation,
  139. 'streaming': streaming,
  140. 'is_model_config_override': is_model_config_override
  141. })
  142. generate_worker_thread.start()
  143. # wait for 5 minutes to close the thread
  144. cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
  145. return cls.compact_response(pubsub, streaming)
  146. @classmethod
  147. def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]):
  148. if isinstance(user, Account):
  149. user = db.session.query(Account).get(user.id)
  150. elif isinstance(user, EndUser):
  151. user = db.session.query(EndUser).get(user.id)
  152. else:
  153. raise Exception("Unknown user type")
  154. return user
  155. @classmethod
  156. def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig,
  157. query: str, inputs: dict, user: Union[Account, EndUser],
  158. conversation: Conversation, streaming: bool, is_model_config_override: bool):
  159. with flask_app.app_context():
  160. try:
  161. if conversation:
  162. # fixed the state of the conversation object when it detached from the original session
  163. conversation = db.session.query(Conversation).filter_by(id=conversation.id).first()
  164. # run
  165. Completion.generate(
  166. task_id=generate_task_id,
  167. app=app_model,
  168. app_model_config=app_model_config,
  169. query=query,
  170. inputs=inputs,
  171. user=user,
  172. conversation=conversation,
  173. streaming=streaming,
  174. is_override=is_model_config_override,
  175. )
  176. except ConversationTaskStoppedException:
  177. pass
  178. except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
  179. LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
  180. ModelCurrentlyNotSupportError) as e:
  181. db.session.rollback()
  182. PubHandler.pub_error(user, generate_task_id, e)
  183. except LLMAuthorizationError:
  184. db.session.rollback()
  185. PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
  186. except Exception as e:
  187. db.session.rollback()
  188. logging.exception("Unknown Error in completion")
  189. PubHandler.pub_error(user, generate_task_id, e)
  190. @classmethod
  191. def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
  192. # wait for 5 minutes to close the thread
  193. timeout = 300
  194. def close_pubsub():
  195. sleep_iterations = 0
  196. while sleep_iterations < timeout and worker_thread.is_alive():
  197. time.sleep(1)
  198. sleep_iterations += 1
  199. if worker_thread.is_alive():
  200. PubHandler.stop(user, generate_task_id)
  201. try:
  202. pubsub.close()
  203. except:
  204. pass
  205. countdown_thread = threading.Thread(target=close_pubsub)
  206. countdown_thread.start()
  207. return countdown_thread
  208. @classmethod
  209. def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
  210. message_id: str, streaming: bool = True) -> Union[dict | Generator]:
  211. if not user:
  212. raise ValueError('user cannot be None')
  213. message = db.session.query(Message).filter(
  214. Message.id == message_id,
  215. Message.app_id == app_model.id,
  216. Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
  217. Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
  218. Message.from_account_id == (user.id if isinstance(user, Account) else None),
  219. ).first()
  220. if not message:
  221. raise MessageNotExistsError()
  222. current_app_model_config = app_model.app_model_config
  223. more_like_this = current_app_model_config.more_like_this_dict
  224. if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
  225. raise MoreLikeThisDisabledError()
  226. app_model_config = message.app_model_config
  227. if message.override_model_configs:
  228. override_model_configs = json.loads(message.override_model_configs)
  229. pre_prompt = override_model_configs.get("pre_prompt", '')
  230. elif app_model_config:
  231. pre_prompt = app_model_config.pre_prompt
  232. else:
  233. raise AppModelConfigBrokenError()
  234. generate_task_id = str(uuid.uuid4())
  235. pubsub = redis_client.pubsub()
  236. pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
  237. user = cls.get_real_user_instead_of_proxy_obj(user)
  238. generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={
  239. 'flask_app': current_app._get_current_object(),
  240. 'generate_task_id': generate_task_id,
  241. 'app_model': app_model,
  242. 'app_model_config': app_model_config,
  243. 'message': message,
  244. 'pre_prompt': pre_prompt,
  245. 'user': user,
  246. 'streaming': streaming
  247. })
  248. generate_worker_thread.start()
  249. cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
  250. return cls.compact_response(pubsub, streaming)
  251. @classmethod
  252. def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App,
  253. app_model_config: AppModelConfig, message: Message, pre_prompt: str,
  254. user: Union[Account, EndUser], streaming: bool):
  255. with flask_app.app_context():
  256. try:
  257. # run
  258. Completion.generate_more_like_this(
  259. task_id=generate_task_id,
  260. app=app_model,
  261. user=user,
  262. message=message,
  263. pre_prompt=pre_prompt,
  264. app_model_config=app_model_config,
  265. streaming=streaming
  266. )
  267. except ConversationTaskStoppedException:
  268. pass
  269. except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
  270. LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
  271. ModelCurrentlyNotSupportError) as e:
  272. db.session.rollback()
  273. PubHandler.pub_error(user, generate_task_id, e)
  274. except LLMAuthorizationError:
  275. db.session.rollback()
  276. PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
  277. except Exception as e:
  278. db.session.rollback()
  279. logging.exception("Unknown Error in completion")
  280. PubHandler.pub_error(user, generate_task_id, e)
  281. @classmethod
  282. def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
  283. if user_inputs is None:
  284. user_inputs = {}
  285. filtered_inputs = {}
  286. # Filter input variables from form configuration, handle required fields, default values, and option values
  287. input_form_config = app_model_config.user_input_form_list
  288. for config in input_form_config:
  289. input_config = list(config.values())[0]
  290. variable = input_config["variable"]
  291. input_type = list(config.keys())[0]
  292. if variable not in user_inputs or not user_inputs[variable]:
  293. if "required" in input_config and input_config["required"]:
  294. raise ValueError(f"{variable} is required in input form")
  295. else:
  296. filtered_inputs[variable] = input_config["default"] if "default" in input_config else ""
  297. continue
  298. value = user_inputs[variable]
  299. if input_type == "select":
  300. options = input_config["options"] if "options" in input_config else []
  301. if value not in options:
  302. raise ValueError(f"{variable} in input form must be one of the following: {options}")
  303. else:
  304. if 'max_length' in variable:
  305. max_length = variable['max_length']
  306. if len(value) > max_length:
  307. raise ValueError(f'{variable} in input form must be less than {max_length} characters')
  308. filtered_inputs[variable] = value
  309. return filtered_inputs
  310. @classmethod
  311. def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict | Generator]:
  312. generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
  313. if not streaming:
  314. try:
  315. for message in pubsub.listen():
  316. if message["type"] == "message":
  317. result = message["data"].decode('utf-8')
  318. result = json.loads(result)
  319. if result.get('error'):
  320. cls.handle_error(result)
  321. return cls.get_message_response_data(result.get('data'))
  322. except ValueError as e:
  323. if e.args[0] != "I/O operation on closed file.": # ignore this error
  324. raise CompletionStoppedError()
  325. else:
  326. logging.exception(e)
  327. raise
  328. finally:
  329. try:
  330. pubsub.unsubscribe(generate_channel)
  331. except ConnectionError:
  332. pass
  333. else:
  334. def generate() -> Generator:
  335. try:
  336. for message in pubsub.listen():
  337. if message["type"] == "message":
  338. result = message["data"].decode('utf-8')
  339. result = json.loads(result)
  340. if result.get('error'):
  341. cls.handle_error(result)
  342. event = result.get('event')
  343. if event == "end":
  344. logging.debug("{} finished".format(generate_channel))
  345. break
  346. if event == 'message':
  347. yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
  348. elif event == 'chain':
  349. yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
  350. elif event == 'agent_thought':
  351. yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
  352. except ValueError as e:
  353. if e.args[0] != "I/O operation on closed file.": # ignore this error
  354. logging.exception(e)
  355. raise
  356. finally:
  357. try:
  358. pubsub.unsubscribe(generate_channel)
  359. except ConnectionError:
  360. pass
  361. return generate()
  362. @classmethod
  363. def get_message_response_data(cls, data: dict):
  364. response_data = {
  365. 'event': 'message',
  366. 'task_id': data.get('task_id'),
  367. 'id': data.get('message_id'),
  368. 'answer': data.get('text'),
  369. 'created_at': int(time.time())
  370. }
  371. if data.get('mode') == 'chat':
  372. response_data['conversation_id'] = data.get('conversation_id')
  373. return response_data
  374. @classmethod
  375. def get_chain_response_data(cls, data: dict):
  376. response_data = {
  377. 'event': 'chain',
  378. 'id': data.get('chain_id'),
  379. 'task_id': data.get('task_id'),
  380. 'message_id': data.get('message_id'),
  381. 'type': data.get('type'),
  382. 'input': data.get('input'),
  383. 'output': data.get('output'),
  384. 'created_at': int(time.time())
  385. }
  386. if data.get('mode') == 'chat':
  387. response_data['conversation_id'] = data.get('conversation_id')
  388. return response_data
  389. @classmethod
  390. def get_agent_thought_response_data(cls, data: dict):
  391. response_data = {
  392. 'event': 'agent_thought',
  393. 'id': data.get('agent_thought_id'),
  394. 'chain_id': data.get('chain_id'),
  395. 'task_id': data.get('task_id'),
  396. 'message_id': data.get('message_id'),
  397. 'position': data.get('position'),
  398. 'thought': data.get('thought'),
  399. 'tool': data.get('tool'), # todo use real dataset obj replace it
  400. 'tool_input': data.get('tool_input'),
  401. 'observation': data.get('observation'),
  402. 'answer': data.get('answer') if not data.get('thought') else '',
  403. 'created_at': int(time.time())
  404. }
  405. if data.get('mode') == 'chat':
  406. response_data['conversation_id'] = data.get('conversation_id')
  407. return response_data
  408. @classmethod
  409. def handle_error(cls, result: dict):
  410. logging.debug("error: %s", result)
  411. error = result.get('error')
  412. description = result.get('description')
  413. # handle errors
  414. llm_errors = {
  415. 'LLMBadRequestError': LLMBadRequestError,
  416. 'LLMAPIConnectionError': LLMAPIConnectionError,
  417. 'LLMAPIUnavailableError': LLMAPIUnavailableError,
  418. 'LLMRateLimitError': LLMRateLimitError,
  419. 'ProviderTokenNotInitError': ProviderTokenNotInitError,
  420. 'QuotaExceededError': QuotaExceededError,
  421. 'ModelCurrentlyNotSupportError': ModelCurrentlyNotSupportError
  422. }
  423. if error in llm_errors:
  424. raise llm_errors[error](description)
  425. elif error == 'LLMAuthorizationError':
  426. raise LLMAuthorizationError('Incorrect API key provided')
  427. else:
  428. raise Exception(description)