[Relax][IR] Skip in-place multiply when two operands are views of the same tensor#19644
[Relax][IR] Skip in-place multiply when two operands are views of the same tensor#19644ConvolutedDog wants to merge 2 commits into
Conversation
… same tensor This PR will fix apache#19577. In this issue, the IRModule before applying any pass looks like: ``` %x: Tensor[(4,), float32] // function param with R.dataflow(): %lv = expand_dims(%x, axis=1) // (4, 1) %lv1 = expand_dims(%x, axis=1) // (4, 1) second call, new Var %lv2 = multiply(%lv, %lv1) // (4, 1) %lv3 = concat(%lv2, %lv1, axis=1) // (4, 2) ... ``` When the users manually apply the `DataflowUseInplaceCalls` pass, the pass will rewrite the statement `%lv2 = multiply(%lv, %lv1)` to be like `%lv = multiply(%lv, %lv1); %lv3 = concat(%lv, %lv1, axis=1)`, which reuses the %lv buffer to avoid storage waste. But this rewrite will chang the buffer context of %lv, and also in LLVM generated code, %lv1 shared the same storage with %lv, so when executing `%lv = concat(%lv, %lv1, axis=1)`, the %lv1 context has also been changed to `multiply(%lv, %lv1)`. So the failure is due to the shared storage of different views of the same tensor %x. During the execution, %lv1 holds `x^2` instead of `x` after `multiply`. `concat` reads %lv1 for the right column and its result is [[1,1],[4,4],[9,9],[16,16]] instead of [[1,1],[4,2],[9,3],[16,4]] (the correct result should be : left col `x^2`, right col should stay `x`). Change: View-like ops (expand_dims, squeeze, reshape, permute_dims, memory.view, ensure_zero_offset) take the input's alias set in alias analysis instead of a new id: %lv and %lv1 share alias with %x. Then the pass rejects in-place of `multiply(%lv, %lv1)`: %lv and %lv1 are different vars but alias ids intersect, so no operand may be reused in-place.
There was a problem hiding this comment.
Code Review
This pull request enhances the Relax in-place transformation pass by tracking alias sets for view-like memory operations (such as reshape, squeeze, and expand_dims). This prevents invalid in-place optimizations when multiple view operations share the same underlying input storage. The feedback identifies a correctness bug in the newly introduced InplaceArgDisjointFromOtherCallArgs function, where checking for -1 (unknown alias) in other_set can prematurely skip checking other valid alias indices, potentially leading to incorrect in-place rewrites. A simplification to use same_as is also suggested.
tlopex
left a comment
There was a problem hiding this comment.
Thanks for the fix. The alias-set approach looks like the right direction, but I think the view-op list is still incomplete.
relax.flatten and likely relax.nn.batch_flatten are also reshape-like/view-like ops: they legalize to reshape-style TIR and can later become runtime reshape/CreateView instead of a real copy. The same corruption pattern can still happen with two flatten(x) results used as operands to an in-place binary op and then one view reused later.
Could you include these reshape-like ops in IsViewMemoryOp and add regression coverage for them? I would also avoid using tensor.numpy().__array_interface__["data"] as a runtime storage check, since Tensor.numpy() copies into a NumPy buffer rather than exposing the TVM tensor storage.
This PR will fix #19577.
In this issue, the IRModule before applying any pass looks like:
When the users manually apply the
DataflowUseInplaceCallspass, the pass will rewrite the statement%lv2 = multiply(%lv, %lv1)to be like%lv = multiply(%lv, %lv1); %lv3 = concat(%lv, %lv1, axis=1), which reuses the %lv buffer to avoid storage waste.But this rewrite will chang the buffer context of %lv, and also in LLVM generated code, %lv1 shared the same storage with %lv, so when executing
%lv = concat(%lv, %lv1, axis=1), the %lv1 context has also been changed tomultiply(%lv, %lv1). So the failure is due to the shared storage of different views of the same tensor %x.During the execution, %lv1 holds
x^2instead ofxaftermultiply.concatreads %lv1 for the right column and its result is [[1,1],[4,4],[9,9],[16,16]] instead of [[1,1],[4,2],[9,3],[16,4]] (the correct result should be : left colx^2, right col should stayx).Change: View-like ops (expand_dims, squeeze, reshape, permute_dims, memory.view, ensure_zero_offset) take the input's alias set in alias analysis instead of a new id: %lv and %lv1 share alias with %x. Then the pass rejects in-place of
multiply(%lv, %lv1): %lv and %lv1 are different vars but alias ids intersect, so no operand may be reused in-place.