device_batched_gemm.hpp Source File

device_batched_gemm.hpp Source File#

Composable Kernel: device_batched_gemm.hpp Source File
device_batched_gemm.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <vector>
8
9#include "device_base.hpp"
10
11namespace ck {
12namespace tensor_operation {
13namespace device {
14
15template <typename ALayout,
16 typename BLayout,
17 typename CLayout,
18 typename ADataType,
19 typename BDataType,
20 typename CDataType,
21 typename AElementwiseOperation,
22 typename BElementwiseOperation,
23 typename CElementwiseOperation>
25{
26 virtual std::unique_ptr<BaseArgument>
27 MakeArgumentPointer(const void* p_a,
28 const void* p_b,
29 void* p_c,
33 ck::index_t StrideA,
34 ck::index_t StrideB,
35 ck::index_t StrideC,
36 ck::index_t BatchStrideA,
37 ck::index_t BatchStrideB,
38 ck::index_t BatchStrideC,
39 ck::index_t Batch,
40 AElementwiseOperation a_element_op,
41 BElementwiseOperation b_element_op,
42 CElementwiseOperation c_element_op) = 0;
43
44 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
45};
46
47template <typename ALayout,
48 typename BLayout,
49 typename CLayout,
50 typename ADataType,
51 typename BDataType,
52 typename BScaleType,
53 typename CDataType,
54 index_t ScaleBlockN,
55 index_t ScaleBlockK,
56 typename AElementwiseOperation,
57 typename BElementwiseOperation,
58 typename CElementwiseOperation>
60{
61 virtual std::unique_ptr<BaseArgument>
62 MakeArgumentPointer(const void* p_a,
63 const void* p_b,
64 void* p_c,
68 ck::index_t StrideA,
69 ck::index_t StrideB,
70 ck::index_t StrideC,
71 ck::index_t StrideScaleB,
72 ck::index_t BatchStrideA,
73 ck::index_t BatchStrideB,
74 ck::index_t BatchStrideC,
75 ck::index_t BatchStrideScaleB,
76 const void* p_b_scale,
77 ck::index_t Batch,
78 ck::index_t KBatch,
79 AElementwiseOperation a_element_op,
80 BElementwiseOperation b_element_op,
81 CElementwiseOperation c_element_op) = 0;
82
83 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
84
85 virtual bool GetPermuteB() = 0;
87};
88
89template <typename ALayout,
90 typename BLayout,
91 typename CLayout,
92 typename ADataType,
93 typename BDataType,
94 typename CDataType,
95 typename AElementwiseOperation,
96 typename BElementwiseOperation,
97 typename CElementwiseOperation>
98using DeviceBatchedGemmPtr = std::unique_ptr<DeviceBatchedGemm<ALayout,
99 BLayout,
100 CLayout,
101 ADataType,
102 BDataType,
103 CDataType,
104 AElementwiseOperation,
105 BElementwiseOperation,
106 CElementwiseOperation>>;
107
108} // namespace device
109} // namespace tensor_operation
110} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
std::unique_ptr< DeviceBatchedGemm< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation > > DeviceBatchedGemmPtr
Definition device_batched_gemm.hpp:98
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_batched_gemm.hpp:25
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, ck::index_t BatchStrideA, ck::index_t BatchStrideB, ck::index_t BatchStrideC, ck::index_t Batch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
Definition device_batched_gemm.hpp:60
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, ck::index_t StrideScaleB, ck::index_t BatchStrideA, ck::index_t BatchStrideB, ck::index_t BatchStrideC, ck::index_t BatchStrideScaleB, const void *p_b_scale, ck::index_t Batch, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0