[TOPI] Use branchless boundary index for reflect/replicate pad#19928
[TOPI] Use branchless boundary index for reflect/replicate pad#19928junghyunpark2001 wants to merge 2 commits into
Conversation
reflect_pad and replicate_pad computed the boundary source index with nested if_then_else, which lowers to per-element branches on CUDA and made both modes ~1.4-1.5x slower than constant/circular since v0.21.0. Replace them with branchless integer expressions that are bit-identical over the valid pad domain: reflect-101: m = size - 1; idx = m - abs(m - abs(orig_idx)) replicate: idx = max(0, min(size - 1, orig_idx)) Measured on RTX 4060 Ti (sm_89, CUDA 11.8): reflect 7.37->4.73us (1.50x), replicate 7.14->4.31us (1.43x), output bit-identical. Also removes the now-unused if_then_else import. Fixes apache#19848
There was a problem hiding this comment.
Code Review
This pull request replaces nested if_then_else branches with branchless integer arithmetic for both the reflect-101 boundary index and edge clamp padding calculations in python/tvm/topi/nn/pad.py. The review feedback highlights potential type mismatch issues when mixing different integer types (such as int32 and int64) during lowering. Specifically, it suggests casting size - 1 and the constant 0 to match orig_idx.dtype to ensure type safety and prevent compilation failures.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| m = size - 1 | ||
| reflected_idx = m - tvm.tirx.abs(m - tvm.tirx.abs(orig_idx)) |
There was a problem hiding this comment.
If size (which is data.shape[i]) and orig_idx have different integer types (e.g., int32 vs int64), the subtraction m - tvm.tirx.abs(orig_idx) can cause a type mismatch error during lowering. To ensure type safety and robustness, we should explicitly cast m to orig_idx.dtype using tvm.tirx.Cast.
| m = size - 1 | |
| reflected_idx = m - tvm.tirx.abs(m - tvm.tirx.abs(orig_idx)) | |
| m = tvm.tirx.Cast(orig_idx.dtype, size - 1) | |
| reflected_idx = m - tvm.tirx.abs(m - tvm.tirx.abs(orig_idx)) |
| # Branchless edge clamp. This is bit-identical to the nested | ||
| # if_then_else form (0 below 0, size-1 at or above size, orig_idx | ||
| # otherwise) but lowers to min/max instead of per-element branches. | ||
| clamped_idx = tvm.tirx.max(tvm.tirx.const(0, "int32"), tvm.tirx.min(size - 1, orig_idx)) |
There was a problem hiding this comment.
Using a hardcoded "int32" for the constant 0 can lead to type mismatch errors if orig_idx is of type int64 (which is common when 64-bit indexing is enabled or for large tensors). To ensure robustness and prevent compilation failures, we should use orig_idx.dtype instead of "int32" and cast size - 1 to orig_idx.dtype as well.
| clamped_idx = tvm.tirx.max(tvm.tirx.const(0, "int32"), tvm.tirx.min(size - 1, orig_idx)) | |
| clamped_idx = tvm.tirx.max(tvm.tirx.const(0, orig_idx.dtype), tvm.tirx.min(tvm.tirx.Cast(orig_idx.dtype, size - 1), orig_idx)) |
TestAutopad.test_edge pinned the exact TIR of replicate_pad, which changed when the edge clamp was rewritten from nested if_then_else to max(0, min(size - 1, idx)). Regenerate the golden; output is unchanged. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
reflect_pad and replicate_pad computed the boundary source index with nested if_then_else, which lowers to per-element branches on CUDA and made both modes ~1.4-1.5x slower than constant/circular since v0.21.0.
Replace them with branchless integer expressions that are bit-identical over the valid pad domain:
reflect-101: m = size - 1; idx = m - abs(m - abs(orig_idx))
replicate: idx = max(0, min(size - 1, orig_idx))
Measured on RTX 4060 Ti (sm_89, CUDA 11.8): reflect 7.37->4.73us (1.50x), replicate 7.14->4.31us (1.43x), output bit-identical. Also removes the now-unused if_then_else import.
Fixes #19848