diff --git a/backends/cadence/aot/quantizer/fusion_pass.py b/backends/cadence/aot/quantizer/fusion_pass.py index 7093ef19c3d..5be90a4f52e 100644 --- a/backends/cadence/aot/quantizer/fusion_pass.py +++ b/backends/cadence/aot/quantizer/fusion_pass.py @@ -432,12 +432,26 @@ def get_args_and_kwargs_mixed_w8a32_conv( torch.ops.aten.permute.default, (other_inputs[0], [0, 2, 1]), # NCL -> NLC ) + assert "val" in other_inputs[0].meta, "Missing val metadata on input node" + original_val = other_inputs[0].meta["val"] + assert original_val.fake_mode is not None, "fake_mode is None on input node" + with original_val.fake_mode: + transposed_inputs.meta["val"] = torch.ops.aten.permute.default( + original_val, [0, 2, 1] + ) copy_node_metadata(transposed_inputs, other_inputs[0]) transposed_weights = graph_module.graph.call_function( torch.ops.aten.permute.default, (weights_inputs[0], [2, 0, 1]), # NCL -> LNC ) + assert "val" in weights_inputs[0].meta, "Missing val metadata on weight node" + original_val = weights_inputs[0].meta["val"] + assert original_val.fake_mode is not None, "fake_mode is None on weight node" + with original_val.fake_mode: + transposed_weights.meta["val"] = torch.ops.aten.permute.default( + original_val, [2, 0, 1] + ) copy_node_metadata(transposed_weights, weights_inputs[0]) args = (