diff --git a/docs/tutorial/builder/graph_builder.md b/docs/tutorial/builder/graph_builder.md index 27eeeb4f02..37866f9990 100644 --- a/docs/tutorial/builder/graph_builder.md +++ b/docs/tutorial/builder/graph_builder.md @@ -339,9 +339,11 @@ The subgraph automatically inherits the opset version from the parent ### Type annotations for subgraph inputs and outputs -`subgraph()` accepts `input_types` and `output_types` lists that describe -the types and shapes of each input and output. Each element can be either an -`ir.TypeAndShape` object or — more conveniently — an +`subgraph()` accepts `inputs` and `outputs` that describe +the types and shapes of each input and output. They can be provided as a +:class:`list` of type specs (names are auto-generated) **or** as a +:class:`dict` mapping explicit names to type specs. Each type spec can be +either an `ir.TypeAndShape` object or — more conveniently — an `onnxscript` tensor-type expression: | Expression | Meaning | @@ -408,8 +410,8 @@ def cumsum_body(op, state, x_i): body = builder.subgraph( cumsum_body, - input_types=[FLOAT[D], FLOAT[D]], # state, x_i - output_types=[FLOAT[D], FLOAT[D]], # new_state, scan_out_i + inputs=[FLOAT[D], FLOAT[D]], # state, x_i + outputs=[FLOAT[D], FLOAT[D]], # new_state, scan_out_i name="cumsum_body", ) @@ -430,7 +432,7 @@ model = ir.Model(graph=graph, ir_version=10) Key points: -- `builder.subgraph(fn, input_types, output_types)` creates a fresh +- `builder.subgraph(fn, inputs, outputs)` creates a fresh `ir.Graph`, calls `fn(op, *inputs)` to trace the body, and wires up the declared input/output types. - The `fn` receives an `OpBuilder` as its first argument — exactly the same @@ -450,8 +452,8 @@ def outer_body(op, state, x_i): # Build a nested subgraph inside the scan body inner = op.builder.subgraph( lambda iop, v: iop.Relu(v), - input_types=[FLOAT[D]], - output_types=[FLOAT[D]], + inputs=[FLOAT[D]], + outputs=[FLOAT[D]], name="relu_body", ) # ... use inner as a graph attribute of a nested op ... diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 06b6edaa85..5630e7249d 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -10,7 +10,7 @@ from __future__ import annotations -from typing import Any, Callable, Sequence, Union +from typing import Any, Callable, Mapping, Sequence, Union import onnx import onnx_ir as ir @@ -74,31 +74,135 @@ def _constant_name( return f"const_1d_{num}" -# Type accepted as an element of *input_types* / *output_types* by +# Type accepted as an element of *inputs* / *outputs* by # :meth:`GraphBuilder.subgraph`. Can be an already-resolved # :class:`ir.TypeAndShape`, or a # :class:`~onnxscript.onnx_types.TensorType` subclass such as ``FLOAT[1024]``. TypeSpec = Union[ir.TypeAndShape, Any] +# Acceptable collection forms for *inputs* / *outputs* in +# :meth:`GraphBuilder.subgraph`. A :class:`Sequence` of :data:`TypeSpec` +# auto-names entries (``input_0``, ``input_1``, …), while a :class:`Mapping` +# from :class:`str` to :data:`TypeSpec` uses the keys as explicit names. +InputOutputSpec = Union[Sequence[TypeSpec], Mapping[str, TypeSpec]] + def _resolve_type_spec(spec: TypeSpec) -> ir.TypeAndShape: """Convert a *TypeSpec* to an :class:`ir.TypeAndShape`. - Accepts either an :class:`ir.TypeAndShape` directly, or a - :class:`~onnxscript.onnx_types.TensorType` subclass (e.g. ``FLOAT[1024]`` - or ``FLOAT['M', 'N']``). + Accepts an :class:`ir.TypeAndShape` directly, or any object with a + ``to_ir_type_and_shape()`` method (e.g. a + :class:`~onnxscript.onnx_types.TensorType` subclass such as + ``FLOAT[1024]`` or ``FLOAT['M', 'N']``). """ - # Lazy import to avoid a circular dependency: onnxscript.__init__ imports - # onnx_types (line ~106) before builder (line ~132), so by the time any - # call reaches here the module is fully initialised — but a top-level - # import in builder.py could break if builder is ever imported first. - from onnxscript.onnx_types import TensorType # pylint: disable=import-outside-toplevel - if isinstance(spec, ir.TypeAndShape): return spec - if isinstance(spec, type) and issubclass(spec, TensorType): - return spec.to_ir() - raise TypeError(f"Expected ir.TypeAndShape or a TensorType subclass, got {type(spec)!r}.") + if hasattr(spec, "to_ir_type_and_shape"): + result = spec.to_ir_type_and_shape() + if not isinstance(result, ir.TypeAndShape): + raise TypeError( + f"{type(spec)!r}.to_ir_type_and_shape() returned {type(result)!r}, " + f"expected ir.TypeAndShape." + ) + return result + raise TypeError( + f"Expected ir.TypeAndShape or an object with a to_ir_type_and_shape() method, " + f"got {type(spec)!r}." + ) + + +def _normalize_io_spec( + spec: InputOutputSpec, default_prefix: str +) -> list[tuple[str, ir.TypeAndShape]]: + """Normalize an *InputOutputSpec* into a list of ``(name, TypeAndShape)`` pairs. + + When *spec* is a :class:`Mapping`, the keys are used as names. When it is + a plain :class:`Sequence`, names are generated as + ``{default_prefix}_0``, ``{default_prefix}_1``, etc. + """ + if isinstance(spec, Mapping): + return [(name, _resolve_type_spec(ts)) for name, ts in spec.items()] + return [(f"{default_prefix}_{i}", _resolve_type_spec(ts)) for i, ts in enumerate(spec)] + + +def build_graph( + trace_function: Callable, + inputs: InputOutputSpec, + outputs: InputOutputSpec, + *, + opset_imports: dict[str, int] | None = None, + name: str = "subgraph", +) -> ir.Graph: + """Build an :class:`ir.Graph` suitable for use as a graph-valued attribute. + + This is a module-level utility that constructs a subgraph by tracing + *trace_function*. It is useful for building body graphs of control-flow ops + such as ``Scan``, ``Loop``, and ``If``. + + Example - building a Scan body that adds two sequences element-wise:: + + body = build_graph( + lambda op, x, y: op.Add(x, y), + inputs={"x": FLOAT[...], "y": FLOAT[...]}, + outputs={"sum": FLOAT[...]}, + ) + + Args: + trace_function: A callable with signature + ``(op: OpBuilder, *inputs: ir.Value) -> ir.Value | Sequence[ir.Value]``. + It is called once with freshly created placeholder inputs to record the + graph topology. + inputs: Types (and optionally names) for each graph input. May be a + :class:`Sequence` of :data:`TypeSpec` values (names are auto-generated + as ``input_0``, ``input_1``, …) **or** a :class:`Mapping` from + :class:`str` names to :data:`TypeSpec` values. Each :data:`TypeSpec` + can be an :class:`ir.TypeAndShape` or a + :class:`~onnxscript.onnx_types.TensorType` subclass (e.g. + ``FLOAT[1024]`` or ``FLOAT['M', 'N']``). + outputs: Types (and optionally names) for each graph output, in the + same format as *inputs*. + opset_imports: Opset version map for the subgraph (e.g. + ``{"": 23}``). Defaults to ``{"": 23}`` when *None*. + name: Name of the resulting :class:`ir.Graph`. + + Returns: + An :class:`ir.Graph` whose inputs and outputs are populated and whose + nodes record the operations traced by *trace_function*. This graph can be + passed directly as a graph-valued attribute (e.g. the ``body`` attribute of + a ``Scan`` or ``Loop`` node). + """ + if opset_imports is None: + opset_imports = {"": 23} + resolved_inputs = _normalize_io_spec(inputs, "input") + resolved_outputs = _normalize_io_spec(outputs, "output") + + subgraph = ir.Graph( + name=name, + inputs=[], + outputs=[], + nodes=[], + opset_imports=opset_imports, + ) + + for input_name, ts in resolved_inputs: + subgraph.inputs.append(ir.Value(name=input_name, type=ts.type, shape=ts.shape)) + + sub_builder = GraphBuilder(subgraph) + trace_outputs = trace_function(sub_builder.op, *subgraph.inputs) + if not isinstance(trace_outputs, Sequence): + trace_outputs = [trace_outputs] + if len(trace_outputs) != len(resolved_outputs): + raise ValueError( + f"trace_function returned {len(trace_outputs)} output(s), " + f"but {len(resolved_outputs)} were declared in outputs." + ) + for output, (output_name, ts) in zip(trace_outputs, resolved_outputs): + output.name = output_name + output.type = ts.type + output.merge_shapes(ts.shape) + + subgraph.outputs.extend(trace_outputs) + return subgraph class GraphBuilder: @@ -332,8 +436,8 @@ def add_node(self, node: ir.Node) -> None: def subgraph( self, trace_function: Callable, - input_types: Sequence[TypeSpec], - output_types: Sequence[TypeSpec], + inputs: InputOutputSpec, + outputs: InputOutputSpec, *, name: str = "subgraph", ) -> ir.Graph: @@ -347,8 +451,17 @@ def subgraph( body = graph_builder.subgraph( lambda op, x, y: op.Add(x, y), - input_types=[FLOAT[...], FLOAT[...]], - output_types=[FLOAT[...]], + inputs=[FLOAT[...], FLOAT[...]], + outputs=[FLOAT[...]], + ) + + Inputs and outputs can also be given as a :class:`dict` to assign + explicit names:: + + body = graph_builder.subgraph( + lambda op, x, y: op.Add(x, y), + inputs={"x": FLOAT[...], "y": FLOAT[...]}, + outputs={"sum": FLOAT[...]}, ) Args: @@ -356,12 +469,15 @@ def subgraph( ``(op: OpBuilder, *inputs: ir.Value) -> ir.Value | Sequence[ir.Value]``. It is called once with freshly created placeholder inputs to record the graph topology. - input_types: Types for each graph input. Each element may be an - :class:`ir.TypeAndShape` **or** a + inputs: Types (and optionally names) for each graph input. May be a + :class:`Sequence` of :data:`TypeSpec` values (names are auto-generated + as ``input_0``, ``input_1``, …) **or** a :class:`Mapping` from + :class:`str` names to :data:`TypeSpec` values. Each :data:`TypeSpec` + can be an :class:`ir.TypeAndShape` or a :class:`~onnxscript.onnx_types.TensorType` subclass (e.g. ``FLOAT[1024]`` or ``FLOAT['M', 'N']``). - output_types: Types for each graph output, in the same format as - *input_types*. + outputs: Types (and optionally names) for each graph output, in the + same format as *inputs*. name: Name of the resulting :class:`ir.Graph`. Returns: @@ -370,37 +486,14 @@ def subgraph( passed directly as a graph-valued attribute (e.g. the ``body`` attribute of a ``Scan`` or ``Loop`` node). """ - opset_version = self._graph.opset_imports[""] - resolved_inputs = [_resolve_type_spec(t) for t in input_types] - resolved_outputs = [_resolve_type_spec(t) for t in output_types] - - subgraph = ir.Graph( + return build_graph( + trace_function, + inputs, + outputs, + opset_imports=dict(self._graph.opset_imports), name=name, - inputs=[], - outputs=[], - nodes=[], - opset_imports={"": opset_version}, ) - for i, ts in enumerate(resolved_inputs): - subgraph.inputs.append(ir.Value(name=f"input_{i}", type=ts.type, shape=ts.shape)) - - sub_builder = GraphBuilder(subgraph) - outputs = trace_function(sub_builder.op, *subgraph.inputs) - if not isinstance(outputs, Sequence): - outputs = [outputs] - if len(outputs) != len(resolved_outputs): - raise ValueError( - f"trace_function returned {len(outputs)} output(s), " - f"but {len(resolved_outputs)} were declared in output_types." - ) - for output, ts in zip(outputs, resolved_outputs): - output.type = ts.type - output.merge_shapes(ts.shape) - - subgraph.outputs.extend(outputs) - return subgraph - def call_op( self, op_type: str, diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index ffc1ab44a4..9a98e2f425 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -869,8 +869,8 @@ def _add(op, x, y): gb = self._make_builder() graph = gb.subgraph( _add, - input_types=[FLOAT[3, 4], FLOAT[3, 4]], - output_types=[FLOAT[3, 4]], + inputs=[FLOAT[3, 4], FLOAT[3, 4]], + outputs=[FLOAT[3, 4]], ) self.assertIsInstance(graph, ir.Graph) self.assertEqual(len(graph.inputs), 2) @@ -883,8 +883,8 @@ def test_subgraph_inherits_opset_version(self): gb = self._make_builder(opset_version=17) graph = gb.subgraph( lambda op, x: op.Identity(x), - input_types=[FLOAT[...]], - output_types=[FLOAT[...]], + inputs=[FLOAT[...]], + outputs=[FLOAT[...]], ) self.assertEqual(graph.opset_imports[""], 17) @@ -898,8 +898,8 @@ def _mul(op, x, y): gb = self._make_builder() graph = gb.subgraph( _mul, - input_types=[float_2d, float_2d], - output_types=[float_2d], + inputs=[float_2d, float_2d], + outputs=[float_2d], ) self.assertIsInstance(graph, ir.Graph) self.assertEqual(len(list(graph)), 1) @@ -915,8 +915,8 @@ def _add_and_mul(op, x, y): gb = self._make_builder() graph = gb.subgraph( _add_and_mul, - input_types=[ts, ts], - output_types=[ts, ts], + inputs=[ts, ts], + outputs=[ts, ts], ) self.assertEqual(len(graph.outputs), 2) @@ -930,8 +930,8 @@ def _returns_one(op, x, y): with self.assertRaises(ValueError): gb.subgraph( _returns_one, - input_types=[FLOAT[...], FLOAT[...]], - output_types=[FLOAT[...], FLOAT[...]], # expects 2, gets 1 + inputs=[FLOAT[...], FLOAT[...]], + outputs=[FLOAT[...], FLOAT[...]], # expects 2, gets 1 ) def test_subgraph_custom_name(self): @@ -943,8 +943,8 @@ def _id(op, x): gb = self._make_builder() graph = gb.subgraph( _id, - input_types=[DOUBLE[...]], - output_types=[DOUBLE[...]], + inputs=[DOUBLE[...]], + outputs=[DOUBLE[...]], name="scan_body", ) self.assertEqual(graph.name, "scan_body") @@ -959,10 +959,72 @@ def _id(op, x): with self.assertRaises(TypeError): gb.subgraph( _id, - input_types=["not_a_type_spec"], - output_types=["not_a_type_spec"], + inputs=["not_a_type_spec"], + outputs=["not_a_type_spec"], ) + def test_subgraph_dict_inputs_outputs(self): + """Subgraph accepts a dict to name inputs and outputs.""" + + def _add(op, x, y): + return op.Add(x, y) + + gb = self._make_builder() + graph = gb.subgraph( + _add, + inputs={"x": FLOAT[3, 4], "y": FLOAT[3, 4]}, + outputs={"sum": FLOAT[3, 4]}, + ) + self.assertIsInstance(graph, ir.Graph) + self.assertEqual(len(graph.inputs), 2) + self.assertEqual(graph.inputs[0].name, "x") + self.assertEqual(graph.inputs[1].name, "y") + self.assertEqual(len(graph.outputs), 1) + self.assertEqual(graph.outputs[0].name, "sum") + + def test_subgraph_list_auto_names(self): + """List-based inputs/outputs get auto-generated names.""" + + def _id(op, x): + return op.Identity(x) + + gb = self._make_builder() + graph = gb.subgraph( + _id, + inputs=[FLOAT[...]], + outputs=[FLOAT[...]], + ) + self.assertEqual(graph.inputs[0].name, "input_0") + self.assertEqual(graph.outputs[0].name, "output_0") + + +class BuildGraphFunctionTest(unittest.TestCase): + """Tests for the module-level build_graph() utility.""" + + def test_build_graph_basic(self): + """build_graph works without a parent GraphBuilder.""" + graph = builder.build_graph( + lambda op, x, y: op.Add(x, y), + inputs={"x": FLOAT[3, 4], "y": FLOAT[3, 4]}, + outputs={"sum": FLOAT[3, 4]}, + opset_imports={"": 20}, + ) + self.assertIsInstance(graph, ir.Graph) + self.assertEqual(graph.opset_imports[""], 20) + self.assertEqual(graph.inputs[0].name, "x") + self.assertEqual(graph.inputs[1].name, "y") + self.assertEqual(graph.outputs[0].name, "sum") + + def test_build_graph_custom_name(self): + """build_graph passes name to the ir.Graph.""" + graph = builder.build_graph( + lambda op, x: op.Identity(x), + inputs=[FLOAT[...]], + outputs=[FLOAT[...]], + name="loop_body", + ) + self.assertEqual(graph.name, "loop_body") + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/onnx_types.py b/onnxscript/onnx_types.py index b0a4006329..d03649b9cf 100644 --- a/onnxscript/onnx_types.py +++ b/onnxscript/onnx_types.py @@ -100,13 +100,13 @@ def to_type_proto(cls) -> onnx.TypeProto: return onnx.helper.make_tensor_type_proto(cls.dtype, shape) # noqa: TID251 @classmethod - def to_ir(cls) -> ir.TypeAndShape: + def to_ir_type_and_shape(cls) -> ir.TypeAndShape: """Return an :class:`ir.TypeAndShape` representing this tensor type and shape. This enables using ONNX Script tensor-type notation (e.g. ``FLOAT[1024]`` or ``FLOAT['M', 'N']``) wherever an :class:`ir.TypeAndShape` is expected, - such as the *input_types* / *output_types* arguments of - :func:`onnxscript._internal.builder.build_subgraph`. + such as the *inputs* / *outputs* arguments of + :func:`onnxscript._internal.builder.build_graph`. """ ir_type = ir.TensorType(cls.dtype) if cls.shape is None: diff --git a/onnxscript/onnx_types_test.py b/onnxscript/onnx_types_test.py index 9898188209..9103abfbd8 100644 --- a/onnxscript/onnx_types_test.py +++ b/onnxscript/onnx_types_test.py @@ -11,11 +11,11 @@ class TensorTypeToIrTest(unittest.TestCase): - """Tests for TensorType.to_ir().""" + """Tests for TensorType.to_ir_type_and_shape().""" def test_scalar_type(self): """FLOAT (no subscript) maps to rank-0 tensor (empty shape).""" - ts = FLOAT.to_ir() + ts = FLOAT.to_ir_type_and_shape() self.assertIsInstance(ts, ir.TypeAndShape) self.assertEqual(ts.type, ir.TensorType(ir.DataType.FLOAT)) self.assertIsNotNone(ts.shape) @@ -23,20 +23,20 @@ def test_scalar_type(self): def test_unknown_rank(self): """FLOAT[...] maps to unknown-rank (shape=None).""" - ts = FLOAT[...].to_ir() + ts = FLOAT[...].to_ir_type_and_shape() self.assertIsInstance(ts, ir.TypeAndShape) self.assertIsNone(ts.shape) def test_single_dim(self): """FLOAT[1024] maps to a 1-D tensor with dimension 1024.""" - ts = FLOAT[1024].to_ir() + ts = FLOAT[1024].to_ir_type_and_shape() self.assertIsNotNone(ts.shape) self.assertEqual(len(ts.shape), 1) self.assertEqual(ts.shape[0], 1024) def test_multi_dim_int(self): """FLOAT[3, 4] maps to a 2-D tensor with dims (3, 4).""" - ts = FLOAT[3, 4].to_ir() + ts = FLOAT[3, 4].to_ir_type_and_shape() self.assertIsNotNone(ts.shape) self.assertEqual(len(ts.shape), 2) self.assertEqual(ts.shape[0], 3) @@ -44,13 +44,13 @@ def test_multi_dim_int(self): def test_symbolic_dims(self): """FLOAT['M', 'N'] maps to a 2-D tensor with symbolic dims.""" - ts = FLOAT["M", "N"].to_ir() + ts = FLOAT["M", "N"].to_ir_type_and_shape() self.assertIsNotNone(ts.shape) self.assertEqual(len(ts.shape), 2) def test_other_dtype(self): """INT64[...] preserves the correct dtype.""" - ts = INT64[...].to_ir() + ts = INT64[...].to_ir_type_and_shape() self.assertEqual(ts.type, ir.TensorType(ir.DataType.INT64))