Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions datafaker/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from datafaker.base import FileUploader, TableGenerator
from datafaker.settings import get_destination_dsn, get_destination_schema
from datafaker.utils import (
create_db_engine,
create_db_engine_dst,
get_sync_engine,
get_vocabulary_table_names,
logger,
Expand Down Expand Up @@ -61,9 +61,15 @@ def remove_on_delete_cascade(element: CreateTable, compiler: Any, **kw: Any) ->
def create_db_tables(metadata: MetaData) -> None:
"""Create tables described by the sqlalchemy metadata object."""
dst_dsn = get_destination_dsn()
engine = get_sync_engine(create_db_engine(dst_dsn))
schema_name = get_destination_schema()
assert dst_dsn != "", "Missing DST_DSN setting."
create_db_tables_into(metadata, dst_dsn, get_destination_schema())


def create_db_tables_into(
metadata: MetaData, dst_dsn: str, schema_name: str | None = None
) -> None:
"""Create tables described by the sqlalchemy metadata object with explicit DSN."""
engine = get_sync_engine(create_db_engine_dst(dst_dsn))
# Create schema, if necessary.
if schema_name is not None:
with engine.connect() as connection:
Expand All @@ -75,9 +81,11 @@ def create_db_tables(metadata: MetaData) -> None:
connection.commit()

# Recreate the engine, this time with a schema specified
engine = get_sync_engine(create_db_engine(dst_dsn, schema_name=schema_name))
engine.dispose()
engine = get_sync_engine(create_db_engine_dst(dst_dsn, schema_name=schema_name))

metadata.create_all(engine)
engine.dispose()


def create_db_vocab(
Expand All @@ -95,7 +103,7 @@ def create_db_vocab(
:return: List of table names loaded.
"""
dst_engine = get_sync_engine(
create_db_engine(
create_db_engine_dst(
get_destination_dsn(),
schema_name=get_destination_schema(),
)
Expand Down Expand Up @@ -165,7 +173,7 @@ def create_db_data_into(
:param db_dsn: Connection string for the destination database.
:param schema_name: Destination schema name.
"""
dst_engine = get_sync_engine(create_db_engine(db_dsn, schema_name=schema_name))
dst_engine = get_sync_engine(create_db_engine_dst(db_dsn, schema_name=schema_name))

row_counts: Counter[str] = Counter()
with dst_engine.connect() as dst_conn:
Expand All @@ -177,6 +185,7 @@ def create_db_data_into(
df_module.story_generator_list,
metadata,
)
dst_engine.dispose()
return row_counts


Expand Down
29 changes: 29 additions & 0 deletions datafaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,35 @@ def connect(dbapi_connection: DBAPIConnection, _: Any) -> None:
return engine


def create_db_engine_dst(
db_dsn: str,
schema_name: Optional[str] = None,
use_asyncio: bool = False,
) -> MaybeAsyncEngine:
"""
Create a SQLAlchemy Engine suitable for output.

This prevents DuckDB from reading any parquet files avoiding any
possible leakage from existing source files into the destination database.
:param db_dsn: The database connection string.
:param schema_name: The name of the schema within the database to use.
:param use_asyncio: True if an asynchronous connection is required.
:return: The ``Engine`` or ``AsyncEngine``.
"""
if db_dsn.startswith("duckdb:"):
return create_db_engine(
db_dsn,
schema_name,
use_asyncio,
connect_args={
"config": {
"enable_external_access": False,
}
},
Comment on lines +231 to +234

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh interesting who knew

)
return create_db_engine(db_dsn, schema_name, use_asyncio)


def set_search_path(connection: DBAPIConnection, schema: str) -> None:
"""Set the SEARCH_PATH for a PostgreSQL connection."""
# https://docs.sqlalchemy.org/en/20/dialects/postgresql.html#remote-schema-table-introspection-and-postgresql-search-path
Expand Down
171 changes: 163 additions & 8 deletions tests/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,24 @@
import itertools as itt
import os
import random
import tempfile
from collections import Counter
from pathlib import Path
from typing import Any, Generator, Mapping, Tuple
from unittest.mock import MagicMock, call, patch

from sqlalchemy import Connection, select
import duckdb
import pandas as pd
from sqlalchemy import Connection, Engine, select
from sqlalchemy.schema import MetaData, Table

from datafaker.base import TableGenerator
from datafaker.create import create_db_vocab, populate
from datafaker.remove import remove_db_vocab
from datafaker.create import (
create_db_data_into,
create_db_tables,
create_db_vocab,
populate,
)
from datafaker.serialize_metadata import metadata_to_dict
from tests.utils import DatafakerTestCase, GeneratesDBTestCase

Expand All @@ -28,7 +35,7 @@ def test_create_vocab(self) -> None:
"""Test the create_db_vocab function."""
with patch.dict(
os.environ,
{"DST_DSN": self.dsn, "DST_SCHEMA": self.schema_name},
{"DST_DSN": self.dst_dsn},
clear=True,
):
config = {
Expand All @@ -42,10 +49,9 @@ def test_create_vocab(self) -> None:
meta_dict = metadata_to_dict(
self.metadata, self.schema_name, self.sync_engine
)
self.remove_data(config)
remove_db_vocab(self.metadata, meta_dict, config)
create_db_tables(self.metadata)
create_db_vocab(self.metadata, meta_dict, config, Path("./tests/examples"))
with self.sync_engine.connect() as conn:
with self.dst_sync_engine.connect() as conn:
stmt = select(self.metadata.tables["player"])
rows = list(conn.execute(stmt).mappings().fetchall())
self.assertEqual(len(rows), 3)
Expand All @@ -64,7 +70,7 @@ def test_make_table_generators(self) -> None:
random.seed(56)
config: Mapping[str, Any] = {}
self.generate_data(config, num_passes=2)
with self.sync_engine.connect() as conn:
with self.dst_sync_engine.connect() as conn:
stmt = select(self.metadata.tables["string"])
rows = list(conn.execute(stmt).mappings().fetchall())
a = rows[0]
Expand Down Expand Up @@ -183,3 +189,152 @@ def test_populate_diff_length(self, mock_insert: MagicMock) -> None:

mock_gen_two.assert_called_once()
mock_gen_three.assert_called_once()


class MockFunctionUsingConnection:
"""Base mock callable that should not be permitted to read parquet files."""

@classmethod
def is_parquet_permitted(cls, connection: Any) -> bool:
"""Test if a normal DuckDB can access the ``fruit.parquet`` file."""
try:
connection.execute("SELECT * FROM fruit.parquet")
except duckdb.PermissionException:
return False
return True

def __init__(self) -> None:
"""Initialize as uncalled."""
self.called = False

def do_call(self, connection: Any) -> None:
"""Test for parquet access not being permitted."""
assert not self.is_parquet_permitted(connection)
self.called = True


class CreateReadsNoParquetTestCase(DatafakerTestCase):
"""
Output to the database should not have access to parquet files.

Otherwise there is a risk of leakage of source data.
"""

examples_dir = Path("tests/examples/duckdb")
parquet_name = "fruit.parquet"

def setUp(self) -> None:
"""Go to the directory where there are parquet files."""
super().setUp()
self.start_dir = os.getcwd()
self.parquet_dir = Path(tempfile.mkdtemp("parq"))
os.chdir(self.parquet_dir)
self.write_parquet()
assert MockFunctionUsingConnection.is_parquet_permitted(duckdb.connect())

def tearDown(self) -> None:
"""Return to the start directory."""
os.chdir(self.start_dir)
return super().tearDown()

def write_parquet(self) -> None:
"""Write a parquet file into the current directory."""
fruit: dict[str, list[Any]] = {
"id": [1, 2, 3],
"orange": [True, True, False],
"banana": ["one", "two", "three"],
}
pd.DataFrame.from_dict(fruit).to_parquet(self.parquet_name)

class MockCreateAll(MockFunctionUsingConnection):
"""Mock for the MetaData.create_all function."""

def __call__(self, engine: Engine) -> None:
self.do_call(engine.raw_connection())

def test_create_db_tables_cannot_access_parquet(self) -> None:
"""Test the database connection cannot access parquet file."""
meta_data = MagicMock()
meta_data.create_all = self.MockCreateAll()
with patch.dict(
os.environ,
{"DST_DSN": "duckdb:///:memory:tables"},
clear=True,
):
create_db_tables(meta_data)
assert meta_data.create_all.called

def test_create_db_tables_cannot_access_parquet_with_schema(self) -> None:
"""
Test the database connection cannot access parquet file.

We use a schema because this activates a different code path.
"""
meta_data = MagicMock()
meta_data.create_all = self.MockCreateAll()
testdb = duckdb.connect("./test.db")
testdb.execute("CREATE SCHEMA fruity")
testdb.close()
with patch.dict(
os.environ,
{"DST_SCHEMA": "fruity", "DST_DSN": "duckdb:///./test.db"},
clear=True,
):
create_db_tables(meta_data)
assert meta_data.create_all.called

@patch("datafaker.create.populate")
def test_create_db_data_cannot_access_parquet(
self, mock_populate: MagicMock
) -> None:
"""Test the database connection cannot access parquet file while creating data."""

class MockPopulate(MockFunctionUsingConnection):
"""Mock ``populate`` function."""

def __call__(
self, connection: Connection, _a2: Any, _a3: Any, _a4: Any, _a5: Any
) -> dict[str, Any]:
super().do_call(connection.connection.dbapi_connection)
return {"vocab1": 1}

mock_populate.side_effect = MockPopulate()
create_db_data_into(
[MagicMock()],
MagicMock(),
1,
"duckdb:///:memory:data",
None,
MagicMock(),
)
assert mock_populate.side_effect.called

@patch("datafaker.create.FileUploader")
def test_create_db_vocab_cannot_access_parquet(
self, file_uploader: MagicMock
) -> None:
"""Test we cannot access parquet file while populating vocabulary tables."""

class MockLoader(MockFunctionUsingConnection):
"""Mock ``FileUploader.load`` function."""

def __call__(self, connection: Connection, base_path: Path) -> None:
assert str(base_path) == "base"
super().do_call(connection.connection.dbapi_connection)

file_uploader.return_value.load = MockLoader()
assert not file_uploader.return_value.load.called
meta_data = MetaData()
Table("table1", meta_data)
with patch.dict(
os.environ,
{"DST_DSN": "duckdb:///:memory:vocab"},
clear=True,
):
create_db_vocab(
meta_data,
{"tables": {"table1": {"columns": {}}}},
{"tables": {"table1": {"vocabulary_table": True}}},
base_path=Path("base"),
)
assert file_uploader.return_value.load.called
1 change: 1 addition & 0 deletions tests/test_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def test_end_to_end_parquet(self) -> None:
# Dump the fake tables
outdir = Path(tempfile.mkdtemp("dump"))
result = runner.invoke(app, ["dump-data", "--output", str(outdir), "--parquet"])
print(result)
self.assertSuccess(result)

# Check the dumped files
Expand Down
21 changes: 10 additions & 11 deletions tests/test_interactive_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ def test_create_with_sampled_choice(self) -> None:
gc.do_quit("")
self.generate_data(gc.config, num_passes=200)
# all generation possibilities should be present
with self.sync_engine.connect() as conn:
with self.dst_sync_engine.connect() as conn:
stats = ChoiceMeasurementTableStats(self.metadata, conn)
self.assertSetEqual(stats.ones, {1, 4})
self.assertSetEqual(stats.twos, {2, 3})
Expand All @@ -654,7 +654,7 @@ def test_create_with_choice(self) -> None:
gc.do_set(str(proposals["dist_gen.zipf_choice"][0]))
gc.do_quit("")
self.generate_data(gc.config, num_passes=200)
with self.sync_engine.connect() as conn:
with self.dst_sync_engine.connect() as conn:
stmt = select(self.metadata.tables[table_name])
rows = conn.execute(stmt).fetchall()
ones = set()
Expand Down Expand Up @@ -733,13 +733,12 @@ def test_create_with_weighted_choice(self) -> None:
gc.do_set(str(prop[0]))
gc.do_quit("")
self.generate_data(gc.config, num_passes=200)
with self.sync_engine.connect() as conn:
with self.sync_engine.connect() as conn:
stats = ChoiceMeasurementTableStats(self.metadata, conn)
# all generation possibilities should be present
self.assertSetEqual(stats.ones, {1, 4})
self.assertSetEqual(stats.twos, {1, 2, 3, 4, 5})
self.assertSetEqual(stats.threes, {1, 2, 3, 4, 5})
with self.dst_sync_engine.connect() as conn:
stats = ChoiceMeasurementTableStats(self.metadata, conn)
# all generation possibilities should be present
self.assertSetEqual(stats.ones, {1, 4})
self.assertSetEqual(stats.twos, {1, 2, 3, 4, 5})
self.assertSetEqual(stats.threes, {1, 2, 3, 4, 5})


class GeneratorsOutputTestsDuckDb(GeneratorsOutputTests):
Expand Down Expand Up @@ -780,7 +779,7 @@ def test_set_null(self) -> None:
config = gc.config
self.generate_data(config, num_passes=3)
# Test that each missingness pattern is present in the database
with self.sync_engine.connect() as conn:
with self.dst_sync_engine.connect() as conn:
# select(self.metadata.tables["string"].c["position", "frequency"]) would be nicer
# but mypy doesn't like it
stmt = select(
Expand Down Expand Up @@ -866,7 +865,7 @@ def test_varchar_ns_are_truncated(self) -> None:
gc.do_quit("")
config = gc.config
self.generate_data(config, num_passes=15)
with self.sync_engine.connect() as conn:
with self.dst_sync_engine.connect() as conn:
stmt = select(self.metadata.tables[table].c[column])
rows = conn.execute(stmt).scalars().fetchall()
self.assert_are_truncated_to(rows, 20)
Expand Down
Loading