Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions python/lsst/pipe/base/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,7 +803,11 @@ def _addConfigImpl(self, label: str, newConfig: pipelineIR.ConfigIR) -> None:
return
if label not in self._pipelineIR.tasks:
raise LookupError(f"There are no tasks labeled '{label}' in the pipeline")
self._pipelineIR.tasks[label].add_or_update_config(newConfig)
match self._pipelineIR.tasks[label]:
case pipelineIR.TaskIR() as task:
task.add_or_update_config(newConfig)
case pipelineIR._AmbigousTask() as ambig_task:
ambig_task.tasks[-1].add_or_update_config(newConfig)

def write_to_uri(self, uri: ResourcePathExpression) -> None:
"""Write the pipeline to a file or directory.
Expand Down Expand Up @@ -845,6 +849,7 @@ def to_graph(
graph : `pipeline_graph.PipelineGraph`
Representation of the pipeline as a graph.
"""
self._pipelineIR.resolve_task_ambiguity()
instrument_class_name = self._pipelineIR.instrument
data_id = {}
if instrument_class_name is not None:
Expand Down Expand Up @@ -888,7 +893,8 @@ def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) ->
"""
if (taskIR := self._pipelineIR.tasks.get(label)) is None:
raise NameError(f"Label {label} does not appear in this pipeline")
taskClass: type[PipelineTask] = doImportType(taskIR.klass)
# type ignore here because all ambiguity should be resolved
taskClass: type[PipelineTask] = doImportType(taskIR.klass) # type: ignore
config = taskClass.ConfigClass()
instrument: PipeBaseInstrument | None = None
if (instrumentName := self._pipelineIR.instrument) is not None:
Expand All @@ -897,7 +903,8 @@ def _add_task_to_graph(self, label: str, graph: pipeline_graph.PipelineGraph) ->
config.applyConfigOverrides(
instrument,
getattr(taskClass, "_DefaultName", ""),
taskIR.config,
# type ignore here because all ambiguity should be resolved
taskIR.config, # type: ignore
self._pipelineIR.parameters,
label,
)
Expand Down
68 changes: 60 additions & 8 deletions python/lsst/pipe/base/pipelineIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@
from collections import Counter
from collections.abc import Generator, Hashable, Iterable, MutableMapping
from dataclasses import dataclass, field
from typing import Any, Literal
from typing import Any, Literal, cast

import yaml

from lsst.resources import ResourcePath, ResourcePathExpression
from lsst.utils import doImportType
from lsst.utils.introspection import find_outside_stacklevel


Expand Down Expand Up @@ -443,6 +444,34 @@ def __eq__(self, other: object) -> bool:
return all(getattr(self, attr) == getattr(other, attr) for attr in ("label", "klass", "config"))


@dataclass
class _AmbigousTask:
"""Representation of tasks which may have conflicting task classes."""

tasks: list[TaskIR]
"""TaskIR objects that need to be compaired late."""

def resolve(self) -> TaskIR:
true_taskIR = self.tasks[0]
task_class = doImportType(true_taskIR.klass)
# need to find out if they are all actually the same
for tmp_taskIR in self.tasks[1:]:
tmp_task_class = doImportType(tmp_taskIR.klass)
if tmp_task_class is task_class:
if tmp_taskIR.config is None:
continue
for config in tmp_taskIR.config:
true_taskIR.add_or_update_config(config)
else:
true_taskIR = tmp_taskIR
task_class = tmp_task_class
return true_taskIR

def to_primitives(self) -> dict[str, str | list[dict]]:
true_task = self.resolve()
return true_task.to_primitives()


@dataclass
class ImportIR:
"""An intermediate representation of imported pipelines."""
Expand Down Expand Up @@ -778,7 +807,7 @@ def merge_pipelines(self, pipelines: Iterable[PipelineIR]) -> None:
existing in this object.
"""
# integrate any imported pipelines
accumulate_tasks: dict[str, TaskIR] = {}
accumulate_tasks: dict[str, TaskIR | _AmbigousTask] = {}
accumulate_labeled_subsets: dict[str, LabeledSubset] = {}
accumulated_parameters = ParametersIR({})
accumulated_steps: dict[str, StepIR] = {}
Expand Down Expand Up @@ -842,17 +871,39 @@ def merge_pipelines(self, pipelines: Iterable[PipelineIR]) -> None:
for label, task in self.tasks.items():
if label not in accumulate_tasks:
accumulate_tasks[label] = task
elif accumulate_tasks[label].klass == task.klass:
if task.config is not None:
for config in task.config:
accumulate_tasks[label].add_or_update_config(config)
else:
accumulate_tasks[label] = task
self.tasks: dict[str, TaskIR] = accumulate_tasks
match (accumulate_tasks[label], task):
case (TaskIR() as taskir_obj, TaskIR() as ctask) if taskir_obj.klass == ctask.klass:
if ctask.config is not None:
for config in ctask.config:
taskir_obj.add_or_update_config(config)
case (TaskIR(klass=klass) as taskir_obj, TaskIR() as ctask) if klass != ctask.klass:
accumulate_tasks[label] = _AmbigousTask([taskir_obj, ctask])
case (_AmbigousTask(ambig_list), TaskIR() as ctask):
ambig_list.append(ctask)
case (TaskIR() as taskir_obj, _AmbigousTask(ambig_list)):
accumulate_tasks[label] = _AmbigousTask([taskir_obj] + ambig_list)
case (_AmbigousTask(existing_ambig_list), _AmbigousTask(new_ambig_list)):
existing_ambig_list.extend(new_ambig_list)

