Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 39 additions & 55 deletions backends/cadence/aot/quantizer/fusion_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,20 +160,15 @@ def get_args_and_kwargs_layer_norm(
),
{"dtype": torch.float32},
)
if len(inputs_inputs) > 0:
if "val" in inputs_inputs[0].meta:
fake_mode = inputs_inputs[0].meta["val"].fake_mode
if fake_mode is not None:
with fake_mode:
fake_weight = torch.full(
other_inputs[0], 1, dtype=torch.float32
)
weight.meta["val"] = fake_weight
else:
weight.meta["val"] = torch.full(
other_inputs[0], 1, dtype=torch.float32
)
copy_node_metadata(weight, inputs_inputs[0])
assert (
len(inputs_inputs) == 1
), f"Expected 1 input for layer norm weight, got {len(inputs_inputs)}"
assert "val" in inputs_inputs[0].meta, "Missing val metadata on input node"
fake_mode = inputs_inputs[0].meta["val"].fake_mode
assert fake_mode is not None, "fake_mode is None on input node"
with fake_mode:
weight.meta["val"] = torch.full(other_inputs[0], 1, dtype=torch.float32)
copy_node_metadata(weight, inputs_inputs[0])

bias = other_inputs[2] if len(other_inputs) > 2 else None

Expand All @@ -186,18 +181,15 @@ def get_args_and_kwargs_layer_norm(
),
{"dtype": torch.float32},
)
if len(inputs_inputs) > 0:
if "val" in inputs_inputs[0].meta:
fake_mode = inputs_inputs[0].meta["val"].fake_mode
if fake_mode is not None:
with fake_mode:
fake_bias = torch.full(other_inputs[0], 0, dtype=torch.float32)
bias.meta["val"] = fake_bias
else:
bias.meta["val"] = torch.full(
other_inputs[0], 0, dtype=torch.float32
)
copy_node_metadata(bias, inputs_inputs[0])
assert (
len(inputs_inputs) == 1
), f"Expected 1 input for layer norm bias, got {len(inputs_inputs)}"
assert "val" in inputs_inputs[0].meta, "Missing val metadata on input node"
fake_mode = inputs_inputs[0].meta["val"].fake_mode
assert fake_mode is not None, "fake_mode is None on input node"
with fake_mode:
bias.meta["val"] = torch.full(other_inputs[0], 0, dtype=torch.float32)
copy_node_metadata(bias, inputs_inputs[0])

# Make the args and kwargs for the replacement op
args = tuple(inputs_inputs + [scale, zero_point])
Expand Down Expand Up @@ -373,16 +365,15 @@ def get_args_and_kwargs_softmax(
),
{"dtype": torch.int32},
)
if len(inputs_inputs) > 0:
if "val" in inputs_inputs[0].meta:
fake_mode = inputs_inputs[0].meta["val"].fake_mode
if fake_mode is not None:
with fake_mode:
fake_mask = torch.full(mask_shape, 0.0, dtype=torch.int32)
mask_tensor.meta["val"] = fake_mask
else:
mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32)
copy_node_metadata(mask_tensor, inputs_inputs[0])
assert (
len(inputs_inputs) == 1
), f"Expected 1 input for softmax, got {len(inputs_inputs)}"
assert "val" in inputs_inputs[0].meta, "Missing val metadata on input node"
fake_mode = inputs_inputs[0].meta["val"].fake_mode
assert fake_mode is not None, "fake_mode is None on input node"
with fake_mode:
mask_tensor.meta["val"] = torch.full(mask_shape, 0.0, dtype=torch.int32)
copy_node_metadata(mask_tensor, inputs_inputs[0])
# Make the scale and zero_point tensors
in_scale = dequants_inputs[0].args[1]
in_zero_point = dequants_inputs[0].args[2]
Expand Down Expand Up @@ -636,25 +627,18 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
torch.ops.aten.transpose.int,
(weights_inputs[0], 0, 1),
)
if "val" in weights_inputs[0].meta:
original_val = weights_inputs[0].meta["val"]
fake_mode = original_val.fake_mode
if fake_mode is not None:
with fake_mode:
transposed_val = torch.ops.aten.transpose.int(
original_val, 0, 1
)
transposed_weights.meta["val"] = transposed_val
else:
transposed_shape = list(original_val.shape)
transposed_shape[0], transposed_shape[1] = (
transposed_shape[1],
transposed_shape[0],
)
transposed_weights.meta["val"] = torch.zeros(
transposed_shape, dtype=original_val.dtype
)
copy_node_metadata(transposed_weights, weights_inputs[0])
assert (
"val" in weights_inputs[0].meta
), "Missing val metadata on weight node"
original_val = weights_inputs[0].meta["val"]
assert (
original_val.fake_mode is not None
), "fake_mode is None on weight node"
with original_val.fake_mode:
transposed_weights.meta["val"] = (
torch.ops.aten.transpose.int(original_val, 0, 1)
)
copy_node_metadata(transposed_weights, weights_inputs[0])

# Call linear with transposed weight
args, kwargs = get_args_and_kwargs_linear(
Expand Down
Loading