-
Notifications
You must be signed in to change notification settings - Fork 269
Optimize sequence_gen and uniform_sequence_gen to reduce template instantiation depth #3585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -199,55 +199,113 @@ template <index_t N> | |
| using make_index_sequence = | ||
| typename __make_integer_seq<impl::__integer_sequence, index_t, N>::seq_type; | ||
|
|
||
| // merge sequence | ||
| template <typename Seq, typename... Seqs> | ||
| struct sequence_merge | ||
| // merge sequence - optimized to avoid recursive instantiation | ||
| // | ||
| // Note: Unlike sequence_gen and uniform_sequence_gen which use __make_integer_seq for O(1) | ||
| // instantiation depth, sequence_merge cannot achieve O(1) depth. Here's why: | ||
| // | ||
| // - sequence_gen and uniform_sequence_gen generate a SINGLE output sequence where each | ||
| // element can be computed independently: output[i] = f(i) | ||
| // | ||
| // - sequence_merge takes MULTIPLE input sequences with different, unknown lengths. | ||
| // To compute output[i], we need to know: | ||
| // 1. Which input sequence contains this index | ||
| // 2. The offset within that sequence | ||
| // This requires computing cumulative sequence lengths, which requires recursion/iteration. | ||
| // | ||
| // Instead, we use a binary tree reduction approach that achieves O(log N) instantiation depth: | ||
| // - Base cases handle 1-4 sequences directly (O(1) for common cases) | ||
| // - Recursive case merges pairs then combines: merge(s1,s2) + merge(s3,s4,...) | ||
| // - This gives O(log N) depth, which is optimal for merging heterogeneous sequences | ||
| // | ||
| // Alternative considered: Fold expressions (... + sequences) would give O(N) depth due to | ||
| // linear dependency chain, so binary tree is superior. | ||
| // | ||
| namespace detail { | ||
|
|
||
| // Helper to concatenate multiple sequences in one step using fold expression | ||
| template <typename... Seqs> | ||
| struct sequence_merge_impl; | ||
|
|
||
| // Base case: single sequence | ||
| template <index_t... Is> | ||
| struct sequence_merge_impl<Sequence<Is...>> | ||
| { | ||
| using type = typename sequence_merge<Seq, typename sequence_merge<Seqs...>::type>::type; | ||
| using type = Sequence<Is...>; | ||
| }; | ||
|
|
||
| // Two sequences: direct concatenation | ||
| template <index_t... Xs, index_t... Ys> | ||
| struct sequence_merge<Sequence<Xs...>, Sequence<Ys...>> | ||
| struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>> | ||
| { | ||
| using type = Sequence<Xs..., Ys...>; | ||
| }; | ||
|
|
||
| template <typename Seq> | ||
| struct sequence_merge<Seq> | ||
| // Three sequences: direct concatenation (avoids one level of recursion) | ||
| template <index_t... Xs, index_t... Ys, index_t... Zs> | ||
| struct sequence_merge_impl<Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>> | ||
| { | ||
| using type = Seq; | ||
| using type = Sequence<Xs..., Ys..., Zs...>; | ||
| }; | ||
|
|
||
| // generate sequence | ||
| template <index_t NSize, typename F> | ||
| struct sequence_gen | ||
| // Four sequences: direct concatenation | ||
| template <index_t... As, index_t... Bs, index_t... Cs, index_t... Ds> | ||
| struct sequence_merge_impl<Sequence<As...>, Sequence<Bs...>, Sequence<Cs...>, Sequence<Ds...>> | ||
| { | ||
| template <index_t IBegin, index_t NRemain, typename G> | ||
| struct sequence_gen_impl | ||
| { | ||
| static constexpr index_t NRemainLeft = NRemain / 2; | ||
| static constexpr index_t NRemainRight = NRemain - NRemainLeft; | ||
| static constexpr index_t IMiddle = IBegin + NRemainLeft; | ||
| using type = Sequence<As..., Bs..., Cs..., Ds...>; | ||
| }; | ||
|
|
||
| using type = typename sequence_merge< | ||
| typename sequence_gen_impl<IBegin, NRemainLeft, G>::type, | ||
| typename sequence_gen_impl<IMiddle, NRemainRight, G>::type>::type; | ||
| }; | ||
| // General case: binary tree reduction (O(log N) depth instead of O(N)) | ||
| template <typename S1, typename S2, typename S3, typename S4, typename... Rest> | ||
| struct sequence_merge_impl<S1, S2, S3, S4, Rest...> | ||
| { | ||
| // Merge pairs first, then recurse | ||
| using left = typename sequence_merge_impl<S1, S2>::type; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I may be misreading this, but since we have three hard-coded, is faster to use <S1, S2, S3> for left and <S4, Rest...> for right. Also, my reading is that this is still O(N) when N is a large number of sequences, but we're cutting down the prefactor and hardcoding smaller numbers of sequences.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, good catch! covering ~70% of the instantiations with the base cases is probably the actual impact. we can follow up with it if it shows up in the build traces again |
||
| using right = typename sequence_merge_impl<S3, S4, Rest...>::type; | ||
| using type = typename sequence_merge_impl<left, right>::type; | ||
| }; | ||
|
|
||
| template <index_t I, typename G> | ||
| struct sequence_gen_impl<I, 1, G> | ||
| { | ||
| static constexpr index_t Is = G{}(Number<I>{}); | ||
| using type = Sequence<Is>; | ||
| }; | ||
| } // namespace detail | ||
|
|
||
| template <index_t I, typename G> | ||
| struct sequence_gen_impl<I, 0, G> | ||
| { | ||
| using type = Sequence<>; | ||
| }; | ||
| template <typename... Seqs> | ||
| struct sequence_merge | ||
| { | ||
| using type = typename detail::sequence_merge_impl<Seqs...>::type; | ||
| }; | ||
|
|
||
| template <> | ||
| struct sequence_merge<> | ||
| { | ||
| using type = Sequence<>; | ||
| }; | ||
|
|
||
| // generate sequence - optimized using __make_integer_seq to avoid recursive instantiation | ||
| namespace detail { | ||
|
|
||
| // Helper that applies functor F to indices and produces a Sequence | ||
| // __make_integer_seq<sequence_gen_helper, index_t, N> produces sequence_gen_helper<index_t, 0, 1, | ||
| // ..., N-1> | ||
| template <typename T, T... Is> | ||
| struct sequence_gen_helper | ||
| { | ||
| // Apply a functor F to all indices at once via pack expansion (O(1) depth) | ||
| template <typename F> | ||
| using apply = Sequence<F{}(Number<Is>{})...>; | ||
| }; | ||
|
|
||
| } // namespace detail | ||
|
|
||
| using type = typename sequence_gen_impl<0, NSize, F>::type; | ||
| template <index_t NSize, typename F> | ||
| struct sequence_gen | ||
| { | ||
| using type = | ||
| typename __make_integer_seq<detail::sequence_gen_helper, index_t, NSize>::template apply<F>; | ||
| }; | ||
|
|
||
| template <typename F> | ||
| struct sequence_gen<0, F> | ||
| { | ||
| using type = Sequence<>; | ||
| }; | ||
|
|
||
| // arithmetic sequence | ||
|
|
@@ -283,16 +341,30 @@ struct arithmetic_sequence_gen<0, IEnd, 1> | |
| using type = typename __make_integer_seq<WrapSequence, index_t, IEnd>::type; | ||
| }; | ||
|
|
||
| // uniform sequence | ||
| // uniform sequence - optimized using __make_integer_seq | ||
| namespace detail { | ||
|
|
||
| template <typename T, T... Is> | ||
| struct uniform_sequence_helper | ||
| { | ||
| // Apply a constant value to all indices via pack expansion | ||
| template <index_t Value> | ||
| using apply = Sequence<((void)Is, Value)...>; | ||
| }; | ||
|
|
||
| } // namespace detail | ||
|
|
||
| template <index_t NSize, index_t I> | ||
| struct uniform_sequence_gen | ||
| { | ||
| struct F | ||
| { | ||
| __host__ __device__ constexpr index_t operator()(index_t) const { return I; } | ||
| }; | ||
| using type = typename __make_integer_seq<detail::uniform_sequence_helper, index_t, NSize>:: | ||
| template apply<I>; | ||
| }; | ||
|
|
||
| using type = typename sequence_gen<NSize, F>::type; | ||
| template <index_t I> | ||
| struct uniform_sequence_gen<0, I> | ||
| { | ||
| using type = Sequence<>; | ||
| }; | ||
|
|
||
| // reverse inclusive scan (with init) sequence | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like these specializations. It will be interesting to get a survey of the code to see how often the specializations are used and if these four smallest cases are the most impactful ones.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm using the build traces to drive the optimizations. Maybe removing the unused code is one other aspect which could help with parsing times