Skip to content

[Fix][Relax] Support ND batched matmul chains in AdjustMatmulOrder pass#19650

Open
ConvolutedDog wants to merge 1 commit into
apache:mainfrom
ConvolutedDog:fix-adjust-mm-order
Open

[Fix][Relax] Support ND batched matmul chains in AdjustMatmulOrder pass#19650
ConvolutedDog wants to merge 1 commit into
apache:mainfrom
ConvolutedDog:fix-adjust-mm-order

Conversation

@ConvolutedDog
Copy link
Copy Markdown
Contributor

Fix a crash (#19576) when AdjustMatmulOrder encounters mixed-dimension matmul chains common in transformer models (e.g. matmul(attn_output[B,S,D], W_o[D,D])). The pass previously assumed all operands in a chained rewrite were 2D and asserted shape_c.size() == 2, failing on 3D intermediate results.

Changes:

  • Replace full 2D transpose with permute_last_two_dims for permuted matmul patterns, swapping only the last two axes for ND tensors.
  • Remove hard ndim==2 checks in the permuted rewrite path.
  • Account for batch prefixes when comparing naive matmul FLOPs, so reorder decisions reflect batched vs. weight-only inner matmuls.
  • Skip reorder when neither evaluation order is provably cheaper.
  • Add regression tests for symbolic/concrete batched LoRA shapes.
  • Add a numerics test covering a minimal attention block with ND permute_dims.

Fix a crash (apache#19576) when AdjustMatmulOrder
encounters mixed-dimension matmul chains common in transformer models (e.g.
matmul(attn_output[B,S,D], W_o[D,D])). The pass previously assumed all operands
in a chained rewrite were 2D and asserted shape_c.size() == 2, failing on 3D
intermediate results.

Changes:
- Replace full 2D transpose with permute_last_two_dims for permuted matmul
  patterns, swapping only the last two axes for ND tensors.
- Remove hard ndim==2 checks in the permuted rewrite path.
- Account for batch prefixes when comparing naive matmul FLOPs, so reorder
  decisions reflect batched vs. weight-only inner matmuls.
- Skip reorder when neither evaluation order is provably cheaper.
- Add regression tests for symbolic/concrete batched LoRA shapes.
- Add a numerics test covering a minimal attention block with ND permute_dims.
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request enhances the adjust_matmul_order pass to support ND tensors by introducing helper functions to permute and transpose the last two dimensions, and updates the FLOP calculation to account for batch dimensions. It also adds several tests covering symbolic and concrete batch sizes, as well as correctness on a batched attention block. A review comment highlights a potential underflow issue in transpose_shape_last_two_dims when handling 1D tensors, which could lead to out-of-bounds memory access, and suggests skipping the optimization for operands with fewer than two dimensions.

Comment thread src/relax/transform/adjust_matmul_order.cc
@ConvolutedDog ConvolutedDog marked this pull request as ready for review June 1, 2026 01:18
Copy link
Copy Markdown
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing the ND crash path. I think the direction is right, but I’d request changes for the new batched FLOP model.

The current estimate multiplies batch_A * batch_B * batch_C, but relax.matmul broadcasts batch prefixes; it does not evaluate all pairwise combinations of equal batch axes. This can flip the chosen order for normal batched chains. For example, with A: [2, 1, 1], B: [2, 1, 2], and C: [2, 2, 3], the actual naive costs are:

  • LHS first: 2*1*1*2 + 2*1*2*3 = 16
  • RHS first: 2*1*2*3 + 2*1*1*3 = 18

so LHS first is cheaper. The current formula estimates 56 vs 48, and would choose RHS first instead.

I think the cost model should use the product of the broadcasted batch prefix for each matmul, e.g. batch_AB, batch_BC, and the common broadcast batch for the outer matmul, rather than multiplying the independent prefix products. It would also be good to add structural tests where all three operands share a nontrivial batch prefix, plus a broadcast case like [B, ...] with [1, ...].

Separately, I agree with the existing comment that the permuted path should guard rank-1 operands before calling permute_last_two_dims / transpose_shape_last_two_dims, since the later 1D padding does not run before that code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants