From acc50902ab5451780463a9cf938191b8fd3fff7f Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 3 Mar 2026 13:39:37 -0800 Subject: [PATCH 1/4] Implement GraphBuilder._partition_inputs_attributes using OpSignature - Convert onnx.defs.OpSchema to ir.schemas.OpSignature via from_op_schema and delegate to separate_input_attributes_from_arguments - Add allow_extra_args parameter to separate_input_attributes_from_arguments for rejecting unexpected positional arguments (default True for compat) - Builder uses strict mode: allow_extra_kwargs=False, allow_extra_args=False - Refactor _build test helper: accept TypeSpec, optional trace_function, return ir.Graph directly - Add comprehensive tests for input/attribute partitioning Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxscript/_internal/builder.py | 15 +- onnxscript/_internal/builder_test.py | 219 ++++++++++++++++----- onnxscript/_internal/param_manipulation.py | 13 ++ 3 files changed, 196 insertions(+), 51 deletions(-) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py index 06b6edaa85..5eca908de4 100644 --- a/onnxscript/_internal/builder.py +++ b/onnxscript/_internal/builder.py @@ -17,7 +17,7 @@ import onnxscript._internal._inference as inference import onnxscript.optimizer -from onnxscript._internal import _inliner +from onnxscript._internal import _inliner, param_manipulation # A permissible value for an op input, which can be converted to an ir.Value. VALUE_LIKE = Union[ @@ -255,9 +255,16 @@ def _partition_inputs_attributes( inputs: Sequence[ir.Value | ir.TensorProtocol], kwargs: dict[str, Any], ) -> tuple[Sequence[ir.Value | ir.TensorProtocol], dict[str, Any]]: - # Not implemented yet - del schema - return inputs, kwargs + if schema is None: + return inputs, kwargs + op_signature = ir.schemas.OpSignature.from_op_schema(schema) + return param_manipulation.separate_input_attributes_from_arguments( + op_signature, + list(inputs), + kwargs, + fill_defaults=False, + allow_extra_args=False, + ) def _cast_inputs( self, diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index ffc1ab44a4..6b15a40dbd 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -12,16 +12,16 @@ import onnxscript._internal.builder as builder import onnxscript.testing from onnxscript import script -from onnxscript.onnx_types import DOUBLE, FLOAT +from onnxscript.onnx_types import DOUBLE, FLOAT, INT64 _default_opset_version = 23 def _build( - trace_function, - input_types: Sequence[ir.TypeAndShape], - output_types: Sequence[ir.TypeAndShape], -) -> ir.Model: + input_types: Sequence[builder.TypeSpec], + trace_function=None, + output_types: Sequence[builder.TypeSpec] | None = None, +) -> ir.Graph: graph = ir.Graph( name="test_model", inputs=[], @@ -30,25 +30,29 @@ def _build( opset_imports={"": _default_opset_version}, ) - onnx_model = ir.Model(graph=graph, ir_version=10) + resolved_inputs = [builder._resolve_type_spec(t) for t in input_types] + for i, ts in enumerate(resolved_inputs): + graph.inputs.append(ir.Value(name=f"input_{i}", type=ts.type, shape=ts.shape)) - for i, input_type in enumerate(input_types): - input_name = f"input_{i}" - graph.inputs.append(ir.Value(name=input_name, type=input_type)) + if trace_function is not None: + graph_builder = builder.GraphBuilder(graph) + outputs = trace_function(graph_builder.op, *graph.inputs) + if not isinstance(outputs, Sequence): + outputs = [outputs] - graph_builder = builder.GraphBuilder(graph) - outputs = trace_function(graph_builder.op, *graph.inputs) - if not isinstance(outputs, Sequence): - outputs = [outputs] - if len(outputs) != len(output_types): - raise ValueError(f"Expected {len(output_types)} outputs, but got {len(outputs)}.") - for output, output_type in zip(outputs, output_types): - output.type = output_type.type # TODO: need merge_type method in ir.Value - output.merge_shapes(output_type.shape) + if output_types is not None: + resolved_outputs = [builder._resolve_type_spec(t) for t in output_types] + if len(outputs) != len(resolved_outputs): + raise ValueError( + f"Expected {len(resolved_outputs)} outputs, but got {len(outputs)}." + ) + for output, ts in zip(outputs, resolved_outputs): + output.type = ts.type + output.merge_shapes(ts.shape) - graph.outputs.extend(outputs) + graph.outputs.extend(outputs) - return onnx_model + return graph def _create_builder_with_inputs() -> tuple[builder.OpBuilder, ir.Value, ir.Value]: @@ -57,24 +61,7 @@ def _create_builder_with_inputs() -> tuple[builder.OpBuilder, ir.Value, ir.Value Returns: A tuple of (op_builder, input_x, input_y). """ - graph = ir.Graph( - name="test_model", - inputs=[], - outputs=[], - nodes=[], - opset_imports={"": 23}, - ) - - for i in range(2): - input_name = f"input_{i}" - graph.inputs.append( - ir.Value( - name=input_name, - type=ir.TensorType(ir.DataType.FLOAT), - shape=ir.Shape([2, 3, 4]), - ) - ) - + graph = _build(input_types=[FLOAT[2, 3, 4], FLOAT[2, 3, 4]]) graph_builder = builder.GraphBuilder(graph) x, y = graph.inputs return graph_builder.op, x, y @@ -89,12 +76,11 @@ def _add_mul_add(op: builder.OpBuilder, x: ir.Value, y: ir.Value) -> ir.Value: return z float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4])) - model = _build( - _add_mul_add, + graph = _build( input_types=[float_2d, float_2d], + trace_function=_add_mul_add, output_types=[float_2d], ) - graph = model.graph # Expect exactly 3 nodes: Add, Mul, Add op_types = [node.op_type for node in graph] self.assertEqual(op_types, ["Add", "Mul", "Add"]) @@ -121,12 +107,11 @@ def _add_with_custom_names( return z float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4])) - model = _build( - _add_with_custom_names, + graph = _build( input_types=[float_2d, float_2d], + trace_function=_add_with_custom_names, output_types=[float_2d], ) - graph = model.graph # Verify that the nodes have outputs with the specified names nodes = list(graph) @@ -207,12 +192,11 @@ def _ops_with_default_names( return z float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4])) - model = _build( - _ops_with_default_names, + graph = _build( input_types=[float_2d, float_2d], + trace_function=_ops_with_default_names, output_types=[float_2d], ) - graph = model.graph # Verify the nodes use the new naming strategy nodes = list(graph) @@ -964,5 +948,146 @@ def _id(op, x): ) +class PartitionInputsAttributesTest(unittest.TestCase): + """Tests for GraphBuilder._partition_inputs_attributes.""" + + def test_unknown_op_passes_inputs_and_kwargs_through(self): + """An unknown op has no schema, so inputs and kwargs pass through unchanged.""" + + def _dummy(op, x, y): + return op.DummyOp(x, y, alpha=1.0) + + graph = _build( + input_types=[FLOAT[3, 4], FLOAT[3, 4]], + trace_function=_dummy, + ) + x, y = graph.inputs + node = list(graph)[0] + self.assertEqual(node.op_type, "DummyOp") + self.assertEqual(list(node.inputs), [x, y]) + self.assertEqual(node.attributes["alpha"].as_float(), 1.0) + + def test_op_with_only_inputs(self): + """Add has two inputs and no attributes.""" + + def _add(op, x, y): + return op.Add(x, y) + + graph = _build( + input_types=[FLOAT[3, 4], FLOAT[3, 4]], + trace_function=_add, + ) + x, y = graph.inputs + node = list(graph)[0] + self.assertEqual(node.op_type, "Add") + self.assertEqual(list(node.inputs), [x, y]) + self.assertEqual(len(node.attributes), 0) + + def test_op_with_inputs_and_attributes_in_kwargs(self): + """Gemm has 3 inputs (A, B, C) and attributes (alpha, beta, transA, transB).""" + + def _gemm(op, a, b, c): + return op.Gemm(a, b, c, alpha=2.0, transB=1) + + graph = _build( + input_types=[FLOAT[3, 4], FLOAT[4, 5], FLOAT[3, 5]], + trace_function=_gemm, + ) + a, b, c = graph.inputs + node = list(graph)[0] + self.assertEqual(node.op_type, "Gemm") + self.assertEqual(list(node.inputs), [a, b, c]) + self.assertEqual(node.attributes["alpha"].as_float(), 2.0) + self.assertEqual(node.attributes["transB"].as_int(), 1) + + def test_op_with_optional_input_omitted(self): + """Gemm's third input (C) is optional. Omitting it should work.""" + + def _gemm_no_c(op, a, b): + return op.Gemm(a, b, alpha=2.0) + + graph = _build( + input_types=[FLOAT[3, 4], FLOAT[4, 5]], + trace_function=_gemm_no_c, + ) + a, b = graph.inputs + node = list(graph)[0] + self.assertEqual(node.op_type, "Gemm") + self.assertEqual(list(node.inputs), [a, b]) + self.assertEqual(node.attributes["alpha"].as_float(), 2.0) + + def test_does_not_fill_attribute_defaults(self): + """Attribute defaults should not be filled in (fill_defaults=False).""" + + def _gemm_no_attrs(op, a, b): + return op.Gemm(a, b) + + graph = _build( + input_types=[FLOAT[3, 4], FLOAT[4, 5]], + trace_function=_gemm_no_attrs, + ) + node = list(graph)[0] + # alpha, beta, transA, transB all have defaults but should NOT appear + self.assertFalse(node.attributes) + + def test_variadic_inputs_with_attribute(self): + """Concat has variadic inputs and an axis attribute.""" + + def _concat(op, x, y, z): + return op.Concat(x, y, z, axis=0) + + graph = _build( + input_types=[FLOAT[3, 4], FLOAT[3, 4], FLOAT[3, 4]], + trace_function=_concat, + ) + x, y, z = graph.inputs + node = list(graph)[0] + self.assertEqual(node.op_type, "Concat") + self.assertEqual(list(node.inputs), [x, y, z]) + self.assertEqual(node.attributes["axis"].as_int(), 0) + + def test_slice_kwargs_are_correctly_ordered_as_inputs(self): + """Calling op.Slice with keyword arguments should place them in schema order.""" + + def _slice(op, data, starts, ends, axes, steps): + # Pass optional inputs as kwargs in non-schema order + return op.Slice(data, ends=ends, steps=steps, starts=starts, axes=axes) + + graph = _build( + input_types=[FLOAT[20, 10], INT64[2], INT64[2], INT64[2], INT64[2]], + trace_function=_slice, + ) + data, starts, ends, axes, steps = graph.inputs + + slice_node = list(graph)[0] + self.assertEqual(slice_node.op_type, "Slice") + # Schema order: data, starts, ends, axes, steps + self.assertEqual(list(slice_node.inputs), [data, starts, ends, axes, steps]) + + def test_omitting_required_input_raises(self): + """Omitting a required input should raise TypeError.""" + + def _add_missing_input(op, x): + return op.Add(x) + + with self.assertRaises(TypeError): + _build( + input_types=[FLOAT[3, 4]], + trace_function=_add_missing_input, + ) + + def test_extra_inputs_raises(self): + """Extra positional inputs beyond the schema should raise TypeError.""" + + def _add_extra_input(op, x, y, z): + return op.Add(x, y, z) + + with self.assertRaises(TypeError): + _build( + input_types=[FLOAT[3, 4], FLOAT[3, 4], FLOAT[3, 4]], + trace_function=_add_extra_input, + ) + + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/_internal/param_manipulation.py b/onnxscript/_internal/param_manipulation.py index f3325f22e7..c75d42504b 100644 --- a/onnxscript/_internal/param_manipulation.py +++ b/onnxscript/_internal/param_manipulation.py @@ -16,6 +16,7 @@ def separate_input_attributes_from_arguments( kwargs, fill_defaults: bool = True, allow_extra_kwargs: bool = False, + allow_extra_args: bool = True, ) -> tuple[list[Any], OrderedDict[str, Any]]: """Separate Python args and kwargs into ONNX inputs and attributes. @@ -26,6 +27,9 @@ def separate_input_attributes_from_arguments( fill_defaults: Whether to fill the default values for attributes. allow_extra_kwargs: Whether to allow extra keyword arguments. When set to True, extra/unknown arguments will be ignored. + allow_extra_args: Whether to allow extra positional arguments beyond + what the schema declares (when no variadic parameter exists). + When set to False, a TypeError is raised for extra args. Returns: A tuple of two elements: @@ -34,6 +38,7 @@ def separate_input_attributes_from_arguments( Raises: TypeError: When allow_extra_kwargs is False and there are unknown kwargs. + TypeError: When allow_extra_args is False and there are extra positional args. TypeError: When a required input is not provided. """ # args, kwargs and op_signature.params should be all in order @@ -46,12 +51,14 @@ def separate_input_attributes_from_arguments( onnx_inputs = [] onnx_attributes = collections.OrderedDict() + has_variadic = False for i, param in enumerate(op_signature.params): is_input = param.is_param() is_variadic = is_input and param.variadic if is_variadic: + has_variadic = True # Exhaust all remaining args onnx_inputs.extend(args[i:]) args = [] @@ -74,6 +81,12 @@ def separate_input_attributes_from_arguments( elif param.required: raise TypeError(f"Required input '{param}' was not provided") + if not allow_extra_args and not has_variadic and len(args) > len(op_signature.params): + raise TypeError( + f"Too many positional arguments: expected {len(op_signature.params)}, " + f"got {len(args)}" + ) + return onnx_inputs, onnx_attributes From 03ea4d21a2f3daecfd4d855eec0198cc2eb01d37 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 4 Mar 2026 21:33:43 -0800 Subject: [PATCH 2/4] Copy _resolve_type_spec to builder_test.py to avoid referencing internal member Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxscript/_internal/builder_test.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 6b15a40dbd..2030a9328f 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -17,6 +17,25 @@ _default_opset_version = 23 +def _resolve_type_spec(spec: builder.TypeSpec) -> ir.TypeAndShape: + """Convert a *TypeSpec* to an :class:`ir.TypeAndShape`. + + Accepts either an :class:`ir.TypeAndShape` directly, or a + :class:`~onnxscript.onnx_types.TensorType` subclass (e.g. ``FLOAT[1024]`` + or ``FLOAT['M', 'N']``). + + NOTE: This is a local copy of :func:`builder._resolve_type_spec` so that + tests do not reference a private helper directly. + """ + from onnxscript.onnx_types import TensorType # pylint: disable=import-outside-toplevel + + if isinstance(spec, ir.TypeAndShape): + return spec + if isinstance(spec, type) and issubclass(spec, TensorType): + return spec.to_ir() + raise TypeError(f"Expected ir.TypeAndShape or a TensorType subclass, got {type(spec)!r}.") + + def _build( input_types: Sequence[builder.TypeSpec], trace_function=None, @@ -30,7 +49,7 @@ def _build( opset_imports={"": _default_opset_version}, ) - resolved_inputs = [builder._resolve_type_spec(t) for t in input_types] + resolved_inputs = [_resolve_type_spec(t) for t in input_types] for i, ts in enumerate(resolved_inputs): graph.inputs.append(ir.Value(name=f"input_{i}", type=ts.type, shape=ts.shape)) @@ -41,7 +60,7 @@ def _build( outputs = [outputs] if output_types is not None: - resolved_outputs = [builder._resolve_type_spec(t) for t in output_types] + resolved_outputs = [_resolve_type_spec(t) for t in output_types] if len(outputs) != len(resolved_outputs): raise ValueError( f"Expected {len(resolved_outputs)} outputs, but got {len(outputs)}." From 398487f20c74b4958f4209c98feabc27fdd47634 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 5 Mar 2026 08:49:30 -0800 Subject: [PATCH 3/4] Replace list(graph)[0] with graph.node(0) in builder_test.py Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxscript/_internal/builder_test.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 2030a9328f..41cff7c878 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -981,7 +981,7 @@ def _dummy(op, x, y): trace_function=_dummy, ) x, y = graph.inputs - node = list(graph)[0] + node = graph.node(0) self.assertEqual(node.op_type, "DummyOp") self.assertEqual(list(node.inputs), [x, y]) self.assertEqual(node.attributes["alpha"].as_float(), 1.0) @@ -997,7 +997,7 @@ def _add(op, x, y): trace_function=_add, ) x, y = graph.inputs - node = list(graph)[0] + node = graph.node(0) self.assertEqual(node.op_type, "Add") self.assertEqual(list(node.inputs), [x, y]) self.assertEqual(len(node.attributes), 0) @@ -1013,7 +1013,7 @@ def _gemm(op, a, b, c): trace_function=_gemm, ) a, b, c = graph.inputs - node = list(graph)[0] + node = graph.node(0) self.assertEqual(node.op_type, "Gemm") self.assertEqual(list(node.inputs), [a, b, c]) self.assertEqual(node.attributes["alpha"].as_float(), 2.0) @@ -1030,7 +1030,7 @@ def _gemm_no_c(op, a, b): trace_function=_gemm_no_c, ) a, b = graph.inputs - node = list(graph)[0] + node = graph.node(0) self.assertEqual(node.op_type, "Gemm") self.assertEqual(list(node.inputs), [a, b]) self.assertEqual(node.attributes["alpha"].as_float(), 2.0) @@ -1045,7 +1045,7 @@ def _gemm_no_attrs(op, a, b): input_types=[FLOAT[3, 4], FLOAT[4, 5]], trace_function=_gemm_no_attrs, ) - node = list(graph)[0] + node = graph.node(0) # alpha, beta, transA, transB all have defaults but should NOT appear self.assertFalse(node.attributes) @@ -1060,7 +1060,7 @@ def _concat(op, x, y, z): trace_function=_concat, ) x, y, z = graph.inputs - node = list(graph)[0] + node = graph.node(0) self.assertEqual(node.op_type, "Concat") self.assertEqual(list(node.inputs), [x, y, z]) self.assertEqual(node.attributes["axis"].as_int(), 0) @@ -1078,7 +1078,7 @@ def _slice(op, data, starts, ends, axes, steps): ) data, starts, ends, axes, steps = graph.inputs - slice_node = list(graph)[0] + slice_node = graph.node(0) self.assertEqual(slice_node.op_type, "Slice") # Schema order: data, starts, ends, axes, steps self.assertEqual(list(slice_node.inputs), [data, starts, ends, axes, steps]) From e330026aa12c98773fa01bda1a3a3d09e8b748f3 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 6 Mar 2026 09:30:12 -0800 Subject: [PATCH 4/4] Fix: replace .to_ir() with .to_ir_type_and_shape() in builder_test Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxscript/_internal/builder_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py index 72da47b32b..f6f301954b 100644 --- a/onnxscript/_internal/builder_test.py +++ b/onnxscript/_internal/builder_test.py @@ -32,7 +32,7 @@ def _resolve_type_spec(spec: builder.TypeSpec) -> ir.TypeAndShape: if isinstance(spec, ir.TypeAndShape): return spec if isinstance(spec, type) and issubclass(spec, TensorType): - return spec.to_ir() + return spec.to_ir_type_and_shape() raise TypeError(f"Expected ir.TypeAndShape or a TensorType subclass, got {type(spec)!r}.")