From 9b8779450aa0c6ed8aa42193cf21cecbbaafec02 Mon Sep 17 00:00:00 2001 From: javierdejesusda Date: Sat, 30 May 2026 21:01:55 +0200 Subject: [PATCH] [Relax][PyTorch] Cast non-bool inputs to bool in logical_not converter torch.logical_not returns a bool tensor for any input dtype, but the frontend lowered it with a plain unary op that passes the input dtype through, so a float32 input produced a float32 result instead of bool. Add a shared _logical_not converter in BaseFXGraphImporter that casts non-bool inputs to bool before applying relax.op.logical_not, and wire up both the FX and ExportedProgram frontends. Update the tests to assert the corrected bool output. --- .../torch/base_fx_graph_translator.py | 8 +++++++ .../torch/exported_program_translator.py | 2 +- .../tvm/relax/frontend/torch/fx_translator.py | 2 +- .../test_frontend_from_exported_program.py | 23 +++++++++++++++++++ tests/python/relax/test_frontend_from_fx.py | 7 +++--- 5 files changed, 37 insertions(+), 5 deletions(-) diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index e9bddc4500bb..a2ebed04807e 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -389,6 +389,14 @@ def _log_softmax(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) + def _logical_not(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + # torch.logical_not accepts any dtype (treating nonzero as True) and returns bool, but + # relax.op.logical_not requires a boolean input, so cast non-bool inputs to bool first. + if x.struct_info.dtype != "bool": + x = self.block_builder.emit(relax.op.astype(x, "bool")) + return self.block_builder.emit(relax.op.logical_not(x)) + def _prelu(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] alpha = self.env[node.args[1]] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 596dc60f555e..26f5a5918ca9 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1551,7 +1551,7 @@ def create_convert_map( "log2.default": self._log2, "log10.default": self._log10, "log1p.default": self._log1p, - "logical_not.default": self._unary_op(relax.op.logical_not), + "logical_not.default": self._logical_not, "logical_and.default": self._binary_op(relax.op.logical_and, operator.and_), "log_softmax.int": self._log_softmax, "_log_softmax.default": self._log_softmax, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index d4dd6902ae54..9d27f62b423d 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -875,7 +875,7 @@ def create_convert_map( "log2": self._log2, "log10": self._log10, "log1p": self._log1p, - "logical_not": self._unary_op(relax.op.logical_not), + "logical_not": self._logical_not, "log_softmax": self._log_softmax, "neg": self._unary_op(relax.op.negative), "pad": self._pad, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 6b758c1ba7ec..d1bdad757807 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1062,6 +1062,29 @@ def main( verify_model(LogAddExp(), example_args, {}, expected) +def test_logical_not(): + class LogicalNot(Module): + def forward(self, input): + return torch.logical_not(input) + + @tvm.script.ir_module + class expected: + @R.function + def main(input: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="bool") + ): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(input, dtype="bool") + lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.logical_not(lv) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(LogicalNot(), example_args, {}, expected) + + def test_logsoftmax(): class LogSoftmax(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 410875985e42..1bf71fb6eb03 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3195,11 +3195,12 @@ def forward(self, input): class expected_logical_not: @R.function def main(inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tensor( - (1, 3, 10, 10), dtype="float32" + (1, 3, 10, 10), dtype="bool" ): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.logical_not(inp_0) - gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(inp_0, dtype="bool") + lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.logical_not(lv) + gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv1 R.output(gv) return gv