self.tasks: MutableMapping[str, TaskIR | _AmbigousTask] = accumulate_tasks
accumulated_parameters.update(self.parameters)
self.parameters = accumulated_parameters
self.steps = list(accumulated_steps.values())

def resolve_task_ambiguity(self) -> None:
new_tasks: dict[str, TaskIR] = {}
for label, task in self.tasks.items():
match task:
case TaskIR():
new_tasks[label] = task
case _AmbigousTask():
new_tasks[label] = task.resolve()
# Do a cast here, because within this function body we want the
# protection that all the tasks are TaskIR objects, but for the
# task level variable, it must stay the same mixed dictionary.
self.tasks = cast(dict[str, TaskIR | _AmbigousTask], new_tasks)

def _read_tasks(self, loaded_yaml: dict[str, Any]) -> None:
"""Process the tasks portion of the loaded yaml document

Expand All @@ -870,6 +921,7 @@ def _read_tasks(self, loaded_yaml: dict[str, Any]) -> None:
if "parameters" in tmp_tasks:
raise ValueError("parameters is a reserved word and cannot be used as a task label")

definition: str | dict[str, Any]
for label, definition in tmp_tasks.items():
if isinstance(definition, str):
definition = {"class": definition}
Expand Down
47 changes: 47 additions & 0 deletions python/lsst/pipe/base/tests/pipelineIRTestClasses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# This file is part of pipe_base.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (http://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This software is dual licensed under the GNU General Public License and also
# under a 3-clause BSD license. Recipients may choose which of these licenses
# to use; please see the files gpl-3.0.txt and/or bsd_license.txt,
# respectively. If you choose the GPL option then the following text applies
# (but note that there is still no warranty even if you opt for BSD instead):
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.

"""Module defining PipelineIR test classes."""

from __future__ import annotations

__all__ = ("ModuleA", "ModuleAAlias", "ModuleAReplace")


class ModuleA:
"""PipelineIR test class for importing."""

pass


ModuleAAlias = ModuleA


class ModuleAReplace:
"""PipelineIR test class for importing."""

pass
97 changes: 97 additions & 0 deletions python/lsst/pipe/base/tests/simpleQGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,103 @@ def makeTask(
return task


class SubTaskConnections(
PipelineTaskConnections,
dimensions=("instrument", "detector"),
defaultTemplates={"in_tmpl": "_in", "out_tmpl": "_out"},
):
"""Connections for SubTask, has one input and two outputs,
plus one init output.
"""

input = cT.Input(
name="add_dataset{in_tmpl}",
dimensions=["instrument", "detector"],
storageClass="NumpyArray",
doc="Input dataset type for this task",
)
output = cT.Output(
name="add_dataset{out_tmpl}",
dimensions=["instrument", "detector"],
storageClass="NumpyArray",
doc="Output dataset type for this task",
)
output2 = cT.Output(
name="add2_dataset{out_tmpl}",
dimensions=["instrument", "detector"],
storageClass="NumpyArray",
doc="Output dataset type for this task",
)
initout = cT.InitOutput(
name="add_init_output{out_tmpl}",
storageClass="NumpyArray",
doc="Init Output dataset type for this task",
)


class SubTaskConfig(PipelineTaskConfig, pipelineConnections=SubTaskConnections):
"""Config for SubTask."""

subtract = pexConfig.Field[int](doc="amount to subtract", default=3)


class SubTask(PipelineTask):
"""Trivial PipelineTask for testing, has some extras useful for specific
unit tests.
"""

ConfigClass = SubTaskConfig
_DefaultName = "sub_task"

initout = numpy.array([999])
"""InitOutputs for this task"""

taskFactory: SubTaskFactoryMock | None = None
"""Factory that makes instances"""

def run(self, input: int) -> Struct:
if self.taskFactory:
# do some bookkeeping
if self.taskFactory.stopAt == self.taskFactory.countExec:
raise RuntimeError("pretend something bad happened")
self.taskFactory.countExec -= 1

self.config = cast(SubTaskConfig, self.config)
self.metadata.add("sub", self.config.subtract)
output = input - self.config.subtract
output2 = output + self.config.subtract
_LOG.info("input = %s, output = %s, output2 = %s", input, output, output2)
return Struct(output=output, output2=output2)


class SubTaskFactoryMock(TaskFactory):
"""Special task factory that instantiates AddTask.

