sgemm_tcore: Hardcode threadblock id 0
this is fine since we're statically dispatching only one threadblock to the whole cluster.
This commit is contained in:
@@ -547,7 +547,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
const uint32_t threadblock_dim_y,
|
const uint32_t threadblock_dim_y,
|
||||||
/*const uint32_t threadblock_id_x,
|
/*const uint32_t threadblock_id_x,
|
||||||
const uint32_t threadblock_id_y,*/
|
const uint32_t threadblock_id_y,*/
|
||||||
const uint32_t threadblock_id_in_cluster,
|
// const uint32_t threadblock_id_in_cluster,
|
||||||
float *sharedmem_per_threadblock) {
|
float *sharedmem_per_threadblock) {
|
||||||
const float *A = (const float *)arg->addr_a;
|
const float *A = (const float *)arg->addr_a;
|
||||||
const float *B = (const float *)arg->addr_b;
|
const float *B = (const float *)arg->addr_b;
|
||||||
@@ -602,7 +602,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
global_dmem_load(dim_n, dim_k, 0 /*k*/, A, B, local_a, local_b,
|
global_dmem_load(dim_n, dim_k, 0 /*k*/, A, B, local_a, local_b,
|
||||||
tid_in_warpgroup, block_n, block_m);
|
tid_in_warpgroup, block_n, block_m);
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y);
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: this *should* be signed integer to trigger arithmetic
|
// NOTE: this *should* be signed integer to trigger arithmetic
|
||||||
@@ -633,11 +633,11 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
local_a_produce, local_b_produce, tid_in_warpgroup,
|
local_a_produce, local_b_produce, tid_in_warpgroup,
|
||||||
block_n, block_m);
|
block_n, block_m);
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y);
|
||||||
}
|
}
|
||||||
|
|
||||||
// sync with final consumer stage in the k-loop
|
// sync with final consumer stage in the k-loop
|
||||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -650,7 +650,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
initialize_C(1);
|
initialize_C(1);
|
||||||
|
|
||||||
// sync with initial producer stage in the k-loop
|
// sync with initial producer stage in the k-loop
|
||||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y);
|
||||||
|
|
||||||
// NOTE: this *should* be signed integer to trigger arithmetic
|
// NOTE: this *should* be signed integer to trigger arithmetic
|
||||||
// right-shift
|
// right-shift
|
||||||
@@ -718,7 +718,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y);
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
@@ -819,7 +819,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
const int warp_id = vx_warp_id();
|
const int warp_id = vx_warp_id();
|
||||||
thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock,
|
thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock,
|
||||||
threadblock_dim_x, threadblock_dim_y, /*threadblock_id_x,
|
threadblock_dim_x, threadblock_dim_y, /*threadblock_id_x,
|
||||||
threadblock_id_y,*/ threadblock_id_in_cluster,
|
threadblock_id_y,*/ /*threadblock_id_in_cluster, */
|
||||||
sharedmem_per_threadblock);
|
sharedmem_per_threadblock);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user