-
Notifications
You must be signed in to change notification settings - Fork 270
Replace O(N) recursive sequence_map_inverse with O(1) pack expansion #3596
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
59f0c32 to
5190578
Compare
6d792da to
f5ada17
Compare
5190578 to
887bdf2
Compare
887bdf2 to
02e42dc
Compare
f5ada17 to
9942fd6
Compare
9d67d0d to
c4d95f7
Compare
82b6016 to
602c127
Compare
c4d95f7 to
631df4f
Compare
602c127 to
1713ea7
Compare
cbaf07b to
3b8b37d
Compare
3b8b37d to
7c9cdf0
Compare
d162e26 to
f8d808e
Compare
e921e01 to
bd98bd1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR optimizes sequence_map_inverse by replacing O(N) recursive template instantiation with O(1) template depth using pack expansion and constexpr loops, reducing compilation overhead.
Changes:
- Replaced recursive
sequence_map_inverse_implwithConstexprArrayand constexpr loop-basedfind_inverse - Added detailed comments explaining the compilation performance benefits
- Achieved 1.6% reduction in template instantiations (126,896 fewer instantiations)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
include/ck/utility/sequence.hpp
Outdated
| if(values[i] == target) | ||
| return i; | ||
| } | ||
| return -1; // should not reach for valid permutation |
Copilot
AI
Jan 23, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return value -1 for an invalid permutation is misleading since index_t is likely unsigned. Consider using a static_assert or explicit error handling to catch invalid permutations at compile-time, or document that this path should never execute for valid inputs.
| return -1; // should not reach for valid permutation | |
| return static_cast<index_t>(-1); // should not reach for valid permutation |
shumway
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we want to remove the O(N^2) loop and probably add memoization, too.
include/ck/utility/sequence.hpp
Outdated
| if(values[i] == target) | ||
| return i; | ||
| } | ||
| return -1; // should not reach for valid permutation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this a compile-time error, just to catch anything that is really broken?
The two patterns I know are to use a static_assert or to call a consteval function that throws an error. We don't want to have this logic silently fail.
| // to expand over. Sequence<0,1,2> gives us Positions = 0,1,2, which expands to: | ||
| // Sequence<find_inverse(0), find_inverse(1), find_inverse(2)> | ||
| // Without a pack, we'd need recursion to generate each element - defeating our goal. | ||
| template <index_t... Positions> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like we have an O(N^2) evaluation of the permutation inverse. That's much better than a recursive template instantiation, but we can probably be faster:
- Write a direct permutation inverse:
template <index_t... Positions>
static constexpr auto compute(Sequence<Positions...>)
{
// Build result array in one pass
detail::ConstexprArray<index_t, sizeof...(Is)> result = {};
index_t pos = 0;
((result[values[pos++]] = Positions), ...); // fold expression
return Sequence<result[Positions]...>{};
}
- Maybe use also static constexpr templated variable to cache the inverted permutations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking into the instantiation patterns for a conv device instances target, I see the maximum sequence length is 9 and all of them are valid permutations.
There are 12 unique instances. Most of them are identity permutations. Maybe we can start with a shortcut for inverting identity and rolling back the separate array struct, doesn't seem worth the introduced maintenance complexity and the constant factor effect may be worse than the overhead of local small arrays
cgmillette
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this handle repeated indices?
bd98bd1 to
7a427d0
Compare
Summary
Replace the O(N) recursive
sequence_map_inverseimplementation with O(1) template depth using pack expansion.Approach
constexprloop infind_source_indexto locate permutation inverse indicesWhy It Works
Template recursion requires N template instantiations for N iterations, each with its own overhead. Constexpr loops execute within a single template instantiation, avoiding per-instantiation overhead.
Build Performance Impact
Template Instantiation Reduction (measured on
device_grouped_conv3d_fwd_bias_bnorm_clamp_instancetarget, 248 files):This confirms the optimization successfully reduces template instantiation overhead by eliminating recursive template patterns in favor of pack expansion.
Test Plan
SequenceMapInverse.InverseMapandSequenceMapInverse.InverseIdentityMaptests validate correctnessNotes
sequence_mergeoptimization removed from this PR (handled in Optimize sequence_gen and uniform_sequence_gen to reduce template instantiation depth #3585)is_valid_sequence_mapbefore callingsequence_map_inverse