Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 23 additions & 36 deletions include/ck/tensor_description/tensor_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,28 +45,29 @@ struct TensorAdaptor
return BottomDimensionHiddenIds{};
}

__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
// Helper to get length of a top dimension from transforms
template <index_t I>
__host__ __device__ static constexpr auto
GetTopDimLengthFromTransforms(const Transforms& transforms)
{
const auto lengths = generate_tuple(
[&](auto idim_top) {
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top);

constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
constexpr bool found = tmp[Number<2>{}];

static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");

const auto length =
transforms[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
constexpr auto result = find_in_tuple_of_sequences<TopDimensionHiddenIds::At(Number<I>{})>(
UpperDimensionHiddenIdss{});
static_assert(result.found, "wrong! not found matching transformation and upper-dimension");
return transforms[Number<result.itran>{}].GetUpperLengths()[Number<result.idim_up>{}];
}

return length;
},
Number<ndim_top_>{});
// Compute element size using pack expansion instead of generate_tuple with lambda
template <index_t... Is>
__host__ __device__ static constexpr auto ComputeElementSizeImpl(const Transforms& transforms,
Sequence<Is...>)
{
return (GetTopDimLengthFromTransforms<Is>(transforms) * ...);
}

// TODO: make container_reduce support tuple of Number and index_t
return container_reduce(lengths, math::multiplies{}, Number<1>{});
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
{
return ComputeElementSizeImpl(transforms,
typename arithmetic_sequence_gen<0, ndim_top_, 1>::type{});
}

template <index_t IDim>
Expand All @@ -76,24 +77,10 @@ struct TensorAdaptor

constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top);

index_t itran_found = 0;
index_t idim_up_found = 0;
bool found = false;

static_for<0, ntransform_, 1>{}([&](auto itran) {
constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran];

static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == idim_hidden)
{
itran_found = itran;
idim_up_found = idim_up;
found = true;
}
});
});
// Use compile-time search helper instead of nested static_for with lambdas
constexpr auto result = find_in_tuple_of_sequences<idim_hidden>(UpperDimensionHiddenIdss{});

return make_tuple(itran_found, idim_up_found, found);
return make_tuple(result.itran, result.idim_up, result.found);
}

__host__ __device__ static constexpr index_t GetNumOfBottomDimension()
Expand Down
62 changes: 26 additions & 36 deletions include/ck/tensor_description/tensor_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,28 +51,29 @@ struct TensorDescriptor
return unique_sort_all_dim_ids::Size();
}

__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
// Helper to get length of a visible dimension from transforms
template <index_t I>
__host__ __device__ static constexpr auto
GetVisibleDimLengthFromTransforms(const Transforms& transforms)
{
const auto lengths = generate_tuple(
[&](auto idim_visible) {
constexpr auto tmp = GetTransformAndItsUpperDimension(idim_visible);

constexpr index_t itran = tmp[Number<0>{}];
constexpr index_t idim_up = tmp[Number<1>{}];
constexpr bool found = tmp[Number<2>{}];

static_assert(found == true,
"wrong! not found matching transformation and upper-dimension");

const auto length =
transforms[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
constexpr auto result =
find_in_tuple_of_sequences<VisibleDimensionIds::At(Number<I>{})>(UpperDimensionIdss{});
static_assert(result.found, "wrong! not found matching transformation and upper-dimension");
return transforms[Number<result.itran>{}].GetUpperLengths()[Number<result.idim_up>{}];
}

return length;
},
Number<ndim_visible_>{});
// Compute element size using pack expansion instead of generate_tuple with lambda
template <index_t... Is>
__host__ __device__ static constexpr auto ComputeElementSizeImpl(const Transforms& transforms,
Sequence<Is...>)
{
return (GetVisibleDimLengthFromTransforms<Is>(transforms) * ...);
}

// TODO: make container_reduce support tuple of Number and index_t
return container_reduce(lengths, math::multiplies{}, Number<1>{});
__host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms)
{
return ComputeElementSizeImpl(
transforms, typename arithmetic_sequence_gen<0, ndim_visible_, 1>::type{});
}

template <index_t IDim>
Expand All @@ -82,24 +83,13 @@ struct TensorDescriptor

constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible);

index_t itran_found = 0;
index_t idim_up_found = 0;
bool found = false;

