diff --git a/dify_client/async_client.py b/dify_client/async_client.py index 848c982..d92a639 100644 --- a/dify_client/async_client.py +++ b/dify_client/async_client.py @@ -178,6 +178,30 @@ async def get_file_preview(self, file_id: str): """Get file preview by file ID.""" return await self._send_request("GET", f"/files/{file_id}/preview") + async def get_app_feedbacks(self, page: int = 1, limit: int = 20): + """Get message feedbacks for the application. + + Args: + page: Page number (default: 1) + limit: Number of items per page (default: 20) + + Returns: + httpx.Response object + """ + params = {"page": page, "limit": limit} + return await self._send_request("GET", "/app/feedbacks", params=params) + + async def get_end_user_info(self, end_user_id: str): + """Get end user information. + + Args: + end_user_id: End user ID + + Returns: + httpx.Response object + """ + return await self._send_request("GET", f"/end-users/{end_user_id}") + # App Configuration APIs async def get_app_site_config(self, app_id: str): """Get app site configuration. @@ -281,6 +305,19 @@ async def create_completion_message( stream=(response_mode == "streaming"), ) + async def stop_completion_message(self, task_id: str, user: str): + """Stop a running completion message generation. + + Args: + task_id: Task ID from the completion message response + user: User identifier (must match the one used in the original request) + + Returns: + httpx.Response object + """ + data = {"user": user} + return await self._send_request("POST", f"/completion-messages/{task_id}/stop", data) + class AsyncChatClient(AsyncDifyClient): """Async client for Chat API operations.""" @@ -842,6 +879,45 @@ async def list_documents( url = f"/datasets/{self._get_dataset_id()}/documents" return await self._send_request("GET", url, params=params, **kwargs) + async def get_document(self, document_id: str, metadata: str = "all"): + """Get detailed information about a specific document. + + Args: + document_id: Document ID + metadata: Metadata inclusion mode ('all', 'only', 'without') + + Returns: + httpx.Response object + """ + params = {"metadata": metadata} + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}" + return await self._send_request("GET", url, params=params) + + async def download_document(self, document_id: str): + """Download a specific document. + + Args: + document_id: Document ID + + Returns: + httpx.Response object with the file download URL + """ + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/download" + return await self._send_request("GET", url) + + async def get_document_segment(self, document_id: str, segment_id: str): + """Get detailed information about a specific segment. + + Args: + document_id: Document ID + segment_id: Segment ID + + Returns: + httpx.Response object + """ + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" + return await self._send_request("GET", url) + async def add_segments(self, document_id: str, segments: list[dict], **kwargs): """Add segments to a document.""" data = {"segments": segments} @@ -890,6 +966,34 @@ async def update_document_segment( return await self._send_request("POST", url, json=data, **kwargs) # Advanced Knowledge Base APIs + async def retrieve( + self, + query: str, + retrieval_model: Dict[str, Any] = None, + external_retrieval_model: Dict[str, Any] = None, + attachment_ids: List[str] = None, + ): + """Retrieve chunks from the knowledge base. + + Args: + query: Search query text + retrieval_model: Retrieval model configuration (optional) + external_retrieval_model: External retrieval model configuration (optional) + attachment_ids: List of attachment IDs to include in retrieval context (optional) + + Returns: + httpx.Response object + """ + data = {"query": query} + if retrieval_model: + data["retrieval_model"] = retrieval_model + if external_retrieval_model: + data["external_retrieval_model"] = external_retrieval_model + if attachment_ids: + data["attachment_ids"] = attachment_ids + url = f"/datasets/{self._get_dataset_id()}/retrieve" + return await self._send_request("POST", url, json=data) + async def hit_testing( self, query: str, diff --git a/dify_client/client.py b/dify_client/client.py index 3cf600e..ce3ee12 100644 --- a/dify_client/client.py +++ b/dify_client/client.py @@ -255,6 +255,32 @@ def get_file_preview(self, file_id: str): """Get file preview by file ID.""" return self._send_request("GET", f"/files/{file_id}/preview") + def get_app_feedbacks(self, page: int = 1, limit: int = 20): + """Get message feedbacks for the application. + + Args: + page: Page number (default: 1) + limit: Number of items per page (default: 20) + + Returns: + httpx.Response object + """ + self._validate_params(page=page, limit=limit) + params = {"page": page, "limit": limit} + return self._send_request("GET", "/app/feedbacks", params=params) + + def get_end_user_info(self, end_user_id: str): + """Get end user information. + + Args: + end_user_id: End user ID + + Returns: + httpx.Response object + """ + self._validate_params(end_user_id=end_user_id) + return self._send_request("GET", f"/end-users/{end_user_id}") + # App Configuration APIs def get_app_site_config(self, app_id: str): """Get app site configuration. @@ -353,6 +379,20 @@ def create_completion_message( stream=(response_mode == "streaming"), ) + def stop_completion_message(self, task_id: str, user: str): + """Stop a running completion message generation. + + Args: + task_id: Task ID from the completion message response + user: User identifier (must match the one used in the original request) + + Returns: + httpx.Response object + """ + self._validate_params(task_id=task_id, user=user) + data = {"user": user} + return self._send_request("POST", f"/completion-messages/{task_id}/stop", data) + class ChatClient(DifyClient): def create_chat_message( @@ -1000,6 +1040,48 @@ def list_documents( url = f"/datasets/{self._get_dataset_id()}/documents" return self._send_request("GET", url, params=params, **kwargs) + def get_document(self, document_id: str, metadata: str = "all"): + """Get detailed information about a specific document. + + Args: + document_id: Document ID + metadata: Metadata inclusion mode ('all', 'only', 'without') + + Returns: + httpx.Response object + """ + self._validate_params(document_id=document_id) + params = {"metadata": metadata} + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}" + return self._send_request("GET", url, params=params) + + def download_document(self, document_id: str): + """Download a specific document. + + Args: + document_id: Document ID + + Returns: + httpx.Response object with the file download URL + """ + self._validate_params(document_id=document_id) + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/download" + return self._send_request("GET", url) + + def get_document_segment(self, document_id: str, segment_id: str): + """Get detailed information about a specific segment. + + Args: + document_id: Document ID + segment_id: Segment ID + + Returns: + httpx.Response object + """ + self._validate_params(document_id=document_id, segment_id=segment_id) + url = f"/datasets/{self._get_dataset_id()}/documents/{document_id}/segments/{segment_id}" + return self._send_request("GET", url) + def add_segments(self, document_id: str, segments: list[dict], **kwargs): """ Add segments to a document. @@ -1065,6 +1147,35 @@ def update_document_segment( return self._send_request("POST", url, json=data, **kwargs) # Advanced Knowledge Base APIs + def retrieve( + self, + query: str, + retrieval_model: Dict[str, Any] = None, + external_retrieval_model: Dict[str, Any] = None, + attachment_ids: List[str] = None, + ): + """Retrieve chunks from the knowledge base. + + Args: + query: Search query text + retrieval_model: Retrieval model configuration (optional) + external_retrieval_model: External retrieval model configuration (optional) + attachment_ids: List of attachment IDs to include in retrieval context (optional) + + Returns: + httpx.Response object + """ + self._validate_params(query=query) + data = {"query": query} + if retrieval_model: + data["retrieval_model"] = retrieval_model + if external_retrieval_model: + data["external_retrieval_model"] = external_retrieval_model + if attachment_ids: + data["attachment_ids"] = attachment_ids + url = f"/datasets/{self._get_dataset_id()}/retrieve" + return self._send_request("POST", url, json=data) + def hit_testing( self, query: str, diff --git a/tests/test_client.py b/tests/test_client.py index 797fe5a..fe50524 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -251,6 +251,75 @@ def _test_014_delete_dataset(self): response = client.delete_dataset() self.assertEqual(204, response.status_code) + @patch("dify_client.client.httpx.Client") + def test_retrieve(self, mock_httpx_client): + # Mock the HTTP response + mock_response = Mock() + mock_response.text = '{"data": [{"id": "segment1", "content": "test content"}]}\n' + mock_response.json.return_value = {"data": [{"id": "segment1", "content": "test content"}]} + mock_response.status_code = 200 + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_httpx_client.return_value = mock_client_instance + + # Create client with mocked httpx + knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url, dataset_id=self.dataset_id) + response = knowledge_base_client.retrieve("test query", "test_user") + self.assertIn("data", response.text) + + @patch("dify_client.client.httpx.Client") + def test_get_document(self, mock_httpx_client): + # Mock the HTTP response + mock_response = Mock() + mock_response.text = '{"id": "doc1", "name": "Test Document"}' + mock_response.json.return_value = {"id": "doc1", "name": "Test Document"} + mock_response.status_code = 200 + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_httpx_client.return_value = mock_client_instance + + # Create client with mocked httpx + knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url, dataset_id=self.dataset_id) + response = knowledge_base_client.get_document(self.document_id) + self.assertIn("id", response.text) + self.assertIn("name", response.text) + + @patch("dify_client.client.httpx.Client") + def test_download_document(self, mock_httpx_client): + # Mock the HTTP response + mock_response = Mock() + mock_response.content = b"test document content" + mock_response.status_code = 200 + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_httpx_client.return_value = mock_client_instance + + # Create client with mocked httpx + knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url, dataset_id=self.dataset_id) + response = knowledge_base_client.download_document(self.document_id) + self.assertEqual(200, response.status_code) + + @patch("dify_client.client.httpx.Client") + def test_get_document_segment(self, mock_httpx_client): + # Mock the HTTP response + mock_response = Mock() + mock_response.text = '{"id": "segment1", "content": "test segment content"}' + mock_response.json.return_value = {"id": "segment1", "content": "test segment content"} + mock_response.status_code = 200 + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_httpx_client.return_value = mock_client_instance + + # Create client with mocked httpx + knowledge_base_client = KnowledgeBaseClient(self.api_key, base_url=self.base_url, dataset_id=self.dataset_id) + response = knowledge_base_client.get_document_segment(self.document_id, self.segment_id) + self.assertIn("id", response.text) + self.assertIn("content", response.text) + class TestChatClient(unittest.TestCase): @patch("dify_client.client.httpx.Client") @@ -489,6 +558,24 @@ def test_create_completion_message_with_vision_model_by_local_file( ) self.assertIn("answer", response.text) + @patch("dify_client.client.httpx.Client") + def test_stop_completion_message(self, mock_httpx_client): + # Mock the HTTP response + mock_response = Mock() + mock_response.text = '{"result": "success"}' + mock_response.json.return_value = {"result": "success"} + mock_response.status_code = 200 + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_httpx_client.return_value = mock_client_instance + + # Create client with mocked httpx + completion_client = CompletionClient(self.api_key) + response = completion_client.stop_completion_message("test-task-id", "test_user") + self.assertIn("result", response.text) + self.assertEqual("success", response.json()["result"]) + class TestDifyClient(unittest.TestCase): @patch("dify_client.client.httpx.Client") @@ -568,6 +655,42 @@ def test_file_upload(self, mock_file_open, mock_httpx_client): response = dify_client.file_upload("test_user", files) self.assertIn("name", response.text) + @patch("dify_client.client.httpx.Client") + def test_get_app_feedbacks(self, mock_httpx_client): + # Mock the HTTP response + mock_response = Mock() + mock_response.text = '{"data": [{"id": "feedback1", "rating": "like"}], "total": 1}' + mock_response.json.return_value = {"data": [{"id": "feedback1", "rating": "like"}], "total": 1} + mock_response.status_code = 200 + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_httpx_client.return_value = mock_client_instance + + # Create client with mocked httpx + dify_client = DifyClient(self.api_key) + response = dify_client.get_app_feedbacks(page=1, limit=10) + self.assertIn("data", response.text) + self.assertIn("total", response.text) + + @patch("dify_client.client.httpx.Client") + def test_get_end_user_info(self, mock_httpx_client): + # Mock the HTTP response + mock_response = Mock() + mock_response.text = '{"id": "user1", "name": "Test User"}' + mock_response.json.return_value = {"id": "user1", "name": "Test User"} + mock_response.status_code = 200 + + mock_client_instance = Mock() + mock_client_instance.request.return_value = mock_response + mock_httpx_client.return_value = mock_client_instance + + # Create client with mocked httpx + dify_client = DifyClient(self.api_key) + response = dify_client.get_end_user_info("test_user") + self.assertIn("id", response.text) + self.assertIn("name", response.text) + if __name__ == "__main__": unittest.main()