diff --git a/mypy.ini b/mypy.ini index 986f62b2d..6326a120f 100644 --- a/mypy.ini +++ b/mypy.ini @@ -72,13 +72,6 @@ disallow_untyped_defs = False # is just not worth the effort. disallow_untyped_defs = False -# These modules are deprecated (maybe implicitly, as being Gen2-only). Not -# worth adding new annotations to them. -[mypy-lsst.pipe.base.argumentParser.*] -disallow_untyped_defs = False -[mypy-lsst.pipe.base.shims.*] -disallow_untyped_defs = False - # ConfigOverrides uses the Python built-in ast module, and figuring out how # to correctly type its visitation interface doesn't seem worth the effort # right now. diff --git a/python/lsst/pipe/base/blocking_limited_butler.py b/python/lsst/pipe/base/blocking_limited_butler.py new file mode 100644 index 000000000..3fe194645 --- /dev/null +++ b/python/lsst/pipe/base/blocking_limited_butler.py @@ -0,0 +1,173 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = ["BlockingLimitedButler"] + +import logging +import time +from collections.abc import Iterable, Mapping +from typing import Any + +from lsst.daf.butler import ( + ButlerMetrics, + DatasetProvenance, + DatasetRef, + DeferredDatasetHandle, + DimensionUniverse, + LimitedButler, + StorageClass, +) + +_LOG = logging.getLogger(__name__) + + +class BlockingLimitedButler(LimitedButler): + """A `LimitedButler` that blocks until certain dataset types exist. + + Parameters + ---------- + wrapped : `LimitedButler` + The butler to wrap. + timeouts : `~collections.abc.Mapping` [ `str`, `float` or `None` ] + Timeouts in seconds to wait for different dataset types. Dataset types + not included not blocked on (i.e. their timeout is ``0.0``). + + Notes + ----- + When a timeout is exceeded, `get` will raise `FileNotFoundError` (as usual + for a dataset that does not exist) and `stored_many` will mark the dataset + as non-existent. `getDeferred` does not block. + """ + + def __init__( + self, + wrapped: LimitedButler, + timeouts: Mapping[str, float | None], + ): + self._wrapped = wrapped + self._timeouts = timeouts + + def close(self) -> None: + self._wrapped.close() + + @property + def _metrics(self) -> ButlerMetrics: + # Need to always forward from the wrapped metrics object. + return self._wrapped._metrics + + @_metrics.setter + def _metrics(self, metrics: ButlerMetrics) -> None: + # Allow record_metrics() context manager to override the wrapped + # butler. + self._wrapped._metrics = metrics + + def get( + self, + ref: DatasetRef, + /, + *, + parameters: dict[str, Any] | None = None, + storageClass: StorageClass | str | None = None, + ) -> Any: + parent_dataset_type_name = ref.datasetType.nameAndComponent()[0] + timeout = self._timeouts.get(parent_dataset_type_name, 0.0) + start = time.time() + while True: + try: + return self._wrapped.get(ref, parameters=parameters, storageClass=storageClass) + except FileNotFoundError as err: + if timeout is not None: + elapsed = time.time() - start + if elapsed > timeout: + err.add_note(f"Timed out after {elapsed:03f}s.") + raise + _LOG.info(f"Dataset {ref.datasetType} not immediately available for {ref.id}, waiting {timeout}s") + time.sleep(0.5) + + def getDeferred( + self, + ref: DatasetRef, + /, + *, + parameters: dict[str, Any] | None = None, + storageClass: str | StorageClass | None = None, + ) -> DeferredDatasetHandle: + # note that this does not use the block at all + return self._wrapped.getDeferred(ref, parameters=parameters, storageClass=storageClass) + + def stored_many(self, refs: Iterable[DatasetRef]) -> dict[DatasetRef, bool]: + start = time.time() + result = self._wrapped.stored_many(refs) + timeouts = {ref.id: self._timeouts.get(ref.datasetType.nameAndComponent()[0], 0.0) for ref in result} + while True: + elapsed = time.time() - start + remaining: list[DatasetRef] = [] + for ref, exists in result.items(): + timeout = timeouts[ref.id] + if not exists and (timeout is None or elapsed <= timeout): + _LOG.info( + f"Dataset {ref.datasetType} not immediately available for {ref.id}, " + f"waiting {timeout}s" + ) + remaining.append(ref) + if not remaining: + return result + result.update(self._wrapped.stored_many(remaining)) + time.sleep(0.5) + + def isWriteable(self) -> bool: + return self._wrapped.isWriteable() + + def put(self, obj: Any, ref: DatasetRef, /, *, provenance: DatasetProvenance | None = None) -> DatasetRef: + return self._wrapped.put(obj, ref, provenance=provenance) + + def pruneDatasets( + self, + refs: Iterable[DatasetRef], + *, + disassociate: bool = True, + unstore: bool = False, + tags: Iterable[str] = (), + purge: bool = False, + ) -> None: + return self._wrapped.pruneDatasets( + refs, disassociate=disassociate, unstore=unstore, tags=tags, purge=purge + ) + + @property + def dimensions(self) -> DimensionUniverse: + return self._wrapped.dimensions + + @property + def _datastore(self) -> Any: + return self._wrapped._datastore + + @_datastore.setter # demanded by MyPy since we declare it to be an instance attribute in LimitedButler. + def _datastore(self, value: Any) -> None: + self._wrapped._datastore = value diff --git a/python/lsst/pipe/base/pipeline_graph/_task_subsets.py b/python/lsst/pipe/base/pipeline_graph/_task_subsets.py index 7c6b012d7..7294cd450 100644 --- a/python/lsst/pipe/base/pipeline_graph/_task_subsets.py +++ b/python/lsst/pipe/base/pipeline_graph/_task_subsets.py @@ -157,7 +157,7 @@ def discard(self, value: str) -> None: self._members.discard(value) @classmethod - def _from_iterable(cls, iterable: Iterable[str]) -> set[str]: + def _from_iterable[S](cls, iterable: Iterable[S]) -> set[S]: # This is the hook used by collections.abc.Set when implementing # operators that return new sets. In this case, we want those to be # regular `set` (builtin) objects, not `TaskSubset` instances. diff --git a/python/lsst/pipe/base/single_quantum_executor.py b/python/lsst/pipe/base/single_quantum_executor.py index 613bfa75e..97c2ad5bd 100644 --- a/python/lsst/pipe/base/single_quantum_executor.py +++ b/python/lsst/pipe/base/single_quantum_executor.py @@ -72,8 +72,9 @@ class SingleQuantumExecutor(QuantumExecutor): Parameters ---------- - butler : `~lsst.daf.butler.Butler` or `None`, optional - Data butler, `None` means that a limited butler should be used instead. + butler : `~lsst.daf.butler.LimitedButler` or `None`, optional + Data butler; `None` means that ``limited_butler_factory`` should be + used instead. task_factory : `.TaskFactory`, optional Instance of a task factory. Defaults to a new instance of `lsst.pipe.base.TaskFactory`. @@ -94,8 +95,8 @@ class SingleQuantumExecutor(QuantumExecutor): Enable debugging with ``lsstDebug`` facility for a task. limited_butler_factory : `~collections.abc.Callable`, optional A method that creates a `~lsst.daf.butler.LimitedButler` instance for a - given Quantum. This parameter must be defined if ``butler`` is `None`. - If ``butler`` is not `None` then this parameter is ignored. + given Quantum. This parameter must be provided if ``butler`` is + `None`. If ``butler`` is not `None` then this parameter is ignored. resources : `.ExecutionResources`, optional The resources available to this quantum when executing. skip_existing : `bool`, optional @@ -115,15 +116,15 @@ class SingleQuantumExecutor(QuantumExecutor): continuing to run downstream tasks. job_metadata : `~collections.abc.Mapping` Mapping with extra metadata to embed within the quantum metadata under - the "job" key. This is intended to correspond to information common - to all quanta being executed in a single process, such as the time - taken to load the quantum graph in a BPS job. + the "job" key. This is intended to correspond to information common to + all quanta being executed in a single process, such as the time taken + to load the quantum graph in a BPS job. """ def __init__( self, *, - butler: Butler | None = None, + butler: LimitedButler | None = None, task_factory: TaskFactory | None = None, skip_existing_in: Any = None, clobber_outputs: bool = False, @@ -135,7 +136,17 @@ def __init__( raise_on_partial_outputs: bool = True, job_metadata: Mapping[str, int | str | float] | None = None, ): - self._butler = butler + self._butler: Butler | None = None + self._limited_butler: LimitedButler | None = None + match butler: + case Butler(): + self._butler = butler + self._limited_butler = butler + case LimitedButler(): + self._limited_butler = butler + case None: + if limited_butler_factory is None: + raise ValueError("limited_butler_factory is needed when butler is None") self._task_factory = task_factory if task_factory is not None else TaskFactory() self._clobber_outputs = clobber_outputs self._enable_lsst_debug = enable_lsst_debug @@ -144,10 +155,6 @@ def __init__( self._assume_no_existing_outputs = assume_no_existing_outputs self._raise_on_partial_outputs = raise_on_partial_outputs self._job_metadata = job_metadata - - if self._butler is None: - assert limited_butler_factory is not None, "limited_butler_factory is needed when butler is None" - # Find whether output run is in skip_existing_in. self._skip_existing = skip_existing if self._butler is not None and skip_existing_in and not self._skip_existing: @@ -190,9 +197,12 @@ def _execute( limited_butler = self._butler else: # We check this in constructor, but mypy needs this check here. - assert self._limited_butler_factory is not None - limited_butler = self._limited_butler_factory(quantum) - used_butler_factory = True + if self._limited_butler is not None: + limited_butler = self._limited_butler + else: + assert self._limited_butler_factory is not None + limited_butler = self._limited_butler_factory(quantum) + used_butler_factory = True try: return self._execute_with_limited_butler( diff --git a/python/lsst/pipe/base/trivial_quantum_graph_builder.py b/python/lsst/pipe/base/trivial_quantum_graph_builder.py new file mode 100644 index 000000000..12343abb6 --- /dev/null +++ b/python/lsst/pipe/base/trivial_quantum_graph_builder.py @@ -0,0 +1,183 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +__all__ = "TrivialQuantumGraphBuilder" + +from collections.abc import Mapping, Sequence +from typing import TYPE_CHECKING, Any, final + +from lsst.daf.butler import Butler, DataCoordinate, DatasetIdGenEnum, DatasetRef, DimensionGroup +from lsst.utils.timer import timeMethod + +from .quantum_graph_builder import QuantumGraphBuilder +from .quantum_graph_skeleton import QuantumGraphSkeleton + +if TYPE_CHECKING: + from .pipeline_graph import PipelineGraph + + +@final +class TrivialQuantumGraphBuilder(QuantumGraphBuilder): + """An optimized quantum-graph builder for pipelines that operate on only + a single data ID or a closely related set of data IDs. + + Parameters + ---------- + pipeline_graph + Pipeline to build a quantum graph from, as a graph. Will be resolved + in-place with the given butler (any existing resolution is ignored). + butler + Client for the data repository. Should be read-only. + data_ids + Mapping from dimension group to the data ID to use for that dimension + group. This is intended to allow the pipeline to switch between + effectively-equivalent dimensions (e.g. ``group``, ``visit`` + ``exposure``). + input_refs + References for input datasets, keyed by task label and then connection + name. This should include all regular overall-input datasets whose + data IDs are not included in ``data_ids``. It may (but need not) + include prerequisite inputs. Existing intermediate datasets should + also be provided when they need to be clobbered or used in skip logic. + dataset_id_modes + Mapping from dataset type name to the ID generation mode for that + dataset type. They default is to generate random UUIDs. + **kwargs + Forwarded to the base `.quantum_graph_builder.QuantumGraphBuilder`. + + Notes + ----- + If ``dataset_id_modes`` is provided, ``clobber=True`` will be passed to + the base builder's constructor, as is this is necessary to avoid spurious + errors about the affected datasets already existing. The only effect of + this to silence *other* errors about datasets in the output run existing + unexpectedly. + """ + + def __init__( + self, + pipeline_graph: PipelineGraph, + butler: Butler, + *, + data_ids: Mapping[DimensionGroup, DataCoordinate], + input_refs: Mapping[str, Mapping[str, Sequence[DatasetRef]]] | None = None, + dataset_id_modes: Mapping[str, DatasetIdGenEnum] | None = None, + **kwargs: Any, + ) -> None: + super().__init__(pipeline_graph, butler, **kwargs) + if dataset_id_modes: + self.clobber = True + self.data_ids = dict(data_ids) + self.data_ids[self.empty_data_id.dimensions] = self.empty_data_id + self.input_refs = input_refs or {} + self.dataset_id_modes = dataset_id_modes or {} + + def _get_data_id(self, dimensions: DimensionGroup, context: str) -> DataCoordinate: + try: + return self.data_ids[dimensions] + except KeyError as e: + e.add_note(context) + raise + + @timeMethod + def process_subgraph(self, subgraph: PipelineGraph) -> QuantumGraphSkeleton: + skeleton = QuantumGraphSkeleton(subgraph.tasks) + for task_node in subgraph.tasks.values(): + quantum_key = skeleton.add_quantum_node( + task_node.label, self._get_data_id(task_node.dimensions, context=f"task {task_node.label!r}") + ) + input_refs_for_task = self.input_refs.get(task_node.label, {}) + + for read_edge in task_node.iter_all_inputs(): + if (input_refs := input_refs_for_task.get(read_edge.connection_name)) is not None: + for input_ref in input_refs: + if read_edge.is_prerequisite: + prereq_key = skeleton.add_prerequisite_node(input_ref) + skeleton.add_input_edge(quantum_key, prereq_key) + self.log.info( + f"Added prereq {task_node.label}.{read_edge.connection_name} " + f"for {input_ref.dataId} from input_refs" + ) + else: + input_key = skeleton.add_dataset_node( + read_edge.parent_dataset_type_name, + input_ref.dataId, + ref=input_ref, + ) + skeleton.add_input_edge(quantum_key, input_key) + self.log.info( + f"Added regular input {task_node.label}.{read_edge.connection_name} " + f"for {input_ref.dataId} from input_refs" + ) + + if read_edge.is_prerequisite: + continue + dataset_type_node = subgraph.dataset_types[read_edge.parent_dataset_type_name] + data_id = self._get_data_id( + dataset_type_node.dimensions, + context=f"input {task_node.label}.{read_edge.connection_name}", + ) + input_key = skeleton.add_dataset_node( + read_edge.parent_dataset_type_name, + data_id, + ) + skeleton.add_input_edge(quantum_key, input_key) + if subgraph.producer_of(read_edge.parent_dataset_type_name) is None: + if skeleton.get_dataset_ref(input_key) is None: + ref = self.butler.find_dataset(dataset_type_node.dataset_type, data_id) + if ref is not None: + skeleton.set_dataset_ref(ref) + self.log.info( + f"Added regular input {task_node.label}.{read_edge.connection_name} for {data_id}" + ) + + for write_edge in task_node.iter_all_outputs(): + dataset_type_node = subgraph.dataset_types[write_edge.parent_dataset_type_name] + data_id = self._get_data_id( + dataset_type_node.dimensions, + context=f"output {task_node.label}.{write_edge.connection_name}", + ) + output_key = skeleton.add_dataset_node(write_edge.parent_dataset_type_name, data_id) + skeleton.add_output_edge(quantum_key, output_key) + self.log.info(f"Added output {task_node.label}.{write_edge.connection_name} for {data_id}") + if mode := self.dataset_id_modes.get(write_edge.parent_dataset_type_name): + ref = DatasetRef( + dataset_type_node.dataset_type, + data_id, + run=self.output_run, + id_generation_mode=mode, + ) + skeleton.set_dataset_ref(ref) + skeleton.set_output_in_the_way(ref) + self.log.info( + f"Added ref for output {task_node.label}.{write_edge.connection_name} for " + f"{data_id} with {mode=}" + ) + + return skeleton diff --git a/tests/test_blocking_limited_butler.py b/tests/test_blocking_limited_butler.py new file mode 100644 index 000000000..21dd58602 --- /dev/null +++ b/tests/test_blocking_limited_butler.py @@ -0,0 +1,101 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Tests for execution butler.""" + +import logging +import os +import unittest + +import lsst.utils.tests +from lsst.daf.butler import DataCoordinate, DatasetRef +from lsst.pipe.base.blocking_limited_butler import _LOG, BlockingLimitedButler +from lsst.pipe.base.tests.mocks import InMemoryRepo + +TESTDIR = os.path.abspath(os.path.dirname(__file__)) + + +class BlockingLimitedButlerTestCase(unittest.TestCase): + """Unit tests for BlockingLimitedButler""" + + def test_no_block_nonexistent(self): + """Test checking/getting with no dataset and blocking disabled.""" + helper = InMemoryRepo("base.yaml") + helper.add_task() + helper.pipeline_graph.resolve(helper.butler.registry) + ref = DatasetRef( + helper.pipeline_graph.dataset_types["dataset_auto0"].dataset_type, + DataCoordinate.make_empty(helper.butler.dimensions), + run="input_run", + ) + helper.pipeline_graph.register_dataset_types(helper.butler) + in_memory_butler = helper.make_limited_butler() + blocking_butler = BlockingLimitedButler(in_memory_butler, timeouts={}) + with self.assertNoLogs(_LOG, level=logging.INFO): + self.assertFalse(blocking_butler.stored_many([ref])[ref]) + with self.assertRaises(FileNotFoundError): + blocking_butler.get(ref) + + def test_timeout_nonexistent(self): + """Test checking/getting with no dataset, leading to a timeout.""" + helper = InMemoryRepo("base.yaml") + helper.add_task() + helper.pipeline_graph.resolve(helper.butler.registry) + ref = DatasetRef( + helper.pipeline_graph.dataset_types["dataset_auto0"].dataset_type, + DataCoordinate.make_empty(helper.butler.dimensions), + run="input_run", + ) + helper.pipeline_graph.register_dataset_types(helper.butler) + in_memory_butler = helper.make_limited_butler() + blocking_butler = BlockingLimitedButler(in_memory_butler, timeouts={"dataset_auto0": 0.1}) + with self.assertLogs(_LOG, level=logging.INFO) as cm: + self.assertFalse(blocking_butler.stored_many([ref])[ref]) + self.assertIn("not immediately available", cm.output[0]) + with self.assertLogs(_LOG, level=logging.INFO) as cm: + with self.assertRaises(FileNotFoundError): + blocking_butler.get(ref) + self.assertIn("not immediately available", cm.output[0]) + + def test_no_waiting_if_exists(self): + """Test checking/getting with dataset present immediately, so no + waiting should be necessary. + """ + helper = InMemoryRepo("base.yaml") + helper.add_task() + (ref,) = helper.insert_datasets("dataset_auto0") + helper.pipeline_graph.register_dataset_types(helper.butler) + in_memory_butler = helper.make_limited_butler() + blocking_butler = BlockingLimitedButler(in_memory_butler, timeouts={}) + with self.assertNoLogs(_LOG, level=logging.INFO): + self.assertTrue(blocking_butler.stored_many([ref])[ref]) + self.assertIsNotNone(blocking_butler.get(ref)) + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main() diff --git a/tests/test_trivial_qg_builder.py b/tests/test_trivial_qg_builder.py new file mode 100644 index 000000000..39e229327 --- /dev/null +++ b/tests/test_trivial_qg_builder.py @@ -0,0 +1,171 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import annotations + +import unittest + +from lsst.daf.butler import DatasetIdGenEnum, DatasetRef +from lsst.pipe.base.tests.mocks import DynamicConnectionConfig, InMemoryRepo +from lsst.pipe.base.trivial_quantum_graph_builder import TrivialQuantumGraphBuilder + + +class TrivialQuantumGraphBuilderTestCase(unittest.TestCase): + """Tests for the TrivialQuantumGraphBuilder class.""" + + def test_trivial_qg_builder(self) -> None: + # Make a test helper with a mock task appropriate for the QG builder: + # - the QG will have no branching + # - while the task have different dimensions, they can be 1-1 related + # (for the purposes of this test, at least). + helper = InMemoryRepo("base.yaml") + helper.add_task( + "a", + dimensions=["band", "detector"], + prerequisite_inputs={ + "prereq_connection": DynamicConnectionConfig( + dataset_type_name="dataset_prereq0", dimensions=["detector"] + ) + }, + ) + helper.add_task( + "b", + dimensions=["physical_filter", "detector"], + inputs={ + "input_connection": DynamicConnectionConfig( + dataset_type_name="dataset_auto1", dimensions=["band", "detector"] + ), + "extra_input_connection": DynamicConnectionConfig( + dataset_type_name="dataset_extra1", dimensions=["physical_filter", "detector"] + ), + }, + ) + # Use the helper to make a quantum graph using the general-purpose + # builder. This will cover all data IDs in the test dataset, which + # includes 4 detectors, 3 physical_filters, and 2 bands. + # This also has useful side-effects: it inserts the input datasets + # and registers all dataset types. + general_qg = helper.make_quantum_graph() + # Make the trivial QG builder we want to test giving it only one + # detector and one band (is the one that corresponds to only one + # physical_filter). + (a_data_id,) = [ + data_id + for data_id in general_qg.quanta_by_task["a"] + if data_id["detector"] == 1 and data_id["band"] == "g" + ] + (b_data_id,) = [ + data_id + for data_id in general_qg.quanta_by_task["b"] + if data_id["detector"] == 1 and data_id["band"] == "g" + ] + prereq_data_id = a_data_id.subset(["detector"]) + dataset_auto0_ref = helper.butler.get_dataset(general_qg.datasets_by_type["dataset_auto0"][a_data_id]) + assert dataset_auto0_ref is not None, "Input dataset should have been inserted above." + dataset_prereq0_ref = helper.butler.get_dataset( + general_qg.datasets_by_type["dataset_prereq0"][prereq_data_id] + ) + assert dataset_prereq0_ref is not None, "Input dataset should have been inserted above." + trivial_builder = TrivialQuantumGraphBuilder( + helper.pipeline_graph, + helper.butler, + data_ids={a_data_id.dimensions: a_data_id, b_data_id.dimensions: b_data_id}, + input_refs={ + "a": {"input_connection": [dataset_auto0_ref], "prereq_connection": [dataset_prereq0_ref]} + }, + dataset_id_modes={"dataset_auto2": DatasetIdGenEnum.DATAID_TYPE_RUN}, + output_run="trivial_output_run", + input_collections=general_qg.header.inputs, + ) + trivial_qg = trivial_builder.finish(attach_datastore_records=False).assemble() + self.assertEqual(len(trivial_qg.quanta_by_task), 2) + self.assertEqual(trivial_qg.quanta_by_task["a"].keys(), {a_data_id}) + self.assertEqual(trivial_qg.quanta_by_task["b"].keys(), {b_data_id}) + self.assertEqual(trivial_qg.datasets_by_type["dataset_prereq0"].keys(), {prereq_data_id}) + self.assertEqual( + trivial_qg.datasets_by_type["dataset_prereq0"][prereq_data_id], + general_qg.datasets_by_type["dataset_prereq0"][prereq_data_id], + ) + self.assertEqual(trivial_qg.datasets_by_type["dataset_auto0"].keys(), {a_data_id}) + self.assertEqual( + trivial_qg.datasets_by_type["dataset_auto0"][a_data_id], + general_qg.datasets_by_type["dataset_auto0"][a_data_id], + ) + self.assertEqual(trivial_qg.datasets_by_type["dataset_extra1"].keys(), {b_data_id}) + self.assertEqual( + trivial_qg.datasets_by_type["dataset_extra1"][b_data_id], + general_qg.datasets_by_type["dataset_extra1"][b_data_id], + ) + self.assertEqual(trivial_qg.datasets_by_type["dataset_auto1"].keys(), {a_data_id}) + self.assertNotEqual( + trivial_qg.datasets_by_type["dataset_auto1"][a_data_id], + general_qg.datasets_by_type["dataset_auto1"][a_data_id], + ) + self.assertEqual(trivial_qg.datasets_by_type["dataset_auto2"].keys(), {b_data_id}) + self.assertNotEqual( + trivial_qg.datasets_by_type["dataset_auto2"][b_data_id], + general_qg.datasets_by_type["dataset_auto2"][b_data_id], + ) + self.assertEqual( + trivial_qg.datasets_by_type["dataset_auto2"][b_data_id], + DatasetRef( + helper.pipeline_graph.dataset_types["dataset_auto2"].dataset_type, + b_data_id, + run="trivial_output_run", + id_generation_mode=DatasetIdGenEnum.DATAID_TYPE_RUN, + ).id, + ) + qo_xg = trivial_qg.quantum_only_xgraph + self.assertEqual(len(qo_xg.nodes), 2) + self.assertEqual(len(qo_xg.edges), 1) + bp_xg = trivial_qg.bipartite_xgraph + self.assertEqual( + set(bp_xg.predecessors(trivial_qg.quanta_by_task["a"][a_data_id])), + set(trivial_qg.datasets_by_type["dataset_auto0"].values()) + | set(trivial_qg.datasets_by_type["dataset_prereq0"].values()), + ) + self.assertEqual( + set(bp_xg.successors(trivial_qg.quanta_by_task["a"][a_data_id])), + set(trivial_qg.datasets_by_type["dataset_auto1"].values()) + | set(trivial_qg.datasets_by_type["a_metadata"].values()) + | set(trivial_qg.datasets_by_type["a_log"].values()), + ) + self.assertEqual( + set(bp_xg.predecessors(trivial_qg.quanta_by_task["b"][b_data_id])), + set(trivial_qg.datasets_by_type["dataset_auto1"].values()) + | set(trivial_qg.datasets_by_type["dataset_extra1"].values()), + ) + self.assertEqual( + set(bp_xg.successors(trivial_qg.quanta_by_task["b"][b_data_id])), + set(trivial_qg.datasets_by_type["dataset_auto2"].values()) + | set(trivial_qg.datasets_by_type["b_metadata"].values()) + | set(trivial_qg.datasets_by_type["b_log"].values()), + ) + + +if __name__ == "__main__": + unittest.main()