completion_service.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import json
  2. from collections.abc import Generator
  3. from typing import Any, Union
  4. from sqlalchemy import and_
  5. from core.application_manager import ApplicationManager
  6. from core.entities.application_entities import InvokeFrom
  7. from core.file.message_file_parser import MessageFileParser
  8. from extensions.ext_database import db
  9. from models.model import Account, App, AppModelConfig, Conversation, EndUser, Message
  10. from services.app_model_config_service import AppModelConfigService
  11. from services.errors.app import MoreLikeThisDisabledError
  12. from services.errors.app_model_config import AppModelConfigBrokenError
  13. from services.errors.conversation import ConversationCompletedError, ConversationNotExistsError
  14. from services.errors.message import MessageNotExistsError
  15. class CompletionService:
  16. @classmethod
  17. def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any,
  18. invoke_from: InvokeFrom, streaming: bool = True,
  19. is_model_config_override: bool = False) -> Union[dict, Generator]:
  20. # is streaming mode
  21. inputs = args['inputs']
  22. query = args['query']
  23. files = args['files'] if 'files' in args and args['files'] else []
  24. auto_generate_name = args['auto_generate_name'] \
  25. if 'auto_generate_name' in args else True
  26. if app_model.mode != 'completion':
  27. if not query:
  28. raise ValueError('query is required')
  29. if query:
  30. if not isinstance(query, str):
  31. raise ValueError('query must be a string')
  32. query = query.replace('\x00', '')
  33. conversation_id = args['conversation_id'] if 'conversation_id' in args else None
  34. conversation = None
  35. if conversation_id:
  36. conversation_filter = [
  37. Conversation.id == args['conversation_id'],
  38. Conversation.app_id == app_model.id,
  39. Conversation.status == 'normal'
  40. ]
  41. if isinstance(user, Account):
  42. conversation_filter.append(Conversation.from_account_id == user.id)
  43. else:
  44. conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
  45. conversation = db.session.query(Conversation).filter(and_(*conversation_filter)).first()
  46. if not conversation:
  47. raise ConversationNotExistsError()
  48. if conversation.status != 'normal':
  49. raise ConversationCompletedError()
  50. if not conversation.override_model_configs:
  51. app_model_config = db.session.query(AppModelConfig).filter(
  52. AppModelConfig.id == conversation.app_model_config_id,
  53. AppModelConfig.app_id == app_model.id
  54. ).first()
  55. if not app_model_config:
  56. raise AppModelConfigBrokenError()
  57. else:
  58. conversation_override_model_configs = json.loads(conversation.override_model_configs)
  59. app_model_config = AppModelConfig(
  60. id=conversation.app_model_config_id,
  61. app_id=app_model.id,
  62. )
  63. app_model_config = app_model_config.from_model_config_dict(conversation_override_model_configs)
  64. if is_model_config_override:
  65. # build new app model config
  66. if 'model' not in args['model_config']:
  67. raise ValueError('model_config.model is required')
  68. if 'completion_params' not in args['model_config']['model']:
  69. raise ValueError('model_config.model.completion_params is required')
  70. completion_params = AppModelConfigService.validate_model_completion_params(
  71. cp=args['model_config']['model']['completion_params'],
  72. model_name=app_model_config.model_dict["name"]
  73. )
  74. app_model_config_model = app_model_config.model_dict
  75. app_model_config_model['completion_params'] = completion_params
  76. app_model_config.retriever_resource = json.dumps({'enabled': True})
  77. app_model_config = app_model_config.copy()
  78. app_model_config.model = json.dumps(app_model_config_model)
  79. else:
  80. if app_model.app_model_config_id is None:
  81. raise AppModelConfigBrokenError()
  82. app_model_config = app_model.app_model_config
  83. if not app_model_config:
  84. raise AppModelConfigBrokenError()
  85. if is_model_config_override:
  86. if not isinstance(user, Account):
  87. raise Exception("Only account can override model config")
  88. # validate config
  89. model_config = AppModelConfigService.validate_configuration(
  90. tenant_id=app_model.tenant_id,
  91. account=user,
  92. config=args['model_config'],
  93. app_mode=app_model.mode
  94. )
  95. app_model_config = AppModelConfig(
  96. id=app_model_config.id,
  97. app_id=app_model.id,
  98. )
  99. app_model_config = app_model_config.from_model_config_dict(model_config)
  100. # clean input by app_model_config form rules
  101. inputs = cls.get_cleaned_inputs(inputs, app_model_config)
  102. # parse files
  103. message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
  104. file_objs = message_file_parser.validate_and_transform_files_arg(
  105. files,
  106. app_model_config,
  107. user
  108. )
  109. application_manager = ApplicationManager()
  110. return application_manager.generate(
  111. tenant_id=app_model.tenant_id,
  112. app_id=app_model.id,
  113. app_model_config_id=app_model_config.id,
  114. app_model_config_dict=app_model_config.to_dict(),
  115. app_model_config_override=is_model_config_override,
  116. user=user,
  117. invoke_from=invoke_from,
  118. inputs=inputs,
  119. query=query,
  120. files=file_objs,
  121. conversation=conversation,
  122. stream=streaming,
  123. extras={
  124. "auto_generate_conversation_name": auto_generate_name
  125. }
  126. )
  127. @classmethod
  128. def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
  129. message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
  130. -> Union[dict, Generator]:
  131. if not user:
  132. raise ValueError('user cannot be None')
  133. message = db.session.query(Message).filter(
  134. Message.id == message_id,
  135. Message.app_id == app_model.id,
  136. Message.from_source == ('api' if isinstance(user, EndUser) else 'console'),
  137. Message.from_end_user_id == (user.id if isinstance(user, EndUser) else None),
  138. Message.from_account_id == (user.id if isinstance(user, Account) else None),
  139. ).first()
  140. if not message:
  141. raise MessageNotExistsError()
  142. current_app_model_config = app_model.app_model_config
  143. more_like_this = current_app_model_config.more_like_this_dict
  144. if not current_app_model_config.more_like_this or more_like_this.get("enabled", False) is False:
  145. raise MoreLikeThisDisabledError()
  146. app_model_config = message.app_model_config
  147. model_dict = app_model_config.model_dict
  148. completion_params = model_dict.get('completion_params')
  149. completion_params['temperature'] = 0.9
  150. model_dict['completion_params'] = completion_params
  151. app_model_config.model = json.dumps(model_dict)
  152. # parse files
  153. message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
  154. file_objs = message_file_parser.transform_message_files(
  155. message.files, app_model_config
  156. )
  157. application_manager = ApplicationManager()
  158. return application_manager.generate(
  159. tenant_id=app_model.tenant_id,
  160. app_id=app_model.id,
  161. app_model_config_id=app_model_config.id,
  162. app_model_config_dict=app_model_config.to_dict(),
  163. app_model_config_override=True,
  164. user=user,
  165. invoke_from=invoke_from,
  166. inputs=message.inputs,
  167. query=message.query,
  168. files=file_objs,
  169. conversation=None,
  170. stream=streaming,
  171. extras={
  172. "auto_generate_conversation_name": False
  173. }
  174. )
  175. @classmethod
  176. def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
  177. if user_inputs is None:
  178. user_inputs = {}
  179. filtered_inputs = {}
  180. # Filter input variables from form configuration, handle required fields, default values, and option values
  181. input_form_config = app_model_config.user_input_form_list
  182. for config in input_form_config:
  183. input_config = list(config.values())[0]
  184. variable = input_config["variable"]
  185. input_type = list(config.keys())[0]
  186. if variable not in user_inputs or not user_inputs[variable]:
  187. if input_type == "external_data_tool":
  188. continue
  189. if "required" in input_config and input_config["required"]:
  190. raise ValueError(f"{variable} is required in input form")
  191. else:
  192. filtered_inputs[variable] = input_config["default"] if "default" in input_config else ""
  193. continue
  194. value = user_inputs[variable]
  195. if value:
  196. if not isinstance(value, str):
  197. raise ValueError(f"{variable} in input form must be a string")
  198. if input_type == "select":
  199. options = input_config["options"] if "options" in input_config else []
  200. if value not in options:
  201. raise ValueError(f"{variable} in input form must be one of the following: {options}")
  202. else:
  203. if 'max_length' in input_config:
  204. max_length = input_config['max_length']
  205. if len(value) > max_length:
  206. raise ValueError(f'{variable} in input form must be less than {max_length} characters')
  207. filtered_inputs[variable] = value.replace('\x00', '') if value else None
  208. return filtered_inputs