completion_service.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. import json
  2. import logging
  3. import threading
  4. import time
  5. import uuid
  6. from typing import Generator, Union, Any, Optional
  7. from flask import current_app, Flask
  8. from redis.client import PubSub
  9. from sqlalchemy import and_
  10. from core.completion import Completion
  11. from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
  12. ConversationTaskInterruptException
  13. from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
  14. LLMRateLimitError, \
  15. LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
  16. from extensions.ext_database import db
  17. from extensions.ext_redis import redis_client
  18. from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
  19. from services.app_model_config_service import AppModelConfigService
  20. from services.errors.app import MoreLikeThisDisabledError
  21. from services.errors.app_model_config import AppModelConfigBrokenError
  22. from services.errors.completion import CompletionStoppedError
  23. from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
  24. from services.errors.message import MessageNotExistsError
  25. class CompletionService:
  26. @classmethod
  27. def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any,
  28. from_source: str, streaming: bool = True,
  29. is_model_config_override: bool = False) -> Union[dict, Generator]:
  30. # is streaming mode
  31. inputs = args['inputs']
  32. query = args['query']
  33. if app_model.mode != 'completion' and not query:
  34. raise ValueError('query is required')
  35. query = query.replace('\x00', '')
  36. conversation_id = args['conversation_id'] if 'conversation_id' in args else None
  37. conversation = None
  38. if conversation_id:
  39. conversation_filter = [
  40. Conversation.id == args['conversation_id'],
  41. Conversation.app_id == app_model.id,
  42. Conversation.status == 'normal'
  43. ]
  44. if from_source == 'console':
  45. conversation_filter.append(Conversation.from_account_id == user.id)
  46. else:
  47. conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
  48. conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first()
  49. if not conversation:
  50. raise ConversationNotExistsError()
  51. if conversation.status != 'normal':
  52. raise ConversationCompletedError()
  53. if not conversation.override_model_configs:
  54. app_model_config = db.session.query(AppModelConfig).filter(
  55. AppModelConfig.id == conversation.app_model_config_id,
  56. AppModelConfig.app_id == app_model.id
  57. ).first()
  58. if not app_model_config:
  59. raise AppModelConfigBrokenError()
  60. else:
  61. conversation_override_model_configs = json.loads(conversation.override_model_configs)
  62. app_model_config = AppModelConfig(
  63. id=conversation.app_model_config_id,
  64. app_id=app_model.id,
  65. )
  66. app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
  67. if is_model_config_override:
  68. # build new app model config
  69. if 'model' not in args['model_config']:
  70. raise ValueError('model_config.model is required')
  71. if 'completion_params' not in args['model_config']['model']:
  72. raise ValueError('model_config.model.completion_params is required')
  73. completion_params = AppModelConfigService.validate_model_completion_params(
  74. cp=args['model_config']['model']['completion_params'],
  75. model_name=app_model_config.model_dict["name"]
  76. )
  77. app_model_config_model = app_model_config.model_dict
  78. app_model_config_model['completion_params'] = completion_params
  79. app_model_config.retriever_resource = json.dumps({'enabled': True})
  80. app_model_config = app_model_config.copy()
  81. app_model_config.model = json.dumps(app_model_config_model)
  82. else:
  83. if app_model.app_model_config_id is None:
  84. raise AppModelConfigBrokenError()
  85. app_model_config = app_model.app_model_config
  86. if not app_model_config:
  87. raise AppModelConfigBrokenError()
  88. if is_model_config_override:
  89. if not isinstance(user, Account):
  90. raise Exception("Only account can override model config")
  91. # validate config
  92. model_config = AppModelConfigService.validate_configuration(
  93. tenant_id=app_model.tenant_id,
  94. account=user,
  95. config=args['model_config'],
  96. mode=app_model.mode
  97. )
  98. app_model_config = AppModelConfig(
  99. id=app_model_config.id,
  100. app_id=app_model.id,
  101. )
  102. app_model_config = app_model_config.from_model_config_dict(model_config)
  103. # clean input by app_model_config form rules
  104. inputs = cls.get_cleaned_inputs(inputs, app_model_config)
  105. generate_task_id = str(uuid.uuid4())
  106. pubsub = redis_client.pubsub()
  107. pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
  108. user = cls.get_real_user_instead_of_proxy_obj(user)
  109. generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
  110. 'flask_app': current_app._get_current_object(),
  111. 'generate_task_id': generate_task_id,
  112. 'detached_app_model': app_model,
  113. 'app_model_config': app_model_config.copy(),
  114. 'query': query,
  115. 'inputs': inputs,
  116. 'detached_user': user,
  117. 'detached_conversation': conversation,
  118. 'streaming': streaming,
  119. 'is_model_config_override': is_model_config_override,
  120. 'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
  121. })
  122. generate_worker_thread.start()
  123. # wait for 10 minutes to close the thread
  124. cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
  125. return cls.compact_response(pubsub, streaming)
  126. @classmethod
  127. def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]):
  128. if isinstance(user, Account):
  129. user = db.session.query(Account).filter(Account.id == user.id).first()
  130. elif isinstance(user, EndUser):
  131. user = db.session.query(EndUser).filter(EndUser.id == user.id).first()
  132. else:
  133. raise Exception("Unknown user type")
  134. return user
  135. @classmethod
  136. def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig,
  137. query: str, inputs: dict, detached_user: Union[Account, EndUser],
  138. detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
  139. retriever_from: str = 'dev'):
  140. with flask_app.app_context():
  141. # fixed the state of the model object when it detached from the original session
  142. user = db.session.merge(detached_user)
  143. app_model = db.session.merge(detached_app_model)
  144. if detached_conversation:
  145. conversation = db.session.merge(detached_conversation)
  146. else:
  147. conversation = None
  148. try:
  149. # run
  150. Completion.generate(
  151. task_id=generate_task_id,
  152. app=app_model,
  153. app_model_config=app_model_config,
  154. query=query,
  155. inputs=inputs,
  156. user=user,
  157. conversation=conversation,
  158. streaming=streaming,
  159. is_override=is_model_config_override,
  160. retriever_from=retriever_from
  161. )
  162. except (ConversationTaskInterruptException, ConversationTaskStoppedException):
  163. pass
  164. except (ValueError, LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
  165. LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
  166. ModelCurrentlyNotSupportError) as e:
  167. PubHandler.pub_error(user, generate_task_id, e)
  168. except LLMAuthorizationError:
  169. PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
  170. except Exception as e:
  171. logging.exception("Unknown Error in completion")
  172. PubHandler.pub_error(user, generate_task_id, e)
  173. finally:
  174. db.session.commit()
  175. @classmethod
  176. def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, generate_task_id) -> threading.Thread:
  177. # wait for 10 minutes to close the thread
  178. timeout = 600
  179. def close_pubsub():
  180. with flask_app.app_context():
  181. user = db.session.merge(detached_user)
  182. sleep_iterations = 0
  183. while sleep_iterations < timeout and worker_thread.is_alive():
  184. if sleep_iterations > 0 and sleep_iterations % 10 == 0:
  185. PubHandler.ping(user, generate_task_id)
  186. time.sleep(1)
  187. sleep_iterations += 1
  188. if worker_thread.is_alive():
  189. PubHandler.stop(user, generate_task_id)
  190. try:
  191. pubsub.close()
  192. except Exception:
  193. pass
  194. countdown_thread = threading.Thread(target=close_pubsub)
  195. countdown_thread.start()
  196. return countdown_thread
  197. @classmethod
  198. def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
  199. message_id: str, streaming: bool = True,
  200. retriever_from: str = 'dev') -> Union[dict, Generator]:
  201. if not user:
  202. raise ValueError('user cannot be None')
  203. message = db.session.query(Message).filter(
  204. Message.id == message_id,
  205. Message.app_id == app_model.id,
  206. Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
  207. Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
  208. Message.from_account_id == (user.id if isinstance(user, Account) else None),
  209. ).first()
  210. if not message:
  211. raise MessageNotExistsError()
  212. current_app_model_config = app_model.app_model_config
  213. more_like_this = current_app_model_config.more_like_this_dict
  214. if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
  215. raise MoreLikeThisDisabledError()
  216. app_model_config = message.app_model_config
  217. model_dict = app_model_config.model_dict
  218. completion_params = model_dict.get('completion_params')
  219. completion_params['temperature'] = 0.9
  220. model_dict['completion_params'] = completion_params
  221. app_model_config.model = json.dumps(model_dict)
  222. generate_task_id = str(uuid.uuid4())
  223. pubsub = redis_client.pubsub()
  224. pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
  225. user = cls.get_real_user_instead_of_proxy_obj(user)
  226. generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
  227. 'flask_app': current_app._get_current_object(),
  228. 'generate_task_id': generate_task_id,
  229. 'detached_app_model': app_model,
  230. 'app_model_config': app_model_config.copy(),
  231. 'query': message.query,
  232. 'inputs': message.inputs,
  233. 'detached_user': user,
  234. 'detached_conversation': None,
  235. 'streaming': streaming,
  236. 'is_model_config_override': True,
  237. 'retriever_from': retriever_from
  238. })
  239. generate_worker_thread.start()
  240. # wait for 10 minutes to close the thread
  241. cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
  242. generate_task_id)
  243. return cls.compact_response(pubsub, streaming)
  244. @classmethod
  245. def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
  246. if user_inputs is None:
  247. user_inputs = {}
  248. filtered_inputs = {}
  249. # Filter input variables from form configuration, handle required fields, default values, and option values
  250. input_form_config = app_model_config.user_input_form_list
  251. for config in input_form_config:
  252. input_config = list(config.values())[0]
  253. variable = input_config["variable"]
  254. input_type = list(config.keys())[0]
  255. if variable not in user_inputs or not user_inputs[variable]:
  256. if "required" in input_config and input_config["required"]:
  257. raise ValueError(f"{variable} is required in input form")
  258. else:
  259. filtered_inputs[variable] = input_config["default"] if "default" in input_config else ""
  260. continue
  261. value = user_inputs[variable]
  262. if input_type == "select":
  263. options = input_config["options"] if "options" in input_config else []
  264. if value not in options:
  265. raise ValueError(f"{variable} in input form must be one of the following: {options}")
  266. else:
  267. if 'max_length' in input_config:
  268. max_length = input_config['max_length']
  269. if len(value) > max_length:
  270. raise ValueError(f'{variable} in input form must be less than {max_length} characters')
  271. filtered_inputs[variable] = value.replace('\x00', '') if value else None
  272. return filtered_inputs
  273. @classmethod
  274. def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict, Generator]:
  275. generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
  276. if not streaming:
  277. try:
  278. message_result = {}
  279. for message in pubsub.listen():
  280. if message["type"] == "message":
  281. result = message["data"].decode('utf-8')
  282. result = json.loads(result)
  283. if result.get('error'):
  284. cls.handle_error(result)
  285. if result['event'] == 'message' and 'data' in result:
  286. message_result['message'] = result.get('data')
  287. if result['event'] == 'message_end' and 'data' in result:
  288. message_result['message_end'] = result.get('data')
  289. return cls.get_blocking_message_response_data(message_result)
  290. except ValueError as e:
  291. if e.args[0] != "I/O operation on closed file.": # ignore this error
  292. raise CompletionStoppedError()
  293. else:
  294. logging.exception(e)
  295. raise
  296. finally:
  297. db.session.commit()
  298. try:
  299. pubsub.unsubscribe(generate_channel)
  300. except ConnectionError:
  301. pass
  302. else:
  303. def generate() -> Generator:
  304. try:
  305. for message in pubsub.listen():
  306. if message["type"] == "message":
  307. result = message["data"].decode('utf-8')
  308. result = json.loads(result)
  309. if result.get('error'):
  310. cls.handle_error(result)
  311. event = result.get('event')
  312. if event == "end":
  313. logging.debug("{} finished".format(generate_channel))
  314. break
  315. if event == 'message':
  316. yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
  317. elif event == 'message_replace':
  318. yield "data: " + json.dumps(cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
  319. elif event == 'chain':
  320. yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
  321. elif event == 'agent_thought':
  322. yield "data: " + json.dumps(
  323. cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
  324. elif event == 'message_end':
  325. yield "data: " + json.dumps(
  326. cls.get_message_end_data(result.get('data'))) + "\n\n"
  327. elif event == 'ping':
  328. yield "event: ping\n\n"
  329. else:
  330. yield "data: " + json.dumps(result) + "\n\n"
  331. except ValueError as e:
  332. if e.args[0] != "I/O operation on closed file.": # ignore this error
  333. logging.exception(e)
  334. raise
  335. finally:
  336. db.session.commit()
  337. try:
  338. pubsub.unsubscribe(generate_channel)
  339. except ConnectionError:
  340. pass
  341. return generate()
  342. @classmethod
  343. def get_message_response_data(cls, data: dict):
  344. response_data = {
  345. 'event': 'message',
  346. 'task_id': data.get('task_id'),
  347. 'id': data.get('message_id'),
  348. 'answer': data.get('text'),
  349. 'created_at': int(time.time())
  350. }
  351. if data.get('mode') == 'chat':
  352. response_data['conversation_id'] = data.get('conversation_id')
  353. return response_data
  354. @classmethod
  355. def get_message_replace_response_data(cls, data: dict):
  356. response_data = {
  357. 'event': 'message_replace',
  358. 'task_id': data.get('task_id'),
  359. 'id': data.get('message_id'),
  360. 'answer': data.get('text'),
  361. 'created_at': int(time.time())
  362. }
  363. if data.get('mode') == 'chat':
  364. response_data['conversation_id'] = data.get('conversation_id')
  365. return response_data
  366. @classmethod
  367. def get_blocking_message_response_data(cls, data: dict):
  368. message = data.get('message')
  369. response_data = {
  370. 'event': 'message',
  371. 'task_id': message.get('task_id'),
  372. 'id': message.get('message_id'),
  373. 'answer': message.get('text'),
  374. 'metadata': {},
  375. 'created_at': int(time.time())
  376. }
  377. if message.get('mode') == 'chat':
  378. response_data['conversation_id'] = message.get('conversation_id')
  379. if 'message_end' in data:
  380. message_end = data.get('message_end')
  381. if 'retriever_resources' in message_end:
  382. response_data['metadata']['retriever_resources'] = message_end.get('retriever_resources')
  383. return response_data
  384. @classmethod
  385. def get_message_end_data(cls, data: dict):
  386. response_data = {
  387. 'event': 'message_end',
  388. 'task_id': data.get('task_id'),
  389. 'id': data.get('message_id')
  390. }
  391. if 'retriever_resources' in data:
  392. response_data['retriever_resources'] = data.get('retriever_resources')
  393. if data.get('mode') == 'chat':
  394. response_data['conversation_id'] = data.get('conversation_id')
  395. return response_data
  396. @classmethod
  397. def get_chain_response_data(cls, data: dict):
  398. response_data = {
  399. 'event': 'chain',
  400. 'id': data.get('chain_id'),
  401. 'task_id': data.get('task_id'),
  402. 'message_id': data.get('message_id'),
  403. 'type': data.get('type'),
  404. 'input': data.get('input'),
  405. 'output': data.get('output'),
  406. 'created_at': int(time.time())
  407. }
  408. if data.get('mode') == 'chat':
  409. response_data['conversation_id'] = data.get('conversation_id')
  410. return response_data
  411. @classmethod
  412. def get_agent_thought_response_data(cls, data: dict):
  413. response_data = {
  414. 'event': 'agent_thought',
  415. 'id': data.get('id'),
  416. 'chain_id': data.get('chain_id'),
  417. 'task_id': data.get('task_id'),
  418. 'message_id': data.get('message_id'),
  419. 'position': data.get('position'),
  420. 'thought': data.get('thought'),
  421. 'tool': data.get('tool'),
  422. 'tool_input': data.get('tool_input'),
  423. 'created_at': int(time.time())
  424. }
  425. if data.get('mode') == 'chat':
  426. response_data['conversation_id'] = data.get('conversation_id')
  427. return response_data
  428. @classmethod
  429. def handle_error(cls, result: dict):
  430. logging.debug("error: %s", result)
  431. error = result.get('error')
  432. description = result.get('description')
  433. # handle errors
  434. llm_errors = {
  435. 'ValueError': LLMBadRequestError,
  436. 'LLMBadRequestError': LLMBadRequestError,
  437. 'LLMAPIConnectionError': LLMAPIConnectionError,
  438. 'LLMAPIUnavailableError': LLMAPIUnavailableError,
  439. 'LLMRateLimitError': LLMRateLimitError,
  440. 'ProviderTokenNotInitError': ProviderTokenNotInitError,
  441. 'QuotaExceededError': QuotaExceededError,
  442. 'ModelCurrentlyNotSupportError': ModelCurrentlyNotSupportError
  443. }
  444. if error in llm_errors:
  445. raise llm_errors[error](description)
  446. elif error == 'LLMAuthorizationError':
  447. raise LLMAuthorizationError('Incorrect API key provided')
  448. else:
  449. raise Exception(description)