Browse Source

[Python SDK] Add KnowledgeBaseClient and the corresponding test cases. (#8465)

Co-authored-by: Wang Ying <wangying@xkool.org>
Ying Wang 7 months ago
parent
commit
4788e1c8c8

+ 280 - 1
sdks/python-client/dify_client/client.py

@@ -1,3 +1,4 @@
+import json
 import requests
 
 
@@ -133,4 +134,282 @@ class WorkflowClient(DifyClient):
         return self._send_request("POST", f"/workflows/tasks/{task_id}/stop", data)
 
     def get_result(self, workflow_run_id):
-        return self._send_request("GET", f"/workflows/run/{workflow_run_id}")
+        return self._send_request("GET", f"/workflows/run/{workflow_run_id}")
+
+
+
+class KnowledgeBaseClient(DifyClient):
+
+    def __init__(self, api_key, base_url: str = 'https://api.dify.ai/v1', dataset_id: str = None):
+        """
+        Construct a KnowledgeBaseClient object.
+
+        Args:
+            api_key (str): API key of Dify.
+            base_url (str, optional): Base URL of Dify API. Defaults to 'https://api.dify.ai/v1'.
+            dataset_id (str, optional): ID of the dataset. Defaults to None. You don't need this if you just want to
+                create a new dataset. or list datasets. otherwise you need to set this.
+        """
+        super().__init__(
+            api_key=api_key,
+            base_url=base_url
+        )
+        self.dataset_id = dataset_id
+
+    def _get_dataset_id(self):
+        if self.dataset_id is None:
+            raise ValueError("dataset_id is not set")
+        return self.dataset_id
+
+    def create_dataset(self, name: str, **kwargs):
+        return self._send_request('POST', '/datasets', {'name': name}, **kwargs)
+
+    def list_datasets(self, page: int = 1, page_size: int = 20, **kwargs):
+        return self._send_request('GET', f'/datasets?page={page}&limit={page_size}', **kwargs)
+
+    def create_document_by_text(self, name, text, extra_params: dict = None, **kwargs):
+        """
+        Create a document by text.
+
+        :param name: Name of the document
+        :param text: Text content of the document
+        :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
+            e.g.
+            {
+            'indexing_technique': 'high_quality',
+            'process_rule': {
+                'rules': {
+                    'pre_processing_rules': [
+                        {'id': 'remove_extra_spaces', 'enabled': True},
+                        {'id': 'remove_urls_emails', 'enabled': True}
+                    ],
+                    'segmentation': {
+                        'separator': '\n',
+                        'max_tokens': 500
+                    }
+                },
+                'mode': 'custom'
+            }
+        }
+        :return: Response from the API
+        """
+        data = {
+            'indexing_technique': 'high_quality',
+            'process_rule': {
+                'mode': 'automatic'
+            },
+            'name': name,
+            'text': text
+        }
+        if extra_params is not None and isinstance(extra_params, dict):
+            data.update(extra_params)
+        url = f"/datasets/{self._get_dataset_id()}/document/create_by_text"
+        return self._send_request("POST", url, json=data, **kwargs)
+
+    def update_document_by_text(self, document_id, name, text, extra_params: dict = None, **kwargs):
+        """
+        Update a document by text.
+
+        :param document_id: ID of the document
+        :param name: Name of the document
+        :param text: Text content of the document
+        :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
+            e.g.
+            {
+            'indexing_technique': 'high_quality',
+            'process_rule': {
+                'rules': {
+                    'pre_processing_rules': [
+                        {'id': 'remove_extra_spaces', 'enabled': True},
+                        {'id': 'remove_urls_emails', 'enabled': True}
+                    ],
+                    'segmentation': {
+                        'separator': '\n',
+                        'max_tokens': 500
+                    }
+                },
+                'mode': 'custom'
+            }
+        }
+        :return: Response from the API
+        """
+        data = {
+            'name': name,
+            'text': text
+        }
+        if extra_params is not None and isinstance(extra_params, dict):
+            data.update(extra_params)
+        url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_text"
+        return self._send_request("POST", url, json=data, **kwargs)
+
+    def create_document_by_file(self, file_path, original_document_id=None, extra_params: dict = None):
+        """
+        Create a document by file.
+
+        :param file_path: Path to the file
+        :param original_document_id: pass this ID if you want to replace the original document (optional)
+        :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
+            e.g.
+            {
+            'indexing_technique': 'high_quality',
+            'process_rule': {
+                'rules': {
+                    'pre_processing_rules': [
+                        {'id': 'remove_extra_spaces', 'enabled': True},
+                        {'id': 'remove_urls_emails', 'enabled': True}
+                    ],
+                    'segmentation': {
+                        'separator': '\n',
+                        'max_tokens': 500
+                    }
+                },
+                'mode': 'custom'
+            }
+        }
+        :return: Response from the API
+        """
+        files = {"file": open(file_path, "rb")}
+        data = {
+            'process_rule': {
+                'mode': 'automatic'
+            },
+            'indexing_technique': 'high_quality'
+        }
+        if extra_params is not None and isinstance(extra_params, dict):
+            data.update(extra_params)
+        if original_document_id is not None:
+            data['original_document_id'] = original_document_id
+        url = f"/datasets/{self._get_dataset_id()}/document/create_by_file"
+        return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
+
+    def update_document_by_file(self, document_id, file_path, extra_params: dict = None):
+        """
+        Update a document by file.
+
+        :param document_id: ID of the document
+        :param file_path: Path to the file
+        :param extra_params: extra parameters pass to the API, such as indexing_technique, process_rule. (optional)
+            e.g.
+            {
+            'indexing_technique': 'high_quality',
+            'process_rule': {
+                'rules': {
+                    'pre_processing_rules': [
+                        {'id': 'remove_extra_spaces', 'enabled': True},
+                        {'id': 'remove_urls_emails', 'enabled': True}
+                    ],
+                    'segmentation': {
+                        'separator': '\n',
+                        'max_tokens': 500
+                    }
+                },
+                'mode': 'custom'
+            }
+        }
+        :return:
+        """
+        files = {"file": open(file_path, "rb")}
+        data = {}
+        if extra_params is not None and isinstance(extra_params, dict):
+            data.update(extra_params)
+        url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/update_by_file"
+        return self._send_request_with_files("POST", url, {"data": json.dumps(data)}, files)
+
+    def batch_indexing_status(self, batch_id: str, **kwargs):
+        """
+        Get the status of the batch indexing.
+
+        :param batch_id: ID of the batch uploading
+        :return: Response from the API
+        """
+        url = f"/datasets/{self._get_dataset_id()}/documents/{batch_id}/indexing-status"
+        return self._send_request("GET", url, **kwargs)
+
+    def delete_dataset(self):
+        """
+        Delete this dataset.
+
+        :return: Response from the API
+        """
+        url = f"/datasets/{self._get_dataset_id()}"
+        return self._send_request("DELETE", url)
+
+    def delete_document(self, document_id):
+        """
+        Delete a document.
+
+        :param document_id: ID of the document
+        :return: Response from the API
+        """
+        url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}"
+        return self._send_request("DELETE", url)
+
+    def list_documents(self, page: int = None, page_size: int = None, keyword: str = None, **kwargs):
+        """
+        Get a list of documents in this dataset.
+
+        :return: Response from the API
+        """
+        params = {}
+        if page is not None:
+            params['page'] = page
+        if page_size is not None:
+            params['limit'] = page_size
+        if keyword is not None:
+            params['keyword'] = keyword
+        url = f"/datasets/{self._get_dataset_id()}/documents"
+        return self._send_request("GET", url, params=params, **kwargs)
+
+    def add_segments(self, document_id, segments, **kwargs):
+        """
+        Add segments to a document.
+
+        :param document_id: ID of the document
+        :param segments: List of segments to add, example: [{"content": "1", "answer": "1", "keyword": ["a"]}]
+        :return: Response from the API
+        """
+        data = {"segments": segments}
+        url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments"
+        return self._send_request("POST", url, json=data, **kwargs)
+
+    def query_segments(self, document_id, keyword: str = None, status: str = None, **kwargs):
+        """
+        Query segments in this document.
+
+        :param document_id: ID of the document
+        :param keyword: query keyword, optional
+        :param status: status of the segment, optional, e.g. completed
+        """
+        url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments"
+        params = {}
+        if keyword is not None:
+            params['keyword'] = keyword
+        if status is not None:
+            params['status'] = status
+        if "params" in kwargs:
+            params.update(kwargs["params"])
+        return self._send_request("GET", url, params=params, **kwargs)
+
+    def delete_document_segment(self, document_id, segment_id):
+        """
+        Delete a segment from a document.
+
+        :param document_id: ID of the document
+        :param segment_id: ID of the segment
+        :return: Response from the API
+        """
+        url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}"
+        return self._send_request("DELETE", url)
+
+    def update_document_segment(self, document_id, segment_id, segment_data, **kwargs):
+        """
+        Update a segment in a document.
+
+        :param document_id: ID of the document
+        :param segment_id: ID of the segment
+        :param segment_data: Data of the segment, example: {"content": "1", "answer": "1", "keyword": ["a"], "enabled": True}
+        :return: Response from the API
+        """
+        data = {"segment": segment_data}
+        url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}"
+        return self._send_request("POST", url, json=data, **kwargs)

