completion_service.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. import json
  2. import logging
  3. import threading
  4. import time
  5. import uuid
  6. from typing import Generator, Union, Any
  7. from flask import current_app, Flask
  8. from redis.client import PubSub
  9. from sqlalchemy import and_
  10. from core.completion import Completion
  11. from core.conversation_message_task import PubHandler, ConversationTaskStoppedException
  12. from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \
  13. LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
  14. from extensions.ext_database import db
  15. from extensions.ext_redis import redis_client
  16. from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
  17. from services.app_model_config_service import AppModelConfigService
  18. from services.errors.app import MoreLikeThisDisabledError
  19. from services.errors.app_model_config import AppModelConfigBrokenError
  20. from services.errors.completion import CompletionStoppedError
  21. from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
  22. from services.errors.message import MessageNotExistsError
  23. class CompletionService:
  24. @classmethod
  25. def completion(cls, app_model: App, user: Union[Account | EndUser], args: Any,
  26. from_source: str, streaming: bool = True,
  27. is_model_config_override: bool = False) -> Union[dict | Generator]:
  28. # is streaming mode
  29. inputs = args['inputs']
  30. query = args['query']
  31. if not query:
  32. raise ValueError('query is required')
  33. query = query.replace('\x00', '')
  34. conversation_id = args['conversation_id'] if 'conversation_id' in args else None
  35. conversation = None
  36. if conversation_id:
  37. conversation_filter = [
  38. Conversation.id == args['conversation_id'],
  39. Conversation.app_id == app_model.id,
  40. Conversation.status == 'normal'
  41. ]
  42. if from_source == 'console':
  43. conversation_filter.append(Conversation.from_account_id == user.id)
  44. else:
  45. conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
  46. conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first()
  47. if not conversation:
  48. raise ConversationNotExistsError()
  49. if conversation.status != 'normal':
  50. raise ConversationCompletedError()
  51. if not conversation.override_model_configs:
  52. app_model_config = db.session.query(AppModelConfig).filter(
  53. AppModelConfig.id == conversation.app_model_config_id,
  54. AppModelConfig.app_id == app_model.id
  55. ).first()
  56. if not app_model_config:
  57. raise AppModelConfigBrokenError()
  58. else:
  59. conversation_override_model_configs = json.loads(conversation.override_model_configs)
  60. app_model_config = AppModelConfig(
  61. id=conversation.app_model_config_id,
  62. app_id=app_model.id,
  63. )
  64. app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
  65. if is_model_config_override:
  66. # build new app model config
  67. if 'model' not in args['model_config']:
  68. raise ValueError('model_config.model is required')
  69. if 'completion_params' not in args['model_config']['model']:
  70. raise ValueError('model_config.model.completion_params is required')
  71. completion_params = AppModelConfigService.validate_model_completion_params(
  72. cp=args['model_config']['model']['completion_params'],
  73. model_name=app_model_config.model_dict["name"]
  74. )
  75. app_model_config_model = app_model_config.model_dict
  76. app_model_config_model['completion_params'] = completion_params
  77. app_model_config = app_model_config.copy()
  78. app_model_config.model = json.dumps(app_model_config_model)
  79. else:
  80. if app_model.app_model_config_id is None:
  81. raise AppModelConfigBrokenError()
  82. app_model_config = app_model.app_model_config
  83. if not app_model_config:
  84. raise AppModelConfigBrokenError()
  85. if is_model_config_override:
  86. if not isinstance(user, Account):
  87. raise Exception("Only account can override model config")
  88. # validate config
  89. model_config = AppModelConfigService.validate_configuration(
  90. tenant_id=app_model.tenant_id,
  91. account=user,
  92. config=args['model_config']
  93. )
  94. app_model_config = AppModelConfig(
  95. id=app_model_config.id,
  96. app_id=app_model.id,
  97. )
  98. app_model_config = app_model_config.from_model_config_dict(model_config)
  99. # clean input by app_model_config form rules
  100. inputs = cls.get_cleaned_inputs(inputs, app_model_config)
  101. generate_task_id = str(uuid.uuid4())
  102. pubsub = redis_client.pubsub()
  103. pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
  104. user = cls.get_real_user_instead_of_proxy_obj(user)
  105. generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
  106. 'flask_app': current_app._get_current_object(),
  107. 'generate_task_id': generate_task_id,
  108. 'app_model': app_model,
  109. 'app_model_config': app_model_config,
  110. 'query': query,
  111. 'inputs': inputs,
  112. 'user': user,
  113. 'conversation': conversation,
  114. 'streaming': streaming,
  115. 'is_model_config_override': is_model_config_override
  116. })
  117. generate_worker_thread.start()
  118. # wait for 10 minutes to close the thread
  119. cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
  120. return cls.compact_response(pubsub, streaming)
  121. @classmethod
  122. def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]):
  123. if isinstance(user, Account):
  124. user = db.session.query(Account).filter(Account.id == user.id).first()
  125. elif isinstance(user, EndUser):
  126. user = db.session.query(EndUser).filter(EndUser.id == user.id).first()
  127. else:
  128. raise Exception("Unknown user type")
  129. return user
  130. @classmethod
  131. def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig,
  132. query: str, inputs: dict, user: Union[Account, EndUser],
  133. conversation: Conversation, streaming: bool, is_model_config_override: bool):
  134. with flask_app.app_context():
  135. try:
  136. if conversation:
  137. # fixed the state of the conversation object when it detached from the original session
  138. conversation = db.session.query(Conversation).filter_by(id=conversation.id).first()
  139. # run
  140. Completion.generate(
  141. task_id=generate_task_id,
  142. app=app_model,
  143. app_model_config=app_model_config,
  144. query=query,
  145. inputs=inputs,
  146. user=user,
  147. conversation=conversation,
  148. streaming=streaming,
  149. is_override=is_model_config_override,
  150. )
  151. except ConversationTaskStoppedException:
  152. pass
  153. except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
  154. LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
  155. ModelCurrentlyNotSupportError) as e:
  156. db.session.rollback()
  157. PubHandler.pub_error(user, generate_task_id, e)
  158. except LLMAuthorizationError:
  159. db.session.rollback()
  160. PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
  161. except Exception as e:
  162. db.session.rollback()
  163. logging.exception("Unknown Error in completion")
  164. PubHandler.pub_error(user, generate_task_id, e)
  165. @classmethod
  166. def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
  167. # wait for 10 minutes to close the thread
  168. timeout = 600
  169. def close_pubsub():
  170. sleep_iterations = 0
  171. while sleep_iterations < timeout and worker_thread.is_alive():
  172. if sleep_iterations > 0 and sleep_iterations % 10 == 0:
  173. PubHandler.ping(user, generate_task_id)
  174. time.sleep(1)
  175. sleep_iterations += 1
  176. if worker_thread.is_alive():
  177. PubHandler.stop(user, generate_task_id)
  178. try:
  179. pubsub.close()
  180. except:
  181. pass
  182. countdown_thread = threading.Thread(target=close_pubsub)
  183. countdown_thread.start()
  184. return countdown_thread
  185. @classmethod
  186. def generate_more_like_this(cls, app_model: App, user: Union[Account | EndUser],
  187. message_id: str, streaming: bool = True) -> Union[dict | Generator]:
  188. if not user:
  189. raise ValueError('user cannot be None')
  190. message = db.session.query(Message).filter(
  191. Message.id == message_id,
  192. Message.app_id == app_model.id,
  193. Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
  194. Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
  195. Message.from_account_id == (user.id if isinstance(user, Account) else None),
  196. ).first()
  197. if not message:
  198. raise MessageNotExistsError()
  199. current_app_model_config = app_model.app_model_config
  200. more_like_this = current_app_model_config.more_like_this_dict
  201. if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
  202. raise MoreLikeThisDisabledError()
  203. app_model_config = message.app_model_config
  204. if message.override_model_configs:
  205. override_model_configs = json.loads(message.override_model_configs)
  206. pre_prompt = override_model_configs.get("pre_prompt", '')
  207. elif app_model_config:
  208. pre_prompt = app_model_config.pre_prompt
  209. else:
  210. raise AppModelConfigBrokenError()
  211. generate_task_id = str(uuid.uuid4())
  212. pubsub = redis_client.pubsub()
  213. pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
  214. user = cls.get_real_user_instead_of_proxy_obj(user)
  215. generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={
  216. 'flask_app': current_app._get_current_object(),
  217. 'generate_task_id': generate_task_id,
  218. 'app_model': app_model,
  219. 'app_model_config': app_model_config,
  220. 'message': message,
  221. 'pre_prompt': pre_prompt,
  222. 'user': user,
  223. 'streaming': streaming
  224. })
  225. generate_worker_thread.start()
  226. cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
  227. return cls.compact_response(pubsub, streaming)
  228. @classmethod
  229. def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App,
  230. app_model_config: AppModelConfig, message: Message, pre_prompt: str,
  231. user: Union[Account, EndUser], streaming: bool):
  232. with flask_app.app_context():
  233. try:
  234. # run
  235. Completion.generate_more_like_this(
  236. task_id=generate_task_id,
  237. app=app_model,
  238. user=user,
  239. message=message,
  240. pre_prompt=pre_prompt,
  241. app_model_config=app_model_config,
  242. streaming=streaming
  243. )
  244. except ConversationTaskStoppedException:
  245. pass
  246. except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
  247. LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
  248. ModelCurrentlyNotSupportError) as e:
  249. db.session.rollback()
  250. PubHandler.pub_error(user, generate_task_id, e)
  251. except LLMAuthorizationError:
  252. db.session.rollback()
  253. PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
  254. except Exception as e:
  255. db.session.rollback()
  256. logging.exception("Unknown Error in completion")
  257. PubHandler.pub_error(user, generate_task_id, e)
  258. @classmethod
  259. def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
  260. if user_inputs is None:
  261. user_inputs = {}
  262. filtered_inputs = {}
  263. # Filter input variables from form configuration, handle required fields, default values, and option values
  264. input_form_config = app_model_config.user_input_form_list
  265. for config in input_form_config:
  266. input_config = list(config.values())[0]
  267. variable = input_config["variable"]
  268. input_type = list(config.keys())[0]
  269. if variable not in user_inputs or not user_inputs[variable]:
  270. if "required" in input_config and input_config["required"]:
  271. raise ValueError(f"{variable} is required in input form")
  272. else:
  273. filtered_inputs[variable] = input_config["default"] if "default" in input_config else ""
  274. continue
  275. value = user_inputs[variable]
  276. if input_type == "select":
  277. options = input_config["options"] if "options" in input_config else []
  278. if value not in options:
  279. raise ValueError(f"{variable} in input form must be one of the following: {options}")
  280. else:
  281. if 'max_length' in variable:
  282. max_length = variable['max_length']
  283. if len(value) > max_length:
  284. raise ValueError(f'{variable} in input form must be less than {max_length} characters')
  285. filtered_inputs[variable] = value.replace('\x00', '') if value else None
  286. return filtered_inputs
  287. @classmethod
  288. def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict | Generator]:
  289. generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
  290. if not streaming:
  291. try:
  292. for message in pubsub.listen():
  293. if message["type"] == "message":
  294. result = message["data"].decode('utf-8')
  295. result = json.loads(result)
  296. if result.get('error'):
  297. cls.handle_error(result)
  298. if 'data' in result:
  299. return cls.get_message_response_data(result.get('data'))
  300. except ValueError as e:
  301. if e.args[0] != "I/O operation on closed file.": # ignore this error
  302. raise CompletionStoppedError()
  303. else:
  304. logging.exception(e)
  305. raise
  306. finally:
  307. try:
  308. pubsub.unsubscribe(generate_channel)
  309. except ConnectionError:
  310. pass
  311. else:
  312. def generate() -> Generator:
  313. try:
  314. for message in pubsub.listen():
  315. if message["type"] == "message":
  316. result = message["data"].decode('utf-8')
  317. result = json.loads(result)
  318. if result.get('error'):
  319. cls.handle_error(result)
  320. event = result.get('event')
  321. if event == "end":
  322. logging.debug("{} finished".format(generate_channel))
  323. break
  324. if event == 'message':
  325. yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
  326. elif event == 'chain':
  327. yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
  328. elif event == 'agent_thought':
  329. yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
  330. elif event == 'ping':
  331. yield "event: ping\n\n"
  332. else:
  333. yield "data: " + json.dumps(result) + "\n\n"
  334. except ValueError as e:
  335. if e.args[0] != "I/O operation on closed file.": # ignore this error
  336. logging.exception(e)
  337. raise
  338. finally:
  339. try:
  340. pubsub.unsubscribe(generate_channel)
  341. except ConnectionError:
  342. pass
  343. return generate()
  344. @classmethod
  345. def get_message_response_data(cls, data: dict):
  346. response_data = {
  347. 'event': 'message',
  348. 'task_id': data.get('task_id'),
  349. 'id': data.get('message_id'),
  350. 'answer': data.get('text'),
  351. 'created_at': int(time.time())
  352. }
  353. if data.get('mode') == 'chat':
  354. response_data['conversation_id'] = data.get('conversation_id')
  355. return response_data
  356. @classmethod
  357. def get_chain_response_data(cls, data: dict):
  358. response_data = {
  359. 'event': 'chain',
  360. 'id': data.get('chain_id'),
  361. 'task_id': data.get('task_id'),
  362. 'message_id': data.get('message_id'),
  363. 'type': data.get('type'),
  364. 'input': data.get('input'),
  365. 'output': data.get('output'),
  366. 'created_at': int(time.time())
  367. }
  368. if data.get('mode') == 'chat':
  369. response_data['conversation_id'] = data.get('conversation_id')
  370. return response_data
  371. @classmethod
  372. def get_agent_thought_response_data(cls, data: dict):
  373. response_data = {
  374. 'event': 'agent_thought',
  375. 'id': data.get('id'),
  376. 'chain_id': data.get('chain_id'),
  377. 'task_id': data.get('task_id'),
  378. 'message_id': data.get('message_id'),
  379. 'position': data.get('position'),
  380. 'thought': data.get('thought'),
  381. 'tool': data.get('tool'),
  382. 'tool_input': data.get('tool_input'),
  383. 'created_at': int(time.time())
  384. }
  385. if data.get('mode') == 'chat':
  386. response_data['conversation_id'] = data.get('conversation_id')
  387. return response_data
  388. @classmethod
  389. def handle_error(cls, result: dict):
  390. logging.debug("error: %s", result)
  391. error = result.get('error')
  392. description = result.get('description')
  393. # handle errors
  394. llm_errors = {
  395. 'LLMBadRequestError': LLMBadRequestError,
  396. 'LLMAPIConnectionError': LLMAPIConnectionError,
  397. 'LLMAPIUnavailableError': LLMAPIUnavailableError,
  398. 'LLMRateLimitError': LLMRateLimitError,
  399. 'ProviderTokenNotInitError': ProviderTokenNotInitError,
  400. 'QuotaExceededError': QuotaExceededError,
  401. 'ModelCurrentlyNotSupportError': ModelCurrentlyNotSupportError
  402. }
  403. if error in llm_errors:
  404. raise llm_errors[error](description)
  405. elif error == 'LLMAuthorizationError':
  406. raise LLMAuthorizationError('Incorrect API key provided')
  407. else:
  408. raise Exception(description)