diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 02958e6f0e..9f8c03a284 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -781,6 +781,227 @@ async def test_large_inserts_unordered(self): self.assertEqual(6, result.inserted_count) self.assertEqual(6, await self.coll.count_documents({})) + async def test_bulk_write_with_comment(self): + """Test bulk write operations with comment parameter.""" + requests = [ + InsertOne({"x": 1}), + UpdateOne({"x": 1}, {"$set": {"y": 1}}), + DeleteOne({"x": 1}), + ] + result = await self.coll.bulk_write(requests, comment="bulk_comment") + self.assertEqual(1, result.inserted_count) + self.assertEqual(1, result.modified_count) + self.assertEqual(1, result.deleted_count) + + async def test_bulk_write_with_let(self): + """Test bulk write operations with let parameter.""" + if not async_client_context.version.at_least(5, 0): + self.skipTest("let parameter requires MongoDB 5.0+") + + await self.coll.insert_one({"x": 1}) + requests = [ + UpdateOne({"$expr": {"$eq": ["$x", "$$targetVal"]}}, {"$set": {"updated": True}}), + ] + result = await self.coll.bulk_write(requests, let={"targetVal": 1}) + self.assertEqual(1, result.modified_count) + + async def test_bulk_write_all_operation_types(self): + """Test bulk write with all operation types combined.""" + await self.coll.insert_many([{"x": i} for i in range(5)]) + + requests = [ + InsertOne({"x": 100}), + UpdateOne({"x": 0}, {"$set": {"updated": True}}), + UpdateMany({"x": {"$lte": 2}}, {"$set": {"batch_updated": True}}), + ReplaceOne({"x": 3}, {"x": 3, "replaced": True}), + DeleteOne({"x": 4}), + DeleteMany({"x": {"$gt": 50}}), + ] + result = await self.coll.bulk_write(requests) + + self.assertEqual(1, result.inserted_count) + self.assertGreaterEqual(result.modified_count, 1) + self.assertGreaterEqual(result.deleted_count, 1) + + async def test_bulk_write_unordered(self): + """Test unordered bulk write continues after error.""" + await self.coll.create_index([("x", 1)], unique=True) + self.addAsyncCleanup(self.coll.drop_index, [("x", 1)]) + + requests = [ + InsertOne({"x": 1}), + InsertOne({"x": 1}), # Duplicate - will error + InsertOne({"x": 2}), + InsertOne({"x": 3}), + ] + + with self.assertRaises(BulkWriteError) as ctx: + await self.coll.bulk_write(requests, ordered=False) + + # With unordered, should have inserted 3 documents + self.assertEqual(3, ctx.exception.details["nInserted"]) + + async def test_bulk_write_ordered(self): + """Test ordered bulk write stops on first error.""" + await self.coll.create_index([("x", 1)], unique=True) + self.addAsyncCleanup(self.coll.drop_index, [("x", 1)]) + + requests = [ + InsertOne({"x": 1}), + InsertOne({"x": 1}), # Duplicate - will error + InsertOne({"x": 2}), + InsertOne({"x": 3}), + ] + + with self.assertRaises(BulkWriteError) as ctx: + await self.coll.bulk_write(requests, ordered=True) + + # With ordered, should have inserted only 1 document + self.assertEqual(1, ctx.exception.details["nInserted"]) + + async def test_bulk_write_bypass_document_validation(self): + """Test bulk write with bypass_document_validation.""" + if not async_client_context.version.at_least(3, 2): + self.skipTest("bypass_document_validation requires MongoDB 3.2+") + + # Create collection with validator + await self.coll.drop() + await self.db.create_collection( + self.coll.name, validator={"$jsonSchema": {"required": ["name"]}} + ) + + # Without bypass, should fail + with self.assertRaises(BulkWriteError): + await self.coll.bulk_write([InsertOne({"x": 1})]) + + # With bypass, should succeed + result = await self.coll.bulk_write([InsertOne({"x": 1})], bypass_document_validation=True) + self.assertEqual(1, result.inserted_count) + + async def test_bulk_write_result_properties(self): + """Test all BulkWriteResult properties.""" + await self.coll.insert_one({"x": 1}) + + requests = [ + InsertOne({"x": 2}), + UpdateOne({"x": 1}, {"$set": {"updated": True}}), + ReplaceOne({"x": 2}, {"x": 2, "replaced": True}, upsert=True), + DeleteOne({"x": 1}), + ] + result = await self.coll.bulk_write(requests) + + # Check all properties + self.assertTrue(result.acknowledged) + self.assertEqual(1, result.inserted_count) + self.assertGreaterEqual(result.matched_count, 0) + self.assertGreaterEqual(result.modified_count, 0) + self.assertEqual(1, result.deleted_count) + self.assertIsInstance(result.upserted_count, int) + self.assertIsInstance(result.upserted_ids, dict) + + async def test_bulk_write_with_upsert(self): + """Test bulk write upsert operations.""" + requests = [ + UpdateOne({"x": 1}, {"$set": {"y": 1}}, upsert=True), + UpdateOne({"x": 2}, {"$set": {"y": 2}}, upsert=True), + ReplaceOne({"x": 3}, {"x": 3, "y": 3}, upsert=True), + ] + result = await self.coll.bulk_write(requests) + + self.assertEqual(3, result.upserted_count) + self.assertEqual(3, len(result.upserted_ids)) + + async def test_update_one_with_hint(self): + """Test UpdateOne with hint parameter.""" + await self.coll.create_index([("x", 1)]) + self.addAsyncCleanup(self.coll.drop_index, [("x", 1)]) + + await self.coll.insert_one({"x": 1}) + + requests = [UpdateOne({"x": 1}, {"$set": {"y": 1}}, hint=[("x", 1)])] + result = await self.coll.bulk_write(requests) + self.assertEqual(1, result.modified_count) + + async def test_update_many_with_hint(self): + """Test UpdateMany with hint parameter.""" + await self.coll.create_index([("x", 1)]) + self.addAsyncCleanup(self.coll.drop_index, [("x", 1)]) + + await self.coll.insert_many([{"x": 1}, {"x": 1}]) + + requests = [UpdateMany({"x": 1}, {"$set": {"y": 1}}, hint=[("x", 1)])] + result = await self.coll.bulk_write(requests) + self.assertEqual(2, result.modified_count) + + async def test_delete_one_with_hint(self): + """Test DeleteOne with hint parameter.""" + await self.coll.create_index([("x", 1)]) + self.addAsyncCleanup(self.coll.drop_index, [("x", 1)]) + + await self.coll.insert_one({"x": 1}) + + requests = [DeleteOne({"x": 1}, hint=[("x", 1)])] + result = await self.coll.bulk_write(requests) + self.assertEqual(1, result.deleted_count) + + async def test_delete_many_with_hint(self): + """Test DeleteMany with hint parameter.""" + await self.coll.create_index([("x", 1)]) + self.addAsyncCleanup(self.coll.drop_index, [("x", 1)]) + + await self.coll.insert_many([{"x": 1}, {"x": 1}]) + + requests = [DeleteMany({"x": 1}, hint=[("x", 1)])] + result = await self.coll.bulk_write(requests) + self.assertEqual(2, result.deleted_count) + + async def test_update_one_with_array_filters(self): + """Test UpdateOne with array_filters parameter.""" + await self.coll.insert_one({"x": [{"y": 1}, {"y": 2}, {"y": 3}]}) + + requests = [ + UpdateOne({}, {"$set": {"x.$[elem].z": 1}}, array_filters=[{"elem.y": {"$gt": 1}}]) + ] + result = await self.coll.bulk_write(requests) + self.assertEqual(1, result.modified_count) + + doc = await self.coll.find_one() + # Elements with y > 1 should have z = 1 + for elem in doc["x"]: + if elem["y"] > 1: + self.assertEqual(1, elem.get("z")) + + async def test_replace_one_with_hint(self): + """Test ReplaceOne with hint parameter.""" + await self.coll.create_index([("x", 1)]) + self.addAsyncCleanup(self.coll.drop_index, [("x", 1)]) + + await self.coll.insert_one({"x": 1}) + + requests = [ReplaceOne({"x": 1}, {"x": 1, "replaced": True}, hint=[("x", 1)])] + result = await self.coll.bulk_write(requests) + self.assertEqual(1, result.modified_count) + + async def test_update_with_collation(self): + """Test update operations with collation.""" + await self.coll.insert_many( + [ + {"name": "cafe"}, + {"name": "Cafe"}, + ] + ) + + requests = [ + UpdateMany( + {"name": "cafe"}, + {"$set": {"updated": True}}, + collation={"locale": "en", "strength": 2}, + ) + ] + result = await self.coll.bulk_write(requests) + # With case-insensitive collation, both docs should match + self.assertEqual(2, result.modified_count) + class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase): @async_client_context.require_auth diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 2e1f83884b..d3bfebcd9e 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -1152,5 +1152,122 @@ def asyncTearDown(self): ) +class AsyncTestChangeStreamCoverage(TestAsyncCollectionAsyncChangeStream): + """Additional tests to improve code coverage for AsyncChangeStream.""" + + async def test_change_stream_alive_property(self): + """Test alive property state transitions.""" + async with await self.change_stream() as cs: + self.assertTrue(cs.alive) + # After context exit, should be closed + self.assertFalse(cs.alive) + + async def test_change_stream_idempotent_close(self): + """Test that close() can be called multiple times safely.""" + cs = await self.change_stream() + await cs.close() + # Second close should not raise + await cs.close() + self.assertFalse(cs.alive) + + async def test_change_stream_resume_token_deepcopy(self): + """Test that resume_token returns a deep copy.""" + coll = self.watched_collection() + async with await self.change_stream() as cs: + await coll.insert_one({"x": 1}) + await anext(cs) # Consume the change event + token1 = cs.resume_token + token2 = cs.resume_token + # Should be equal but different objects + self.assertEqual(token1, token2) + self.assertIsNot(token1, token2) + + async def test_change_stream_with_comment(self): + """Test change stream with comment parameter.""" + client, listener = await self.client_with_listener("aggregate") + try: + async with await self.change_stream_with_client(client, comment="test_comment"): + pass + finally: + await client.close() + + # Check that comment was in the aggregate command + self.assertGreater(len(listener.started_events), 0) + cmd = listener.started_events[0].command + self.assertEqual("test_comment", cmd.get("comment")) + + async def test_change_stream_with_show_expanded_events(self): + """Test change stream with show_expanded_events parameter.""" + if not async_client_context.version.at_least(6, 0): + self.skipTest("show_expanded_events requires MongoDB 6.0+") + + async with await self.change_stream(show_expanded_events=True) as cs: + # Just verify it doesn't error + self.assertTrue(cs.alive) + + @async_client_context.require_version_min(6, 0) + async def test_change_stream_with_full_document_before_change(self): + """Test change stream with full_document_before_change parameter.""" + coll = self.watched_collection() + # Need to ensure collection exists with changeStreamPreAndPostImages enabled + await coll.drop() + await self.db.create_collection(coll.name, changeStreamPreAndPostImages={"enabled": True}) + await coll.insert_one({"x": 1}) + + async with await self.change_stream(full_document_before_change="whenAvailable") as cs: + await coll.update_one({"x": 1}, {"$set": {"x": 2}}) + change = await anext(cs) + self.assertEqual("update", change["operationType"]) + # fullDocumentBeforeChange should be present + self.assertIn("fullDocumentBeforeChange", change) + + async def test_change_stream_next_after_close(self): + """Test that next() on closed stream raises StopAsyncIteration.""" + cs = await self.change_stream() + await cs.close() + with self.assertRaises(StopAsyncIteration): + await anext(cs) + + async def test_change_stream_try_next_after_close(self): + """Test that try_next() on closed stream raises StopAsyncIteration.""" + cs = await self.change_stream() + await cs.close() + with self.assertRaises(StopAsyncIteration): + await cs.try_next() + + async def test_change_stream_pipeline_construction(self): + """Test change stream pipeline is properly constructed.""" + pipeline = [{"$match": {"operationType": "insert"}}] + client, listener = await self.client_with_listener("aggregate") + try: + async with await self.change_stream_with_client(client, pipeline=pipeline): + pass + finally: + await client.close() + + cmd = listener.started_events[0].command + agg_pipeline = cmd["pipeline"] + # First stage should be $changeStream + self.assertIn("$changeStream", agg_pipeline[0]) + # Second stage should be our match + self.assertEqual({"$match": {"operationType": "insert"}}, agg_pipeline[1]) + + async def test_change_stream_empty_pipeline(self): + """Test change stream with empty pipeline.""" + async with await self.change_stream(pipeline=[]) as cs: + self.assertTrue(cs.alive) + + async def test_change_stream_context_manager_exception(self): + """Test change stream context manager closes on exception.""" + cs = None + try: + async with await self.change_stream() as cs: + raise ValueError("test exception") + except ValueError: + pass + # Stream should be closed + self.assertFalse(cs.alive) + + if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 3232650487..bfc395ed69 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -2260,5 +2260,264 @@ async def afind(*args, **kwargs): await helper(*args, let={}) # type: ignore +class AsyncTestCollectionCoverage(AsyncIntegrationTest): + """Additional tests to improve code coverage for AsyncCollection.""" + + async def asyncSetUp(self): + await super().asyncSetUp() + await self.db.test.drop() + await self.db.test.insert_many([{"x": i, "y": i * 2} for i in range(10)]) + + async def test_collection_full_name(self): + """Test full_name property.""" + expected = f"{self.db.name}.test" + self.assertEqual(expected, self.db.test.full_name) + + async def test_collection_name(self): + """Test name property.""" + self.assertEqual("test", self.db.test.name) + + async def test_collection_database(self): + """Test database property.""" + self.assertEqual(self.db, self.db.test.database) + + async def test_collection_equality(self): + """Test collection equality.""" + coll1 = self.db.test + coll2 = self.db.test + coll3 = self.db.other + self.assertEqual(coll1, coll2) + self.assertNotEqual(coll1, coll3) + + async def test_collection_hash(self): + """Test collection hashability.""" + coll1 = self.db.test + coll2 = self.db.test + # Same collection should have same hash + self.assertEqual(hash(coll1), hash(coll2)) + # Collections can be used in sets + s = {coll1, coll2} + self.assertEqual(1, len(s)) + + async def test_collection_repr(self): + """Test collection repr.""" + coll = self.db.test + repr_str = repr(coll) + self.assertIn("test", repr_str) + self.assertIn("AsyncCollection", repr_str) + + async def test_collection_getattr(self): + """Test sub-collection access via attribute.""" + subcoll = self.db.test.subcollection + self.assertEqual("test.subcollection", subcoll.name) + + async def test_collection_getitem(self): + """Test sub-collection access via indexing.""" + subcoll = self.db.test["subcollection"] + self.assertEqual("test.subcollection", subcoll.name) + + async def test_collection_with_options(self): + """Test with_options creates new collection with options.""" + from pymongo.read_concern import ReadConcern + from pymongo.write_concern import WriteConcern + + coll = self.db.test.with_options( + read_concern=ReadConcern("majority"), write_concern=WriteConcern(w=1) + ) + self.assertEqual("majority", coll.read_concern.level) + self.assertEqual({"w": 1}, coll.write_concern.document) + # Original should be unchanged + self.assertNotEqual("majority", self.db.test.read_concern.level) + + async def test_collection_drop(self): + """Test collection drop.""" + await self.db.test_drop.insert_one({"x": 1}) + await self.db.test_drop.drop() + names = await self.db.list_collection_names() + self.assertNotIn("test_drop", names) + + async def test_collection_drop_with_comment(self): + """Test collection drop with comment.""" + await self.db.test_drop_comment.insert_one({"x": 1}) + await self.db.test_drop_comment.drop(comment="test_comment") + names = await self.db.list_collection_names() + self.assertNotIn("test_drop_comment", names) + + async def test_find_raw_batches(self): + """Test find_raw_batches returns raw BSON.""" + from bson import decode_all + + cursor = self.db.test.find_raw_batches(batch_size=5) + batch_count = 0 + async for batch in cursor: + self.assertIsInstance(batch, bytes) + docs = decode_all(batch) + self.assertGreater(len(docs), 0) + batch_count += 1 + self.assertGreater(batch_count, 0) + + async def test_aggregate_raw_batches(self): + """Test aggregate_raw_batches returns raw BSON.""" + from bson import decode_all + + cursor = await self.db.test.aggregate_raw_batches([{"$sort": {"x": 1}}], batchSize=5) + batch_count = 0 + async for batch in cursor: + self.assertIsInstance(batch, bytes) + docs = decode_all(batch) + self.assertGreater(len(docs), 0) + batch_count += 1 + self.assertGreater(batch_count, 0) + + async def test_distinct_with_collation(self): + """Test distinct with collation.""" + await self.db.test.drop() + await self.db.test.insert_many( + [ + {"name": "abc"}, + {"name": "ABC"}, + {"name": "def"}, + ] + ) + # Case-insensitive distinct + values = await self.db.test.distinct("name", collation={"locale": "en_US", "strength": 2}) + # abc and ABC should be considered the same + self.assertEqual(2, len(values)) + + async def test_count_documents_with_options(self): + """Test count_documents with skip, limit, hint.""" + await self.db.test.create_index([("x", 1)]) + + count = await self.db.test.count_documents( + {"x": {"$gte": 0}}, skip=2, limit=5, hint=[("x", 1)] + ) + self.assertEqual(5, count) + + async def test_estimated_document_count(self): + """Test estimated_document_count.""" + count = await self.db.test.estimated_document_count() + self.assertEqual(10, count) + + async def test_estimated_document_count_with_options(self): + """Test estimated_document_count with maxTimeMS and comment.""" + count = await self.db.test.estimated_document_count(maxTimeMS=5000, comment="test_comment") + self.assertEqual(10, count) + + async def test_find_one_and_delete_with_options(self): + """Test find_one_and_delete with projection, sort.""" + doc = await self.db.test.find_one_and_delete( + {"x": {"$gte": 0}}, projection={"x": 1}, sort=[("x", -1)] + ) + self.assertEqual(9, doc["x"]) + self.assertNotIn("y", doc) + + async def test_find_one_and_replace_with_options(self): + """Test find_one_and_replace with various options.""" + from pymongo import ReturnDocument + + doc = await self.db.test.find_one_and_replace( + {"x": 0}, + {"x": 0, "replaced": True}, + projection={"x": 1, "replaced": 1}, + return_document=ReturnDocument.AFTER, + ) + self.assertEqual(0, doc["x"]) + self.assertTrue(doc.get("replaced")) + + async def test_find_one_and_update_with_options(self): + """Test find_one_and_update with various options.""" + from pymongo import ReturnDocument + + doc = await self.db.test.find_one_and_update( + {"x": 0}, + {"$set": {"updated": True}}, + projection={"x": 1, "updated": 1}, + return_document=ReturnDocument.AFTER, + ) + self.assertEqual(0, doc["x"]) + self.assertTrue(doc.get("updated")) + + async def test_update_one_with_array_filters(self): + """Test update_one with array_filters.""" + await self.db.test.drop() + await self.db.test.insert_one({"items": [{"v": 1}, {"v": 2}, {"v": 3}]}) + + result = await self.db.test.update_one( + {}, {"$set": {"items.$[elem].updated": True}}, array_filters=[{"elem.v": {"$gt": 1}}] + ) + self.assertEqual(1, result.modified_count) + + async def test_update_many_with_hint(self): + """Test update_many with hint.""" + await self.db.test.create_index([("x", 1)]) + + result = await self.db.test.update_many( + {"x": {"$gte": 0}}, {"$set": {"batch_updated": True}}, hint=[("x", 1)] + ) + self.assertEqual(10, result.modified_count) + + async def test_delete_one_with_hint(self): + """Test delete_one with hint.""" + await self.db.test.create_index([("x", 1)]) + + result = await self.db.test.delete_one({"x": 0}, hint=[("x", 1)]) + self.assertEqual(1, result.deleted_count) + + async def test_delete_many_with_hint(self): + """Test delete_many with hint.""" + await self.db.test.create_index([("x", 1)]) + + result = await self.db.test.delete_many({"x": {"$lt": 5}}, hint=[("x", 1)]) + self.assertEqual(5, result.deleted_count) + + async def test_aggregate_with_let(self): + """Test aggregate with let parameter.""" + if not async_client_context.version.at_least(5, 0): + self.skipTest("let parameter requires MongoDB 5.0+") + + pipeline = [{"$match": {"$expr": {"$eq": ["$x", "$$targetVal"]}}}] + cursor = await self.db.test.aggregate(pipeline, let={"targetVal": 5}) + docs = await cursor.to_list() + self.assertEqual(1, len(docs)) + self.assertEqual(5, docs[0]["x"]) + + async def test_aggregate_with_batch_size(self): + """Test aggregate with batchSize.""" + cursor = await self.db.test.aggregate([{"$sort": {"x": 1}}], batchSize=2) + docs = await cursor.to_list() + self.assertEqual(10, len(docs)) + + async def test_list_indexes(self): + """Test list_indexes returns cursor.""" + await self.db.test.create_index([("x", 1)]) + cursor = await self.db.test.list_indexes() + + # Should get at least the _id index + indexes = await cursor.to_list() + self.assertGreaterEqual(len(indexes), 1) + index_names = [idx["name"] for idx in indexes] + self.assertIn("_id_", index_names) + + async def test_index_information(self): + """Test index_information returns dict.""" + await self.db.test.create_index([("x", 1)], name="x_index") + info = await self.db.test.index_information() + + self.assertIsInstance(info, dict) + self.assertIn("_id_", info) + self.assertIn("x_index", info) + + async def test_options_method(self): + """Test options() returns collection options.""" + # Create a capped collection + await self.db.drop_collection("test_capped") + await self.db.create_collection("test_capped", capped=True, size=10000) + + opts = await self.db.test_capped.options() + self.assertTrue(opts.get("capped")) + + await self.db.drop_collection("test_capped") + + if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index 08da82762c..b0c2ab035d 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -1864,5 +1864,404 @@ async def test_exhaust_cursor_db_set(self): self.assertEqual(cmd.command["$db"], "pymongo_test") +class AsyncTestCursorCoverage(AsyncIntegrationTest): + """Additional tests to improve code coverage for AsyncCursor.""" + + async def asyncSetUp(self): + await super().asyncSetUp() + await self.db.test.drop() + await self.db.test.insert_many([{"x": i, "y": i * 2} for i in range(10)]) + + async def test_get_namespace(self): + """Test _get_namespace() method.""" + cursor = self.db.test.find() + expected_ns = f"{self.db.name}.test" + self.assertEqual(expected_ns, cursor._get_namespace()) + + async def test_cursor_alive_property_states(self): + """Test cursor alive property in different states.""" + cursor = self.db.test.find() + # Cursor is alive even before starting (has potential to return data) + self.assertTrue(cursor.alive) + + # Start the cursor + await anext(cursor) + self.assertTrue(cursor.alive) + + # Exhaust the cursor + await cursor.to_list() + self.assertFalse(cursor.alive) + + async def test_cursor_closed_property(self): + """Test cursor behavior after close.""" + cursor = self.db.test.find() + await anext(cursor) + self.assertTrue(cursor.alive) + + await cursor.close() + # After close, cursor is killed (check internal _killed flag) + self.assertTrue(cursor._killed) + + async def test_retrieved_property(self): + """Test the retrieved property tracking.""" + cursor = self.db.test.find().batch_size(2) + self.assertEqual(0, cursor.retrieved) + + await anext(cursor) + self.assertGreater(cursor.retrieved, 0) + + async def test_cursor_with_let_parameter(self): + """Test cursor with let parameter.""" + # let parameter allows variables to be used in the filter + cursor = self.db.test.find( + {"$expr": {"$eq": ["$x", "$$targetValue"]}}, let={"targetValue": 5} + ) + docs = await cursor.to_list() + self.assertEqual(1, len(docs)) + self.assertEqual(5, docs[0]["x"]) + + async def test_cursor_with_invalid_let_parameter(self): + """Test cursor raises error for invalid let parameter.""" + with self.assertRaises(TypeError): + self.db.test.find(let="invalid") # type: ignore[arg-type] + + async def test_cursor_with_show_record_id(self): + """Test cursor with show_record_id option.""" + cursor = self.db.test.find(show_record_id=True) + doc = await anext(cursor) + self.assertIn("$recordId", doc) + + async def test_cursor_with_return_key(self): + """Test cursor with return_key option.""" + await self.db.test.create_index([("x", ASCENDING)]) + cursor = self.db.test.find({"x": 5}, return_key=True).hint([("x", ASCENDING)]) + doc = await anext(cursor) + # return_key returns only index keys + self.assertIn("x", doc) + self.assertNotIn("y", doc) + + async def test_check_okay_to_chain_after_iteration(self): + """Test that cursor configuration methods raise after iteration.""" + cursor = self.db.test.find() + await anext(cursor) # Start iteration + + # All these should raise InvalidOperation + with self.assertRaises(InvalidOperation): + cursor.limit(5) + with self.assertRaises(InvalidOperation): + cursor.skip(2) + with self.assertRaises(InvalidOperation): + cursor.sort("x") + with self.assertRaises(InvalidOperation): + cursor.hint([("x", ASCENDING)]) + with self.assertRaises(InvalidOperation): + cursor.max([("x", 10)]) + with self.assertRaises(InvalidOperation): + cursor.min([("x", 0)]) + with self.assertRaises(InvalidOperation): + await cursor.add_option(2) + with self.assertRaises(InvalidOperation): + cursor.remove_option(2) + with self.assertRaises(InvalidOperation): + cursor.batch_size(10) + with self.assertRaises(InvalidOperation): + cursor.max_time_ms(1000) + with self.assertRaises(InvalidOperation): + cursor.collation(Collation("en_US")) + with self.assertRaises(InvalidOperation): + cursor.allow_disk_use(True) + with self.assertRaises(InvalidOperation): + cursor.where("this.x > 5") + with self.assertRaises(InvalidOperation): + cursor.comment("test") + + async def test_cursor_context_manager(self): + """Test cursor as async context manager.""" + async with self.db.test.find() as cursor: + doc = await anext(cursor) + self.assertIsNotNone(doc) + # Cursor should be killed after context (check _killed flag) + self.assertTrue(cursor._killed) + + async def test_cursor_context_manager_with_exception(self): + """Test cursor context manager closes on exception.""" + cursor = None + try: + async with self.db.test.find() as cursor: + await anext(cursor) + raise ValueError("test exception") + except ValueError: + pass + # Cursor should be killed after exception + self.assertTrue(cursor._killed) + + async def test_cursor_collation(self): + """Test cursor with collation.""" + await self.db.test.drop() + await self.db.test.insert_many([{"name": "abc"}, {"name": "ABC"}, {"name": "def"}]) + # Case-insensitive sort + cursor = ( + self.db.test.find().collation(Collation("en_US", strength=2)).sort("name", ASCENDING) + ) + docs = await cursor.to_list() + self.assertEqual(3, len(docs)) + + async def test_cursor_collation_type_error(self): + """Test cursor raises error for invalid collation.""" + with self.assertRaises(TypeError): + self.db.test.find().collation("invalid") # type: ignore[arg-type] + + async def test_cursor_getitem_not_supported(self): + """Test that AsyncCursor does not support indexing.""" + cursor = self.db.test.find() + with self.assertRaises(IndexError) as ctx: + cursor[5] + self.assertIn("does not support indexing", str(ctx.exception)) + + async def test_cursor_next_after_close(self): + """Test that next() raises StopAsyncIteration after close.""" + cursor = self.db.test.find() + await cursor.close() + with self.assertRaises(StopAsyncIteration): + await anext(cursor) + + async def test_cursor_rewind_resets_state(self): + """Test that rewind properly resets cursor state.""" + cursor = self.db.test.find().limit(3) + + # Iterate fully + docs1 = await cursor.to_list() + self.assertEqual(3, len(docs1)) + self.assertEqual(0, len(cursor._data)) + + # Rewind and iterate again + await cursor.rewind() + docs2 = await cursor.to_list() + self.assertEqual(3, len(docs2)) + self.assertEqual(docs1, docs2) + + async def test_cursor_clone_with_session(self): + """Test that clone preserves explicit session.""" + async with self.client.start_session() as session: + cursor = self.db.test.find(session=session) + cloned = cursor.clone() + # Clone should reference the same session + self.assertEqual(cursor.session, cloned.session) + + async def test_cursor_clone_without_session(self): + """Test that clone without session doesn't add one.""" + cursor = self.db.test.find() + cloned = cursor.clone() + # Clone should have no session if original had none + self.assertIsNone(cloned.session) + + async def test_cursor_distinct_with_collation(self): + """Test distinct with collation.""" + await self.db.test.drop() + await self.db.test.insert_many([{"name": "abc"}, {"name": "ABC"}, {"name": "def"}]) + # Case-insensitive distinct + cursor = self.db.test.find().collation(Collation("en_US", strength=2)) + # distinct() on cursor with collation + values = await cursor.distinct("name") + # Should have 2 distinct values (abc/ABC treated as same) + self.assertEqual(2, len(values)) + + async def test_cursor_explain_with_options(self): + """Test explain with cursor options set.""" + cursor = self.db.test.find({"x": {"$gt": 5}}).sort("x", ASCENDING).limit(5).skip(1) + explanation = await cursor.explain() + self.assertIn("queryPlanner", explanation) + + async def test_cursor_max_time_ms_type_errors(self): + """Test max_time_ms raises TypeError for invalid input.""" + cursor = self.db.test.find() + with self.assertRaises(TypeError): + cursor.max_time_ms("invalid") # type: ignore[arg-type] + + async def test_cursor_max_await_time_ms_type_errors(self): + """Test max_await_time_ms raises TypeError for invalid input.""" + cursor = self.db.test.find() + with self.assertRaises(TypeError): + cursor.max_await_time_ms("invalid") # type: ignore[arg-type] + + async def test_cursor_comment_type(self): + """Test cursor with comment of various types.""" + # String comment + cursor1 = self.db.test.find().comment("test comment") + docs1 = await cursor1.to_list() + self.assertGreater(len(docs1), 0) + + # Dict comment + cursor2 = self.db.test.find().comment({"key": "value"}) + docs2 = await cursor2.to_list() + self.assertGreater(len(docs2), 0) + + async def test_cursor_batch_size_validation(self): + """Test batch_size validation.""" + with self.assertRaises(TypeError): + self.db.test.find(batch_size="invalid") # type: ignore[arg-type] + with self.assertRaises(ValueError): + self.db.test.find(batch_size=-1) + + async def test_cursor_skip_validation(self): + """Test skip validation.""" + with self.assertRaises(TypeError): + self.db.test.find(skip="invalid") # type: ignore[arg-type] + + async def test_cursor_limit_validation(self): + """Test limit validation.""" + with self.assertRaises(TypeError): + self.db.test.find(limit="invalid") # type: ignore[arg-type] + + async def test_cursor_filter_validation(self): + """Test filter validation.""" + with self.assertRaises(TypeError): + self.db.test.find(filter="invalid") # type: ignore[arg-type] + + async def test_cursor_type_validation(self): + """Test cursor_type validation.""" + with self.assertRaises(ValueError): + self.db.test.find(cursor_type=999) + + async def test_cursor_query_spec_with_modifiers(self): + """Test _query_spec includes modifiers.""" + cursor = ( + self.db.test.find() + .sort("x", ASCENDING) + .hint([("x", ASCENDING)]) + .max_time_ms(1000) + .comment("test") + ) + spec = cursor._query_spec() + self.assertIsInstance(spec, dict) + + async def test_cursor_copy(self): + """Test cursor __copy__ returns clone.""" + cursor = self.db.test.find().limit(5) + copied = copy.copy(cursor) + self.assertIsNot(cursor, copied) + self.assertEqual(cursor._limit, copied._limit) + + async def test_cursor_deepcopy(self): + """Test cursor __deepcopy__ returns deep clone.""" + cursor = self.db.test.find({"x": {"$gt": 0}}).limit(5) + copied = copy.deepcopy(cursor) + self.assertIsNot(cursor, copied) + self.assertEqual(cursor._limit, copied._limit) + self.assertEqual(cursor._spec, copied._spec) + # Spec should be a different object + self.assertIsNot(cursor._spec, copied._spec) + + async def test_cursor_iteration_protocol(self): + """Test cursor async iteration protocol.""" + cursor = self.db.test.find().limit(3) + + # Test __aiter__ returns self + self.assertIs(cursor, cursor.__aiter__()) + + # Test __anext__ returns documents + doc1 = await cursor.__anext__() + self.assertIsNotNone(doc1) + + async def test_cursor_to_list_with_limit(self): + """Test to_list respects cursor limit.""" + cursor = self.db.test.find().limit(3) + docs = await cursor.to_list() + self.assertEqual(3, len(docs)) + + async def test_cursor_to_list_with_length(self): + """Test to_list with length parameter.""" + cursor = self.db.test.find() + docs = await cursor.to_list(length=3) + self.assertEqual(3, len(docs)) + + async def test_min_max_require_hint(self): + """Test that min/max require hint for proper execution.""" + await self.db.test.create_index([("x", ASCENDING)]) + + # min without hint should work when index exists + cursor = self.db.test.find().min([("x", 5)]).hint([("x", ASCENDING)]) + docs = await cursor.to_list() + self.assertTrue(all(doc["x"] >= 5 for doc in docs)) + + # max without hint should work when index exists + cursor = self.db.test.find().max([("x", 5)]).hint([("x", ASCENDING)]) + docs = await cursor.to_list() + self.assertTrue(all(doc["x"] < 5 for doc in docs)) + + async def test_cursor_address_property(self): + """Test cursor address is set after first batch.""" + cursor = self.db.test.find() + self.assertIsNone(cursor.address) + await anext(cursor) + # Address should be set after query + self.assertIsNotNone(cursor.address) + + async def test_cursor_session_property(self): + """Test cursor session property.""" + # Cursor without explicit session + cursor1 = self.db.test.find() + self.assertIsNone(cursor1.session) + + # Cursor with explicit session + async with self.client.start_session() as session: + cursor2 = self.db.test.find(session=session) + self.assertEqual(session, cursor2.session) + + async def test_cursor_allow_disk_use_type_error(self): + """Test allow_disk_use raises TypeError for invalid input.""" + with self.assertRaises(TypeError): + self.db.test.find().allow_disk_use("invalid") # type: ignore[arg-type] + + +class AsyncTestRawBatchCursorCoverage(AsyncIntegrationTest): + """Additional tests for AsyncRawBatchCursor coverage.""" + + async def asyncSetUp(self): + await super().asyncSetUp() + await self.db.test.drop() + await self.db.test.insert_many([{"x": i} for i in range(20)]) + + async def test_raw_batch_cursor_iteration(self): + """Test raw batch cursor returns raw BSON.""" + cursor = self.db.test.find_raw_batches(batch_size=5) + batch_count = 0 + async for batch in cursor: + self.assertIsInstance(batch, bytes) + # Decode the batch to verify it's valid BSON + docs = decode_all(batch) + self.assertGreater(len(docs), 0) + batch_count += 1 + self.assertGreater(batch_count, 0) + + async def test_raw_batch_cursor_explain(self): + """Test raw batch cursor explain.""" + cursor = self.db.test.find_raw_batches() + explanation = await cursor.explain() + self.assertIn("queryPlanner", explanation) + + async def test_raw_batch_cursor_getitem_raises(self): + """Test raw batch cursor __getitem__ raises InvalidOperation.""" + cursor = self.db.test.find_raw_batches() + with self.assertRaises(InvalidOperation): + cursor[0] + + async def test_raw_batch_cursor_with_sort(self): + """Test raw batch cursor with sort.""" + cursor = self.db.test.find_raw_batches(batch_size=5).sort("x", DESCENDING) + first_batch = await anext(cursor) + docs = decode_all(first_batch) + # First doc should have highest x value + self.assertEqual(19, docs[0]["x"]) + + async def test_raw_batch_cursor_with_limit(self): + """Test raw batch cursor with limit.""" + cursor = self.db.test.find_raw_batches(batch_size=5).limit(7) + all_docs = [] + async for batch in cursor: + all_docs.extend(decode_all(batch)) + self.assertEqual(7, len(all_docs)) + + if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 404a69fdee..22da658882 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -53,6 +53,8 @@ from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure from pymongo.operations import IndexModel, InsertOne, UpdateOne from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference +from pymongo.write_concern import WriteConcern _IS_SYNC = False @@ -1361,5 +1363,284 @@ async def next_heartbeat(): self.assertEqual(started.command["$clusterTime"], cluster_time) +class AsyncTestClientSessionCoverage(AsyncIntegrationTest): + """Additional tests to improve code coverage for AsyncClientSession.""" + + @async_client_context.require_sessions + async def test_session_has_ended_property(self): + """Test has_ended property state transitions.""" + session = self.client.start_session() + self.assertFalse(session.has_ended) + await session.end_session() + self.assertTrue(session.has_ended) + + @async_client_context.require_sessions + async def test_session_session_id_property(self): + """Test session_id property returns correct value.""" + async with self.client.start_session() as session: + session_id = session.session_id + self.assertIsInstance(session_id, dict) + self.assertIn("id", session_id) + + @async_client_context.require_sessions + async def test_session_cluster_time_operations(self): + """Test cluster time advance operations.""" + async with self.client.start_session() as session: + # Initially None + self.assertIsNone(session.cluster_time) + + # Perform operation to get cluster time + await self.db.test.find_one({}, session=session) + + # Cluster time should be set after operation + # (may still be None on some server versions) + + @async_client_context.require_sessions + async def test_session_operation_time_operations(self): + """Test operation time advance operations.""" + async with self.client.start_session() as session: + # Initially None + self.assertIsNone(session.operation_time) + + # Perform operation to get operation time + await self.db.test.find_one({}, session=session) + + @async_client_context.require_sessions + async def test_session_options_property(self): + """Test session options property.""" + async with self.client.start_session(causal_consistency=True) as session: + self.assertTrue(session.options.causal_consistency) + + @async_client_context.require_sessions + async def test_session_client_property(self): + """Test session client property.""" + async with self.client.start_session() as session: + self.assertEqual(self.client, session.client) + + @async_client_context.require_sessions + async def test_session_in_transaction_property(self): + """Test in_transaction property.""" + if async_client_context.is_rs or async_client_context.is_mongos: + async with self.client.start_session() as session: + self.assertFalse(session.in_transaction) + session.start_transaction() + self.assertTrue(session.in_transaction) + await session.abort_transaction() + self.assertFalse(session.in_transaction) + + @async_client_context.require_sessions + async def test_session_context_manager(self): + """Test session async context manager.""" + async with self.client.start_session() as session: + self.assertFalse(session.has_ended) + await self.db.test.find_one({}, session=session) + self.assertTrue(session.has_ended) + + @async_client_context.require_sessions + async def test_session_context_manager_exception(self): + """Test session context manager closes on exception.""" + session = None + try: + async with self.client.start_session() as session: + raise ValueError("test exception") + except ValueError: + pass + self.assertTrue(session.has_ended) + + @async_client_context.require_sessions + async def test_session_operations_after_end(self): + """Test operations on ended session raise InvalidOperation.""" + session = self.client.start_session() + await session.end_session() + + with self.assertRaises(InvalidOperation): + await self.db.test.find_one({}, session=session) + + @async_client_context.require_sessions + async def test_session_end_session_idempotent(self): + """Test that end_session can be called multiple times.""" + session = self.client.start_session() + await session.end_session() + # Second call should not raise + await session.end_session() + self.assertTrue(session.has_ended) + + @async_client_context.require_transactions + async def test_transaction_start_without_prior_transaction(self): + """Test start_transaction on fresh session.""" + async with self.client.start_session() as session: + session.start_transaction() + self.assertTrue(session.in_transaction) + await session.abort_transaction() + + @async_client_context.require_transactions + async def test_transaction_start_twice_raises(self): + """Test starting transaction twice raises error.""" + async with self.client.start_session() as session: + session.start_transaction() + with self.assertRaises(InvalidOperation): + session.start_transaction() + await session.abort_transaction() + + @async_client_context.require_transactions + async def test_transaction_abort_without_transaction_raises(self): + """Test aborting without transaction raises error.""" + async with self.client.start_session() as session: + with self.assertRaises(InvalidOperation): + await session.abort_transaction() + + @async_client_context.require_transactions + async def test_transaction_commit_without_transaction_raises(self): + """Test committing without transaction raises error.""" + async with self.client.start_session() as session: + with self.assertRaises(InvalidOperation): + await session.commit_transaction() + + @async_client_context.require_sessions + async def test_session_advance_cluster_time_validation(self): + """Test advance_cluster_time with invalid input.""" + async with self.client.start_session() as session: + with self.assertRaises(TypeError): + session.advance_cluster_time("invalid") # type: ignore + with self.assertRaises(ValueError): + session.advance_cluster_time({}) + + @async_client_context.require_sessions + async def test_session_advance_operation_time_validation(self): + """Test advance_operation_time with invalid input.""" + from bson import Timestamp + + async with self.client.start_session() as session: + with self.assertRaises(TypeError): + session.advance_operation_time("invalid") # type: ignore + # Valid Timestamp should work + session.advance_operation_time(Timestamp(1, 1)) + + @async_client_context.require_transactions + async def test_with_transaction_callback_success(self): + """Test with_transaction with successful callback.""" + async with self.client.start_session() as session: + + async def callback(session): + await self.db.test.insert_one({"x": 1}, session=session) + return "success" + + result = await session.with_transaction(callback) + self.assertEqual("success", result) + + @async_client_context.require_transactions + async def test_with_transaction_callback_exception(self): + """Test with_transaction with callback exception.""" + async with self.client.start_session() as session: + + async def callback(session): + await self.db.test.insert_one({"x": 1}, session=session) + raise ValueError("callback error") + + with self.assertRaises(ValueError): + await session.with_transaction(callback) + # Transaction should be aborted + self.assertFalse(session.in_transaction) + + +class AsyncTestSessionOptionsCoverage(AsyncUnitTest): + """Tests for SessionOptions coverage.""" + + def test_session_options_defaults(self): + """Test SessionOptions default values.""" + from pymongo.asynchronous.client_session import SessionOptions + + options = SessionOptions() + self.assertTrue(options.causal_consistency) + self.assertIsNone(options.default_transaction_options) + self.assertFalse(options.snapshot) + + def test_session_options_snapshot_disables_causal_consistency(self): + """Test snapshot=True forces causal_consistency=False.""" + from pymongo.asynchronous.client_session import SessionOptions + + options = SessionOptions(snapshot=True) + self.assertFalse(options.causal_consistency) + self.assertTrue(options.snapshot) + + def test_session_options_snapshot_with_causal_raises(self): + """Test snapshot=True with causal_consistency=True raises error.""" + from pymongo.asynchronous.client_session import SessionOptions + + with self.assertRaises(ConfigurationError): + SessionOptions(snapshot=True, causal_consistency=True) + + def test_session_options_invalid_transaction_options(self): + """Test SessionOptions with invalid transaction options type.""" + from pymongo.asynchronous.client_session import SessionOptions + + with self.assertRaises(TypeError): + SessionOptions(default_transaction_options="invalid") # type: ignore + + +class AsyncTestTransactionOptionsCoverage(AsyncUnitTest): + """Tests for TransactionOptions coverage.""" + + def test_transaction_options_defaults(self): + """Test TransactionOptions default values.""" + from pymongo.asynchronous.client_session import TransactionOptions + + options = TransactionOptions() + self.assertIsNone(options.read_concern) + self.assertIsNone(options.write_concern) + self.assertIsNone(options.read_preference) + self.assertIsNone(options.max_commit_time_ms) + + def test_transaction_options_with_values(self): + """Test TransactionOptions with all values set.""" + from pymongo.asynchronous.client_session import TransactionOptions + + options = TransactionOptions( + read_concern=ReadConcern("majority"), + write_concern=WriteConcern(w="majority"), + read_preference=ReadPreference.PRIMARY, + max_commit_time_ms=5000, + ) + self.assertEqual("majority", options.read_concern.level) + self.assertEqual("majority", options.write_concern.document.get("w")) + self.assertEqual(ReadPreference.PRIMARY, options.read_preference) + self.assertEqual(5000, options.max_commit_time_ms) + + def test_transaction_options_invalid_read_concern(self): + """Test TransactionOptions with invalid read_concern type.""" + from pymongo.asynchronous.client_session import TransactionOptions + + with self.assertRaises(TypeError): + TransactionOptions(read_concern="invalid") # type: ignore + + def test_transaction_options_invalid_write_concern(self): + """Test TransactionOptions with invalid write_concern type.""" + from pymongo.asynchronous.client_session import TransactionOptions + + with self.assertRaises(TypeError): + TransactionOptions(write_concern="invalid") # type: ignore + + def test_transaction_options_invalid_read_preference(self): + """Test TransactionOptions with invalid read_preference type.""" + from pymongo.asynchronous.client_session import TransactionOptions + + with self.assertRaises(TypeError): + TransactionOptions(read_preference="invalid") # type: ignore + + def test_transaction_options_invalid_max_commit_time(self): + """Test TransactionOptions with invalid max_commit_time_ms type.""" + from pymongo.asynchronous.client_session import TransactionOptions + + with self.assertRaises(TypeError): + TransactionOptions(max_commit_time_ms="invalid") # type: ignore + + def test_transaction_options_unacknowledged_write_concern(self): + """Test TransactionOptions rejects unacknowledged write concern.""" + from pymongo.asynchronous.client_session import TransactionOptions + + with self.assertRaises(ConfigurationError): + TransactionOptions(write_concern=WriteConcern(w=0)) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_bulk.py b/test/test_bulk.py index 1de406fca5..65017ccdb1 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -779,6 +779,225 @@ def test_large_inserts_unordered(self): self.assertEqual(6, result.inserted_count) self.assertEqual(6, self.coll.count_documents({})) + def test_bulk_write_with_comment(self): + """Test bulk write operations with comment parameter.""" + requests = [ + InsertOne({"x": 1}), + UpdateOne({"x": 1}, {"$set": {"y": 1}}), + DeleteOne({"x": 1}), + ] + result = self.coll.bulk_write(requests, comment="bulk_comment") + self.assertEqual(1, result.inserted_count) + self.assertEqual(1, result.modified_count) + self.assertEqual(1, result.deleted_count) + + def test_bulk_write_with_let(self): + """Test bulk write operations with let parameter.""" + if not client_context.version.at_least(5, 0): + self.skipTest("let parameter requires MongoDB 5.0+") + + self.coll.insert_one({"x": 1}) + requests = [ + UpdateOne({"$expr": {"$eq": ["$x", "$$targetVal"]}}, {"$set": {"updated": True}}), + ] + result = self.coll.bulk_write(requests, let={"targetVal": 1}) + self.assertEqual(1, result.modified_count) + + def test_bulk_write_all_operation_types(self): + """Test bulk write with all operation types combined.""" + self.coll.insert_many([{"x": i} for i in range(5)]) + + requests = [ + InsertOne({"x": 100}), + UpdateOne({"x": 0}, {"$set": {"updated": True}}), + UpdateMany({"x": {"$lte": 2}}, {"$set": {"batch_updated": True}}), + ReplaceOne({"x": 3}, {"x": 3, "replaced": True}), + DeleteOne({"x": 4}), + DeleteMany({"x": {"$gt": 50}}), + ] + result = self.coll.bulk_write(requests) + + self.assertEqual(1, result.inserted_count) + self.assertGreaterEqual(result.modified_count, 1) + self.assertGreaterEqual(result.deleted_count, 1) + + def test_bulk_write_unordered(self): + """Test unordered bulk write continues after error.""" + self.coll.create_index([("x", 1)], unique=True) + self.addCleanup(self.coll.drop_index, [("x", 1)]) + + requests = [ + InsertOne({"x": 1}), + InsertOne({"x": 1}), # Duplicate - will error + InsertOne({"x": 2}), + InsertOne({"x": 3}), + ] + + with self.assertRaises(BulkWriteError) as ctx: + self.coll.bulk_write(requests, ordered=False) + + # With unordered, should have inserted 3 documents + self.assertEqual(3, ctx.exception.details["nInserted"]) + + def test_bulk_write_ordered(self): + """Test ordered bulk write stops on first error.""" + self.coll.create_index([("x", 1)], unique=True) + self.addCleanup(self.coll.drop_index, [("x", 1)]) + + requests = [ + InsertOne({"x": 1}), + InsertOne({"x": 1}), # Duplicate - will error + InsertOne({"x": 2}), + InsertOne({"x": 3}), + ] + + with self.assertRaises(BulkWriteError) as ctx: + self.coll.bulk_write(requests, ordered=True) + + # With ordered, should have inserted only 1 document + self.assertEqual(1, ctx.exception.details["nInserted"]) + + def test_bulk_write_bypass_document_validation(self): + """Test bulk write with bypass_document_validation.""" + if not client_context.version.at_least(3, 2): + self.skipTest("bypass_document_validation requires MongoDB 3.2+") + + # Create collection with validator + self.coll.drop() + self.db.create_collection(self.coll.name, validator={"$jsonSchema": {"required": ["name"]}}) + + # Without bypass, should fail + with self.assertRaises(BulkWriteError): + self.coll.bulk_write([InsertOne({"x": 1})]) + + # With bypass, should succeed + result = self.coll.bulk_write([InsertOne({"x": 1})], bypass_document_validation=True) + self.assertEqual(1, result.inserted_count) + + def test_bulk_write_result_properties(self): + """Test all BulkWriteResult properties.""" + self.coll.insert_one({"x": 1}) + + requests = [ + InsertOne({"x": 2}), + UpdateOne({"x": 1}, {"$set": {"updated": True}}), + ReplaceOne({"x": 2}, {"x": 2, "replaced": True}, upsert=True), + DeleteOne({"x": 1}), + ] + result = self.coll.bulk_write(requests) + + # Check all properties + self.assertTrue(result.acknowledged) + self.assertEqual(1, result.inserted_count) + self.assertGreaterEqual(result.matched_count, 0) + self.assertGreaterEqual(result.modified_count, 0) + self.assertEqual(1, result.deleted_count) + self.assertIsInstance(result.upserted_count, int) + self.assertIsInstance(result.upserted_ids, dict) + + def test_bulk_write_with_upsert(self): + """Test bulk write upsert operations.""" + requests = [ + UpdateOne({"x": 1}, {"$set": {"y": 1}}, upsert=True), + UpdateOne({"x": 2}, {"$set": {"y": 2}}, upsert=True), + ReplaceOne({"x": 3}, {"x": 3, "y": 3}, upsert=True), + ] + result = self.coll.bulk_write(requests) + + self.assertEqual(3, result.upserted_count) + self.assertEqual(3, len(result.upserted_ids)) + + def test_update_one_with_hint(self): + """Test UpdateOne with hint parameter.""" + self.coll.create_index([("x", 1)]) + self.addCleanup(self.coll.drop_index, [("x", 1)]) + + self.coll.insert_one({"x": 1}) + + requests = [UpdateOne({"x": 1}, {"$set": {"y": 1}}, hint=[("x", 1)])] + result = self.coll.bulk_write(requests) + self.assertEqual(1, result.modified_count) + + def test_update_many_with_hint(self): + """Test UpdateMany with hint parameter.""" + self.coll.create_index([("x", 1)]) + self.addCleanup(self.coll.drop_index, [("x", 1)]) + + self.coll.insert_many([{"x": 1}, {"x": 1}]) + + requests = [UpdateMany({"x": 1}, {"$set": {"y": 1}}, hint=[("x", 1)])] + result = self.coll.bulk_write(requests) + self.assertEqual(2, result.modified_count) + + def test_delete_one_with_hint(self): + """Test DeleteOne with hint parameter.""" + self.coll.create_index([("x", 1)]) + self.addCleanup(self.coll.drop_index, [("x", 1)]) + + self.coll.insert_one({"x": 1}) + + requests = [DeleteOne({"x": 1}, hint=[("x", 1)])] + result = self.coll.bulk_write(requests) + self.assertEqual(1, result.deleted_count) + + def test_delete_many_with_hint(self): + """Test DeleteMany with hint parameter.""" + self.coll.create_index([("x", 1)]) + self.addCleanup(self.coll.drop_index, [("x", 1)]) + + self.coll.insert_many([{"x": 1}, {"x": 1}]) + + requests = [DeleteMany({"x": 1}, hint=[("x", 1)])] + result = self.coll.bulk_write(requests) + self.assertEqual(2, result.deleted_count) + + def test_update_one_with_array_filters(self): + """Test UpdateOne with array_filters parameter.""" + self.coll.insert_one({"x": [{"y": 1}, {"y": 2}, {"y": 3}]}) + + requests = [ + UpdateOne({}, {"$set": {"x.$[elem].z": 1}}, array_filters=[{"elem.y": {"$gt": 1}}]) + ] + result = self.coll.bulk_write(requests) + self.assertEqual(1, result.modified_count) + + doc = self.coll.find_one() + # Elements with y > 1 should have z = 1 + for elem in doc["x"]: + if elem["y"] > 1: + self.assertEqual(1, elem.get("z")) + + def test_replace_one_with_hint(self): + """Test ReplaceOne with hint parameter.""" + self.coll.create_index([("x", 1)]) + self.addCleanup(self.coll.drop_index, [("x", 1)]) + + self.coll.insert_one({"x": 1}) + + requests = [ReplaceOne({"x": 1}, {"x": 1, "replaced": True}, hint=[("x", 1)])] + result = self.coll.bulk_write(requests) + self.assertEqual(1, result.modified_count) + + def test_update_with_collation(self): + """Test update operations with collation.""" + self.coll.insert_many( + [ + {"name": "cafe"}, + {"name": "Cafe"}, + ] + ) + + requests = [ + UpdateMany( + {"name": "cafe"}, + {"$set": {"updated": True}}, + collation={"locale": "en", "strength": 2}, + ) + ] + result = self.coll.bulk_write(requests) + # With case-insensitive collation, both docs should match + self.assertEqual(2, result.modified_count) + class BulkAuthorizationTestBase(BulkTestBase): @client_context.require_auth diff --git a/test/test_change_stream.py b/test/test_change_stream.py index 792b39cc29..15fa76454f 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -1132,5 +1132,122 @@ def tearDown(self): ) +class TestChangeStreamCoverage(TestCollectionChangeStream): + """Additional tests to improve code coverage for ChangeStream.""" + + def test_change_stream_alive_property(self): + """Test alive property state transitions.""" + with self.change_stream() as cs: + self.assertTrue(cs.alive) + # After context exit, should be closed + self.assertFalse(cs.alive) + + def test_change_stream_idempotent_close(self): + """Test that close() can be called multiple times safely.""" + cs = self.change_stream() + cs.close() + # Second close should not raise + cs.close() + self.assertFalse(cs.alive) + + def test_change_stream_resume_token_deepcopy(self): + """Test that resume_token returns a deep copy.""" + coll = self.watched_collection() + with self.change_stream() as cs: + coll.insert_one({"x": 1}) + next(cs) # Consume the change event + token1 = cs.resume_token + token2 = cs.resume_token + # Should be equal but different objects + self.assertEqual(token1, token2) + self.assertIsNot(token1, token2) + + def test_change_stream_with_comment(self): + """Test change stream with comment parameter.""" + client, listener = self.client_with_listener("aggregate") + try: + with self.change_stream_with_client(client, comment="test_comment"): + pass + finally: + client.close() + + # Check that comment was in the aggregate command + self.assertGreater(len(listener.started_events), 0) + cmd = listener.started_events[0].command + self.assertEqual("test_comment", cmd.get("comment")) + + def test_change_stream_with_show_expanded_events(self): + """Test change stream with show_expanded_events parameter.""" + if not client_context.version.at_least(6, 0): + self.skipTest("show_expanded_events requires MongoDB 6.0+") + + with self.change_stream(show_expanded_events=True) as cs: + # Just verify it doesn't error + self.assertTrue(cs.alive) + + @client_context.require_version_min(6, 0) + def test_change_stream_with_full_document_before_change(self): + """Test change stream with full_document_before_change parameter.""" + coll = self.watched_collection() + # Need to ensure collection exists with changeStreamPreAndPostImages enabled + coll.drop() + self.db.create_collection(coll.name, changeStreamPreAndPostImages={"enabled": True}) + coll.insert_one({"x": 1}) + + with self.change_stream(full_document_before_change="whenAvailable") as cs: + coll.update_one({"x": 1}, {"$set": {"x": 2}}) + change = next(cs) + self.assertEqual("update", change["operationType"]) + # fullDocumentBeforeChange should be present + self.assertIn("fullDocumentBeforeChange", change) + + def test_change_stream_next_after_close(self): + """Test that next() on closed stream raises StopIteration.""" + cs = self.change_stream() + cs.close() + with self.assertRaises(StopIteration): + next(cs) + + def test_change_stream_try_next_after_close(self): + """Test that try_next() on closed stream raises StopIteration.""" + cs = self.change_stream() + cs.close() + with self.assertRaises(StopIteration): + cs.try_next() + + def test_change_stream_pipeline_construction(self): + """Test change stream pipeline is properly constructed.""" + pipeline = [{"$match": {"operationType": "insert"}}] + client, listener = self.client_with_listener("aggregate") + try: + with self.change_stream_with_client(client, pipeline=pipeline): + pass + finally: + client.close() + + cmd = listener.started_events[0].command + agg_pipeline = cmd["pipeline"] + # First stage should be $changeStream + self.assertIn("$changeStream", agg_pipeline[0]) + # Second stage should be our match + self.assertEqual({"$match": {"operationType": "insert"}}, agg_pipeline[1]) + + def test_change_stream_empty_pipeline(self): + """Test change stream with empty pipeline.""" + with self.change_stream(pipeline=[]) as cs: + self.assertTrue(cs.alive) + + def test_change_stream_context_manager_exception(self): + """Test change stream context manager closes on exception.""" + cs = None + try: + with self.change_stream() as cs: + raise ValueError("test exception") + except ValueError: + pass + # Stream should be closed + self.assertFalse(cs.alive) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_collection.py b/test/test_collection.py index ac469782e9..b3ee795b5e 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -2238,5 +2238,262 @@ def afind(*args, **kwargs): helper(*args, let={}) # type: ignore +class TestCollectionCoverage(IntegrationTest): + """Additional tests to improve code coverage for Collection.""" + + def setUp(self): + super().setUp() + self.db.test.drop() + self.db.test.insert_many([{"x": i, "y": i * 2} for i in range(10)]) + + def test_collection_full_name(self): + """Test full_name property.""" + expected = f"{self.db.name}.test" + self.assertEqual(expected, self.db.test.full_name) + + def test_collection_name(self): + """Test name property.""" + self.assertEqual("test", self.db.test.name) + + def test_collection_database(self): + """Test database property.""" + self.assertEqual(self.db, self.db.test.database) + + def test_collection_equality(self): + """Test collection equality.""" + coll1 = self.db.test + coll2 = self.db.test + coll3 = self.db.other + self.assertEqual(coll1, coll2) + self.assertNotEqual(coll1, coll3) + + def test_collection_hash(self): + """Test collection hashability.""" + coll1 = self.db.test + coll2 = self.db.test + # Same collection should have same hash + self.assertEqual(hash(coll1), hash(coll2)) + # Collections can be used in sets + s = {coll1, coll2} + self.assertEqual(1, len(s)) + + def test_collection_repr(self): + """Test collection repr.""" + coll = self.db.test + repr_str = repr(coll) + self.assertIn("test", repr_str) + self.assertIn("Collection", repr_str) + + def test_collection_getattr(self): + """Test sub-collection access via attribute.""" + subcoll = self.db.test.subcollection + self.assertEqual("test.subcollection", subcoll.name) + + def test_collection_getitem(self): + """Test sub-collection access via indexing.""" + subcoll = self.db.test["subcollection"] + self.assertEqual("test.subcollection", subcoll.name) + + def test_collection_with_options(self): + """Test with_options creates new collection with options.""" + from pymongo.read_concern import ReadConcern + from pymongo.write_concern import WriteConcern + + coll = self.db.test.with_options( + read_concern=ReadConcern("majority"), write_concern=WriteConcern(w=1) + ) + self.assertEqual("majority", coll.read_concern.level) + self.assertEqual({"w": 1}, coll.write_concern.document) + # Original should be unchanged + self.assertNotEqual("majority", self.db.test.read_concern.level) + + def test_collection_drop(self): + """Test collection drop.""" + self.db.test_drop.insert_one({"x": 1}) + self.db.test_drop.drop() + names = self.db.list_collection_names() + self.assertNotIn("test_drop", names) + + def test_collection_drop_with_comment(self): + """Test collection drop with comment.""" + self.db.test_drop_comment.insert_one({"x": 1}) + self.db.test_drop_comment.drop(comment="test_comment") + names = self.db.list_collection_names() + self.assertNotIn("test_drop_comment", names) + + def test_find_raw_batches(self): + """Test find_raw_batches returns raw BSON.""" + from bson import decode_all + + cursor = self.db.test.find_raw_batches(batch_size=5) + batch_count = 0 + for batch in cursor: + self.assertIsInstance(batch, bytes) + docs = decode_all(batch) + self.assertGreater(len(docs), 0) + batch_count += 1 + self.assertGreater(batch_count, 0) + + def test_aggregate_raw_batches(self): + """Test aggregate_raw_batches returns raw BSON.""" + from bson import decode_all + + cursor = self.db.test.aggregate_raw_batches([{"$sort": {"x": 1}}], batchSize=5) + batch_count = 0 + for batch in cursor: + self.assertIsInstance(batch, bytes) + docs = decode_all(batch) + self.assertGreater(len(docs), 0) + batch_count += 1 + self.assertGreater(batch_count, 0) + + def test_distinct_with_collation(self): + """Test distinct with collation.""" + self.db.test.drop() + self.db.test.insert_many( + [ + {"name": "abc"}, + {"name": "ABC"}, + {"name": "def"}, + ] + ) + # Case-insensitive distinct + values = self.db.test.distinct("name", collation={"locale": "en_US", "strength": 2}) + # abc and ABC should be considered the same + self.assertEqual(2, len(values)) + + def test_count_documents_with_options(self): + """Test count_documents with skip, limit, hint.""" + self.db.test.create_index([("x", 1)]) + + count = self.db.test.count_documents({"x": {"$gte": 0}}, skip=2, limit=5, hint=[("x", 1)]) + self.assertEqual(5, count) + + def test_estimated_document_count(self): + """Test estimated_document_count.""" + count = self.db.test.estimated_document_count() + self.assertEqual(10, count) + + def test_estimated_document_count_with_options(self): + """Test estimated_document_count with maxTimeMS and comment.""" + count = self.db.test.estimated_document_count(maxTimeMS=5000, comment="test_comment") + self.assertEqual(10, count) + + def test_find_one_and_delete_with_options(self): + """Test find_one_and_delete with projection, sort.""" + doc = self.db.test.find_one_and_delete( + {"x": {"$gte": 0}}, projection={"x": 1}, sort=[("x", -1)] + ) + self.assertEqual(9, doc["x"]) + self.assertNotIn("y", doc) + + def test_find_one_and_replace_with_options(self): + """Test find_one_and_replace with various options.""" + from pymongo import ReturnDocument + + doc = self.db.test.find_one_and_replace( + {"x": 0}, + {"x": 0, "replaced": True}, + projection={"x": 1, "replaced": 1}, + return_document=ReturnDocument.AFTER, + ) + self.assertEqual(0, doc["x"]) + self.assertTrue(doc.get("replaced")) + + def test_find_one_and_update_with_options(self): + """Test find_one_and_update with various options.""" + from pymongo import ReturnDocument + + doc = self.db.test.find_one_and_update( + {"x": 0}, + {"$set": {"updated": True}}, + projection={"x": 1, "updated": 1}, + return_document=ReturnDocument.AFTER, + ) + self.assertEqual(0, doc["x"]) + self.assertTrue(doc.get("updated")) + + def test_update_one_with_array_filters(self): + """Test update_one with array_filters.""" + self.db.test.drop() + self.db.test.insert_one({"items": [{"v": 1}, {"v": 2}, {"v": 3}]}) + + result = self.db.test.update_one( + {}, {"$set": {"items.$[elem].updated": True}}, array_filters=[{"elem.v": {"$gt": 1}}] + ) + self.assertEqual(1, result.modified_count) + + def test_update_many_with_hint(self): + """Test update_many with hint.""" + self.db.test.create_index([("x", 1)]) + + result = self.db.test.update_many( + {"x": {"$gte": 0}}, {"$set": {"batch_updated": True}}, hint=[("x", 1)] + ) + self.assertEqual(10, result.modified_count) + + def test_delete_one_with_hint(self): + """Test delete_one with hint.""" + self.db.test.create_index([("x", 1)]) + + result = self.db.test.delete_one({"x": 0}, hint=[("x", 1)]) + self.assertEqual(1, result.deleted_count) + + def test_delete_many_with_hint(self): + """Test delete_many with hint.""" + self.db.test.create_index([("x", 1)]) + + result = self.db.test.delete_many({"x": {"$lt": 5}}, hint=[("x", 1)]) + self.assertEqual(5, result.deleted_count) + + def test_aggregate_with_let(self): + """Test aggregate with let parameter.""" + if not client_context.version.at_least(5, 0): + self.skipTest("let parameter requires MongoDB 5.0+") + + pipeline = [{"$match": {"$expr": {"$eq": ["$x", "$$targetVal"]}}}] + cursor = self.db.test.aggregate(pipeline, let={"targetVal": 5}) + docs = cursor.to_list() + self.assertEqual(1, len(docs)) + self.assertEqual(5, docs[0]["x"]) + + def test_aggregate_with_batch_size(self): + """Test aggregate with batchSize.""" + cursor = self.db.test.aggregate([{"$sort": {"x": 1}}], batchSize=2) + docs = cursor.to_list() + self.assertEqual(10, len(docs)) + + def test_list_indexes(self): + """Test list_indexes returns cursor.""" + self.db.test.create_index([("x", 1)]) + cursor = self.db.test.list_indexes() + + # Should get at least the _id index + indexes = cursor.to_list() + self.assertGreaterEqual(len(indexes), 1) + index_names = [idx["name"] for idx in indexes] + self.assertIn("_id_", index_names) + + def test_index_information(self): + """Test index_information returns dict.""" + self.db.test.create_index([("x", 1)], name="x_index") + info = self.db.test.index_information() + + self.assertIsInstance(info, dict) + self.assertIn("_id_", info) + self.assertIn("x_index", info) + + def test_options_method(self): + """Test options() returns collection options.""" + # Create a capped collection + self.db.drop_collection("test_capped") + self.db.create_collection("test_capped", capped=True, size=10000) + + opts = self.db.test_capped.options() + self.assertTrue(opts.get("capped")) + + self.db.drop_collection("test_capped") + + if __name__ == "__main__": unittest.main() diff --git a/test/test_cursor.py b/test/test_cursor.py index b63638bfab..406f6618ac 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1853,5 +1853,404 @@ def test_exhaust_cursor_db_set(self): self.assertEqual(cmd.command["$db"], "pymongo_test") +class TestCursorCoverage(IntegrationTest): + """Additional tests to improve code coverage for Cursor.""" + + def setUp(self): + super().setUp() + self.db.test.drop() + self.db.test.insert_many([{"x": i, "y": i * 2} for i in range(10)]) + + def test_get_namespace(self): + """Test _get_namespace() method.""" + cursor = self.db.test.find() + expected_ns = f"{self.db.name}.test" + self.assertEqual(expected_ns, cursor._get_namespace()) + + def test_cursor_alive_property_states(self): + """Test cursor alive property in different states.""" + cursor = self.db.test.find() + # Cursor is alive even before starting (has potential to return data) + self.assertTrue(cursor.alive) + + # Start the cursor + next(cursor) + self.assertTrue(cursor.alive) + + # Exhaust the cursor + cursor.to_list() + self.assertFalse(cursor.alive) + + def test_cursor_closed_property(self): + """Test cursor behavior after close.""" + cursor = self.db.test.find() + next(cursor) + self.assertTrue(cursor.alive) + + cursor.close() + # After close, cursor is killed (check internal _killed flag) + self.assertTrue(cursor._killed) + + def test_retrieved_property(self): + """Test the retrieved property tracking.""" + cursor = self.db.test.find().batch_size(2) + self.assertEqual(0, cursor.retrieved) + + next(cursor) + self.assertGreater(cursor.retrieved, 0) + + def test_cursor_with_let_parameter(self): + """Test cursor with let parameter.""" + # let parameter allows variables to be used in the filter + cursor = self.db.test.find( + {"$expr": {"$eq": ["$x", "$$targetValue"]}}, let={"targetValue": 5} + ) + docs = cursor.to_list() + self.assertEqual(1, len(docs)) + self.assertEqual(5, docs[0]["x"]) + + def test_cursor_with_invalid_let_parameter(self): + """Test cursor raises error for invalid let parameter.""" + with self.assertRaises(TypeError): + self.db.test.find(let="invalid") # type: ignore[arg-type] + + def test_cursor_with_show_record_id(self): + """Test cursor with show_record_id option.""" + cursor = self.db.test.find(show_record_id=True) + doc = next(cursor) + self.assertIn("$recordId", doc) + + def test_cursor_with_return_key(self): + """Test cursor with return_key option.""" + self.db.test.create_index([("x", ASCENDING)]) + cursor = self.db.test.find({"x": 5}, return_key=True).hint([("x", ASCENDING)]) + doc = next(cursor) + # return_key returns only index keys + self.assertIn("x", doc) + self.assertNotIn("y", doc) + + def test_check_okay_to_chain_after_iteration(self): + """Test that cursor configuration methods raise after iteration.""" + cursor = self.db.test.find() + next(cursor) # Start iteration + + # All these should raise InvalidOperation + with self.assertRaises(InvalidOperation): + cursor.limit(5) + with self.assertRaises(InvalidOperation): + cursor.skip(2) + with self.assertRaises(InvalidOperation): + cursor.sort("x") + with self.assertRaises(InvalidOperation): + cursor.hint([("x", ASCENDING)]) + with self.assertRaises(InvalidOperation): + cursor.max([("x", 10)]) + with self.assertRaises(InvalidOperation): + cursor.min([("x", 0)]) + with self.assertRaises(InvalidOperation): + cursor.add_option(2) + with self.assertRaises(InvalidOperation): + cursor.remove_option(2) + with self.assertRaises(InvalidOperation): + cursor.batch_size(10) + with self.assertRaises(InvalidOperation): + cursor.max_time_ms(1000) + with self.assertRaises(InvalidOperation): + cursor.collation(Collation("en_US")) + with self.assertRaises(InvalidOperation): + cursor.allow_disk_use(True) + with self.assertRaises(InvalidOperation): + cursor.where("this.x > 5") + with self.assertRaises(InvalidOperation): + cursor.comment("test") + + def test_cursor_context_manager(self): + """Test cursor as async context manager.""" + with self.db.test.find() as cursor: + doc = next(cursor) + self.assertIsNotNone(doc) + # Cursor should be killed after context (check _killed flag) + self.assertTrue(cursor._killed) + + def test_cursor_context_manager_with_exception(self): + """Test cursor context manager closes on exception.""" + cursor = None + try: + with self.db.test.find() as cursor: + next(cursor) + raise ValueError("test exception") + except ValueError: + pass + # Cursor should be killed after exception + self.assertTrue(cursor._killed) + + def test_cursor_collation(self): + """Test cursor with collation.""" + self.db.test.drop() + self.db.test.insert_many([{"name": "abc"}, {"name": "ABC"}, {"name": "def"}]) + # Case-insensitive sort + cursor = ( + self.db.test.find().collation(Collation("en_US", strength=2)).sort("name", ASCENDING) + ) + docs = cursor.to_list() + self.assertEqual(3, len(docs)) + + def test_cursor_collation_type_error(self): + """Test cursor raises error for invalid collation.""" + with self.assertRaises(TypeError): + self.db.test.find().collation("invalid") # type: ignore[arg-type] + + def test_cursor_getitem_not_supported(self): + """Test that Cursor does not support indexing.""" + cursor = self.db.test.find() + with self.assertRaises(IndexError) as ctx: + cursor[5] + self.assertIn("does not support indexing", str(ctx.exception)) + + def test_cursor_next_after_close(self): + """Test that next() raises StopIteration after close.""" + cursor = self.db.test.find() + cursor.close() + with self.assertRaises(StopIteration): + next(cursor) + + def test_cursor_rewind_resets_state(self): + """Test that rewind properly resets cursor state.""" + cursor = self.db.test.find().limit(3) + + # Iterate fully + docs1 = cursor.to_list() + self.assertEqual(3, len(docs1)) + self.assertEqual(0, len(cursor._data)) + + # Rewind and iterate again + cursor.rewind() + docs2 = cursor.to_list() + self.assertEqual(3, len(docs2)) + self.assertEqual(docs1, docs2) + + def test_cursor_clone_with_session(self): + """Test that clone preserves explicit session.""" + with self.client.start_session() as session: + cursor = self.db.test.find(session=session) + cloned = cursor.clone() + # Clone should reference the same session + self.assertEqual(cursor.session, cloned.session) + + def test_cursor_clone_without_session(self): + """Test that clone without session doesn't add one.""" + cursor = self.db.test.find() + cloned = cursor.clone() + # Clone should have no session if original had none + self.assertIsNone(cloned.session) + + def test_cursor_distinct_with_collation(self): + """Test distinct with collation.""" + self.db.test.drop() + self.db.test.insert_many([{"name": "abc"}, {"name": "ABC"}, {"name": "def"}]) + # Case-insensitive distinct + cursor = self.db.test.find().collation(Collation("en_US", strength=2)) + # distinct() on cursor with collation + values = cursor.distinct("name") + # Should have 2 distinct values (abc/ABC treated as same) + self.assertEqual(2, len(values)) + + def test_cursor_explain_with_options(self): + """Test explain with cursor options set.""" + cursor = self.db.test.find({"x": {"$gt": 5}}).sort("x", ASCENDING).limit(5).skip(1) + explanation = cursor.explain() + self.assertIn("queryPlanner", explanation) + + def test_cursor_max_time_ms_type_errors(self): + """Test max_time_ms raises TypeError for invalid input.""" + cursor = self.db.test.find() + with self.assertRaises(TypeError): + cursor.max_time_ms("invalid") # type: ignore[arg-type] + + def test_cursor_max_await_time_ms_type_errors(self): + """Test max_await_time_ms raises TypeError for invalid input.""" + cursor = self.db.test.find() + with self.assertRaises(TypeError): + cursor.max_await_time_ms("invalid") # type: ignore[arg-type] + + def test_cursor_comment_type(self): + """Test cursor with comment of various types.""" + # String comment + cursor1 = self.db.test.find().comment("test comment") + docs1 = cursor1.to_list() + self.assertGreater(len(docs1), 0) + + # Dict comment + cursor2 = self.db.test.find().comment({"key": "value"}) + docs2 = cursor2.to_list() + self.assertGreater(len(docs2), 0) + + def test_cursor_batch_size_validation(self): + """Test batch_size validation.""" + with self.assertRaises(TypeError): + self.db.test.find(batch_size="invalid") # type: ignore[arg-type] + with self.assertRaises(ValueError): + self.db.test.find(batch_size=-1) + + def test_cursor_skip_validation(self): + """Test skip validation.""" + with self.assertRaises(TypeError): + self.db.test.find(skip="invalid") # type: ignore[arg-type] + + def test_cursor_limit_validation(self): + """Test limit validation.""" + with self.assertRaises(TypeError): + self.db.test.find(limit="invalid") # type: ignore[arg-type] + + def test_cursor_filter_validation(self): + """Test filter validation.""" + with self.assertRaises(TypeError): + self.db.test.find(filter="invalid") # type: ignore[arg-type] + + def test_cursor_type_validation(self): + """Test cursor_type validation.""" + with self.assertRaises(ValueError): + self.db.test.find(cursor_type=999) + + def test_cursor_query_spec_with_modifiers(self): + """Test _query_spec includes modifiers.""" + cursor = ( + self.db.test.find() + .sort("x", ASCENDING) + .hint([("x", ASCENDING)]) + .max_time_ms(1000) + .comment("test") + ) + spec = cursor._query_spec() + self.assertIsInstance(spec, dict) + + def test_cursor_copy(self): + """Test cursor __copy__ returns clone.""" + cursor = self.db.test.find().limit(5) + copied = copy.copy(cursor) + self.assertIsNot(cursor, copied) + self.assertEqual(cursor._limit, copied._limit) + + def test_cursor_deepcopy(self): + """Test cursor __deepcopy__ returns deep clone.""" + cursor = self.db.test.find({"x": {"$gt": 0}}).limit(5) + copied = copy.deepcopy(cursor) + self.assertIsNot(cursor, copied) + self.assertEqual(cursor._limit, copied._limit) + self.assertEqual(cursor._spec, copied._spec) + # Spec should be a different object + self.assertIsNot(cursor._spec, copied._spec) + + def test_cursor_iteration_protocol(self): + """Test cursor async iteration protocol.""" + cursor = self.db.test.find().limit(3) + + # Test __iter__ returns self + self.assertIs(cursor, cursor.__iter__()) + + # Test __next__ returns documents + doc1 = cursor.__next__() + self.assertIsNotNone(doc1) + + def test_cursor_to_list_with_limit(self): + """Test to_list respects cursor limit.""" + cursor = self.db.test.find().limit(3) + docs = cursor.to_list() + self.assertEqual(3, len(docs)) + + def test_cursor_to_list_with_length(self): + """Test to_list with length parameter.""" + cursor = self.db.test.find() + docs = cursor.to_list(length=3) + self.assertEqual(3, len(docs)) + + def test_min_max_require_hint(self): + """Test that min/max require hint for proper execution.""" + self.db.test.create_index([("x", ASCENDING)]) + + # min without hint should work when index exists + cursor = self.db.test.find().min([("x", 5)]).hint([("x", ASCENDING)]) + docs = cursor.to_list() + self.assertTrue(all(doc["x"] >= 5 for doc in docs)) + + # max without hint should work when index exists + cursor = self.db.test.find().max([("x", 5)]).hint([("x", ASCENDING)]) + docs = cursor.to_list() + self.assertTrue(all(doc["x"] < 5 for doc in docs)) + + def test_cursor_address_property(self): + """Test cursor address is set after first batch.""" + cursor = self.db.test.find() + self.assertIsNone(cursor.address) + next(cursor) + # Address should be set after query + self.assertIsNotNone(cursor.address) + + def test_cursor_session_property(self): + """Test cursor session property.""" + # Cursor without explicit session + cursor1 = self.db.test.find() + self.assertIsNone(cursor1.session) + + # Cursor with explicit session + with self.client.start_session() as session: + cursor2 = self.db.test.find(session=session) + self.assertEqual(session, cursor2.session) + + def test_cursor_allow_disk_use_type_error(self): + """Test allow_disk_use raises TypeError for invalid input.""" + with self.assertRaises(TypeError): + self.db.test.find().allow_disk_use("invalid") # type: ignore[arg-type] + + +class TestRawBatchCursorCoverage(IntegrationTest): + """Additional tests for RawBatchCursor coverage.""" + + def setUp(self): + super().setUp() + self.db.test.drop() + self.db.test.insert_many([{"x": i} for i in range(20)]) + + def test_raw_batch_cursor_iteration(self): + """Test raw batch cursor returns raw BSON.""" + cursor = self.db.test.find_raw_batches(batch_size=5) + batch_count = 0 + for batch in cursor: + self.assertIsInstance(batch, bytes) + # Decode the batch to verify it's valid BSON + docs = decode_all(batch) + self.assertGreater(len(docs), 0) + batch_count += 1 + self.assertGreater(batch_count, 0) + + def test_raw_batch_cursor_explain(self): + """Test raw batch cursor explain.""" + cursor = self.db.test.find_raw_batches() + explanation = cursor.explain() + self.assertIn("queryPlanner", explanation) + + def test_raw_batch_cursor_getitem_raises(self): + """Test raw batch cursor __getitem__ raises InvalidOperation.""" + cursor = self.db.test.find_raw_batches() + with self.assertRaises(InvalidOperation): + cursor[0] + + def test_raw_batch_cursor_with_sort(self): + """Test raw batch cursor with sort.""" + cursor = self.db.test.find_raw_batches(batch_size=5).sort("x", DESCENDING) + first_batch = next(cursor) + docs = decode_all(first_batch) + # First doc should have highest x value + self.assertEqual(19, docs[0]["x"]) + + def test_raw_batch_cursor_with_limit(self): + """Test raw batch cursor with limit.""" + cursor = self.db.test.find_raw_batches(batch_size=5).limit(7) + all_docs = [] + for batch in cursor: + all_docs.extend(decode_all(batch)) + self.assertEqual(7, len(all_docs)) + + if __name__ == "__main__": unittest.main() diff --git a/test/test_session.py b/test/test_session.py index 3963f88da0..b45cd45ea0 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -50,9 +50,11 @@ from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure from pymongo.operations import IndexModel, InsertOne, UpdateOne from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.cursor import Cursor from pymongo.synchronous.helpers import next +from pymongo.write_concern import WriteConcern _IS_SYNC = True @@ -1345,5 +1347,284 @@ def next_heartbeat(): self.assertEqual(started.command["$clusterTime"], cluster_time) +class TestClientSessionCoverage(IntegrationTest): + """Additional tests to improve code coverage for ClientSession.""" + + @client_context.require_sessions + def test_session_has_ended_property(self): + """Test has_ended property state transitions.""" + session = self.client.start_session() + self.assertFalse(session.has_ended) + session.end_session() + self.assertTrue(session.has_ended) + + @client_context.require_sessions + def test_session_session_id_property(self): + """Test session_id property returns correct value.""" + with self.client.start_session() as session: + session_id = session.session_id + self.assertIsInstance(session_id, dict) + self.assertIn("id", session_id) + + @client_context.require_sessions + def test_session_cluster_time_operations(self): + """Test cluster time advance operations.""" + with self.client.start_session() as session: + # Initially None + self.assertIsNone(session.cluster_time) + + # Perform operation to get cluster time + self.db.test.find_one({}, session=session) + + # Cluster time should be set after operation + # (may still be None on some server versions) + + @client_context.require_sessions + def test_session_operation_time_operations(self): + """Test operation time advance operations.""" + with self.client.start_session() as session: + # Initially None + self.assertIsNone(session.operation_time) + + # Perform operation to get operation time + self.db.test.find_one({}, session=session) + + @client_context.require_sessions + def test_session_options_property(self): + """Test session options property.""" + with self.client.start_session(causal_consistency=True) as session: + self.assertTrue(session.options.causal_consistency) + + @client_context.require_sessions + def test_session_client_property(self): + """Test session client property.""" + with self.client.start_session() as session: + self.assertEqual(self.client, session.client) + + @client_context.require_sessions + def test_session_in_transaction_property(self): + """Test in_transaction property.""" + if client_context.is_rs or client_context.is_mongos: + with self.client.start_session() as session: + self.assertFalse(session.in_transaction) + session.start_transaction() + self.assertTrue(session.in_transaction) + session.abort_transaction() + self.assertFalse(session.in_transaction) + + @client_context.require_sessions + def test_session_context_manager(self): + """Test session async context manager.""" + with self.client.start_session() as session: + self.assertFalse(session.has_ended) + self.db.test.find_one({}, session=session) + self.assertTrue(session.has_ended) + + @client_context.require_sessions + def test_session_context_manager_exception(self): + """Test session context manager closes on exception.""" + session = None + try: + with self.client.start_session() as session: + raise ValueError("test exception") + except ValueError: + pass + self.assertTrue(session.has_ended) + + @client_context.require_sessions + def test_session_operations_after_end(self): + """Test operations on ended session raise InvalidOperation.""" + session = self.client.start_session() + session.end_session() + + with self.assertRaises(InvalidOperation): + self.db.test.find_one({}, session=session) + + @client_context.require_sessions + def test_session_end_session_idempotent(self): + """Test that end_session can be called multiple times.""" + session = self.client.start_session() + session.end_session() + # Second call should not raise + session.end_session() + self.assertTrue(session.has_ended) + + @client_context.require_transactions + def test_transaction_start_without_prior_transaction(self): + """Test start_transaction on fresh session.""" + with self.client.start_session() as session: + session.start_transaction() + self.assertTrue(session.in_transaction) + session.abort_transaction() + + @client_context.require_transactions + def test_transaction_start_twice_raises(self): + """Test starting transaction twice raises error.""" + with self.client.start_session() as session: + session.start_transaction() + with self.assertRaises(InvalidOperation): + session.start_transaction() + session.abort_transaction() + + @client_context.require_transactions + def test_transaction_abort_without_transaction_raises(self): + """Test aborting without transaction raises error.""" + with self.client.start_session() as session: + with self.assertRaises(InvalidOperation): + session.abort_transaction() + + @client_context.require_transactions + def test_transaction_commit_without_transaction_raises(self): + """Test committing without transaction raises error.""" + with self.client.start_session() as session: + with self.assertRaises(InvalidOperation): + session.commit_transaction() + + @client_context.require_sessions + def test_session_advance_cluster_time_validation(self): + """Test advance_cluster_time with invalid input.""" + with self.client.start_session() as session: + with self.assertRaises(TypeError): + session.advance_cluster_time("invalid") # type: ignore + with self.assertRaises(ValueError): + session.advance_cluster_time({}) + + @client_context.require_sessions + def test_session_advance_operation_time_validation(self): + """Test advance_operation_time with invalid input.""" + from bson import Timestamp + + with self.client.start_session() as session: + with self.assertRaises(TypeError): + session.advance_operation_time("invalid") # type: ignore + # Valid Timestamp should work + session.advance_operation_time(Timestamp(1, 1)) + + @client_context.require_transactions + def test_with_transaction_callback_success(self): + """Test with_transaction with successful callback.""" + with self.client.start_session() as session: + + def callback(session): + self.db.test.insert_one({"x": 1}, session=session) + return "success" + + result = session.with_transaction(callback) + self.assertEqual("success", result) + + @client_context.require_transactions + def test_with_transaction_callback_exception(self): + """Test with_transaction with callback exception.""" + with self.client.start_session() as session: + + def callback(session): + self.db.test.insert_one({"x": 1}, session=session) + raise ValueError("callback error") + + with self.assertRaises(ValueError): + session.with_transaction(callback) + # Transaction should be aborted + self.assertFalse(session.in_transaction) + + +class TestSessionOptionsCoverage(UnitTest): + """Tests for SessionOptions coverage.""" + + def test_session_options_defaults(self): + """Test SessionOptions default values.""" + from pymongo.synchronous.client_session import SessionOptions + + options = SessionOptions() + self.assertTrue(options.causal_consistency) + self.assertIsNone(options.default_transaction_options) + self.assertFalse(options.snapshot) + + def test_session_options_snapshot_disables_causal_consistency(self): + """Test snapshot=True forces causal_consistency=False.""" + from pymongo.synchronous.client_session import SessionOptions + + options = SessionOptions(snapshot=True) + self.assertFalse(options.causal_consistency) + self.assertTrue(options.snapshot) + + def test_session_options_snapshot_with_causal_raises(self): + """Test snapshot=True with causal_consistency=True raises error.""" + from pymongo.synchronous.client_session import SessionOptions + + with self.assertRaises(ConfigurationError): + SessionOptions(snapshot=True, causal_consistency=True) + + def test_session_options_invalid_transaction_options(self): + """Test SessionOptions with invalid transaction options type.""" + from pymongo.synchronous.client_session import SessionOptions + + with self.assertRaises(TypeError): + SessionOptions(default_transaction_options="invalid") # type: ignore + + +class TestTransactionOptionsCoverage(UnitTest): + """Tests for TransactionOptions coverage.""" + + def test_transaction_options_defaults(self): + """Test TransactionOptions default values.""" + from pymongo.synchronous.client_session import TransactionOptions + + options = TransactionOptions() + self.assertIsNone(options.read_concern) + self.assertIsNone(options.write_concern) + self.assertIsNone(options.read_preference) + self.assertIsNone(options.max_commit_time_ms) + + def test_transaction_options_with_values(self): + """Test TransactionOptions with all values set.""" + from pymongo.synchronous.client_session import TransactionOptions + + options = TransactionOptions( + read_concern=ReadConcern("majority"), + write_concern=WriteConcern(w="majority"), + read_preference=ReadPreference.PRIMARY, + max_commit_time_ms=5000, + ) + self.assertEqual("majority", options.read_concern.level) + self.assertEqual("majority", options.write_concern.document.get("w")) + self.assertEqual(ReadPreference.PRIMARY, options.read_preference) + self.assertEqual(5000, options.max_commit_time_ms) + + def test_transaction_options_invalid_read_concern(self): + """Test TransactionOptions with invalid read_concern type.""" + from pymongo.synchronous.client_session import TransactionOptions + + with self.assertRaises(TypeError): + TransactionOptions(read_concern="invalid") # type: ignore + + def test_transaction_options_invalid_write_concern(self): + """Test TransactionOptions with invalid write_concern type.""" + from pymongo.synchronous.client_session import TransactionOptions + + with self.assertRaises(TypeError): + TransactionOptions(write_concern="invalid") # type: ignore + + def test_transaction_options_invalid_read_preference(self): + """Test TransactionOptions with invalid read_preference type.""" + from pymongo.synchronous.client_session import TransactionOptions + + with self.assertRaises(TypeError): + TransactionOptions(read_preference="invalid") # type: ignore + + def test_transaction_options_invalid_max_commit_time(self): + """Test TransactionOptions with invalid max_commit_time_ms type.""" + from pymongo.synchronous.client_session import TransactionOptions + + with self.assertRaises(TypeError): + TransactionOptions(max_commit_time_ms="invalid") # type: ignore + + def test_transaction_options_unacknowledged_write_concern(self): + """Test TransactionOptions rejects unacknowledged write concern.""" + from pymongo.synchronous.client_session import TransactionOptions + + with self.assertRaises(ConfigurationError): + TransactionOptions(write_concern=WriteConcern(w=0)) + + if __name__ == "__main__": unittest.main()