+ 1 - 1
sdks/python-client/setup.py

@@ -5,7 +5,7 @@ with open("README.md", "r", encoding="utf-8") as fh:
 
 setup(
     name="dify-client",
-    version="0.1.11",
+    version="0.1.12",
     author="Dify",
     author_email="hello@dify.ai",
     description="A package for interacting with the Dify Service-API",

+ 148 - 1
sdks/python-client/tests/test_client.py

@@ -1,10 +1,157 @@
 import os
+import time
 import unittest
 
-from dify_client.client import ChatClient, CompletionClient, DifyClient
+from dify_client.client import ChatClient, CompletionClient, DifyClient, KnowledgeBaseClient
 
 API_KEY = os.environ.get("API_KEY")
 APP_ID = os.environ.get("APP_ID")
+API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.dify.ai/v1")
+FILE_PATH_BASE = os.path.dirname(__file__)
+
+
+class TestKnowledgeBaseClient(unittest.TestCase):
+    def setUp(self):
+        self.knowledge_base_client = KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL)
+        self.README_FILE_PATH = os.path.abspath(os.path.join(FILE_PATH_BASE, "../README.md"))
+        self.dataset_id = None
+        self.document_id = None
+        self.segment_id = None
+        self.batch_id = None
+
+    def _get_dataset_kb_client(self):
+        self.assertIsNotNone(self.dataset_id)
+        return KnowledgeBaseClient(API_KEY, base_url=API_BASE_URL, dataset_id=self.dataset_id)
+
+    def test_001_create_dataset(self):
+        response = self.knowledge_base_client.create_dataset(name="test_dataset")
+        data = response.json()
+        self.assertIn("id", data)
+        self.dataset_id = data["id"]
+        self.assertEqual("test_dataset", data["name"])
+
+        # the following tests require to be executed in order because they use
+        # the dataset/document/segment ids from the previous test
+        self._test_002_list_datasets()
+        self._test_003_create_document_by_text()
+        time.sleep(1)
+        self._test_004_update_document_by_text()
+        # self._test_005_batch_indexing_status()
+        time.sleep(1)
+        self._test_006_update_document_by_file()
+        time.sleep(1)
+        self._test_007_list_documents()
+        self._test_008_delete_document()
+        self._test_009_create_document_by_file()
+        time.sleep(1)
+        self._test_010_add_segments()
+        self._test_011_query_segments()
+        self._test_012_update_document_segment()
+        self._test_013_delete_document_segment()
+        self._test_014_delete_dataset()
+
+    def _test_002_list_datasets(self):
+        response = self.knowledge_base_client.list_datasets()
+        data = response.json()
+        self.assertIn("data", data)
+        self.assertIn("total", data)
+
+    def _test_003_create_document_by_text(self):
+        client = self._get_dataset_kb_client()
+        response = client.create_document_by_text("test_document", "test_text")
+        data = response.json()
+        self.assertIn("document", data)
+        self.document_id = data["document"]["id"]
+        self.batch_id = data["batch"]
+
+    def _test_004_update_document_by_text(self):
+        client = self._get_dataset_kb_client()
+        self.assertIsNotNone(self.document_id)
+        response = client.update_document_by_text(self.document_id, "test_document_updated", "test_text_updated")
+        data = response.json()
+        self.assertIn("document", data)
+        self.assertIn("batch", data)
+        self.batch_id = data["batch"]
+
+    def _test_005_batch_indexing_status(self):
+        client = self._get_dataset_kb_client()
+        response = client.batch_indexing_status(self.batch_id)
+        data = response.json()
+        self.assertEqual(response.status_code, 200)
+
+    def _test_006_update_document_by_file(self):
+        client = self._get_dataset_kb_client()
+        self.assertIsNotNone(self.document_id)
+        response = client.update_document_by_file(self.document_id, self.README_FILE_PATH)
+        data = response.json()
+        self.assertIn("document", data)
+        self.assertIn("batch", data)
+        self.batch_id = data["batch"]
+
+    def _test_007_list_documents(self):
+        client = self._get_dataset_kb_client()
+        response = client.list_documents()
+        data = response.json()
+        self.assertIn("data", data)
+
+    def _test_008_delete_document(self):
+        client = self._get_dataset_kb_client()
+        self.assertIsNotNone(self.document_id)
+        response = client.delete_document(self.document_id)
+        data = response.json()
+        self.assertIn("result", data)
+        self.assertEqual("success", data["result"])
+
+    def _test_009_create_document_by_file(self):
+        client = self._get_dataset_kb_client()
+        response = client.create_document_by_file(self.README_FILE_PATH)
+        data = response.json()
+        self.assertIn("document", data)
+        self.document_id = data["document"]["id"]
+        self.batch_id = data["batch"]
+
+    def _test_010_add_segments(self):
+        client = self._get_dataset_kb_client()
+        response = client.add_segments(self.document_id, [
+            {"content": "test text segment 1"}
+        ])
+        data = response.json()
+        self.assertIn("data", data)
+        self.assertGreater(len(data["data"]), 0)
+        segment = data["data"][0]
+        self.segment_id = segment["id"]
+
+    def _test_011_query_segments(self):
+        client = self._get_dataset_kb_client()
+        response = client.query_segments(self.document_id)
+        data = response.json()
+        self.assertIn("data", data)
+        self.assertGreater(len(data["data"]), 0)
+
+    def _test_012_update_document_segment(self):
+        client = self._get_dataset_kb_client()
+        self.assertIsNotNone(self.segment_id)
+        response = client.update_document_segment(self.document_id, self.segment_id,
+                                                  {"content": "test text segment 1 updated"}
+                                                  )
+        data = response.json()
+        self.assertIn("data", data)
+        self.assertGreater(len(data["data"]), 0)
+        segment = data["data"]
+        self.assertEqual("test text segment 1 updated", segment["content"])
+
+    def _test_013_delete_document_segment(self):
+        client = self._get_dataset_kb_client()
+        self.assertIsNotNone(self.segment_id)
+        response = client.delete_document_segment(self.document_id, self.segment_id)
+        data = response.json()
+        self.assertIn("result", data)
+        self.assertEqual("success", data["result"])
+
+    def _test_014_delete_dataset(self):
+        client = self._get_dataset_kb_client()
+        response = client.delete_dataset()
+        self.assertEqual(204, response.status_code)
 
 
 class TestChatClient(unittest.TestCase):