diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index e9bddc4500bb..a2ebed04807e 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -389,6 +389,14 @@ def _log_softmax(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) + def _logical_not(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + # torch.logical_not accepts any dtype (treating nonzero as True) and returns bool, but + # relax.op.logical_not requires a boolean input, so cast non-bool inputs to bool first. + if x.struct_info.dtype != "bool": + x = self.block_builder.emit(relax.op.astype(x, "bool")) + return self.block_builder.emit(relax.op.logical_not(x)) + def _prelu(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] alpha = self.env[node.args[1]] diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 596dc60f555e..26f5a5918ca9 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1551,7 +1551,7 @@ def create_convert_map( "log2.default": self._log2, "log10.default": self._log10, "log1p.default": self._log1p, - "logical_not.default": self._unary_op(relax.op.logical_not), + "logical_not.default": self._logical_not, "logical_and.default": self._binary_op(relax.op.logical_and, operator.and_), "log_softmax.int": self._log_softmax, "_log_softmax.default": self._log_softmax, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index d4dd6902ae54..9d27f62b423d 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -875,7 +875,7 @@ def create_convert_map( "log2": self._log2, "log10": self._log10, "log1p": self._log1p, - "logical_not": self._unary_op(relax.op.logical_not), + "logical_not": self._logical_not, "log_softmax": self._log_softmax, "neg": self._unary_op(relax.op.negative), "pad": self._pad, diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 6b758c1ba7ec..d1bdad757807 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -1062,6 +1062,29 @@ def main( verify_model(LogAddExp(), example_args, {}, expected) +def test_logical_not(): + class LogicalNot(Module): + def forward(self, input): + return torch.logical_not(input) + + @tvm.script.ir_module + class expected: + @R.function + def main(input: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tuple( + R.Tensor((1, 3, 10, 10), dtype="bool") + ): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(input, dtype="bool") + lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.logical_not(lv) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="bool")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(LogicalNot(), example_args, {}, expected) + + def test_logsoftmax(): class LogSoftmax(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index 410875985e42..1bf71fb6eb03 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -3195,11 +3195,12 @@ def forward(self, input): class expected_logical_not: @R.function def main(inp_0: R.Tensor((1, 3, 10, 10), dtype="float32")) -> R.Tensor( - (1, 3, 10, 10), dtype="float32" + (1, 3, 10, 10), dtype="bool" ): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.logical_not(inp_0) - gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + lv: R.Tensor((1, 3, 10, 10), dtype="bool") = R.astype(inp_0, dtype="bool") + lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.logical_not(lv) + gv: R.Tensor((1, 3, 10, 10), dtype="bool") = lv1 R.output(gv) return gv