Skip to content

Commit 84d05e9

Browse files
authored
[LLVM][Codegen] Cast NaN to bool gives true (#18772)
Due to some recent changes to CI, PR #18646 will be closed and a new PR created. This PR was created by opening a new pull request on the main branch. - Update test code: using tvmscript instead of schedule te. ### Summary Cast NaN to bool gives true to ensure consistency with the existing framework (C, C++, Python, Torch, NumPy, OnnxRuntime, ...). ### Steps to Reproduce - Python: ``` bool(float('nan')) ``` > True - Torch: ``` torch.tensor(float("nan"), dtype=torch.float32).to(torch.bool) ``` > tensor(True) - Numpy: ``` import numpy as np bool(np.nan) ``` > True - TVM: ``` class Module: def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((), dtype="bool"): with R.dataflow(): gv: R.Tensor((), dtype="bool") = R.astype(x, dtype="bool") R.output(gv) return gv x = np.array(float("nan"), dtype="float32") ``` > False ### Expected - TVM: ``` class Module: def main(x: R.Tensor((), dtype="float32")) -> R.Tensor((), dtype="bool"): with R.dataflow(): gv: R.Tensor((), dtype="bool") = R.astype(x, dtype="bool") R.output(gv) return gv x = np.array(float("nan"), dtype="float32") ``` > True ### Resolved - Replace the instruction `fcmp one` with `fcmp une` in LLVM. - Citation: https://releases.llvm.org/20.1.0/docs/LangRef.html#fcmp-instruction <img width="400" height="200" alt="PR1-18605" src="https://github.com/user-attachments/assets/cffeebd8-dfe6-436e-9c4c-61e1e84d5439" /> - Related: + https://stackoverflow.com/questions/9158567/nan-to-bool-conversion-true-or-false + https://stackoverflow.com/questions/15686318/why-do-not-a-number-values-equal-true-when-cast-as-boolean-in-python-numpy - Fixed: #18605
1 parent 783a9bf commit 84d05e9

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

src/target/llvm/codegen_llvm.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -928,7 +928,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va
928928
} else if (to.is_bool()) {
929929
if (from.is_float()) {
930930
llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.);
931-
return builder_->CreateFCmpONE(value, zero);
931+
return builder_->CreateFCmpUNE(value, zero);
932932
} else {
933933
llvm::Constant* zero = llvm::ConstantInt::get(DTypeToLLVMType(from), 0);
934934
return builder_->CreateICmpNE(value, zero);

tests/python/codegen/test_target_codegen_llvm.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,30 @@ def main(A: T.Buffer((64,), "int32"), C: T.Buffer((64,), "float32")):
370370
tvm.testing.assert_allclose(c.numpy(), c_np)
371371

372372

373+
@tvm.testing.requires_llvm
374+
def test_llvm_cast_float_to_bool():
375+
@I.ir_module
376+
class Module:
377+
@T.prim_func
378+
def main(A: T.Buffer((4,), "float32"), C: T.Buffer((4,), "bool")):
379+
T.func_attr({"tir.noalias": True})
380+
for i in range(4):
381+
with T.sblock("C"):
382+
v_i = T.axis.spatial(4, i)
383+
T.reads(A[v_i])
384+
T.writes(C[v_i])
385+
C[v_i] = T.Cast("bool", A[v_i])
386+
387+
n = 4
388+
f = tvm.compile(Module, target="llvm")
389+
dev = tvm.cpu(0)
390+
a = tvm.runtime.tensor(np.array([0.0, 1.0, np.nan, np.inf], dtype="float32"), dev)
391+
c = tvm.runtime.empty((n,), dtype="bool", device=dev)
392+
f(a, c)
393+
c_np = np.array([False, True, True, True], dtype="bool")
394+
tvm.testing.assert_allclose(c.numpy(), c_np)
395+
396+
373397
@tvm.testing.requires_llvm
374398
def test_rank_zero():
375399
@I.ir_module

0 commit comments

Comments
 (0)