Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions demo/reference_mugraphs/group_query_attention_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import mirage as mi
import numpy as np
import torch

torch.manual_seed(42)

if __name__ == "__main__":
graph = mi.new_kernel_graph()
Q = graph.new_input(dims=(2, 64, 64), dtype=mi.float16)
K = graph.new_input(dims=(2, 64, 2048), dtype=mi.float16)
V = graph.new_input(dims=(2, 2048, 64), dtype=mi.float16)
tbgraph1 = mi.new_threadblock_graph(grid_dim=(2,8,1),block_dim=(256,1,1), forloop_range=4, reduction_dimx=64, cluster_dim=(1, 8, 1))
bQ = tbgraph1.new_input(dtensor=Q, input_map=(0, -1, -1), forloop_dim=-1)
bK = tbgraph1.new_input(dtensor=K, input_map=(0, 2, -1), forloop_dim=2)
bV = tbgraph1.new_input(dtensor=V, input_map=(0, 1, -1), forloop_dim=1)
bA = tbgraph1.matmul(bQ, bK)
bE = tbgraph1.exp(bA)
bS = tbgraph1.matmul(bE, bV)
bV1 = tbgraph1.forloop_accum(bS)
bEs = tbgraph1.forloop_accum(bE, "sum")
bEss = tbgraph1.cluster_accum(bEs, "sum")
bO = tbgraph1.div(bV1, bEss)
tbgraph1.new_output(stensor=bO, output_map=(0, 1, -1))
O = graph.customized([Q, K, V], tbgraph1)

# torch.Size([2, 256, 1024])
# torch.Size([2, 256, 16])

graph.mark_output(O[0])

input_tensors = [
torch.full((2, 64, 64), 0.1, dtype=torch.float16, device='cuda:0'),
torch.full((2, 64, 2048), 0.1, dtype=torch.float16, device='cuda:0'),
torch.full((2, 2048, 64), 0.1, dtype=torch.float16, device='cuda:0')
]
# input_tensors = [
# torch.randn(2, 256, 64, dtype=torch.float16, device='cuda:0'),
# torch.randn(2, 64, 2048, dtype=torch.float16, device='cuda:0'),
# # torch.randn(2, 2048, 64, dtype=torch.float16, device='cuda:0')
# ]

# input_strides = [tensor.stride() for tensor in input_tensors]
# p = mi.generate_cuda_program(graph.cygraph, target_cc=90, input_strides=input_strides, num_warp_groups = 2, pipeline_stages = 2)
# print(p["code"])
outputs = graph(inputs=input_tensors)

