diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 5be90a4f52e..349659d1e7f 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -160,20 +160,15 @@ def get_args_and_kwargs_layer_norm( ), {"dtype": torch.float32}, ) - if len(inputs_inputs) > 0: - if "val" in inputs_inputs[0].meta: - fake_mode = inputs_inputs[0].meta["val"].fake_mode - if fake_mode is not None: - with fake_mode: - fake_weight = torch.full( - other_inputs[0], 1, dtype=torch.float32 - ) - weight.meta["val"] = fake_weight - else: - weight.meta["val"] = torch.full( - other_inputs[0], 1, dtype=torch.float32 - ) - copy_node_metadata(weight, inputs_inputs[0]) + assert ( + len(inputs_inputs) == 1 + ), f"Expected 1 input for layer norm weight, got {len(inputs_inputs)}" + assert "val" in inputs_inputs[0].meta, "Missing val metadata on input node" + fake_mode = inputs_inputs[0].meta["val"].fake_mode + assert fake_mode is not None, "fake_mode is None on input node" + with fake_mode: + weight.meta["val"] = torch.full(other_inputs[0], 1, dtype=torch.float32) + copy_node_metadata(weight, inputs_inputs[0]) bias = other_inputs[2] if len(other_inputs) > 2 else None @@ -186,18 +181,15 @@ def get_args_and_kwargs_layer_norm( ), {"dtype": torch.float32}, ) - if len(inputs_inputs) > 0: - if "val" in inputs_inputs[0].meta: - fake_mode = inputs_inputs[0].meta["val"].fake_mode - if fake_mode is not None: - with fake_mode: - fake_bias = torch.full(other_inputs[0], 0, dtype=torch.float32) - bias.meta["val"] = fake_bias - else: - bias.meta["val"] = torch.full( - other_inputs[0], 0, dtype=torch.float32 - ) - copy_node_metadata(bias, inputs_inputs[0]) + assert ( + len(inputs_inputs) == 1 + ), f"Expected 1 input for layer norm bias, got {len(inputs_inputs)}" + assert "val" in inputs_inputs[0].meta, "Missing val metadata on input node" + fake_mode = inputs_inputs[0].meta["val"].fake_mode + assert fake_mode is not None, "fake_mode is None on input node" + with fake_mode: + bias.meta["val"] = torch.full(other_inputs[0], 0, dtype=torch.float32) + copy_node_metadata(bias, inputs_inputs[0]) # Make the args and kwargs for the replacement op args = tuple(inputs_inputs + [scale, zero_point]) @@ -373,16 +365,15 @@ def get_args_and_kwargs_softmax( ), {"dtype": torch.int32}, ) - if len(inputs_inputs) > 0: - if "val" in inputs_inputs[0].meta: - fake_mode = inputs_inputs[0].meta["val"].fake_mode - if fake_mode is not None: - with fake_mode: - fake_mask = torch.full(mask_shape, 0.0, dtype=torch.int32) - mask_tensor.meta["val"] = fake_mask - else: - mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32) - copy_node_metadata(mask_tensor, inputs_inputs[0]) + assert ( + len(inputs_inputs) == 1 + ), f"Expected 1 input for softmax, got {len(inputs_inputs)}" + assert "val" in inputs_inputs[0].meta, "Missing val metadata on input node" + fake_mode = inputs_inputs[0].meta["val"].fake_mode + assert fake_mode is not None, "fake_mode is None on input node" + with fake_mode: + mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32) + copy_node_metadata(mask_tensor, inputs_inputs[0]) # Make the scale and zero_point tensors in_scale = dequants_inputs[0].args[1] in_zero_point = dequants_inputs[0].args[2] @@ -636,25 +627,18 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 torch.ops.aten.transpose.int, (weights_inputs[0], 0, 1), ) - if "val" in weights_inputs[0].meta: - original_val = weights_inputs[0].meta["val"] - fake_mode = original_val.fake_mode - if fake_mode is not None: - with fake_mode: - transposed_val = torch.ops.aten.transpose.int( - original_val, 0, 1 - ) - transposed_weights.meta["val"] = transposed_val - else: - transposed_shape = list(original_val.shape) - transposed_shape[0], transposed_shape[1] = ( - transposed_shape[1], - transposed_shape[0], - ) - transposed_weights.meta["val"] = torch.zeros( - transposed_shape, dtype=original_val.dtype - ) - copy_node_metadata(transposed_weights, weights_inputs[0]) + assert ( + "val" in weights_inputs[0].meta + ), "Missing val metadata on weight node" + original_val = weights_inputs[0].meta["val"] + assert ( + original_val.fake_mode is not None + ), "fake_mode is None on weight node" + with original_val.fake_mode: + transposed_weights.meta["val"] = ( + torch.ops.aten.transpose.int(original_val, 0, 1) + ) + copy_node_metadata(transposed_weights, weights_inputs[0]) # Call linear with transposed weight args, kwargs = get_args_and_kwargs_linear(