From fede2a848e3985f1655aaa34235e5f8dc12572a0 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Thu, 26 Feb 2026 18:21:48 +0000 Subject: [PATCH 1/2] Destination DuckDB should not have access to parquet files --- datafaker/create.py | 21 ++- datafaker/utils.py | 29 +++ tests/test_create.py | 171 +++++++++++++++++- tests/test_dump.py | 1 + tests/test_interactive_generators.py | 21 +-- ...test_interactive_generators_partitioned.py | 18 +- tests/test_make.py | 2 +- tests/utils.py | 47 +++-- 8 files changed, 263 insertions(+), 47 deletions(-) diff --git a/datafaker/create.py b/datafaker/create.py index bfd9c9df..b3aef3f3 100644 --- a/datafaker/create.py +++ b/datafaker/create.py @@ -13,7 +13,7 @@ from datafaker.base import FileUploader, TableGenerator from datafaker.settings import get_destination_dsn, get_destination_schema from datafaker.utils import ( - create_db_engine, + create_db_engine_dst, get_sync_engine, get_vocabulary_table_names, logger, @@ -61,9 +61,15 @@ def remove_on_delete_cascade(element: CreateTable, compiler: Any, **kw: Any) -> def create_db_tables(metadata: MetaData) -> None: """Create tables described by the sqlalchemy metadata object.""" dst_dsn = get_destination_dsn() - engine = get_sync_engine(create_db_engine(dst_dsn)) - schema_name = get_destination_schema() + assert dst_dsn != "", "Missing DST_DSN setting." + create_db_tables_into(metadata, dst_dsn, get_destination_schema()) + +def create_db_tables_into( + metadata: MetaData, dst_dsn: str, schema_name: str | None = None +) -> None: + """Create tables described by the sqlalchemy metadata object with explicit DSN.""" + engine = get_sync_engine(create_db_engine_dst(dst_dsn)) # Create schema, if necessary. if schema_name is not None: with engine.connect() as connection: @@ -75,9 +81,11 @@ def create_db_tables(metadata: MetaData) -> None: connection.commit() # Recreate the engine, this time with a schema specified - engine = get_sync_engine(create_db_engine(dst_dsn, schema_name=schema_name)) + engine.dispose() + engine = get_sync_engine(create_db_engine_dst(dst_dsn, schema_name=schema_name)) metadata.create_all(engine) + engine.dispose() def create_db_vocab( @@ -95,7 +103,7 @@ def create_db_vocab( :return: List of table names loaded. """ dst_engine = get_sync_engine( - create_db_engine( + create_db_engine_dst( get_destination_dsn(), schema_name=get_destination_schema(), ) @@ -165,7 +173,7 @@ def create_db_data_into( :param db_dsn: Connection string for the destination database. :param schema_name: Destination schema name. """ - dst_engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name)) + dst_engine = get_sync_engine(create_db_engine_dst(db_dsn, schema_name=schema_name)) row_counts: Counter[str] = Counter() with dst_engine.connect() as dst_conn: @@ -177,6 +185,7 @@ def create_db_data_into( df_module.story_generator_list, metadata, ) + dst_engine.dispose() return row_counts diff --git a/datafaker/utils.py b/datafaker/utils.py index a216ae9f..ae36ad03 100644 --- a/datafaker/utils.py +++ b/datafaker/utils.py @@ -207,6 +207,35 @@ def connect(dbapi_connection: DBAPIConnection, _: Any) -> None: return engine +def create_db_engine_dst( + db_dsn: str, + schema_name: Optional[str] = None, + use_asyncio: bool = False, +) -> MaybeAsyncEngine: + """ + Create a SQLAlchemy Engine suitable for output. + + This prevents DuckDB from reading any parquet files avoiding any + possible leakage from existing source files into the destination database. + :param db_dsn: The database connection string. + :param schema_name: The name of the schema within the database to use. + :param use_asyncio: True if an asynchronous connection is required. + :return: The ``Engine`` or ``AsyncEngine``. + """ + if db_dsn.startswith("duckdb:"): + return create_db_engine( + db_dsn, + schema_name, + use_asyncio, + connect_args={ + "config": { + "enable_external_access": False, + } + }, + ) + return create_db_engine(db_dsn, schema_name, use_asyncio) + + def set_search_path(connection: DBAPIConnection, schema: str) -> None: """Set the SEARCH_PATH for a PostgreSQL connection.""" # https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#remote-schema-table-introspection-and-postgresql-search-path diff --git a/tests/test_create.py b/tests/test_create.py index 933fbe26..face6894 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -2,17 +2,24 @@ import itertools as itt import os import random +import tempfile from collections import Counter from pathlib import Path from typing import Any, Generator, Mapping, Tuple from unittest.mock import MagicMock, call, patch -from sqlalchemy import Connection, select +import duckdb +import pandas as pd +from sqlalchemy import Connection, Engine, select from sqlalchemy.schema import MetaData, Table from datafaker.base import TableGenerator -from datafaker.create import create_db_vocab, populate -from datafaker.remove import remove_db_vocab +from datafaker.create import ( + create_db_data_into, + create_db_tables, + create_db_vocab, + populate, +) from datafaker.serialize_metadata import metadata_to_dict from tests.utils import DatafakerTestCase, GeneratesDBTestCase @@ -28,7 +35,7 @@ def test_create_vocab(self) -> None: """Test the create_db_vocab function.""" with patch.dict( os.environ, - {"DST_DSN": self.dsn, "DST_SCHEMA": self.schema_name}, + {"DST_DSN": self.dst_dsn}, clear=True, ): config = { @@ -42,10 +49,9 @@ def test_create_vocab(self) -> None: meta_dict = metadata_to_dict( self.metadata, self.schema_name, self.sync_engine ) - self.remove_data(config) - remove_db_vocab(self.metadata, meta_dict, config) + create_db_tables(self.metadata) create_db_vocab(self.metadata, meta_dict, config, Path("./tests/examples")) - with self.sync_engine.connect() as conn: + with self.dst_sync_engine.connect() as conn: stmt = select(self.metadata.tables["player"]) rows = list(conn.execute(stmt).mappings().fetchall()) self.assertEqual(len(rows), 3) @@ -64,7 +70,7 @@ def test_make_table_generators(self) -> None: random.seed(56) config: Mapping[str, Any] = {} self.generate_data(config, num_passes=2) - with self.sync_engine.connect() as conn: + with self.dst_sync_engine.connect() as conn: stmt = select(self.metadata.tables["string"]) rows = list(conn.execute(stmt).mappings().fetchall()) a = rows[0] @@ -183,3 +189,152 @@ def test_populate_diff_length(self, mock_insert: MagicMock) -> None: mock_gen_two.assert_called_once() mock_gen_three.assert_called_once() + + +class MockFunctionUsingConnection: + """Base mock callable that should not be permitted to read parquet files.""" + + @classmethod + def is_parquet_permitted(cls, connection: Any) -> bool: + """Test if a normal DuckDB can access the ``fruit.parquet`` file.""" + try: + connection.execute("SELECT * FROM fruit.parquet") + except duckdb.PermissionException: + return False + return True + + def __init__(self) -> None: + """Initialize as uncalled.""" + self.called = False + + def do_call(self, connection: Any) -> None: + """Test for parquet access not being permitted.""" + assert not self.is_parquet_permitted(connection) + self.called = True + + +class CreateReadsNoParquetTestCase(DatafakerTestCase): + """ + Output to the database should not have access to parquet files. + + Otherwise there is a risk of leakage of source data. + """ + + examples_dir = Path("tests/examples/duckdb") + parquet_name = "fruit.parquet" + + def setUp(self) -> None: + """Go to the directory where there are parquet files.""" + super().setUp() + self.start_dir = os.getcwd() + self.parquet_dir = Path(tempfile.mkdtemp("parq")) + os.chdir(self.parquet_dir) + self.write_parquet() + assert MockFunctionUsingConnection.is_parquet_permitted(duckdb.connect()) + + def tearDown(self) -> None: + """Return to the start directory.""" + os.chdir(self.start_dir) + return super().tearDown() + + def write_parquet(self) -> None: + """Write a parquet file into the current directory.""" + fruit: dict[str, list[Any]] = { + "id": [1, 2, 3], + "orange": [True, True, False], + "banana": ["one", "two", "three"], + } + pd.DataFrame.from_dict(fruit).to_parquet(self.parquet_name) + + class MockCreateAll(MockFunctionUsingConnection): + """Mock for the MetaData.create_all function.""" + + def __call__(self, engine: Engine) -> None: + self.do_call(engine.raw_connection()) + + def test_create_db_tables_cannot_access_parquet(self) -> None: + """Test the database connection cannot access parquet file.""" + meta_data = MagicMock() + meta_data.create_all = self.MockCreateAll() + with patch.dict( + os.environ, + {"DST_DSN": "duckdb:///:memory:tables"}, + clear=True, + ): + create_db_tables(meta_data) + assert meta_data.create_all.called + + def test_create_db_tables_cannot_access_parquet_with_schema(self) -> None: + """ + Test the database connection cannot access parquet file. + + We use a schema because this activates a different code path. + """ + meta_data = MagicMock() + meta_data.create_all = self.MockCreateAll() + testdb = duckdb.connect("./test.db") + testdb.execute("CREATE SCHEMA fruity") + testdb.close() + with patch.dict( + os.environ, + {"DST_SCHEMA": "fruity", "DST_DSN": "duckdb:///./test.db"}, + clear=True, + ): + create_db_tables(meta_data) + assert meta_data.create_all.called + + @patch("datafaker.create.populate") + def test_create_db_data_cannot_access_parquet( + self, mock_populate: MagicMock + ) -> None: + """Test the database connection cannot access parquet file while creating data.""" + + class MockPopulate(MockFunctionUsingConnection): + """Mock ``populate`` function.""" + + def __call__( + self, connection: Connection, _a2: Any, _a3: Any, _a4: Any, _a5: Any + ) -> dict[str, Any]: + super().do_call(connection.connection.dbapi_connection) + return {"vocab1": 1} + + mock_populate.side_effect = MockPopulate() + create_db_data_into( + [MagicMock()], + MagicMock(), + 1, + "duckdb:///:memory:data", + None, + MagicMock(), + ) + assert mock_populate.side_effect.called + + @patch("datafaker.create.FileUploader") + def test_create_db_vocab_cannot_access_parquet( + self, file_uploader: MagicMock + ) -> None: + """Test we cannot access parquet file while populating vocabulary tables.""" + + class MockLoader(MockFunctionUsingConnection): + """Mock ``FileUploader.load`` function.""" + + def __call__(self, connection: Connection, base_path: Path) -> None: + assert str(base_path) == "base" + super().do_call(connection.connection.dbapi_connection) + + file_uploader.return_value.load = MockLoader() + assert not file_uploader.return_value.load.called + meta_data = MetaData() + Table("table1", meta_data) + with patch.dict( + os.environ, + {"DST_DSN": "duckdb:///:memory:vocab"}, + clear=True, + ): + create_db_vocab( + meta_data, + {"tables": {"table1": {"columns": {}}}}, + {"tables": {"table1": {"vocabulary_table": True}}}, + base_path=Path("base"), + ) + assert file_uploader.return_value.load.called diff --git a/tests/test_dump.py b/tests/test_dump.py index ad0cecaa..7663bc71 100644 --- a/tests/test_dump.py +++ b/tests/test_dump.py @@ -146,6 +146,7 @@ def test_end_to_end_parquet(self) -> None: # Dump the fake tables outdir = Path(tempfile.mkdtemp("dump")) result = runner.invoke(app, ["dump-data", "--output", str(outdir), "--parquet"]) + print(result) self.assertSuccess(result) # Check the dumped files diff --git a/tests/test_interactive_generators.py b/tests/test_interactive_generators.py index b7f0ed6d..807a6950 100644 --- a/tests/test_interactive_generators.py +++ b/tests/test_interactive_generators.py @@ -636,7 +636,7 @@ def test_create_with_sampled_choice(self) -> None: gc.do_quit("") self.generate_data(gc.config, num_passes=200) # all generation possibilities should be present - with self.sync_engine.connect() as conn: + with self.dst_sync_engine.connect() as conn: stats = ChoiceMeasurementTableStats(self.metadata, conn) self.assertSetEqual(stats.ones, {1, 4}) self.assertSetEqual(stats.twos, {2, 3}) @@ -654,7 +654,7 @@ def test_create_with_choice(self) -> None: gc.do_set(str(proposals["dist_gen.zipf_choice"][0])) gc.do_quit("") self.generate_data(gc.config, num_passes=200) - with self.sync_engine.connect() as conn: + with self.dst_sync_engine.connect() as conn: stmt = select(self.metadata.tables[table_name]) rows = conn.execute(stmt).fetchall() ones = set() @@ -733,13 +733,12 @@ def test_create_with_weighted_choice(self) -> None: gc.do_set(str(prop[0])) gc.do_quit("") self.generate_data(gc.config, num_passes=200) - with self.sync_engine.connect() as conn: - with self.sync_engine.connect() as conn: - stats = ChoiceMeasurementTableStats(self.metadata, conn) - # all generation possibilities should be present - self.assertSetEqual(stats.ones, {1, 4}) - self.assertSetEqual(stats.twos, {1, 2, 3, 4, 5}) - self.assertSetEqual(stats.threes, {1, 2, 3, 4, 5}) + with self.dst_sync_engine.connect() as conn: + stats = ChoiceMeasurementTableStats(self.metadata, conn) + # all generation possibilities should be present + self.assertSetEqual(stats.ones, {1, 4}) + self.assertSetEqual(stats.twos, {1, 2, 3, 4, 5}) + self.assertSetEqual(stats.threes, {1, 2, 3, 4, 5}) class GeneratorsOutputTestsDuckDb(GeneratorsOutputTests): @@ -780,7 +779,7 @@ def test_set_null(self) -> None: config = gc.config self.generate_data(config, num_passes=3) # Test that each missingness pattern is present in the database - with self.sync_engine.connect() as conn: + with self.dst_sync_engine.connect() as conn: # select(self.metadata.tables["string"].c["position", "frequency"]) would be nicer # but mypy doesn't like it stmt = select( @@ -866,7 +865,7 @@ def test_varchar_ns_are_truncated(self) -> None: gc.do_quit("") config = gc.config self.generate_data(config, num_passes=15) - with self.sync_engine.connect() as conn: + with self.dst_sync_engine.connect() as conn: stmt = select(self.metadata.tables[table].c[column]) rows = conn.execute(stmt).scalars().fetchall() self.assert_are_truncated_to(rows, 20) diff --git a/tests/test_interactive_generators_partitioned.py b/tests/test_interactive_generators_partitioned.py index b1357d82..d5319165 100644 --- a/tests/test_interactive_generators_partitioned.py +++ b/tests/test_interactive_generators_partitioned.py @@ -176,10 +176,10 @@ def test_create_with_null_partitioned_grouped_multivariate(self) -> None: self.set_configuration(gc.config) self.get_src_stats(gc.config) self.create_generators(gc.config) - self.remove_data(gc.config) + self.create_tables() self.populate_measurement_type_vocab() self.create_data(gc.config, num_passes=generate_count) - with self.sync_engine.connect() as conn: + with self.dst_sync_engine.connect() as conn: stats = EavMeasurementTableStats(conn, self.metadata, self) # type 1 self.assertAlmostEqual( @@ -224,7 +224,7 @@ def test_create_with_null_partitioned_grouped_multivariate(self) -> None: def populate_measurement_type_vocab(self) -> None: """Add a vocab table without messing around with files""" table = self.metadata.tables["measurement_type"] - with self.sync_engine.connect() as conn: + with self.dst_sync_engine.connect() as conn: conn.execute(insert(table).values({"id": 1, "name": "agreement"})) conn.execute(insert(table).values({"id": 2, "name": "acceleration"})) conn.execute(insert(table).values({"id": 3, "name": "velocity"})) @@ -291,10 +291,10 @@ def test_create_with_null_partitioned_grouped_sampled_and_suppressed(self) -> No self.set_configuration(gc.config) self.get_src_stats(gc.config) self.create_generators(gc.config) - self.remove_data(gc.config) + self.create_tables() self.populate_measurement_type_vocab() self.create_data(gc.config, num_passes=generate_count) - with self.sync_engine.connect() as conn: + with self.dst_sync_engine.connect() as conn: stats = EavMeasurementTableStats(conn, self.metadata, self) stmt = select(self.metadata.tables["observation"]) rows = conn.execute(stmt).fetchall() @@ -406,10 +406,10 @@ def test_create_with_null_partitioned_grouped_sampled_only(self) -> None: self.set_configuration(gc.config) self.get_src_stats(gc.config) self.create_generators(gc.config) - self.remove_data(gc.config) + self.create_tables() self.populate_measurement_type_vocab() self.create_data(gc.config, num_passes=generate_count) - with self.sync_engine.connect() as conn: + with self.dst_sync_engine.connect() as conn: stmt = select(self.metadata.tables[table_name]) rows = conn.execute(stmt).fetchall() self.assert_subset({row.type for row in rows}, {1, 2, 3, 4, 5}) @@ -440,10 +440,10 @@ def test_create_with_null_partitioned_grouped_sampled_tiny(self) -> None: self.set_configuration(gc.config) self.get_src_stats(gc.config) self.create_generators(gc.config) - self.remove_data(gc.config) + self.create_tables() self.populate_measurement_type_vocab() self.create_data(gc.config, num_passes=generate_count) - with self.sync_engine.connect() as conn: + with self.dst_sync_engine.connect() as conn: stmt = select(self.metadata.tables[table_name]) rows = conn.execute(stmt).fetchall() # we should only have one or two of "ham", "eggs" and "cheese" represented diff --git a/tests/test_make.py b/tests/test_make.py index 49bb9e71..b890e4fa 100644 --- a/tests/test_make.py +++ b/tests/test_make.py @@ -45,7 +45,7 @@ def test_make_table_generators(self) -> None: }, } self.generate_data(config, num_passes=3) - with self.sync_engine.connect() as conn: + with self.dst_sync_engine.connect() as conn: stmt = select(self.metadata.tables["player"]) rows = conn.execute(stmt).mappings().fetchall() for row in rows: diff --git a/tests/utils.py b/tests/utils.py index dc54291b..c120ed6c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,14 +22,14 @@ from sqlalchemy import Engine, MetaData from datafaker import settings -from datafaker.create import create_db_data_into +from datafaker.create import create_db_data_into, create_db_tables_into from datafaker.interactive.base import DbCmd from datafaker.make import make_src_stats, make_table_generators, make_tables_file -from datafaker.remove import remove_db_data_from from datafaker.utils import ( MaybeAsyncEngine, T, create_db_engine, + create_db_engine_dst, get_sync_engine, import_file, sorted_non_vocabulary_tables, @@ -194,8 +194,8 @@ def open(self) -> None: self.close() self._make_con_string() # create the database (must be non-read-only) - duckdb_con = duckdb.connect(self._db_path) - duckdb_con.close() + # duckdb_con = duckdb.connect(self._db_path) + # duckdb_con.close() def close(self) -> None: """Tear down the test database.""" @@ -238,6 +238,9 @@ class DatafakerTestCase(TestCase): dump_file_path: str | None = None database_name: str | None = None + def setUp(self) -> None: + settings.get_settings.cache_clear() + def assertReturnCode( # pylint: disable=invalid-name self, result: Any, expected_code: int ) -> None: @@ -350,7 +353,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.sync_engine: Engine def setUp(self) -> None: - settings.get_settings.cache_clear() super().setUp() if self.database is None: self.database = self.database_type() @@ -378,9 +380,12 @@ def dsn(self) -> str: return self.database.get_dsn(self.database_name) +# pylint: disable=too-many-instance-attributes class GeneratesDBTestCase(RequiresDBTestCase): """A test case for which a database is generated.""" + dst_schema_name: str | None = None + def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialise a GeneratedDB test case.""" super().__init__(*args, **kwargs) @@ -389,10 +394,29 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.stats_file_path = "" self.config_file_path = "" self.config_fd = 0 + self.dst_database: TestDatabaseBase | None = None + self.dst_engine: MaybeAsyncEngine + self.dst_sync_engine: Engine + + @property + def dst_dsn(self) -> str: + """Get the database connection string.""" + assert self.dst_database is not None + return self.dst_database.get_dsn(None) def setUp(self) -> None: """Set up the test case with an actual orm.yaml file.""" super().setUp() + if self.dst_database is None: + self.dst_database = self.database_type() + else: + self.dst_database.open() + self.dst_engine = create_db_engine_dst( + self.dst_dsn, + schema_name=self.dst_schema_name, + use_asyncio=self.use_asyncio, + ) + self.dst_sync_engine = get_sync_engine(self.dst_engine) # Generate the `orm.yaml` from the database (self.orm_fd, self.orm_file_path) = mkstemp(".yaml", "orm_", text=True) with os.fdopen(self.orm_fd, "w", encoding="utf-8") as orm_fh: @@ -435,10 +459,9 @@ def create_generators(self, config: Mapping[str, Any]) -> None: with os.fdopen(generators_fd, "w", encoding="utf-8") as datafaker_fh: datafaker_fh.write(datafaker_content) - def remove_data(self, config: Mapping[str, Any]) -> None: - """Remove source data from the DB.""" - # `remove-data` so we don't have to use a separate database for the destination - remove_db_data_from(self.metadata, config, self.dsn, self.schema_name) + def create_tables(self) -> None: + """Create tables in the output DB.""" + create_db_tables_into(self.metadata, self.dst_dsn, self.dst_schema_name) def create_data(self, config: Mapping[str, Any], num_passes: int = 1) -> None: """Create fake data in the DB.""" @@ -448,8 +471,8 @@ def create_data(self, config: Mapping[str, Any], num_passes: int = 1) -> None: sorted_non_vocabulary_tables(self.metadata, config), datafaker_module, num_passes, - self.dsn, - self.schema_name, + self.dst_dsn, + self.dst_schema_name, self.metadata, ) @@ -463,7 +486,7 @@ def generate_data( self.set_configuration(config) src_stats = self.get_src_stats(config) self.create_generators(config) - self.remove_data(config) + self.create_tables() self.create_data(config, num_passes) return src_stats From bc7a8b4f16b35d66f8e095c406816e934d98bad7 Mon Sep 17 00:00:00 2001 From: Tim Band Date: Mon, 2 Mar 2026 12:15:20 +0000 Subject: [PATCH 2/2] Remove commented-out code --- tests/utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index c120ed6c..c5fef7f8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -193,9 +193,6 @@ def open(self) -> None: """Start the test database""" self.close() self._make_con_string() - # create the database (must be non-read-only) - # duckdb_con = duckdb.connect(self._db_path) - # duckdb_con.close() def close(self) -> None: """Tear down the test database."""