Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions docs/tutorial/builder/graph_builder.md
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,11 @@ The subgraph automatically inherits the opset version from the parent

### Type annotations for subgraph inputs and outputs

`subgraph()` accepts `input_types` and `output_types` lists that describe
the types and shapes of each input and output. Each element can be either an
`ir.TypeAndShape` object or — more conveniently — an
`subgraph()` accepts `inputs` and `outputs` that describe
the types and shapes of each input and output. They can be provided as a
:class:`list` of type specs (names are auto-generated) **or** as a
:class:`dict` mapping explicit names to type specs. Each type spec can be
either an `ir.TypeAndShape` object or — more conveniently — an
`onnxscript` tensor-type expression:

| Expression | Meaning |
Expand Down Expand Up @@ -408,8 +410,8 @@ def cumsum_body(op, state, x_i):

body = builder.subgraph(
cumsum_body,
input_types=[FLOAT[D], FLOAT[D]], # state, x_i
output_types=[FLOAT[D], FLOAT[D]], # new_state, scan_out_i
inputs=[FLOAT[D], FLOAT[D]], # state, x_i
outputs=[FLOAT[D], FLOAT[D]], # new_state, scan_out_i
name="cumsum_body",
)

Expand All @@ -430,7 +432,7 @@ model = ir.Model(graph=graph, ir_version=10)

Key points:

- `builder.subgraph(fn, input_types, output_types)` creates a fresh
- `builder.subgraph(fn, inputs, outputs)` creates a fresh
`ir.Graph`, calls `fn(op, *inputs)` to trace the body, and wires up the
declared input/output types.
- The `fn` receives an `OpBuilder` as its first argument — exactly the same
Expand All @@ -450,8 +452,8 @@ def outer_body(op, state, x_i):
# Build a nested subgraph inside the scan body
inner = op.builder.subgraph(
lambda iop, v: iop.Relu(v),
input_types=[FLOAT[D]],
output_types=[FLOAT[D]],
inputs=[FLOAT[D]],
outputs=[FLOAT[D]],
name="relu_body",
)
# ... use inner as a graph attribute of a nested op ...
Expand Down
193 changes: 143 additions & 50 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from __future__ import annotations

from typing import Any, Callable, Sequence, Union
from typing import Any, Callable, Mapping, Sequence, Union

