Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,25 @@
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/testing/testing.hpp"
#include "ck_tile/builder/testing/testing_reflect.hpp"
#include "ck_tile/builder/testing/filter_extent.hpp"
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/builder/testing/tensor_initialization.hpp"
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "ck_tile/builder/testing/validation.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"

/// This file implements common functionality for invoking/testing grouped
/// forward convolutions created through the CK Builder API. The main item
/// of it is the ConvArgs structure - which contains a complete description
/// of it is the Args structure - which contains a complete description
/// of a convolution operation.
///
/// It is not intended that this file contains implementation details for
/// actually launching a convolution operation. As this can be done
/// through different APIs depending on the kernel (CK, CK Tile, or a
/// reference implementation), the code dealing with that is split out
/// into a separate header for each implementation.
/// into a separate header for each implementation. Nor does this file
/// deal with details for defining the data types (`Inputs` and `Outputs`)
/// for different conv directions, that is also split out into separate
/// headers to keep this one small.

namespace ck_tile::builder::test {

Expand Down Expand Up @@ -56,7 +55,7 @@ struct ConvTensorLengths
///
/// @see Args
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
requires ValidConvSignature<SIGNATURE>
struct Args<SIGNATURE>
{
constexpr static auto SPATIAL_DIM = SIGNATURE.spatial_dim;
Expand Down Expand Up @@ -204,53 +203,4 @@ struct Args<SIGNATURE>
}
};

/// @brief `Inputs` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see Inputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
struct Inputs<SIGNATURE>
{
void* input;
void* weight;

static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
{
inspect("input", args.make_input_descriptor(), &Inputs<SIGNATURE>::input);
inspect("weight", args.make_weight_descriptor(), &Inputs<SIGNATURE>::weight);
}
};

/// @brief `Outputs` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see Outputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
struct Outputs<SIGNATURE>
{
void* output;

static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
{
inspect("output", args.make_output_descriptor(), &Outputs<SIGNATURE>::output);
}
};

/// @brief `init_inputs()` specialization for forward convolution.
///
/// @tparam SIGNATURE Forward convolution signature.
///
/// @see alloc_inputs()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
{
init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f);
init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f);
}

} // namespace ck_tile::builder::test
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#pragma once

#include "ck_tile/builder/testing/tensor_initialization.hpp"
#include "ck_tile/builder/testing/testing_reflect.hpp"
#include "ck_tile/builder/testing/conv/args.hpp"
#include "ck_tile/builder/testing/conv/fwd.hpp"
#include "ck_tile/builder/testing/error.hpp"

/// This file deals with the backward weight-specific details of running grouped
/// convolution backwards weight operations. It mainly defines the data
/// structures (`Input` and `Output`), initialization, and validation. Note
/// that for this operation specifically, many of the operations are
/// implemented automatically via testing_reflect.hpp.

namespace ck_tile::builder::test {

/// @brief `Inputs` specialization for backwards weight convolution.
///
/// @tparam SIGNATURE Backwards weight convolution signature.
///
/// @see Inputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardWeight<SIGNATURE>
struct Inputs<SIGNATURE>
{
void* input;
void* output;

// See testing_reflect.hpp
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
{
inspect("input", args.make_input_descriptor(), &Inputs<SIGNATURE>::input);
inspect("output", args.make_output_descriptor(), &Inputs<SIGNATURE>::output);
}
};

/// @brief `Outputs` specialization for backwards weight convolution.
///
/// @tparam SIGNATURE Backwards weight convolution signature.
///
/// @see Outputs
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardWeight<SIGNATURE>
struct Outputs<SIGNATURE>
{
void* weight;

// See testing_reflect.hpp
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
{
inspect("weight", args.make_weight_descriptor(), &Outputs<SIGNATURE>::weight);
}
};

/// @brief `init_inputs()` specialization for backwards convolution.
///
/// @tparam SIGNATURE Backwards weight convolution signature.
///
/// @see init_inputs()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardWeight<SIGNATURE>
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
{
init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f);
init_tensor_buffer_uniform_fp(inputs.output, args.make_output_descriptor(), -2.0f, 2.0f);
}

} // namespace ck_tile::builder::test
Loading