34 changes: 34 additions & 0 deletions include/mirage/threadblock/cluster_accum.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/* Copyright 2023-2024 CMU
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "mirage/threadblock/operator.h"

namespace mirage {
namespace threadblock {

class TBClusterAccumOp : public TBOperator {
public:
TBClusterAccumOp(Graph *_graph,
STensor const &input,
mirage::type::TBOperatorType type);
~TBClusterAccumOp();

operator json() const override;
};

} // namespace threadblock
} // namespace mirage
16 changes: 15 additions & 1 deletion include/mirage/threadblock/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ class Graph {
public:
Graph();
Graph(dim3 grid_dim, dim3 block_dim, int forloop_range, int reduction_dimx);
Graph(dim3 grid_dim,
dim3 cluster_dim,
dim3 block_dim,
int forloop_range,
int reduction_dimx);
~Graph();
Graph(Graph const &) = delete;
Graph &operator=(Graph const &) = delete;
Expand Down Expand Up @@ -134,6 +139,14 @@ class Graph {
TBOperator *create_forloop_accum_op(STensor const &input,
mirage::type::TBOperatorType type);

STensor cluster_accum(STensor const &input,
mirage::type::TBOperatorType type);

STensor *cluster_accum(STensor const *input,
mirage::type::TBOperatorType type);
TBOperator *create_cluster_accum_op(STensor const &input,
mirage::type::TBOperatorType type);

// fingerprint related memory management
off_t allocate_fingerprint(STensor const &tensor);
void free_fingerprint(STensor const &tensor);
Expand All @@ -148,7 +161,8 @@ class Graph {
operator json() const;

public:
dim3 grid_dim, block_dim, cluster_dim;
dim3 grid_dim, block_dim;
dim3 cluster_dim = dim3(1, 1, 1);
int forloop_range;
int reduction_dimx;
std::vector<mirage::threadblock::TBOperator *> operators;
Expand Down
2 changes: 1 addition & 1 deletion include/mirage/transpiler/runtime/kernel/element_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ static __device__ __forceinline__ T perform_element_unary_op(T a) {
} else if constexpr (OP == ElementUnaryOpType::SILU) {
return (T)(((float)a) * (1.0f / (1.0f + expf((float)-a))));
} else if constexpr (OP == ElementUnaryOpType::GELU) {
return (T)((((float)a) / 2.0f)*(1.0f + erff(((float)a) / sqrtf(2.0f))));
return (T)((((float)a) / 2.0f) * (1.0f + erff(((float)a) / sqrtf(2.0f))));
} else if constexpr (OP == ElementUnaryOpType::SQUARE) {
return (T)((float)a * (float)a);
} else if constexpr (OP == ElementUnaryOpType::SQRT) {
Expand Down
37 changes: 25 additions & 12 deletions include/mirage/transpiler/runtime/threadblock/hopper_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ template <typename T,
// data, it does not use the standard
// "epilogue" semantic
bool IS_STORE_ACCUM,
bool IS_COORPERATIVE>
bool IS_COORPERATIVE,
bool IS_PIPELINE_A,
bool IS_PIPELINE_B,
int PIPELINE_STAGES>
class Hopper_Matmul {
public:
CUTE_STATIC_ASSERT_V(rank(SmemLayoutA_{}) == _2{});
Expand Down Expand Up @@ -77,6 +80,9 @@ class Hopper_Matmul {
CUTE_STATIC_ASSERT_V(M{} == get<0>(shape(SmemLayoutC{})));
CUTE_STATIC_ASSERT_V(N{} == get<1>(shape(SmemLayoutC{})));

static constexpr int PIPELINE_STAGE_A = IS_PIPELINE_A ? PIPELINE_STAGES : 1;
static constexpr int PIPELINE_STAGE_B = IS_PIPELINE_B ? PIPELINE_STAGES : 1;

// using TiledMMA = decltype(make_tiled_mma(
// SM90_64x32x16_F16F16F16_SS<GMMA::Major::K, GMMA::Major::MN>{}));

Expand Down Expand Up @@ -158,15 +164,17 @@ class Hopper_Matmul {
int read_stage) {
// cutlass::arch::warpgroup_reg_alloc<192>();
TiledMMA tiled_mma;
auto sA_l = tile_to_shape(
TileALayout{},
make_shape(shape<0>(SmemLayoutA{}), shape<1>(SmemLayoutA{}), Int<2>{}),
Step<_1, _2, _3>{});

auto sB_l = tile_to_shape(
TileBLayout{},
make_shape(shape<0>(SmemLayoutB{}), shape<1>(SmemLayoutB{}), Int<2>{}),
Step<_1, _2, _3>{});
auto sA_l = tile_to_shape(TileALayout{},
make_shape(shape<0>(SmemLayoutA{}),
shape<1>(SmemLayoutA{}),
Int<PIPELINE_STAGE_A>{}),
Step<_1, _2, _3>{});

auto sB_l = tile_to_shape(TileBLayout{},
make_shape(shape<0>(SmemLayoutB{}),
shape<1>(SmemLayoutB{}),
Int<PIPELINE_STAGE_B>{}),
Step<_1, _2, _3>{});

Tensor sA = make_tensor(make_smem_ptr(a_ptr), sA_l); // [M, K]
Tensor sB = make_tensor(make_smem_ptr(b_ptr), sB_l); // [N, K]
Expand All @@ -182,12 +190,17 @@ class Hopper_Matmul {
warpgroup_fence_operand(mma_rC);
cute::warpgroup_arrive();
gemm(tiled_mma,
tCrA(_, _, _, read_stage),
tCrB(_, _, _, read_stage),
tCrA(_, _, _, IS_PIPELINE_A ? read_stage : 0),
tCrB(_, _, _, IS_PIPELINE_B ? read_stage : 0),
mma_rC);
cute::warpgroup_commit_batch();
cute::warpgroup_wait<0>();
warpgroup_fence_operand(mma_rC);

// if(thread0()){
// print_tensor(mma_rC);
// print("-------\n");
// }
}

// no pipe version
Expand Down
186 changes: 186 additions & 0 deletions include/mirage/transpiler/runtime/threadblock/tb_cluster.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
// cluster.h - Implementation of threadblock block level cluster operators
//
#pragma once

#include <cstdio>
#include <cuda/ptx>
#include <cuda/barrier>
#include <cooperative_groups.h>

namespace tb {

using namespace cooperative_groups;

// utils
static __device__ __forceinline__ void cluster_sync() {
// arrive & wait, use align/aquire?
// asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : :);
// asm volatile("barrier.cluster.arrive.aligned;\n" : :);
}

static __device__ __forceinline__ uint32_t _ptr_to_int32(void const* const ptr){
return static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
}


// mbarrier init
static __device__ __forceinline__ void
mbarrier_init(uint64_t *__addr,
const uint32_t &__count) {
asm("mbarrier.init.shared.b64 [%0], %1;"
:
: "r"(_ptr_to_int32(__addr)), "r"(__count)
: "memory");
}

// threads participant in computation
static __device__ __forceinline__ uint64_t
mbarrier_arrive_expect(uint64_t *__addr,
const uint32_t &__tx_count) {
uint64_t __state;
asm("mbarrier.arrive.expect_tx.release.cluster.shared::cta.b64 %0, [%1], %2; // 8. "
: "=l"(__state)
: "r"(_ptr_to_int32(__addr)), "r"(__tx_count)
: "memory");
return __state;
}

// // threads not participant in computation
// static __device__ __forceinline__ void
// mbarrier_arrive(uint64_t *__addr,
// const uint32_t &__tx_count) {
// uint64_t __state;
// // scope cluster
// asm("mbarrier.arrive.release.cluster.shared::cta.b64 %0, [%1]; // 3a. "
// : "=l"(__state)
// : "r"(_ptr_to_int32(__addr))
// : "memory");
// }

// put a value to remote buffer in same cluster
static __device__ __forceinline__ void cluster_put_float(
half_t *__addr, half_t const &__value, uint64_t *__remote_bar)
{
asm("st.async.weak.shared::cluster.mbarrier::complete_tx::bytes.b32 [%0], "
"%1, [%2]; // 1. "
:
: "r"(_ptr_to_int32(__addr)),
"r"(/*as_b32*/ *reinterpret_cast<const int32_t *>(
&__value)),
"r"(_ptr_to_int32(__remote_bar))
: "memory");
}