It also defines some bookkeeping variables used by SubTask to report
progress to unit tests.

Parameters
----------
stopAt : `int`, optional
Number of times to call `run` before stopping.
"""

def __init__(self, stopAt: int = -1):
self.countExec = 100 # reduced by SubTask
self.stopAt = stopAt # AddTask raises exception at this call to run()

def makeTask(
self,
task_node: TaskNode,
/,
butler: LimitedButler,
initInputRefs: Iterable[DatasetRef] | None,
) -> PipelineTask:
task = task_node.task_class(config=task_node.config, initInputs=None, name=task_node.label)
task.taskFactory = self # type: ignore
return task


def registerDatasetTypes(registry: Registry, pipeline: Pipeline | PipelineGraph) -> None:
"""Register all dataset types used by tasks in a registry.

Expand Down
2 changes: 1 addition & 1 deletion tests/testPipeline2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ parameters:
value3: valueC
tasks:
modA:
class: "test.moduleA"
class: "lsst.pipe.base.tests.pipelineIRTestClasses.ModuleA"
config:
value1: 1
subsets:
Expand Down
44 changes: 43 additions & 1 deletion tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import lsst.utils.tests
from lsst.pipe.base import LabelSpecifier, Pipeline, TaskDef
from lsst.pipe.base.pipelineIR import LabeledSubset
from lsst.pipe.base.tests.simpleQGraph import AddTask, makeSimplePipeline
from lsst.pipe.base.tests.simpleQGraph import AddTask, SubTask, makeSimplePipeline


class PipelineTestCase(unittest.TestCase):
Expand Down Expand Up @@ -130,6 +130,48 @@ def testMergingPipelines(self):
pipeline1.mergePipeline(pipeline2)
self.assertEqual(pipeline1._pipelineIR.tasks.keys(), {"task0", "task1", "task2", "task3"})

# Test merging pipelines with ambiguous tasks
pipeline1 = makeSimplePipeline(2)
pipeline2 = makeSimplePipeline(2)
pipeline2.addTask(SubTask, "task1")
pipeline2.mergePipeline(pipeline1)

# Now merge in another pipeline with a config applied.
pipeline3 = makeSimplePipeline(2)
pipeline3.addTask(SubTask, "task1")
pipeline3.addConfigOverride("task1", "subtract", 10)
pipeline3.mergePipeline(pipeline2)
graph = pipeline3.to_graph()
# assert equality from the graph to trigger ambiquity resolution
self.assertEqual(graph.tasks["task1"].config.subtract, 10)

# Now change the order of the merging
pipeline1 = makeSimplePipeline(2)
pipeline2 = makeSimplePipeline(2)
pipeline2.addTask(SubTask, "task1")
pipeline3 = makeSimplePipeline(2)
pipeline3.mergePipeline(pipeline2)
pipeline3.mergePipeline(pipeline1)
graph = pipeline3.to_graph()
# assert equality from the graph to trigger ambiquity resolution
self.assertEqual(graph.tasks["task1"].config.addend, 3)

# Now do two ambiguous chains
pipeline1 = makeSimplePipeline(2)
pipeline2 = makeSimplePipeline(2)
pipeline2.addTask(SubTask, "task1")
pipeline2.addConfigOverride("task1", "subtract", 10)
pipeline2.mergePipeline(pipeline1)

pipeline3 = makeSimplePipeline(2)
pipeline4 = makeSimplePipeline(2)
pipeline4.addTask(SubTask, "task1")
pipeline4.addConfigOverride("task1", "subtract", 7)
pipeline4.mergePipeline(pipeline3)
graph = pipeline4.to_graph()
# assert equality from the graph to trigger ambiquity resolution
self.assertEqual(graph.tasks["task1"].config.subtract, 7)

def testFindingSubset(self):
pipeline = makeSimplePipeline(2)
pipeline._pipelineIR.labeled_subsets["test1"] = LabeledSubset("test1", set(), None)
Expand Down
Loading