static_for<0, ntransform_, 1>{}([&](auto itran) {
constexpr auto up_dim_ids = UpperDimensionIdss{}[itran];

static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) {
if constexpr(up_dim_ids[idim_up] == idim_hidden)
{
itran_found = itran;
idim_up_found = idim_up;
found = true;
}
});
});
// Use compile-time search helper instead of nested static_for loops
// This optimization significantly reduces applier::operator() template instantiations
// by replacing nested lambda-based loops with a single constexpr search function.
// See sequence_helper.hpp::find_in_tuple_of_sequences for implementation details.
constexpr auto result = find_in_tuple_of_sequences<idim_hidden>(UpperDimensionIdss{});

return make_tuple(itran_found, idim_up_found, found);
return make_tuple(result.itran, result.idim_up, result.found);
}

constexpr static index_t ntransform_ = GetNumOfTransform();
Expand Down
130 changes: 130 additions & 0 deletions include/ck/utility/sequence_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,134 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
return Sequence<Is...>{};
}

// sequence_find_value - O(1) template depth constexpr search
//
// Optimization: Constexpr loop with array lookup instead of recursive template pattern
//
// Why this approach:
// - Recursive template (OLD): template instantiation for each recursion level → O(N) instantiations
// Example: Finding value in Sequence<1,2,3,4,5> requires 5 recursive instantiations
//
// - Constexpr loop (NEW): Single function instantiation with runtime loop → O(1) instantiation
// Same search requires only 1 function instantiation, loop executes at compile-time
//
// Implementation details:
// 1. Pack expansion creates constexpr array: {(Is == Target)...}
// 2. Constexpr for loop searches the array
// 3. Entire function evaluates at compile-time (no runtime cost)
//
// Impact:
// - Significantly reduces template instantiation depth for sequence search operations
// - Dramatically improves compilation time vs recursive template approach
// - Pattern applies to any compile-time search/lookup operation
//
// Trade-off: Uses constexpr evaluation instead of pure template metaprogramming.
// Requires C++14 constexpr but results in dramatically better compile times.
//
template <index_t Target, index_t... Is>
__host__ __device__ constexpr index_t sequence_find_value(Sequence<Is...>)
{
if constexpr(sizeof...(Is) == 0)
{
return -1;
}
else
{
constexpr bool matches[] = {(Is == Target)...};
for(index_t i = 0; i < static_cast<index_t>(sizeof...(Is)); ++i)
{
if(matches[i])
return i;
}
return -1;
}
}

// Result type for find_in_tuple_of_sequences
template <index_t ITran, index_t IDimUp, bool Found>
struct FindTransformResult
{
static constexpr index_t itran = ITran;
static constexpr index_t idim_up = IDimUp;
static constexpr bool found = Found;
};

// find_in_tuple_of_sequences - finds which sequence contains a target value
//
// Optimization: Pack expansion with constexpr search instead of nested static_for loops
//
// Why this approach:
// - Nested static_for (OLD): Creates lambda closure for each iteration level
// Example: Searching Tuple<Seq<0,1>, Seq<2,3>, Seq<4,5>> creates multiple applier::operator()
// instantiations Result: Many applier instantiations for typical tensor descriptor operations
//
// - Pack expansion + constexpr (NEW): Single function with compile-time array search
// Example: Same search creates constexpr array, single search function
// Result: 1 function instantiation regardless of tuple size
//
// Implementation:
// 1. Pack expansion: sequence_find_value<Target>(Seqs{})... applies search to each sequence
// 2. Results collected in constexpr array
// 3. Linear search finds first non-negative result (sequence containing target)
//
// Impact:
// - Significantly reduces applier::operator() instantiations in tensor descriptor transforms
// - O(1) template depth instead of O(N*M) for N sequences of length M
//
// Use case: Finding which dimension index contains a specific value (common in tensor reordering)
//
template <index_t Target, typename... Seqs>
struct FindInTupleOfSequencesCompute
{
private:
// Result struct for constexpr computation
struct ResultData
{
index_t itran;
index_t idim_up;
bool found;
};

// Compute result using constexpr function with array lookup
static constexpr ResultData compute()
{
if constexpr(sizeof...(Seqs) == 0)
{
return {0, 0, false};
}
else
{
// Pack expansion creates array - O(1) template depth
constexpr index_t indices[] = {sequence_find_value<Target>(Seqs{})...};

// Find first matching sequence
for(index_t i = 0; i < static_cast<index_t>(sizeof...(Seqs)); ++i)
{
if(indices[i] >= 0)
{
return {i, indices[i], true};
}
}
return {0, 0, false};
}
}

static constexpr ResultData result_ = compute();

public:
static constexpr index_t itran = result_.itran;
static constexpr index_t idim_up = result_.idim_up;
static constexpr bool found = result_.found;

using type = FindTransformResult<itran, idim_up, found>;
};