static __device__ __forceinline__ bool
mbarrier_try_wait(uint64_t *__addr,
const uint64_t &__state) {
uint32_t __waitComplete;
asm("{\n\t .reg .pred P_OUT; \n\t"
"mbarrier.try_wait.acquire.cluster.shared::cta.b64 P_OUT, [%1], %2; // 6a. \n\t"
"selp.b32 %0, 1, 0, P_OUT; \n"
"}"
: "=r"(__waitComplete)
: "r"(_ptr_to_int32(__addr)), "l"(__state)
: "memory");
return static_cast<bool>(__waitComplete);
}

// // Set the destination block-ID in cluster for a given SMEM Address
// static __device__ __forceinline__ uint32_t set_block_rank(uint64_t smemAddr,
// uint32_t rank) {
// uint32_t result;
// asm volatile("mapa.shared::cluster.u64 %0, %1, %2;\n"
// : "=r"(result)
// : "r"(smemAddr), "r"(rank));
// return result;
// }

// reduction
template <typename T, int CLUSTER_SIZE, int NUM_THREADS, int BUF_SIZE>
class ClusterReduction {

public:

static constexpr int THREAD_COUNT = BUF_SIZE / NUM_THREADS;


using barrier_t = cuda::barrier<cuda::thread_scope_block>;

static __device__ __forceinline__ void run(uint64_t* bar_i,
T * __restrict__ dst,
T *const __restrict__ src,
int block_rank,
int thread_idx) {
// using cuda::ptx::sem_release;
// using cuda::ptx::sem_acquire;
// using cuda::ptx::space_cluster;
// using cuda::ptx::space_shared;
// using cuda::ptx::scope_cluster;


static __device__ __forceinline__ void run(uint64_t* bar,
T * __restrict__ dst,
T *const __restrict__ src,
int block_rank,
int thread_idx) {

auto cluster = this_cluster();

// init barrier, all threads participant
mbarrier_init(bar, blockDim.x * blockDim.y * blockDim.z);

// cluster sync
cluster.sync();


uint64_t arrival_token;

if (block_rank != 0) {
unsigned int block0rank = 0;
// uint64_t *remote_bar = set_block_rank(bar, block0rank);
uint64_t *remote_bar = cluster.map_shared_rank(bar, block0rank);
T *remote_receive_buffer = cluster.map_shared_rank(dst, block0rank);
// T *remote_receive_buffer = set_block_rank(dst, block0rank);


arrival_token = mbarrier_arrive_expect(bar, THREAD_COUNT);

// step 1 all blocks write to block0
for (int i = thread_idx; i < BUF_SIZE; i += NUM_THREADS) {
cluster_put_float(remote_receive_buffer + i, src[i], remote_bar);
}


while (!mbarrier_try_wait(bar, arrival_token)) {
}

}



cluster.sync();

if (block_rank == 0) {
for (int b = 1; b < CLUSTER_SIZE; b++) {
unsigned int target_blockrank = b;
// uint64_t *remote_bar = set_block_rank(bar, target_blockrank);
// T *remote_receive_buffer = set_block_rank(dst, target_blockrank);
uint64_t *remote_bar = cluster.map_shared_rank(bar, target_blockrank);
T *remote_receive_buffer = cluster.map_shared_rank(dst, target_blockrank);

arrival_token = mbarrier_arrive_expect(bar, THREAD_COUNT);

// step 2 block0 write to all blocks inside the cluster
for (int i = thread_idx; i < BUF_SIZE; i += NUM_THREADS) {
cluster_put_float(remote_receive_buffer + i, src[i], remote_bar);
}
while (!mbarrier_try_wait(bar, arrival_token)) {
}
}
}

cluster.sync();
}
};

} // namespace tb
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
#include "threadblock/output.h"
#include "threadblock/pipeline.h"
#include "threadblock/reduction.h"
#include "threadblock/tb_cluster.h"
#include "threadblock/utils.h"
Loading