tag_service.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import uuid
  2. from typing import Optional
  3. from flask_login import current_user
  4. from sqlalchemy import func
  5. from werkzeug.exceptions import NotFound
  6. from extensions.ext_database import db
  7. from models.dataset import Dataset
  8. from models.model import App, Tag, TagBinding
  9. class TagService:
  10. @staticmethod
  11. def get_tags(tag_type: str, current_tenant_id: str, keyword: Optional[str] = None) -> list:
  12. query = (
  13. db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
  14. .outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
  15. .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
  16. )
  17. if keyword:
  18. query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
  19. query = query.group_by(Tag.id)
  20. results = query.order_by(Tag.created_at.desc()).all()
  21. return results
  22. @staticmethod
  23. def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
  24. tags = (
  25. db.session.query(Tag)
  26. .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
  27. .all()
  28. )
  29. if not tags:
  30. return []
  31. tag_ids = [tag.id for tag in tags]
  32. tag_bindings = (
  33. db.session.query(TagBinding.target_id)
  34. .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
  35. .all()
  36. )
  37. if not tag_bindings:
  38. return []
  39. results = [tag_binding.target_id for tag_binding in tag_bindings]
  40. return results
  41. @staticmethod
  42. def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list:
  43. tags = (
  44. db.session.query(Tag)
  45. .join(TagBinding, Tag.id == TagBinding.tag_id)
  46. .filter(
  47. TagBinding.target_id == target_id,
  48. TagBinding.tenant_id == current_tenant_id,
  49. Tag.tenant_id == current_tenant_id,
  50. Tag.type == tag_type,
  51. )
  52. .all()
  53. )
  54. return tags or []
  55. @staticmethod
  56. def save_tags(args: dict) -> Tag:
  57. tag = Tag(
  58. id=str(uuid.uuid4()),
  59. name=args["name"],
  60. type=args["type"],
  61. created_by=current_user.id,
  62. tenant_id=current_user.current_tenant_id,
  63. )
  64. db.session.add(tag)
  65. db.session.commit()
  66. return tag
  67. @staticmethod
  68. def update_tags(args: dict, tag_id: str) -> Tag:
  69. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  70. if not tag:
  71. raise NotFound("Tag not found")
  72. tag.name = args["name"]
  73. db.session.commit()
  74. return tag
  75. @staticmethod
  76. def get_tag_binding_count(tag_id: str) -> int:
  77. count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
  78. return count
  79. @staticmethod
  80. def delete_tag(tag_id: str):
  81. tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
  82. if not tag:
  83. raise NotFound("Tag not found")
  84. db.session.delete(tag)
  85. # delete tag binding
  86. tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
  87. if tag_bindings:
  88. for tag_binding in tag_bindings:
  89. db.session.delete(tag_binding)
  90. db.session.commit()
  91. @staticmethod
  92. def save_tag_binding(args):
  93. # check if target exists
  94. TagService.check_target_exists(args["type"], args["target_id"])
  95. # save tag binding
  96. for tag_id in args["tag_ids"]:
  97. tag_binding = (
  98. db.session.query(TagBinding)
  99. .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
  100. .first()
  101. )
  102. if tag_binding:
  103. continue
  104. new_tag_binding = TagBinding(
  105. tag_id=tag_id,
  106. target_id=args["target_id"],
  107. tenant_id=current_user.current_tenant_id,
  108. created_by=current_user.id,
  109. )
  110. db.session.add(new_tag_binding)
  111. db.session.commit()
  112. @staticmethod
  113. def delete_tag_binding(args):
  114. # check if target exists
  115. TagService.check_target_exists(args["type"], args["target_id"])
  116. # delete tag binding
  117. tag_bindings = (
  118. db.session.query(TagBinding)
  119. .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
  120. .first()
  121. )
  122. if tag_bindings:
  123. db.session.delete(tag_bindings)
  124. db.session.commit()
  125. @staticmethod
  126. def check_target_exists(type: str, target_id: str):
  127. if type == "knowledge":
  128. dataset = (
  129. db.session.query(Dataset)
  130. .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
  131. .first()
  132. )
  133. if not dataset:
  134. raise NotFound("Dataset not found")
  135. elif type == "app":
  136. app = (
  137. db.session.query(App)
  138. .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
  139. .first()
  140. )
  141. if not app:
  142. raise NotFound("App not found")
  143. else:
  144. raise NotFound("Invalid binding type")