// Find target value in a tuple of sequences
// Returns FindTransformResult<itran, idim_up, found>
// Uses O(1) template depth via pack expansion (no recursion)
template <index_t Target, typename... Seqs>
__host__ __device__ constexpr auto find_in_tuple_of_sequences(Tuple<Seqs...>)
{
return typename FindInTupleOfSequencesCompute<Target, Seqs...>::type{};
}
} // namespace ck
5 changes: 5 additions & 0 deletions test/util/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,8 @@ add_gtest_executable(unit_sequence unit_sequence.cpp)
if(result EQUAL 0)
target_link_libraries(unit_sequence PRIVATE utility)
endif()

add_gtest_executable(unit_sequence_helper unit_sequence_helper.cpp)
if(result EQUAL 0)
target_link_libraries(unit_sequence_helper PRIVATE utility)
endif()
92 changes: 92 additions & 0 deletions test/util/unit_sequence_helper.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include <gtest/gtest.h>
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/tuple_helper.hpp"

using namespace ck;

// Tests for sequence_find_value (PR #3600)
TEST(SequenceFindValue, FindExistingElement)
{
constexpr auto result = sequence_find_value<17>(Sequence<5, 11, 17, 23, 29>{});
EXPECT_EQ(result, 2); // 17 is at index 2
}

TEST(SequenceFindValue, FindFirstElement)
{
constexpr auto result = sequence_find_value<7>(Sequence<7, 13, 19, 31>{});
EXPECT_EQ(result, 0);
}

TEST(SequenceFindValue, FindLastElement)
{
constexpr auto result = sequence_find_value<41>(Sequence<3, 11, 23, 41>{});
EXPECT_EQ(result, 3);
}

TEST(SequenceFindValue, ElementNotFound)
{
constexpr auto result = sequence_find_value<50>(Sequence<2, 8, 14, 26>{});
EXPECT_EQ(result, -1);
}

TEST(SequenceFindValue, EmptySequence)
{
constexpr auto result = sequence_find_value<1>(Sequence<>{});
EXPECT_EQ(result, -1);
}

// Tests for find_in_tuple_of_sequences (PR #3600)
TEST(FindInTupleOfSequences, FindInFirstSequence)
{
constexpr auto tuple_of_seqs =
make_tuple(Sequence<5, 11>{}, Sequence<17, 23>{}, Sequence<29, 37>{});
constexpr auto result = find_in_tuple_of_sequences<11>(tuple_of_seqs);
EXPECT_EQ(result.itran, 0); // Found in first sequence (index 0)
EXPECT_EQ(result.idim_up, 1); // At position 1 within that sequence
EXPECT_TRUE(result.found);
}

TEST(FindInTupleOfSequences, FindInMiddleSequence)
{
constexpr auto tuple_of_seqs =
make_tuple(Sequence<2, 4, 6>{}, Sequence<8, 10>{}, Sequence<12>{});
constexpr auto result = find_in_tuple_of_sequences<10>(tuple_of_seqs);
EXPECT_EQ(result.itran, 1); // Found in second sequence (index 1)
EXPECT_EQ(result.idim_up, 1); // At position 1 within that sequence
EXPECT_TRUE(result.found);
}

TEST(FindInTupleOfSequences, FindInLastSequence)
{
constexpr auto tuple_of_seqs = make_tuple(Sequence<3>{}, Sequence<7>{}, Sequence<13, 19, 31>{});
constexpr auto result = find_in_tuple_of_sequences<31>(tuple_of_seqs);
EXPECT_EQ(result.itran, 2); // Found in third sequence (index 2)
EXPECT_EQ(result.idim_up, 2); // At position 2 within that sequence
EXPECT_TRUE(result.found);
}

TEST(FindInTupleOfSequences, NotFound)
{
constexpr auto tuple_of_seqs = make_tuple(Sequence<1, 3>{}, Sequence<5, 7, 9>{});
constexpr auto result = find_in_tuple_of_sequences<100>(tuple_of_seqs);
EXPECT_FALSE(result.found);
}

TEST(FindInTupleOfSequences, EmptyTuple)
{
constexpr auto tuple_of_seqs = make_tuple();
constexpr auto result = find_in_tuple_of_sequences<1>(tuple_of_seqs);
EXPECT_FALSE(result.found);
}

TEST(FindInTupleOfSequences, SingleSequence)
{
constexpr auto tuple_of_seqs = make_tuple(Sequence<41, 43, 47, 53>{});
constexpr auto result = find_in_tuple_of_sequences<47>(tuple_of_seqs);
EXPECT_EQ(result.itran, 0);
EXPECT_EQ(result.idim_up, 2);
EXPECT_TRUE(result.found);
}