Skip to content

[TOPI] Use branchless boundary index for reflect/replicate pad#19928

Open
junghyunpark2001 wants to merge 2 commits into
apache:mainfrom
junghyunpark2001:perf/pad-reflect-replicate-branchless
Open

[TOPI] Use branchless boundary index for reflect/replicate pad#19928
junghyunpark2001 wants to merge 2 commits into
apache:mainfrom
junghyunpark2001:perf/pad-reflect-replicate-branchless

Conversation

@junghyunpark2001

Copy link
Copy Markdown

reflect_pad and replicate_pad computed the boundary source index with nested if_then_else, which lowers to per-element branches on CUDA and made both modes ~1.4-1.5x slower than constant/circular since v0.21.0.

Replace them with branchless integer expressions that are bit-identical over the valid pad domain:
reflect-101: m = size - 1; idx = m - abs(m - abs(orig_idx))
replicate: idx = max(0, min(size - 1, orig_idx))

Measured on RTX 4060 Ti (sm_89, CUDA 11.8): reflect 7.37->4.73us (1.50x), replicate 7.14->4.31us (1.43x), output bit-identical. Also removes the now-unused if_then_else import.

Fixes #19848

reflect_pad and replicate_pad computed the boundary source index with
nested if_then_else, which lowers to per-element branches on CUDA and made
both modes ~1.4-1.5x slower than constant/circular since v0.21.0.

Replace them with branchless integer expressions that are bit-identical
over the valid pad domain:
  reflect-101: m = size - 1; idx = m - abs(m - abs(orig_idx))
  replicate:   idx = max(0, min(size - 1, orig_idx))

Measured on RTX 4060 Ti (sm_89, CUDA 11.8): reflect 7.37->4.73us (1.50x),
replicate 7.14->4.31us (1.43x), output bit-identical. Also removes the
now-unused if_then_else import.

Fixes apache#19848

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request replaces nested if_then_else branches with branchless integer arithmetic for both the reflect-101 boundary index and edge clamp padding calculations in python/tvm/topi/nn/pad.py. The review feedback highlights potential type mismatch issues when mixing different integer types (such as int32 and int64) during lowering. Specifically, it suggests casting size - 1 and the constant 0 to match orig_idx.dtype to ensure type safety and prevent compilation failures.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread python/tvm/topi/nn/pad.py
Comment on lines +217 to +218
m = size - 1
reflected_idx = m - tvm.tirx.abs(m - tvm.tirx.abs(orig_idx))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If size (which is data.shape[i]) and orig_idx have different integer types (e.g., int32 vs int64), the subtraction m - tvm.tirx.abs(orig_idx) can cause a type mismatch error during lowering. To ensure type safety and robustness, we should explicitly cast m to orig_idx.dtype using tvm.tirx.Cast.

Suggested change
m = size - 1
reflected_idx = m - tvm.tirx.abs(m - tvm.tirx.abs(orig_idx))
m = tvm.tirx.Cast(orig_idx.dtype, size - 1)
reflected_idx = m - tvm.tirx.abs(m - tvm.tirx.abs(orig_idx))

Comment thread python/tvm/topi/nn/pad.py
# Branchless edge clamp. This is bit-identical to the nested
# if_then_else form (0 below 0, size-1 at or above size, orig_idx
# otherwise) but lowers to min/max instead of per-element branches.
clamped_idx = tvm.tirx.max(tvm.tirx.const(0, "int32"), tvm.tirx.min(size - 1, orig_idx))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a hardcoded "int32" for the constant 0 can lead to type mismatch errors if orig_idx is of type int64 (which is common when 64-bit indexing is enabled or for large tensors). To ensure robustness and prevent compilation failures, we should use orig_idx.dtype instead of "int32" and cast size - 1 to orig_idx.dtype as well.

Suggested change
clamped_idx = tvm.tirx.max(tvm.tirx.const(0, "int32"), tvm.tirx.min(size - 1, orig_idx))
clamped_idx = tvm.tirx.max(tvm.tirx.const(0, orig_idx.dtype), tvm.tirx.min(tvm.tirx.Cast(orig_idx.dtype, size - 1), orig_idx))

@yongwww yongwww closed this Jul 2, 2026
@yongwww yongwww reopened this Jul 2, 2026
TestAutopad.test_edge pinned the exact TIR of replicate_pad, which changed
when the edge clamp was rewritten from nested if_then_else to
max(0, min(size - 1, idx)). Regenerate the golden; output is unchanged.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

# [Performance] nn.pad reflect/replicate ~1.5x slower than needed since v0.21 (nested if_then_else boundary index)

2 participants