[Relax][Frontend][ONNX] Add GroupNormalization support#19907
Conversation
There was a problem hiding this comment.
Code Review
This pull request adds support for converting ONNX GroupNormalization nodes (opsets 18 and 21) into equivalent Relax expressions, along with comprehensive unit tests to verify correctness. The feedback suggests optimizing the opset 18 implementation by resolving scale and bias to constants at import time using get_constant. If they are constants, they can be expanded directly using NumPy (_np.repeat), which avoids adding redundant reshape and broadcast_to operators to the generated Relax graph.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| data = inputs[0] | ||
| scale = inputs[1] | ||
| bias = inputs[2] |
There was a problem hiding this comment.
We can use get_constant to resolve scale and bias to constants if they are initializers. This allows us to perform the per-group to per-channel expansion at import time using NumPy, avoiding redundant reshape and broadcast_to operators in the Relax graph.
| data = inputs[0] | |
| scale = inputs[1] | |
| bias = inputs[2] | |
| data = inputs[0] | |
| scale = get_constant(inputs[1], params) | |
| bias = get_constant(inputs[2], params) |
| scale = relax.op.reshape(scale, [num_groups, 1]) | ||
| scale = relax.op.broadcast_to(scale, [num_groups, channels_per_group]) | ||
| scale = relax.op.reshape(scale, [channels]) | ||
|
|
||
| bias = relax.op.reshape(bias, [num_groups, 1]) | ||
| bias = relax.op.broadcast_to(bias, [num_groups, channels_per_group]) | ||
| bias = relax.op.reshape(bias, [channels]) |
There was a problem hiding this comment.
If scale and bias are resolved to constants, we can expand them directly using _np.repeat at import time. This simplifies the generated Relax graph by eliminating unnecessary reshape and broadcast_to operations.
| scale = relax.op.reshape(scale, [num_groups, 1]) | |
| scale = relax.op.broadcast_to(scale, [num_groups, channels_per_group]) | |
| scale = relax.op.reshape(scale, [channels]) | |
| bias = relax.op.reshape(bias, [num_groups, 1]) | |
| bias = relax.op.broadcast_to(bias, [num_groups, channels_per_group]) | |
| bias = relax.op.reshape(bias, [channels]) | |
| if isinstance(scale, relax.Constant): | |
| scale = relax.const(_np.repeat(scale.data.numpy(), channels_per_group), scale.ty.dtype) | |
| else: | |
| scale = relax.op.reshape(scale, [num_groups, 1]) | |
| scale = relax.op.broadcast_to(scale, [num_groups, channels_per_group]) | |
| scale = relax.op.reshape(scale, [channels]) | |
| if isinstance(bias, relax.Constant): | |
| bias = relax.const(_np.repeat(bias.data.numpy(), channels_per_group), bias.ty.dtype) | |
| else: | |
| bias = relax.op.reshape(bias, [num_groups, 1]) | |
| bias = relax.op.broadcast_to(bias, [num_groups, channels_per_group]) | |
| bias = relax.op.reshape(bias, [channels]) |
Summary
Adds ONNX frontend support for
GroupNormalizationby mapping it to the existingrelax.op.nn.group_norm.Supports opset 18 per-group scale/bias expansion, opset 21 per-channel scale/bias, and
stash_typecast behavior.Testing
Includes structural checks for opset 18, opset 21, rank-3 inputs, and fp16
stash_typepaths.