import onnx
import onnx_ir as ir
Expand Down Expand Up @@ -74,31 +74,135 @@ def _constant_name(
return f"const_1d_{num}"


# Type accepted as an element of *input_types* / *output_types* by
# Type accepted as an element of *inputs* / *outputs* by
# :meth:`GraphBuilder.subgraph`. Can be an already-resolved
# :class:`ir.TypeAndShape`, or a
# :class:`~onnxscript.onnx_types.TensorType` subclass such as ``FLOAT[1024]``.
TypeSpec = Union[ir.TypeAndShape, Any]

# Acceptable collection forms for *inputs* / *outputs* in
# :meth:`GraphBuilder.subgraph`. A :class:`Sequence` of :data:`TypeSpec`
# auto-names entries (``input_0``, ``input_1``, …), while a :class:`Mapping`
# from :class:`str` to :data:`TypeSpec` uses the keys as explicit names.
InputOutputSpec = Union[Sequence[TypeSpec], Mapping[str, TypeSpec]]


def _resolve_type_spec(spec: TypeSpec) -> ir.TypeAndShape:
"""Convert a *TypeSpec* to an :class:`ir.TypeAndShape`.

Accepts either an :class:`ir.TypeAndShape` directly, or a
:class:`~onnxscript.onnx_types.TensorType` subclass (e.g. ``FLOAT[1024]``
or ``FLOAT['M', 'N']``).
Accepts an :class:`ir.TypeAndShape` directly, or any object with a
``to_ir_type_and_shape()`` method (e.g. a
:class:`~onnxscript.onnx_types.TensorType` subclass such as
``FLOAT[1024]`` or ``FLOAT['M', 'N']``).
"""
# Lazy import to avoid a circular dependency: onnxscript.__init__ imports
# onnx_types (line ~106) before builder (line ~132), so by the time any
# call reaches here the module is fully initialised — but a top-level
# import in builder.py could break if builder is ever imported first.
from onnxscript.onnx_types import TensorType # pylint: disable=import-outside-toplevel

if isinstance(spec, ir.TypeAndShape):
return spec
if isinstance(spec, type) and issubclass(spec, TensorType):
return spec.to_ir()
raise TypeError(f"Expected ir.TypeAndShape or a TensorType subclass, got {type(spec)!r}.")
if hasattr(spec, "to_ir_type_and_shape"):
result = spec.to_ir_type_and_shape()
if not isinstance(result, ir.TypeAndShape):
raise TypeError(
f"{type(spec)!r}.to_ir_type_and_shape() returned {type(result)!r}, "
f"expected ir.TypeAndShape."
)
return result
raise TypeError(
f"Expected ir.TypeAndShape or an object with a to_ir_type_and_shape() method, "
f"got {type(spec)!r}."
)


def _normalize_io_spec(
spec: InputOutputSpec, default_prefix: str
) -> list[tuple[str, ir.TypeAndShape]]:
"""Normalize an *InputOutputSpec* into a list of ``(name, TypeAndShape)`` pairs.

When *spec* is a :class:`Mapping`, the keys are used as names. When it is
a plain :class:`Sequence`, names are generated as
``{default_prefix}_0``, ``{default_prefix}_1``, etc.
"""
if isinstance(spec, Mapping):
return [(name, _resolve_type_spec(ts)) for name, ts in spec.items()]
return [(f"{default_prefix}_{i}", _resolve_type_spec(ts)) for i, ts in enumerate(spec)]


def build_graph(
trace_function: Callable,
inputs: InputOutputSpec,
outputs: InputOutputSpec,
*,
opset_imports: dict[str, int] | None = None,
name: str = "subgraph",
) -> ir.Graph:
"""Build an :class:`ir.Graph` suitable for use as a graph-valued attribute.

This is a module-level utility that constructs a subgraph by tracing
*trace_function*. It is useful for building body graphs of control-flow ops
such as ``Scan``, ``Loop``, and ``If``.

Example - building a Scan body that adds two sequences element-wise::

body = build_graph(
lambda op, x, y: op.Add(x, y),
inputs={"x": FLOAT[...], "y": FLOAT[...]},
outputs={"sum": FLOAT[...]},
)

Args:
trace_function: A callable with signature
``(op: OpBuilder, *inputs: ir.Value) -> ir.Value | Sequence[ir.Value]``.
It is called once with freshly created placeholder inputs to record the
graph topology.
inputs: Types (and optionally names) for each graph input. May be a
:class:`Sequence` of :data:`TypeSpec` values (names are auto-generated
as ``input_0``, ``input_1``, …) **or** a :class:`Mapping` from
:class:`str` names to :data:`TypeSpec` values. Each :data:`TypeSpec`
can be an :class:`ir.TypeAndShape` or a
:class:`~onnxscript.onnx_types.TensorType` subclass (e.g.
``FLOAT[1024]`` or ``FLOAT['M', 'N']``).
outputs: Types (and optionally names) for each graph output, in the
same format as *inputs*.
opset_imports: Opset version map for the subgraph (e.g.
``{"": 23}``). Defaults to ``{"": 23}`` when *None*.
name: Name of the resulting :class:`ir.Graph`.

Returns:
An :class:`ir.Graph` whose inputs and outputs are populated and whose
nodes record the operations traced by *trace_function*. This graph can be
passed directly as a graph-valued attribute (e.g. the ``body`` attribute of
a ``Scan`` or ``Loop`` node).
"""
if opset_imports is None:
opset_imports = {"": 23}
resolved_inputs = _normalize_io_spec(inputs, "input")
resolved_outputs = _normalize_io_spec(outputs, "output")

subgraph = ir.Graph(
name=name,
inputs=[],
outputs=[],
nodes=[],
opset_imports=opset_imports,
)

for input_name, ts in resolved_inputs:
subgraph.inputs.append(ir.Value(name=input_name, type=ts.type, shape=ts.shape))

sub_builder = GraphBuilder(subgraph)
trace_outputs = trace_function(sub_builder.op, *subgraph.inputs)
if not isinstance(trace_outputs, Sequence):
trace_outputs = [trace_outputs]
if len(trace_outputs) != len(resolved_outputs):
raise ValueError(
f"trace_function returned {len(trace_outputs)} output(s), "
f"but {len(resolved_outputs)} were declared in outputs."
)
for output, (output_name, ts) in zip(trace_outputs, resolved_outputs):
output.name = output_name
output.type = ts.type
output.merge_shapes(ts.shape)

subgraph.outputs.extend(trace_outputs)
return subgraph


class GraphBuilder:
Expand Down Expand Up @@ -332,8 +436,8 @@ def add_node(self, node: ir.Node) -> None:
def subgraph(
self,
trace_function: Callable,
input_types: Sequence[TypeSpec],
output_types: Sequence[TypeSpec],
inputs: InputOutputSpec,
outputs: InputOutputSpec,
*,
name: str = "subgraph",
) -> ir.Graph:
Expand All @@ -347,21 +451,33 @@ def subgraph(

body = graph_builder.subgraph(
lambda op, x, y: op.Add(x, y),
input_types=[FLOAT[...], FLOAT[...]],
output_types=[FLOAT[...]],
inputs=[FLOAT[...], FLOAT[...]],
outputs=[FLOAT[...]],
)

Inputs and outputs can also be given as a :class:`dict` to assign
explicit names::

body = graph_builder.subgraph(
lambda op, x, y: op.Add(x, y),
inputs={"x": FLOAT[...], "y": FLOAT[...]},
outputs={"sum": FLOAT[...]},
)

Args:
trace_function: A callable with signature
``(op: OpBuilder, *inputs: ir.Value) -> ir.Value | Sequence[ir.Value]``.
It is called once with freshly created placeholder inputs to record the
graph topology.
input_types: Types for each graph input. Each element may be an
:class:`ir.TypeAndShape` **or** a
inputs: Types (and optionally names) for each graph input. May be a
:class:`Sequence` of :data:`TypeSpec` values (names are auto-generated
as ``input_0``, ``input_1``, …) **or** a :class:`Mapping` from
:class:`str` names to :data:`TypeSpec` values. Each :data:`TypeSpec`
can be an :class:`ir.TypeAndShape` or a
:class:`~onnxscript.onnx_types.TensorType` subclass (e.g.
``FLOAT[1024]`` or ``FLOAT['M', 'N']``).
output_types: Types for each graph output, in the same format as
*input_types*.
outputs: Types (and optionally names) for each graph output, in the
same format as *inputs*.
name: Name of the resulting :class:`ir.Graph`.

Returns:
Expand All @@ -370,37 +486,14 @@ def subgraph(
passed directly as a graph-valued attribute (e.g. the ``body`` attribute of
a ``Scan`` or ``Loop`` node).
"""
opset_version = self._graph.opset_imports[""]
resolved_inputs = [_resolve_type_spec(t) for t in input_types]
resolved_outputs = [_resolve_type_spec(t) for t in output_types]

subgraph = ir.Graph(
return build_graph(
trace_function,
inputs,
outputs,
opset_imports=dict(self._graph.opset_imports),
name=name,
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": opset_version},
)

for i, ts in enumerate(resolved_inputs):
subgraph.inputs.append(ir.Value(name=f"input_{i}", type=ts.type, shape=ts.shape))

sub_builder = GraphBuilder(subgraph)
outputs = trace_function(sub_builder.op, *subgraph.inputs)
if not isinstance(outputs, Sequence):
outputs = [outputs]
if len(outputs) != len(resolved_outputs):
raise ValueError(
f"trace_function returned {len(outputs)} output(s), "
f"but {len(resolved_outputs)} were declared in output_types."
)
for output, ts in zip(outputs, resolved_outputs):
output.type = ts.type
output.merge_shapes(ts.shape)

subgraph.outputs.extend(outputs)
return subgraph

def call_op(
self,
op_type: str,
Expand Down
Loading
Loading