[Fix][Relax] Support ND batched matmul chains in AdjustMatmulOrder pass#19650
[Fix][Relax] Support ND batched matmul chains in AdjustMatmulOrder pass#19650ConvolutedDog wants to merge 1 commit into
Conversation
Fix a crash (apache#19576) when AdjustMatmulOrder encounters mixed-dimension matmul chains common in transformer models (e.g. matmul(attn_output[B,S,D], W_o[D,D])). The pass previously assumed all operands in a chained rewrite were 2D and asserted shape_c.size() == 2, failing on 3D intermediate results. Changes: - Replace full 2D transpose with permute_last_two_dims for permuted matmul patterns, swapping only the last two axes for ND tensors. - Remove hard ndim==2 checks in the permuted rewrite path. - Account for batch prefixes when comparing naive matmul FLOPs, so reorder decisions reflect batched vs. weight-only inner matmuls. - Skip reorder when neither evaluation order is provably cheaper. - Add regression tests for symbolic/concrete batched LoRA shapes. - Add a numerics test covering a minimal attention block with ND permute_dims.
There was a problem hiding this comment.
Code Review
This pull request enhances the adjust_matmul_order pass to support ND tensors by introducing helper functions to permute and transpose the last two dimensions, and updates the FLOP calculation to account for batch dimensions. It also adds several tests covering symbolic and concrete batch sizes, as well as correctness on a batched attention block. A review comment highlights a potential underflow issue in transpose_shape_last_two_dims when handling 1D tensors, which could lead to out-of-bounds memory access, and suggests skipping the optimization for operands with fewer than two dimensions.
tlopex
left a comment
There was a problem hiding this comment.
Thanks for fixing the ND crash path. I think the direction is right, but I’d request changes for the new batched FLOP model.
The current estimate multiplies batch_A * batch_B * batch_C, but relax.matmul broadcasts batch prefixes; it does not evaluate all pairwise combinations of equal batch axes. This can flip the chosen order for normal batched chains. For example, with A: [2, 1, 1], B: [2, 1, 2], and C: [2, 2, 3], the actual naive costs are:
- LHS first:
2*1*1*2 + 2*1*2*3 = 16 - RHS first:
2*1*2*3 + 2*1*1*3 = 18
so LHS first is cheaper. The current formula estimates 56 vs 48, and would choose RHS first instead.
I think the cost model should use the product of the broadcasted batch prefix for each matmul, e.g. batch_AB, batch_BC, and the common broadcast batch for the outer matmul, rather than multiplying the independent prefix products. It would also be good to add structural tests where all three operands share a nontrivial batch prefix, plus a broadcast case like [B, ...] with [1, ...].
Separately, I agree with the existing comment that the permuted path should guard rank-1 operands before calling permute_last_two_dims / transpose_shape_last_two_dims, since the later 1D padding does not run before that code.
Fix a crash (#19576) when AdjustMatmulOrder encounters mixed-dimension matmul chains common in transformer models (e.g. matmul(attn_output[B,S,D], W_o[D,D])). The pass previously assumed all operands in a chained rewrite were 2D and asserted shape_c.size() == 2, failing on 3D intermediate results.
Changes: