completion_service.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506
  1. import json
  2. import logging
  3. import threading
  4. import time
  5. import uuid
  6. from typing import Generator, Union, Any
  7. from flask import current_app, Flask
  8. from redis.client import PubSub
  9. from sqlalchemy import and_
  10. from core.completion import Completion
  11. from core.conversation_message_task import PubHandler, ConversationTaskStoppedException
  12. from core.llm.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \
  13. LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
  14. from extensions.ext_database import db
  15. from extensions.ext_redis import redis_client
  16. from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
  17. from services.app_model_config_service import AppModelConfigService
  18. from services.errors.app import MoreLikeThisDisabledError
  19. from services.errors.app_model_config import AppModelConfigBrokenError
  20. from services.errors.completion import CompletionStoppedError
  21. from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
  22. from services.errors.message import MessageNotExistsError
  23. class CompletionService:
  24. @classmethod
  25. def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any,
  26. from_source: str, streaming: bool = True,
  27. is_model_config_override: bool = False) -> Union[dict | Generator]:
  28. # is streaming mode
  29. inputs = args['inputs']
  30. query = args['query']
  31. conversation_id = args['conversation_id'] if 'conversation_id' in args else None
  32. conversation = None
  33. if conversation_id:
  34. conversation_filter = [
  35. Conversation.id == args['conversation_id'],
  36. Conversation.app_id == app_model.id,
  37. Conversation.status == 'normal'
  38. ]
  39. if from_source == 'console':
  40. conversation_filter.append(Conversation.from_account_id == user.id)
  41. else:
  42. conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
  43. conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first()
  44. if not conversation:
  45. raise ConversationNotExistsError()
  46. if conversation.status != 'normal':
  47. raise ConversationCompletedError()
  48. if not conversation.override_model_configs:
  49. app_model_config = db.session.query(AppModelConfig).get(conversation.app_model_config_id)
  50. if not app_model_config:
  51. raise AppModelConfigBrokenError()
  52. else:
  53. conversation_override_model_configs = json.loads(conversation.override_model_configs)
  54. app_model_config = AppModelConfig(
  55. id=conversation.app_model_config_id,
  56. app_id=app_model.id,
  57. provider="",
  58. model_id="",
  59. configs="",
  60. opening_statement=conversation_override_model_configs['opening_statement'],
  61. suggested_questions=json.dumps(conversation_override_model_configs['suggested_questions']),
  62. model=json.dumps(conversation_override_model_configs['model']),
  63. user_input_form=json.dumps(conversation_override_model_configs['user_input_form']),
  64. pre_prompt=conversation_override_model_configs['pre_prompt'],
  65. agent_mode=json.dumps(conversation_override_model_configs['agent_mode']),
  66. )
  67. if is_model_config_override:
  68. # build new app model config
  69. if 'model' not in args['model_config']:
  70. raise ValueError('model_config.model is required')
  71. if 'completion_params' not in args['model_config']['model']:
  72. raise ValueError('model_config.model.completion_params is required')
  73. completion_params = AppModelConfigService.validate_model_completion_params(
  74. cp=args['model_config']['model']['completion_params'],
  75. model_name=app_model_config.model_dict["name"]
  76. )
  77. app_model_config_model = app_model_config.model_dict
  78. app_model_config_model['completion_params'] = completion_params
  79. app_model_config = AppModelConfig(
  80. id=app_model_config.id,
  81. app_id=app_model.id,
  82. provider="",
  83. model_id="",
  84. configs="",
  85. opening_statement=app_model_config.opening_statement,
  86. suggested_questions=app_model_config.suggested_questions,
  87. model=json.dumps(app_model_config_model),
  88. user_input_form=app_model_config.user_input_form,
  89. pre_prompt=app_model_config.pre_prompt,
  90. agent_mode=app_model_config.agent_mode,
  91. )
  92. else:
  93. if app_model.app_model_config_id is None:
  94. raise AppModelConfigBrokenError()
  95. app_model_config = app_model.app_model_config
  96. if not app_model_config:
  97. raise AppModelConfigBrokenError()
  98. if is_model_config_override:
  99. if not isinstance(user, Account):
  100. raise Exception("Only account can override model config")
  101. # validate config
  102. model_config = AppModelConfigService.validate_configuration(
  103. account=user,
  104. config=args['model_config'],
  105. mode=app_model.mode
  106. )
  107. app_model_config = AppModelConfig(
  108. id=app_model_config.id,
  109. app_id=app_model.id,
  110. provider="",
  111. model_id="",
  112. configs="",
  113. opening_statement=model_config['opening_statement'],
  114. suggested_questions=json.dumps(model_config['suggested_questions']),
  115. suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']),
  116. more_like_this=json.dumps(model_config['more_like_this']),
  117. model=json.dumps(model_config['model']),
  118. user_input_form=json.dumps(model_config['user_input_form']),
  119. pre_prompt=model_config['pre_prompt'],
  120. agent_mode=json.dumps(model_config['agent_mode']),
  121. )
  122. # clean input by app_model_config form rules
  123. inputs = cls.get_cleaned_inputs(inputs, app_model_config)
  124. generate_task_id = str(uuid.uuid4())
  125. pubsub = redis_client.pubsub()
  126. pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
  127. user = cls.get_real_user_instead_of_proxy_obj(user)
  128. generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
  129. 'flask_app': current_app._get_current_object(),
  130. 'generate_task_id': generate_task_id,
  131. 'app_model': app_model,
  132. 'app_model_config': app_model_config,
  133. 'query': query,
  134. 'inputs': inputs,
  135. 'user': user,
  136. 'conversation': conversation,
  137. 'streaming': streaming,
  138. 'is_model_config_override': is_model_config_override
  139. })
  140. generate_worker_thread.start()
  141. # wait for 5 minutes to close the thread
  142. cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
  143. return cls.compact_response(pubsub, streaming)
  144. @classmethod
  145. def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]):
  146. if isinstance(user, Account):
  147. user = db.session.query(Account).get(user.id)
  148. elif isinstance(user, EndUser):
  149. user = db.session.query(EndUser).get(user.id)
  150. else:
  151. raise Exception("Unknown user type")
  152. return user
  153. @classmethod
  154. def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig,
  155. query: str, inputs: dict, user: Union[Account, EndUser],
  156. conversation: Conversation, streaming: bool, is_model_config_override: bool):
  157. with flask_app.app_context():
  158. try:
  159. if conversation:
  160. # fixed the state of the conversation object when it detached from the original session
  161. conversation = db.session.query(Conversation).filter_by(id=conversation.id).first()
  162. # run
  163. Completion.generate(
  164. task_id=generate_task_id,
  165. app=app_model,
  166. app_model_config=app_model_config,
  167. query=query,
  168. inputs=inputs,
  169. user=user,
  170. conversation=conversation,
  171. streaming=streaming,
  172. is_override=is_model_config_override,
  173. )
  174. except ConversationTaskStoppedException:
  175. pass
  176. except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
  177. LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
  178. ModelCurrentlyNotSupportError) as e:
  179. db.session.rollback()
  180. PubHandler.pub_error(user, generate_task_id, e)
  181. except LLMAuthorizationError:
  182. db.session.rollback()
  183. PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
  184. except Exception as e:
  185. db.session.rollback()
  186. logging.exception("Unknown Error in completion")
  187. PubHandler.pub_error(user, generate_task_id, e)
  188. @classmethod
  189. def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
  190. # wait for 5 minutes to close the thread
  191. timeout = 300
  192. def close_pubsub():
  193. sleep_iterations = 0
  194. while sleep_iterations < timeout and worker_thread.is_alive():
  195. time.sleep(1)
  196. sleep_iterations += 1
  197. if worker_thread.is_alive():
  198. PubHandler.stop(user, generate_task_id)
  199. try:
  200. pubsub.close()
  201. except:
  202. pass
  203. countdown_thread = threading.Thread(target=close_pubsub)
  204. countdown_thread.start()
  205. return countdown_thread
  206. @classmethod
  207. def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
  208. message_id: str, streaming: bool = True) -> Union[dict | Generator]:
  209. if not user:
  210. raise ValueError('user cannot be None')
  211. message = db.session.query(Message).filter(
  212. Message.id == message_id,
  213. Message.app_id == app_model.id,
  214. Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
  215. Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
  216. Message.from_account_id == (user.id if isinstance(user, Account) else None),
  217. ).first()
  218. if not message:
  219. raise MessageNotExistsError()
  220. current_app_model_config = app_model.app_model_config
  221. more_like_this = current_app_model_config.more_like_this_dict
  222. if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
  223. raise MoreLikeThisDisabledError()
  224. app_model_config = message.app_model_config
  225. if message.override_model_configs:
  226. override_model_configs = json.loads(message.override_model_configs)
  227. pre_prompt = override_model_configs.get("pre_prompt", '')
  228. elif app_model_config:
  229. pre_prompt = app_model_config.pre_prompt
  230. else:
  231. raise AppModelConfigBrokenError()
  232. generate_task_id = str(uuid.uuid4())
  233. pubsub = redis_client.pubsub()
  234. pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
  235. user = cls.get_real_user_instead_of_proxy_obj(user)
  236. generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={
  237. 'flask_app': current_app._get_current_object(),
  238. 'generate_task_id': generate_task_id,
  239. 'app_model': app_model,
  240. 'app_model_config': app_model_config,
  241. 'message': message,
  242. 'pre_prompt': pre_prompt,
  243. 'user': user,
  244. 'streaming': streaming
  245. })
  246. generate_worker_thread.start()
  247. cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
  248. return cls.compact_response(pubsub, streaming)
  249. @classmethod
  250. def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App,
  251. app_model_config: AppModelConfig, message: Message, pre_prompt: str,
  252. user: Union[Account, EndUser], streaming: bool):
  253. with flask_app.app_context():
  254. try:
  255. # run
  256. Completion.generate_more_like_this(
  257. task_id=generate_task_id,
  258. app=app_model,
  259. user=user,
  260. message=message,
  261. pre_prompt=pre_prompt,
  262. app_model_config=app_model_config,
  263. streaming=streaming
  264. )
  265. except ConversationTaskStoppedException:
  266. pass
  267. except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
  268. LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
  269. ModelCurrentlyNotSupportError) as e:
  270. db.session.rollback()
  271. PubHandler.pub_error(user, generate_task_id, e)
  272. except LLMAuthorizationError:
  273. db.session.rollback()
  274. PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
  275. except Exception as e:
  276. db.session.rollback()
  277. logging.exception("Unknown Error in completion")
  278. PubHandler.pub_error(user, generate_task_id, e)
  279. @classmethod
  280. def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
  281. if user_inputs is None:
  282. user_inputs = {}
  283. filtered_inputs = {}
  284. # Filter input variables from form configuration, handle required fields, default values, and option values
  285. input_form_config = app_model_config.user_input_form_list
  286. for config in input_form_config:
  287. input_config = list(config.values())[0]
  288. variable = input_config["variable"]
  289. input_type = list(config.keys())[0]
  290. if variable not in user_inputs or not user_inputs[variable]:
  291. if "required" in input_config and input_config["required"]:
  292. raise ValueError(f"{variable} is required in input form")
  293. else:
  294. filtered_inputs[variable] = input_config["default"] if "default" in input_config else ""
  295. continue
  296. value = user_inputs[variable]
  297. if input_type == "select":
  298. options = input_config["options"] if "options" in input_config else []
  299. if value not in options:
  300. raise ValueError(f"{variable} in input form must be one of the following: {options}")
  301. else:
  302. if 'max_length' in variable:
  303. max_length = variable['max_length']
  304. if len(value) > max_length:
  305. raise ValueError(f'{variable} in input form must be less than {max_length} characters')
  306. filtered_inputs[variable] = value
  307. return filtered_inputs
  308. @classmethod
  309. def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict | Generator]:
  310. generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
  311. if not streaming:
  312. try:
  313. for message in pubsub.listen():
  314. if message["type"] == "message":
  315. result = message["data"].decode('utf-8')
  316. result = json.loads(result)
  317. if result.get('error'):
  318. cls.handle_error(result)
  319. return cls.get_message_response_data(result.get('data'))
  320. except ValueError as e:
  321. if e.args[0] != "I/O operation on closed file.": # ignore this error
  322. raise CompletionStoppedError()
  323. else:
  324. logging.exception(e)
  325. raise
  326. finally:
  327. try:
  328. pubsub.unsubscribe(generate_channel)
  329. except ConnectionError:
  330. pass
  331. else:
  332. def generate() -> Generator:
  333. try:
  334. for message in pubsub.listen():
  335. if message["type"] == "message":
  336. result = message["data"].decode('utf-8')
  337. result = json.loads(result)
  338. if result.get('error'):
  339. cls.handle_error(result)
  340. event = result.get('event')
  341. if event == "end":
  342. logging.debug("{} finished".format(generate_channel))
  343. break
  344. if event == 'message':
  345. yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
  346. elif event == 'chain':
  347. yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
  348. elif event == 'agent_thought':
  349. yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
  350. except ValueError as e:
  351. if e.args[0] != "I/O operation on closed file.": # ignore this error
  352. logging.exception(e)
  353. raise
  354. finally:
  355. try:
  356. pubsub.unsubscribe(generate_channel)
  357. except ConnectionError:
  358. pass
  359. return generate()
  360. @classmethod
  361. def get_message_response_data(cls, data: dict):
  362. response_data = {
  363. 'event': 'message',
  364. 'task_id': data.get('task_id'),
  365. 'id': data.get('message_id'),
  366. 'answer': data.get('text'),
  367. 'created_at': int(time.time())
  368. }
  369. if data.get('mode') == 'chat':
  370. response_data['conversation_id'] = data.get('conversation_id')
  371. return response_data
  372. @classmethod
  373. def get_chain_response_data(cls, data: dict):
  374. response_data = {
  375. 'event': 'chain',
  376. 'id': data.get('chain_id'),
  377. 'task_id': data.get('task_id'),
  378. 'message_id': data.get('message_id'),
  379. 'type': data.get('type'),
  380. 'input': data.get('input'),
  381. 'output': data.get('output'),
  382. 'created_at': int(time.time())
  383. }
  384. if data.get('mode') == 'chat':
  385. response_data['conversation_id'] = data.get('conversation_id')
  386. return response_data
  387. @classmethod
  388. def get_agent_thought_response_data(cls, data: dict):
  389. response_data = {
  390. 'event': 'agent_thought',
  391. 'id': data.get('agent_thought_id'),
  392. 'chain_id': data.get('chain_id'),
  393. 'task_id': data.get('task_id'),
  394. 'message_id': data.get('message_id'),
  395. 'position': data.get('position'),
  396. 'thought': data.get('thought'),
  397. 'tool': data.get('tool'), # todo use real dataset obj replace it
  398. 'tool_input': data.get('tool_input'),
  399. 'observation': data.get('observation'),
  400. 'answer': data.get('answer') if not data.get('thought') else '',
  401. 'created_at': int(time.time())
  402. }
  403. if data.get('mode') == 'chat':
  404. response_data['conversation_id'] = data.get('conversation_id')
  405. return response_data
  406. @classmethod
  407. def handle_error(cls, result: dict):
  408. logging.debug("error: %s", result)
  409. error = result.get('error')
  410. description = result.get('description')
  411. # handle errors
  412. llm_errors = {
  413. 'LLMBadRequestError': LLMBadRequestError,
  414. 'LLMAPIConnectionError': LLMAPIConnectionError,
  415. 'LLMAPIUnavailableError': LLMAPIUnavailableError,
  416. 'LLMRateLimitError': LLMRateLimitError,
  417. 'ProviderTokenNotInitError': ProviderTokenNotInitError,
  418. 'QuotaExceededError': QuotaExceededError,
  419. 'ModelCurrentlyNotSupportError': ModelCurrentlyNotSupportError
  420. }
  421. if error in llm_errors:
  422. raise llm_errors[error](description)
  423. elif error == 'LLMAuthorizationError':
  424. raise LLMAuthorizationError('Incorrect API key provided')
  425. else:
  426. raise Exception(description)