completion_service.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621
  1. import json
  2. import logging
  3. import threading
  4. import time
  5. import uuid
  6. from typing import Generator, Union, Any, Optional, List
  7. from flask import current_app, Flask
  8. from redis.client import PubSub
  9. from sqlalchemy import and_
  10. from core.completion import Completion
  11. from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
  12. ConversationTaskInterruptException
  13. from core.file.message_file_parser import MessageFileParser
  14. from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
  15. LLMRateLimitError, \
  16. LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
  17. from core.model_providers.models.entity.message import PromptMessageFile
  18. from extensions.ext_database import db
  19. from extensions.ext_redis import redis_client
  20. from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
  21. from services.app_model_config_service import AppModelConfigService
  22. from services.errors.app import MoreLikeThisDisabledError
  23. from services.errors.app_model_config import AppModelConfigBrokenError
  24. from services.errors.completion import CompletionStoppedError
  25. from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
  26. from services.errors.message import MessageNotExistsError
  27. class CompletionService:
  28. @classmethod
  29. def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any,
  30. from_source: str, streaming: bool = True,
  31. is_model_config_override: bool = False) -> Union[dict, Generator]:
  32. # is streaming mode
  33. inputs = args['inputs']
  34. query = args['query']
  35. files = args['files'] if 'files' in args and args['files'] else []
  36. auto_generate_name = args['auto_generate_name'] \
  37. if 'auto_generate_name' in args else True
  38. if app_model.mode != 'completion' and not query:
  39. raise ValueError('query is required')
  40. query = query.replace('\x00', '')
  41. conversation_id = args['conversation_id'] if 'conversation_id' in args else None
  42. conversation = None
  43. if conversation_id:
  44. conversation_filter = [
  45. Conversation.id == args['conversation_id'],
  46. Conversation.app_id == app_model.id,
  47. Conversation.status == 'normal'
  48. ]
  49. if from_source == 'console':
  50. conversation_filter.append(Conversation.from_account_id == user.id)
  51. else:
  52. conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
  53. conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first()
  54. if not conversation:
  55. raise ConversationNotExistsError()
  56. if conversation.status != 'normal':
  57. raise ConversationCompletedError()
  58. if not conversation.override_model_configs:
  59. app_model_config = db.session.query(AppModelConfig).filter(
  60. AppModelConfig.id == conversation.app_model_config_id,
  61. AppModelConfig.app_id == app_model.id
  62. ).first()
  63. if not app_model_config:
  64. raise AppModelConfigBrokenError()
  65. else:
  66. conversation_override_model_configs = json.loads(conversation.override_model_configs)
  67. app_model_config = AppModelConfig(
  68. id=conversation.app_model_config_id,
  69. app_id=app_model.id,
  70. )
  71. app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
  72. if is_model_config_override:
  73. # build new app model config
  74. if 'model' not in args['model_config']:
  75. raise ValueError('model_config.model is required')
  76. if 'completion_params' not in args['model_config']['model']:
  77. raise ValueError('model_config.model.completion_params is required')
  78. completion_params = AppModelConfigService.validate_model_completion_params(
  79. cp=args['model_config']['model']['completion_params'],
  80. model_name=app_model_config.model_dict["name"]
  81. )
  82. app_model_config_model = app_model_config.model_dict
  83. app_model_config_model['completion_params'] = completion_params
  84. app_model_config.retriever_resource = json.dumps({'enabled': True})
  85. app_model_config = app_model_config.copy()
  86. app_model_config.model = json.dumps(app_model_config_model)
  87. else:
  88. if app_model.app_model_config_id is None:
  89. raise AppModelConfigBrokenError()
  90. app_model_config = app_model.app_model_config
  91. if not app_model_config:
  92. raise AppModelConfigBrokenError()
  93. if is_model_config_override:
  94. if not isinstance(user, Account):
  95. raise Exception("Only account can override model config")
  96. # validate config
  97. model_config = AppModelConfigService.validate_configuration(
  98. tenant_id=app_model.tenant_id,
  99. account=user,
  100. config=args['model_config'],
  101. mode=app_model.mode
  102. )
  103. app_model_config = AppModelConfig(
  104. id=app_model_config.id,
  105. app_id=app_model.id,
  106. )
  107. app_model_config = app_model_config.from_model_config_dict(model_config)
  108. # clean input by app_model_config form rules
  109. inputs = cls.get_cleaned_inputs(inputs, app_model_config)
  110. # parse files
  111. message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
  112. file_objs = message_file_parser.validate_and_transform_files_arg(
  113. files,
  114. app_model_config,
  115. user
  116. )
  117. generate_task_id = str(uuid.uuid4())
  118. pubsub = redis_client.pubsub()
  119. pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
  120. user = cls.get_real_user_instead_of_proxy_obj(user)
  121. generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
  122. 'flask_app': current_app._get_current_object(),
  123. 'generate_task_id': generate_task_id,
  124. 'detached_app_model': app_model,
  125. 'app_model_config': app_model_config.copy(),
  126. 'query': query,
  127. 'inputs': inputs,
  128. 'files': file_objs,
  129. 'detached_user': user,
  130. 'detached_conversation': conversation,
  131. 'streaming': streaming,
  132. 'is_model_config_override': is_model_config_override,
  133. 'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
  134. 'auto_generate_name': auto_generate_name,
  135. 'from_source': from_source
  136. })
  137. generate_worker_thread.start()
  138. # wait for 10 minutes to close the thread
  139. cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
  140. generate_task_id)
  141. return cls.compact_response(pubsub, streaming)
  142. @classmethod
  143. def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]):
  144. if isinstance(user, Account):
  145. user = db.session.query(Account).filter(Account.id == user.id).first()
  146. elif isinstance(user, EndUser):
  147. user = db.session.query(EndUser).filter(EndUser.id == user.id).first()
  148. else:
  149. raise Exception("Unknown user type")
  150. return user
  151. @classmethod
  152. def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
  153. app_model_config: AppModelConfig,
  154. query: str, inputs: dict, files: List[PromptMessageFile],
  155. detached_user: Union[Account, EndUser],
  156. detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
  157. retriever_from: str = 'dev', auto_generate_name: bool = True, from_source: str = 'console'):
  158. with flask_app.app_context():
  159. # fixed the state of the model object when it detached from the original session
  160. user = db.session.merge(detached_user)
  161. app_model = db.session.merge(detached_app_model)
  162. if detached_conversation:
  163. conversation = db.session.merge(detached_conversation)
  164. else:
  165. conversation = None
  166. try:
  167. # run
  168. Completion.generate(
  169. task_id=generate_task_id,
  170. app=app_model,
  171. app_model_config=app_model_config,
  172. query=query,
  173. inputs=inputs,
  174. user=user,
  175. files=files,
  176. conversation=conversation,
  177. streaming=streaming,
  178. is_override=is_model_config_override,
  179. retriever_from=retriever_from,
  180. auto_generate_name=auto_generate_name,
  181. from_source=from_source
  182. )
  183. except (ConversationTaskInterruptException, ConversationTaskStoppedException):
  184. pass
  185. except (ValueError, LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
  186. LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
  187. ModelCurrentlyNotSupportError) as e:
  188. PubHandler.pub_error(user, generate_task_id, e)
  189. except LLMAuthorizationError:
  190. PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
  191. except Exception as e:
  192. logging.exception("Unknown Error in completion")
  193. PubHandler.pub_error(user, generate_task_id, e)
  194. finally:
  195. db.session.remove()
  196. @classmethod
  197. def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user,
  198. generate_task_id) -> threading.Thread:
  199. # wait for 10 minutes to close the thread
  200. timeout = 600
  201. def close_pubsub():
  202. with flask_app.app_context():
  203. try:
  204. user = db.session.merge(detached_user)
  205. sleep_iterations = 0
  206. while sleep_iterations < timeout and worker_thread.is_alive():
  207. if sleep_iterations > 0 and sleep_iterations % 10 == 0:
  208. PubHandler.ping(user, generate_task_id)
  209. time.sleep(1)
  210. sleep_iterations += 1
  211. if worker_thread.is_alive():
  212. PubHandler.stop(user, generate_task_id)
  213. try:
  214. pubsub.close()
  215. except Exception:
  216. pass
  217. finally:
  218. db.session.remove()
  219. countdown_thread = threading.Thread(target=close_pubsub)
  220. countdown_thread.start()
  221. return countdown_thread
  222. @classmethod
  223. def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
  224. message_id: str, streaming: bool = True,
  225. retriever_from: str = 'dev') -> Union[dict, Generator]:
  226. if not user:
  227. raise ValueError('user cannot be None')
  228. message = db.session.query(Message).filter(
  229. Message.id == message_id,
  230. Message.app_id == app_model.id,
  231. Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
  232. Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
  233. Message.from_account_id == (user.id if isinstance(user, Account) else None),
  234. ).first()
  235. if not message:
  236. raise MessageNotExistsError()
  237. current_app_model_config = app_model.app_model_config
  238. more_like_this = current_app_model_config.more_like_this_dict
  239. if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
  240. raise MoreLikeThisDisabledError()
  241. app_model_config = message.app_model_config
  242. model_dict = app_model_config.model_dict
  243. completion_params = model_dict.get('completion_params')
  244. completion_params['temperature'] = 0.9
  245. model_dict['completion_params'] = completion_params
  246. app_model_config.model = json.dumps(model_dict)
  247. # parse files
  248. message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
  249. file_objs = message_file_parser.transform_message_files(
  250. message.files, app_model_config
  251. )
  252. generate_task_id = str(uuid.uuid4())
  253. pubsub = redis_client.pubsub()
  254. pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
  255. user = cls.get_real_user_instead_of_proxy_obj(user)
  256. generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
  257. 'flask_app': current_app._get_current_object(),
  258. 'generate_task_id': generate_task_id,
  259. 'detached_app_model': app_model,
  260. 'app_model_config': app_model_config.copy(),
  261. 'query': message.query,
  262. 'inputs': message.inputs,
  263. 'files': file_objs,
  264. 'detached_user': user,
  265. 'detached_conversation': None,
  266. 'streaming': streaming,
  267. 'is_model_config_override': True,
  268. 'retriever_from': retriever_from,
  269. 'auto_generate_name': False
  270. })
  271. generate_worker_thread.start()
  272. # wait for 10 minutes to close the thread
  273. cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
  274. generate_task_id)
  275. return cls.compact_response(pubsub, streaming)
  276. @classmethod
  277. def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
  278. if user_inputs is None:
  279. user_inputs = {}
  280. filtered_inputs = {}
  281. # Filter input variables from form configuration, handle required fields, default values, and option values
  282. input_form_config = app_model_config.user_input_form_list
  283. for config in input_form_config:
  284. input_config = list(config.values())[0]
  285. variable = input_config["variable"]
  286. input_type = list(config.keys())[0]
  287. if variable not in user_inputs or not user_inputs[variable]:
  288. if "required" in input_config and input_config["required"]:
  289. raise ValueError(f"{variable} is required in input form")
  290. else:
  291. filtered_inputs[variable] = input_config["default"] if "default" in input_config else ""
  292. continue
  293. value = user_inputs[variable]
  294. if input_type == "select":
  295. options = input_config["options"] if "options" in input_config else []
  296. if value not in options:
  297. raise ValueError(f"{variable} in input form must be one of the following: {options}")
  298. else:
  299. if 'max_length' in input_config:
  300. max_length = input_config['max_length']
  301. if len(value) > max_length:
  302. raise ValueError(f'{variable} in input form must be less than {max_length} characters')
  303. filtered_inputs[variable] = value.replace('\x00', '') if value else None
  304. return filtered_inputs
  305. @classmethod
  306. def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict, Generator]:
  307. generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
  308. if not streaming:
  309. try:
  310. message_result = {}
  311. for message in pubsub.listen():
  312. if message["type"] == "message":
  313. result = message["data"].decode('utf-8')
  314. result = json.loads(result)
  315. if result.get('error'):
  316. cls.handle_error(result)
  317. if result['event'] == 'annotation' and 'data' in result:
  318. message_result['annotation'] = result.get('data')
  319. return cls.get_blocking_annotation_message_response_data(message_result)
  320. if result['event'] == 'message' and 'data' in result:
  321. message_result['message'] = result.get('data')
  322. if result['event'] == 'message_end' and 'data' in result:
  323. message_result['message_end'] = result.get('data')
  324. return cls.get_blocking_message_response_data(message_result)
  325. except ValueError as e:
  326. if e.args[0] != "I/O operation on closed file.": # ignore this error
  327. raise CompletionStoppedError()
  328. else:
  329. logging.exception(e)
  330. raise
  331. finally:
  332. db.session.remove()
  333. try:
  334. pubsub.unsubscribe(generate_channel)
  335. except ConnectionError:
  336. pass
  337. else:
  338. def generate() -> Generator:
  339. try:
  340. for message in pubsub.listen():
  341. if message["type"] == "message":
  342. result = message["data"].decode('utf-8')
  343. result = json.loads(result)
  344. if result.get('error'):
  345. cls.handle_error(result)
  346. event = result.get('event')
  347. if event == "end":
  348. logging.debug("{} finished".format(generate_channel))
  349. break
  350. if event == 'message':
  351. yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
  352. elif event == 'message_replace':
  353. yield "data: " + json.dumps(
  354. cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
  355. elif event == 'chain':
  356. yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
  357. elif event == 'agent_thought':
  358. yield "data: " + json.dumps(
  359. cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
  360. elif event == 'annotation':
  361. yield "data: " + json.dumps(
  362. cls.get_annotation_response_data(result.get('data'))) + "\n\n"
  363. elif event == 'message_end':
  364. yield "data: " + json.dumps(
  365. cls.get_message_end_data(result.get('data'))) + "\n\n"
  366. elif event == 'ping':
  367. yield "event: ping\n\n"
  368. else:
  369. yield "data: " + json.dumps(result) + "\n\n"
  370. except ValueError as e:
  371. if e.args[0] != "I/O operation on closed file.": # ignore this error
  372. logging.exception(e)
  373. raise
  374. finally:
  375. db.session.remove()
  376. try:
  377. pubsub.unsubscribe(generate_channel)
  378. except ConnectionError:
  379. pass
  380. return generate()
  381. @classmethod
  382. def get_message_response_data(cls, data: dict):
  383. response_data = {
  384. 'event': 'message',
  385. 'task_id': data.get('task_id'),
  386. 'id': data.get('message_id'),
  387. 'answer': data.get('text'),
  388. 'created_at': int(time.time())
  389. }
  390. if data.get('mode') == 'chat':
  391. response_data['conversation_id'] = data.get('conversation_id')
  392. return response_data
  393. @classmethod
  394. def get_message_replace_response_data(cls, data: dict):
  395. response_data = {
  396. 'event': 'message_replace',
  397. 'task_id': data.get('task_id'),
  398. 'id': data.get('message_id'),
  399. 'answer': data.get('text'),
  400. 'created_at': int(time.time())
  401. }
  402. if data.get('mode') == 'chat':
  403. response_data['conversation_id'] = data.get('conversation_id')
  404. return response_data
  405. @classmethod
  406. def get_blocking_message_response_data(cls, data: dict):
  407. message = data.get('message')
  408. response_data = {
  409. 'event': 'message',
  410. 'task_id': message.get('task_id'),
  411. 'id': message.get('message_id'),
  412. 'answer': message.get('text'),
  413. 'metadata': {},
  414. 'created_at': int(time.time())
  415. }
  416. if message.get('mode') == 'chat':
  417. response_data['conversation_id'] = message.get('conversation_id')
  418. if 'message_end' in data:
  419. message_end = data.get('message_end')
  420. if 'retriever_resources' in message_end:
  421. response_data['metadata']['retriever_resources'] = message_end.get('retriever_resources')
  422. return response_data
  423. @classmethod
  424. def get_blocking_annotation_message_response_data(cls, data: dict):
  425. message = data.get('annotation')
  426. response_data = {
  427. 'event': 'annotation',
  428. 'task_id': message.get('task_id'),
  429. 'id': message.get('message_id'),
  430. 'answer': message.get('text'),
  431. 'metadata': {},
  432. 'created_at': int(time.time()),
  433. 'annotation_id': message.get('annotation_id'),
  434. 'annotation_author_name': message.get('annotation_author_name')
  435. }
  436. if message.get('mode') == 'chat':
  437. response_data['conversation_id'] = message.get('conversation_id')
  438. return response_data
  439. @classmethod
  440. def get_message_end_data(cls, data: dict):
  441. response_data = {
  442. 'event': 'message_end',
  443. 'task_id': data.get('task_id'),
  444. 'id': data.get('message_id')
  445. }
  446. if 'retriever_resources' in data:
  447. response_data['retriever_resources'] = data.get('retriever_resources')
  448. if data.get('mode') == 'chat':
  449. response_data['conversation_id'] = data.get('conversation_id')
  450. return response_data
  451. @classmethod
  452. def get_chain_response_data(cls, data: dict):
  453. response_data = {
  454. 'event': 'chain',
  455. 'id': data.get('chain_id'),
  456. 'task_id': data.get('task_id'),
  457. 'message_id': data.get('message_id'),
  458. 'type': data.get('type'),
  459. 'input': data.get('input'),
  460. 'output': data.get('output'),
  461. 'created_at': int(time.time())
  462. }
  463. if data.get('mode') == 'chat':
  464. response_data['conversation_id'] = data.get('conversation_id')
  465. return response_data
  466. @classmethod
  467. def get_agent_thought_response_data(cls, data: dict):
  468. response_data = {
  469. 'event': 'agent_thought',
  470. 'id': data.get('id'),
  471. 'chain_id': data.get('chain_id'),
  472. 'task_id': data.get('task_id'),
  473. 'message_id': data.get('message_id'),
  474. 'position': data.get('position'),
  475. 'thought': data.get('thought'),
  476. 'tool': data.get('tool'),
  477. 'tool_input': data.get('tool_input'),
  478. 'created_at': int(time.time())
  479. }
  480. if data.get('mode') == 'chat':
  481. response_data['conversation_id'] = data.get('conversation_id')
  482. return response_data
  483. @classmethod
  484. def get_annotation_response_data(cls, data: dict):
  485. response_data = {
  486. 'event': 'annotation',
  487. 'task_id': data.get('task_id'),
  488. 'id': data.get('message_id'),
  489. 'answer': data.get('text'),
  490. 'created_at': int(time.time()),
  491. 'annotation_id': data.get('annotation_id'),
  492. 'annotation_author_name': data.get('annotation_author_name'),
  493. }
  494. if data.get('mode') == 'chat':
  495. response_data['conversation_id'] = data.get('conversation_id')
  496. return response_data
  497. @classmethod
  498. def handle_error(cls, result: dict):
  499. logging.debug("error: %s", result)
  500. error = result.get('error')
  501. description = result.get('description')
  502. # handle errors
  503. llm_errors = {
  504. 'ValueError': LLMBadRequestError,
  505. 'LLMBadRequestError': LLMBadRequestError,
  506. 'LLMAPIConnectionError': LLMAPIConnectionError,
  507. 'LLMAPIUnavailableError': LLMAPIUnavailableError,
  508. 'LLMRateLimitError': LLMRateLimitError,
  509. 'ProviderTokenNotInitError': ProviderTokenNotInitError,
  510. 'QuotaExceededError': QuotaExceededError,
  511. 'ModelCurrentlyNotSupportError': ModelCurrentlyNotSupportError
  512. }
  513. if error in llm_errors:
  514. raise llm_errors[error](description)
  515. elif error == 'LLMAuthorizationError':
  516. raise LLMAuthorizationError('Incorrect API key provided')
  517. else:
  518. raise Exception(description)