diff --git a/mypy.ini b/mypy.ini
index 986f62b2d..6326a120f 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -72,13 +72,6 @@ disallow_untyped_defs = False
# is just not worth the effort.
disallow_untyped_defs = False
-# These modules are deprecated (maybe implicitly, as being Gen2-only). Not
-# worth adding new annotations to them.
-[mypy-lsst.pipe.base.argumentParser.*]
-disallow_untyped_defs = False
-[mypy-lsst.pipe.base.shims.*]
-disallow_untyped_defs = False
-
# ConfigOverrides uses the Python built-in ast module, and figuring out how
# to correctly type its visitation interface doesn't seem worth the effort
# right now.
diff --git a/python/lsst/pipe/base/blocking_limited_butler.py b/python/lsst/pipe/base/blocking_limited_butler.py
new file mode 100644
index 000000000..3fe194645
--- /dev/null
+++ b/python/lsst/pipe/base/blocking_limited_butler.py
@@ -0,0 +1,173 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+__all__ = ["BlockingLimitedButler"]
+
+import logging
+import time
+from collections.abc import Iterable, Mapping
+from typing import Any
+
+from lsst.daf.butler import (
+ ButlerMetrics,
+ DatasetProvenance,
+ DatasetRef,
+ DeferredDatasetHandle,
+ DimensionUniverse,
+ LimitedButler,
+ StorageClass,
+)
+
+_LOG = logging.getLogger(__name__)
+
+
+class BlockingLimitedButler(LimitedButler):
+ """A `LimitedButler` that blocks until certain dataset types exist.
+
+ Parameters
+ ----------
+ wrapped : `LimitedButler`
+ The butler to wrap.
+ timeouts : `~collections.abc.Mapping` [ `str`, `float` or `None` ]
+ Timeouts in seconds to wait for different dataset types. Dataset types
+ not included not blocked on (i.e. their timeout is ``0.0``).
+
+ Notes
+ -----
+ When a timeout is exceeded, `get` will raise `FileNotFoundError` (as usual
+ for a dataset that does not exist) and `stored_many` will mark the dataset
+ as non-existent. `getDeferred` does not block.
+ """
+
+ def __init__(
+ self,
+ wrapped: LimitedButler,
+ timeouts: Mapping[str, float | None],
+ ):
+ self._wrapped = wrapped
+ self._timeouts = timeouts
+
+ def close(self) -> None:
+ self._wrapped.close()
+
+ @property
+ def _metrics(self) -> ButlerMetrics:
+ # Need to always forward from the wrapped metrics object.
+ return self._wrapped._metrics
+
+ @_metrics.setter
+ def _metrics(self, metrics: ButlerMetrics) -> None:
+ # Allow record_metrics() context manager to override the wrapped
+ # butler.
+ self._wrapped._metrics = metrics
+
+ def get(
+ self,
+ ref: DatasetRef,
+ /,
+ *,
+ parameters: dict[str, Any] | None = None,
+ storageClass: StorageClass | str | None = None,
+ ) -> Any:
+ parent_dataset_type_name = ref.datasetType.nameAndComponent()[0]
+ timeout = self._timeouts.get(parent_dataset_type_name, 0.0)
+ start = time.time()
+ while True:
+ try:
+ return self._wrapped.get(ref, parameters=parameters, storageClass=storageClass)
+ except FileNotFoundError as err:
+ if timeout is not None:
+ elapsed = time.time() - start
+ if elapsed > timeout:
+ err.add_note(f"Timed out after {elapsed:03f}s.")
+ raise
+ _LOG.info(f"Dataset {ref.datasetType} not immediately available for {ref.id}, waiting {timeout}s")
+ time.sleep(0.5)
+
+ def getDeferred(
+ self,
+ ref: DatasetRef,
+ /,
+ *,
+ parameters: dict[str, Any] | None = None,
+ storageClass: str | StorageClass | None = None,
+ ) -> DeferredDatasetHandle:
+ # note that this does not use the block at all
+ return self._wrapped.getDeferred(ref, parameters=parameters, storageClass=storageClass)
+
+ def stored_many(self, refs: Iterable[DatasetRef]) -> dict[DatasetRef, bool]:
+ start = time.time()
+ result = self._wrapped.stored_many(refs)
+ timeouts = {ref.id: self._timeouts.get(ref.datasetType.nameAndComponent()[0], 0.0) for ref in result}
+ while True:
+ elapsed = time.time() - start
+ remaining: list[DatasetRef] = []
+ for ref, exists in result.items():
+ timeout = timeouts[ref.id]
+ if not exists and (timeout is None or elapsed <= timeout):
+ _LOG.info(
+ f"Dataset {ref.datasetType} not immediately available for {ref.id}, "
+ f"waiting {timeout}s"
+ )
+ remaining.append(ref)
+ if not remaining:
+ return result
+ result.update(self._wrapped.stored_many(remaining))
+ time.sleep(0.5)
+
+ def isWriteable(self) -> bool:
+ return self._wrapped.isWriteable()
+
+ def put(self, obj: Any, ref: DatasetRef, /, *, provenance: DatasetProvenance | None = None) -> DatasetRef:
+ return self._wrapped.put(obj, ref, provenance=provenance)
+
+ def pruneDatasets(
+ self,
+ refs: Iterable[DatasetRef],
+ *,
+ disassociate: bool = True,
+ unstore: bool = False,
+ tags: Iterable[str] = (),
+ purge: bool = False,
+ ) -> None:
+ return self._wrapped.pruneDatasets(
+ refs, disassociate=disassociate, unstore=unstore, tags=tags, purge=purge
+ )
+
+ @property
+ def dimensions(self) -> DimensionUniverse:
+ return self._wrapped.dimensions
+
+ @property
+ def _datastore(self) -> Any:
+ return self._wrapped._datastore
+
+ @_datastore.setter # demanded by MyPy since we declare it to be an instance attribute in LimitedButler.
+ def _datastore(self, value: Any) -> None:
+ self._wrapped._datastore = value
diff --git a/python/lsst/pipe/base/pipeline_graph/_task_subsets.py b/python/lsst/pipe/base/pipeline_graph/_task_subsets.py
index 7c6b012d7..7294cd450 100644
--- a/python/lsst/pipe/base/pipeline_graph/_task_subsets.py
+++ b/python/lsst/pipe/base/pipeline_graph/_task_subsets.py
@@ -157,7 +157,7 @@ def discard(self, value: str) -> None:
self._members.discard(value)
@classmethod
- def _from_iterable(cls, iterable: Iterable[str]) -> set[str]:
+ def _from_iterable[S](cls, iterable: Iterable[S]) -> set[S]:
# This is the hook used by collections.abc.Set when implementing
# operators that return new sets. In this case, we want those to be
# regular `set` (builtin) objects, not `TaskSubset` instances.
diff --git a/python/lsst/pipe/base/single_quantum_executor.py b/python/lsst/pipe/base/single_quantum_executor.py
index 613bfa75e..97c2ad5bd 100644
--- a/python/lsst/pipe/base/single_quantum_executor.py
+++ b/python/lsst/pipe/base/single_quantum_executor.py
@@ -72,8 +72,9 @@ class SingleQuantumExecutor(QuantumExecutor):
Parameters
----------
- butler : `~lsst.daf.butler.Butler` or `None`, optional
- Data butler, `None` means that a limited butler should be used instead.
+ butler : `~lsst.daf.butler.LimitedButler` or `None`, optional
+ Data butler; `None` means that ``limited_butler_factory`` should be
+ used instead.
task_factory : `.TaskFactory`, optional
Instance of a task factory. Defaults to a new instance of
`lsst.pipe.base.TaskFactory`.
@@ -94,8 +95,8 @@ class SingleQuantumExecutor(QuantumExecutor):
Enable debugging with ``lsstDebug`` facility for a task.
limited_butler_factory : `~collections.abc.Callable`, optional
A method that creates a `~lsst.daf.butler.LimitedButler` instance for a
- given Quantum. This parameter must be defined if ``butler`` is `None`.
- If ``butler`` is not `None` then this parameter is ignored.
+ given Quantum. This parameter must be provided if ``butler`` is
+ `None`. If ``butler`` is not `None` then this parameter is ignored.
resources : `.ExecutionResources`, optional
The resources available to this quantum when executing.
skip_existing : `bool`, optional
@@ -115,15 +116,15 @@ class SingleQuantumExecutor(QuantumExecutor):
continuing to run downstream tasks.
job_metadata : `~collections.abc.Mapping`
Mapping with extra metadata to embed within the quantum metadata under
- the "job" key. This is intended to correspond to information common
- to all quanta being executed in a single process, such as the time
- taken to load the quantum graph in a BPS job.
+ the "job" key. This is intended to correspond to information common to
+ all quanta being executed in a single process, such as the time taken
+ to load the quantum graph in a BPS job.
"""
def __init__(
self,
*,
- butler: Butler | None = None,
+ butler: LimitedButler | None = None,
task_factory: TaskFactory | None = None,
skip_existing_in: Any = None,
clobber_outputs: bool = False,
@@ -135,7 +136,17 @@ def __init__(
raise_on_partial_outputs: bool = True,
job_metadata: Mapping[str, int | str | float] | None = None,
):
- self._butler = butler
+ self._butler: Butler | None = None
+ self._limited_butler: LimitedButler | None = None
+ match butler:
+ case Butler():
+ self._butler = butler
+ self._limited_butler = butler
+ case LimitedButler():
+ self._limited_butler = butler
+ case None:
+ if limited_butler_factory is None:
+ raise ValueError("limited_butler_factory is needed when butler is None")
self._task_factory = task_factory if task_factory is not None else TaskFactory()
self._clobber_outputs = clobber_outputs
self._enable_lsst_debug = enable_lsst_debug
@@ -144,10 +155,6 @@ def __init__(
self._assume_no_existing_outputs = assume_no_existing_outputs
self._raise_on_partial_outputs = raise_on_partial_outputs
self._job_metadata = job_metadata
-
- if self._butler is None:
- assert limited_butler_factory is not None, "limited_butler_factory is needed when butler is None"
-
# Find whether output run is in skip_existing_in.
self._skip_existing = skip_existing
if self._butler is not None and skip_existing_in and not self._skip_existing:
@@ -190,9 +197,12 @@ def _execute(
limited_butler = self._butler
else:
# We check this in constructor, but mypy needs this check here.
- assert self._limited_butler_factory is not None
- limited_butler = self._limited_butler_factory(quantum)
- used_butler_factory = True
+ if self._limited_butler is not None:
+ limited_butler = self._limited_butler
+ else:
+ assert self._limited_butler_factory is not None
+ limited_butler = self._limited_butler_factory(quantum)
+ used_butler_factory = True
try:
return self._execute_with_limited_butler(
diff --git a/python/lsst/pipe/base/trivial_quantum_graph_builder.py b/python/lsst/pipe/base/trivial_quantum_graph_builder.py
new file mode 100644
index 000000000..12343abb6
--- /dev/null
+++ b/python/lsst/pipe/base/trivial_quantum_graph_builder.py
@@ -0,0 +1,183 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (http://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+__all__ = "TrivialQuantumGraphBuilder"
+
+from collections.abc import Mapping, Sequence
+from typing import TYPE_CHECKING, Any, final
+
+from lsst.daf.butler import Butler, DataCoordinate, DatasetIdGenEnum, DatasetRef, DimensionGroup
+from lsst.utils.timer import timeMethod
+
+from .quantum_graph_builder import QuantumGraphBuilder
+from .quantum_graph_skeleton import QuantumGraphSkeleton
+
+if TYPE_CHECKING:
+ from .pipeline_graph import PipelineGraph
+
+
+@final
+class TrivialQuantumGraphBuilder(QuantumGraphBuilder):
+ """An optimized quantum-graph builder for pipelines that operate on only
+ a single data ID or a closely related set of data IDs.
+
+ Parameters
+ ----------
+ pipeline_graph
+ Pipeline to build a quantum graph from, as a graph. Will be resolved
+ in-place with the given butler (any existing resolution is ignored).
+ butler
+ Client for the data repository. Should be read-only.
+ data_ids
+ Mapping from dimension group to the data ID to use for that dimension
+ group. This is intended to allow the pipeline to switch between
+ effectively-equivalent dimensions (e.g. ``group``, ``visit``
+ ``exposure``).
+ input_refs
+ References for input datasets, keyed by task label and then connection
+ name. This should include all regular overall-input datasets whose
+ data IDs are not included in ``data_ids``. It may (but need not)
+ include prerequisite inputs. Existing intermediate datasets should
+ also be provided when they need to be clobbered or used in skip logic.
+ dataset_id_modes
+ Mapping from dataset type name to the ID generation mode for that
+ dataset type. They default is to generate random UUIDs.
+ **kwargs
+ Forwarded to the base `.quantum_graph_builder.QuantumGraphBuilder`.
+
+ Notes
+ -----
+ If ``dataset_id_modes`` is provided, ``clobber=True`` will be passed to
+ the base builder's constructor, as is this is necessary to avoid spurious
+ errors about the affected datasets already existing. The only effect of
+ this to silence *other* errors about datasets in the output run existing
+ unexpectedly.
+ """
+
+ def __init__(
+ self,
+ pipeline_graph: PipelineGraph,
+ butler: Butler,
+ *,
+ data_ids: Mapping[DimensionGroup, DataCoordinate],
+ input_refs: Mapping[str, Mapping[str, Sequence[DatasetRef]]] | None = None,
+ dataset_id_modes: Mapping[str, DatasetIdGenEnum] | None = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(pipeline_graph, butler, **kwargs)
+ if dataset_id_modes:
+ self.clobber = True
+ self.data_ids = dict(data_ids)
+ self.data_ids[self.empty_data_id.dimensions] = self.empty_data_id
+ self.input_refs = input_refs or {}
+ self.dataset_id_modes = dataset_id_modes or {}
+
+ def _get_data_id(self, dimensions: DimensionGroup, context: str) -> DataCoordinate:
+ try:
+ return self.data_ids[dimensions]
+ except KeyError as e:
+ e.add_note(context)
+ raise
+
+ @timeMethod
+ def process_subgraph(self, subgraph: PipelineGraph) -> QuantumGraphSkeleton:
+ skeleton = QuantumGraphSkeleton(subgraph.tasks)
+ for task_node in subgraph.tasks.values():
+ quantum_key = skeleton.add_quantum_node(
+ task_node.label, self._get_data_id(task_node.dimensions, context=f"task {task_node.label!r}")
+ )
+ input_refs_for_task = self.input_refs.get(task_node.label, {})
+
+ for read_edge in task_node.iter_all_inputs():
+ if (input_refs := input_refs_for_task.get(read_edge.connection_name)) is not None:
+ for input_ref in input_refs:
+ if read_edge.is_prerequisite:
+ prereq_key = skeleton.add_prerequisite_node(input_ref)
+ skeleton.add_input_edge(quantum_key, prereq_key)
+ self.log.info(
+ f"Added prereq {task_node.label}.{read_edge.connection_name} "
+ f"for {input_ref.dataId} from input_refs"
+ )
+ else:
+ input_key = skeleton.add_dataset_node(
+ read_edge.parent_dataset_type_name,
+ input_ref.dataId,
+ ref=input_ref,
+ )
+ skeleton.add_input_edge(quantum_key, input_key)
+ self.log.info(
+ f"Added regular input {task_node.label}.{read_edge.connection_name} "
+ f"for {input_ref.dataId} from input_refs"
+ )
+
+ if read_edge.is_prerequisite:
+ continue
+ dataset_type_node = subgraph.dataset_types[read_edge.parent_dataset_type_name]
+ data_id = self._get_data_id(
+ dataset_type_node.dimensions,
+ context=f"input {task_node.label}.{read_edge.connection_name}",
+ )
+ input_key = skeleton.add_dataset_node(
+ read_edge.parent_dataset_type_name,
+ data_id,
+ )
+ skeleton.add_input_edge(quantum_key, input_key)
+ if subgraph.producer_of(read_edge.parent_dataset_type_name) is None:
+ if skeleton.get_dataset_ref(input_key) is None:
+ ref = self.butler.find_dataset(dataset_type_node.dataset_type, data_id)
+ if ref is not None:
+ skeleton.set_dataset_ref(ref)
+ self.log.info(
+ f"Added regular input {task_node.label}.{read_edge.connection_name} for {data_id}"
+ )
+
+ for write_edge in task_node.iter_all_outputs():
+ dataset_type_node = subgraph.dataset_types[write_edge.parent_dataset_type_name]
+ data_id = self._get_data_id(
+ dataset_type_node.dimensions,
+ context=f"output {task_node.label}.{write_edge.connection_name}",
+ )
+ output_key = skeleton.add_dataset_node(write_edge.parent_dataset_type_name, data_id)
+ skeleton.add_output_edge(quantum_key, output_key)
+ self.log.info(f"Added output {task_node.label}.{write_edge.connection_name} for {data_id}")
+ if mode := self.dataset_id_modes.get(write_edge.parent_dataset_type_name):
+ ref = DatasetRef(
+ dataset_type_node.dataset_type,
+ data_id,
+ run=self.output_run,
+ id_generation_mode=mode,
+ )
+ skeleton.set_dataset_ref(ref)
+ skeleton.set_output_in_the_way(ref)
+ self.log.info(
+ f"Added ref for output {task_node.label}.{write_edge.connection_name} for "
+ f"{data_id} with {mode=}"
+ )
+
+ return skeleton
diff --git a/tests/test_blocking_limited_butler.py b/tests/test_blocking_limited_butler.py
new file mode 100644
index 000000000..21dd58602
--- /dev/null
+++ b/tests/test_blocking_limited_butler.py
@@ -0,0 +1,101 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+"""Tests for execution butler."""
+
+import logging
+import os
+import unittest
+
+import lsst.utils.tests
+from lsst.daf.butler import DataCoordinate, DatasetRef
+from lsst.pipe.base.blocking_limited_butler import _LOG, BlockingLimitedButler
+from lsst.pipe.base.tests.mocks import InMemoryRepo
+
+TESTDIR = os.path.abspath(os.path.dirname(__file__))
+
+
+class BlockingLimitedButlerTestCase(unittest.TestCase):
+ """Unit tests for BlockingLimitedButler"""
+
+ def test_no_block_nonexistent(self):
+ """Test checking/getting with no dataset and blocking disabled."""
+ helper = InMemoryRepo("base.yaml")
+ helper.add_task()
+ helper.pipeline_graph.resolve(helper.butler.registry)
+ ref = DatasetRef(
+ helper.pipeline_graph.dataset_types["dataset_auto0"].dataset_type,
+ DataCoordinate.make_empty(helper.butler.dimensions),
+ run="input_run",
+ )
+ helper.pipeline_graph.register_dataset_types(helper.butler)
+ in_memory_butler = helper.make_limited_butler()
+ blocking_butler = BlockingLimitedButler(in_memory_butler, timeouts={})
+ with self.assertNoLogs(_LOG, level=logging.INFO):
+ self.assertFalse(blocking_butler.stored_many([ref])[ref])
+ with self.assertRaises(FileNotFoundError):
+ blocking_butler.get(ref)
+
+ def test_timeout_nonexistent(self):
+ """Test checking/getting with no dataset, leading to a timeout."""
+ helper = InMemoryRepo("base.yaml")
+ helper.add_task()
+ helper.pipeline_graph.resolve(helper.butler.registry)
+ ref = DatasetRef(
+ helper.pipeline_graph.dataset_types["dataset_auto0"].dataset_type,
+ DataCoordinate.make_empty(helper.butler.dimensions),
+ run="input_run",
+ )
+ helper.pipeline_graph.register_dataset_types(helper.butler)
+ in_memory_butler = helper.make_limited_butler()
+ blocking_butler = BlockingLimitedButler(in_memory_butler, timeouts={"dataset_auto0": 0.1})
+ with self.assertLogs(_LOG, level=logging.INFO) as cm:
+ self.assertFalse(blocking_butler.stored_many([ref])[ref])
+ self.assertIn("not immediately available", cm.output[0])
+ with self.assertLogs(_LOG, level=logging.INFO) as cm:
+ with self.assertRaises(FileNotFoundError):
+ blocking_butler.get(ref)
+ self.assertIn("not immediately available", cm.output[0])
+
+ def test_no_waiting_if_exists(self):
+ """Test checking/getting with dataset present immediately, so no
+ waiting should be necessary.
+ """
+ helper = InMemoryRepo("base.yaml")
+ helper.add_task()
+ (ref,) = helper.insert_datasets("dataset_auto0")
+ helper.pipeline_graph.register_dataset_types(helper.butler)
+ in_memory_butler = helper.make_limited_butler()
+ blocking_butler = BlockingLimitedButler(in_memory_butler, timeouts={})
+ with self.assertNoLogs(_LOG, level=logging.INFO):
+ self.assertTrue(blocking_butler.stored_many([ref])[ref])
+ self.assertIsNotNone(blocking_butler.get(ref))
+
+
+if __name__ == "__main__":
+ lsst.utils.tests.init()
+ unittest.main()
diff --git a/tests/test_trivial_qg_builder.py b/tests/test_trivial_qg_builder.py
new file mode 100644
index 000000000..39e229327
--- /dev/null
+++ b/tests/test_trivial_qg_builder.py
@@ -0,0 +1,171 @@
+# This file is part of pipe_base.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This software is dual licensed under the GNU General Public License and also
+# under a 3-clause BSD license. Recipients may choose which of these licenses
+# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
+# respectively. If you choose the GPL option then the following text applies
+# (but note that there is still no warranty even if you opt for BSD instead):
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import annotations
+
+import unittest
+
+from lsst.daf.butler import DatasetIdGenEnum, DatasetRef
+from lsst.pipe.base.tests.mocks import DynamicConnectionConfig, InMemoryRepo
+from lsst.pipe.base.trivial_quantum_graph_builder import TrivialQuantumGraphBuilder
+
+
+class TrivialQuantumGraphBuilderTestCase(unittest.TestCase):
+ """Tests for the TrivialQuantumGraphBuilder class."""
+
+ def test_trivial_qg_builder(self) -> None:
+ # Make a test helper with a mock task appropriate for the QG builder:
+ # - the QG will have no branching
+ # - while the task have different dimensions, they can be 1-1 related
+ # (for the purposes of this test, at least).
+ helper = InMemoryRepo("base.yaml")
+ helper.add_task(
+ "a",
+ dimensions=["band", "detector"],
+ prerequisite_inputs={
+ "prereq_connection": DynamicConnectionConfig(
+ dataset_type_name="dataset_prereq0", dimensions=["detector"]
+ )
+ },
+ )
+ helper.add_task(
+ "b",
+ dimensions=["physical_filter", "detector"],
+ inputs={
+ "input_connection": DynamicConnectionConfig(
+ dataset_type_name="dataset_auto1", dimensions=["band", "detector"]
+ ),
+ "extra_input_connection": DynamicConnectionConfig(
+ dataset_type_name="dataset_extra1", dimensions=["physical_filter", "detector"]
+ ),
+ },
+ )
+ # Use the helper to make a quantum graph using the general-purpose
+ # builder. This will cover all data IDs in the test dataset, which
+ # includes 4 detectors, 3 physical_filters, and 2 bands.
+ # This also has useful side-effects: it inserts the input datasets
+ # and registers all dataset types.
+ general_qg = helper.make_quantum_graph()
+ # Make the trivial QG builder we want to test giving it only one
+ # detector and one band (is the one that corresponds to only one
+ # physical_filter).
+ (a_data_id,) = [
+ data_id
+ for data_id in general_qg.quanta_by_task["a"]
+ if data_id["detector"] == 1 and data_id["band"] == "g"
+ ]
+ (b_data_id,) = [
+ data_id
+ for data_id in general_qg.quanta_by_task["b"]
+ if data_id["detector"] == 1 and data_id["band"] == "g"
+ ]
+ prereq_data_id = a_data_id.subset(["detector"])
+ dataset_auto0_ref = helper.butler.get_dataset(general_qg.datasets_by_type["dataset_auto0"][a_data_id])
+ assert dataset_auto0_ref is not None, "Input dataset should have been inserted above."
+ dataset_prereq0_ref = helper.butler.get_dataset(
+ general_qg.datasets_by_type["dataset_prereq0"][prereq_data_id]
+ )
+ assert dataset_prereq0_ref is not None, "Input dataset should have been inserted above."
+ trivial_builder = TrivialQuantumGraphBuilder(
+ helper.pipeline_graph,
+ helper.butler,
+ data_ids={a_data_id.dimensions: a_data_id, b_data_id.dimensions: b_data_id},
+ input_refs={
+ "a": {"input_connection": [dataset_auto0_ref], "prereq_connection": [dataset_prereq0_ref]}
+ },
+ dataset_id_modes={"dataset_auto2": DatasetIdGenEnum.DATAID_TYPE_RUN},
+ output_run="trivial_output_run",
+ input_collections=general_qg.header.inputs,
+ )
+ trivial_qg = trivial_builder.finish(attach_datastore_records=False).assemble()
+ self.assertEqual(len(trivial_qg.quanta_by_task), 2)
+ self.assertEqual(trivial_qg.quanta_by_task["a"].keys(), {a_data_id})
+ self.assertEqual(trivial_qg.quanta_by_task["b"].keys(), {b_data_id})
+ self.assertEqual(trivial_qg.datasets_by_type["dataset_prereq0"].keys(), {prereq_data_id})
+ self.assertEqual(
+ trivial_qg.datasets_by_type["dataset_prereq0"][prereq_data_id],
+ general_qg.datasets_by_type["dataset_prereq0"][prereq_data_id],
+ )
+ self.assertEqual(trivial_qg.datasets_by_type["dataset_auto0"].keys(), {a_data_id})
+ self.assertEqual(
+ trivial_qg.datasets_by_type["dataset_auto0"][a_data_id],
+ general_qg.datasets_by_type["dataset_auto0"][a_data_id],
+ )
+ self.assertEqual(trivial_qg.datasets_by_type["dataset_extra1"].keys(), {b_data_id})
+ self.assertEqual(
+ trivial_qg.datasets_by_type["dataset_extra1"][b_data_id],
+ general_qg.datasets_by_type["dataset_extra1"][b_data_id],
+ )
+ self.assertEqual(trivial_qg.datasets_by_type["dataset_auto1"].keys(), {a_data_id})
+ self.assertNotEqual(
+ trivial_qg.datasets_by_type["dataset_auto1"][a_data_id],
+ general_qg.datasets_by_type["dataset_auto1"][a_data_id],
+ )
+ self.assertEqual(trivial_qg.datasets_by_type["dataset_auto2"].keys(), {b_data_id})
+ self.assertNotEqual(
+ trivial_qg.datasets_by_type["dataset_auto2"][b_data_id],
+ general_qg.datasets_by_type["dataset_auto2"][b_data_id],
+ )
+ self.assertEqual(
+ trivial_qg.datasets_by_type["dataset_auto2"][b_data_id],
+ DatasetRef(
+ helper.pipeline_graph.dataset_types["dataset_auto2"].dataset_type,
+ b_data_id,
+ run="trivial_output_run",
+ id_generation_mode=DatasetIdGenEnum.DATAID_TYPE_RUN,
+ ).id,
+ )
+ qo_xg = trivial_qg.quantum_only_xgraph
+ self.assertEqual(len(qo_xg.nodes), 2)
+ self.assertEqual(len(qo_xg.edges), 1)
+ bp_xg = trivial_qg.bipartite_xgraph
+ self.assertEqual(
+ set(bp_xg.predecessors(trivial_qg.quanta_by_task["a"][a_data_id])),
+ set(trivial_qg.datasets_by_type["dataset_auto0"].values())
+ | set(trivial_qg.datasets_by_type["dataset_prereq0"].values()),
+ )
+ self.assertEqual(
+ set(bp_xg.successors(trivial_qg.quanta_by_task["a"][a_data_id])),
+ set(trivial_qg.datasets_by_type["dataset_auto1"].values())
+ | set(trivial_qg.datasets_by_type["a_metadata"].values())
+ | set(trivial_qg.datasets_by_type["a_log"].values()),
+ )
+ self.assertEqual(
+ set(bp_xg.predecessors(trivial_qg.quanta_by_task["b"][b_data_id])),
+ set(trivial_qg.datasets_by_type["dataset_auto1"].values())
+ | set(trivial_qg.datasets_by_type["dataset_extra1"].values()),
+ )
+ self.assertEqual(
+ set(bp_xg.successors(trivial_qg.quanta_by_task["b"][b_data_id])),
+ set(trivial_qg.datasets_by_type["dataset_auto2"].values())
+ | set(trivial_qg.datasets_by_type["b_metadata"].values())
+ | set(trivial_qg.datasets_by_type["b_log"].values()),
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()