device_elementwise_scale.hpp Source File

device_elementwise_scale.hpp Source File#

Composable Kernel: device_elementwise_scale.hpp Source File
device_elementwise_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <memory>
7#include <array>
8
9#include "ck/ck.hpp"
11
12namespace ck {
13namespace tensor_operation {
14namespace device {
15
20template <typename InDataTypeTuple,
21 typename OutDataTypeTuple,
22 typename ElementwiseOperation,
23 typename UnaryOperation,
24 typename Scale,
25 index_t NumDim>
26struct DeviceElementwise : public BaseOperator
27{
28 static constexpr int NumInput = InDataTypeTuple::Size();
29 static constexpr int NumOutput = OutDataTypeTuple::Size();
30
31 virtual std::unique_ptr<BaseArgument>
32 MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
33 const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
34 const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
35 const std::array<const void*, NumInput> in_dev_buffers,
36 const std::array<void*, NumOutput> out_dev_buffers,
37 ElementwiseOperation elementwise_op,
38 UnaryOperation unary_op,
39 Scale scale_op) = 0;
40
41 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
42}; // namespace device
43
44template <typename InDataTypeTuple,
45 typename OutDataTypeTuple,
46 typename ElementwiseOperation,
47 typename UnaryOperation,
48 typename Scale,
49 index_t NumDim>
50using DeviceElementwisePtr = std::unique_ptr<DeviceElementwise<InDataTypeTuple,
51 OutDataTypeTuple,
52 ElementwiseOperation,
53 UnaryOperation,
54 Scale,
55 NumDim>>;
56
57} // namespace device
58} // namespace tensor_operation
59} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
std::unique_ptr< DeviceElementwise< InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim > > DeviceElementwisePtr
Definition device_elementwise.hpp:40
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_base.hpp:223
Definition device_elementwise.hpp:21
static constexpr int NumInput
Definition device_elementwise.hpp:22
static constexpr int NumOutput
Definition device_elementwise.hpp:23
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, NumDim > lengths, const std::array< std::array< index_t, NumDim >, NumInput > inStridesArray, const std::array< std::array< index_t, NumDim >, NumOutput > outStridesArray, const std::array< const void *, NumInput > in_dev_buffers, const std::array< void *, NumOutput > out_dev_buffers, ElementwiseOperation elementwise_op, UnaryOperation unary_op, Scale scale_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0