Files
tssc-hpcg/DASP/src/dasp_spmv3.cu
2026-01-18 21:54:29 +08:00

1308 lines
53 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "dasp_spmv3.h"
#define groupNum 1
#define warpNum_short 4
#define loopNum_short 4
#define warpNum_long 4
#define loopNum_long 2
__device__ __forceinline__ MAT_VAL_TYPE warpReduceSum(MAT_VAL_TYPE sum){
sum += __shfl_down_sync(0xffffffff,sum,16);
sum += __shfl_down_sync(0xffffffff,sum,8);
sum += __shfl_down_sync(0xffffffff,sum,4);
sum += __shfl_down_sync(0xffffffff,sum,2);
sum += __shfl_down_sync(0xffffffff,sum,1);
return sum;
}
__device__ __forceinline__ MAT_VAL_TYPE load_double_from_global(const MAT_VAL_TYPE* a)
{
MAT_VAL_TYPE r;
asm volatile("ld.global.cs.f64 %0, [%1];" : "=d"(r) : "l"(a));
return r;
}
__device__ __forceinline__ void store_double_to_global(const MAT_VAL_TYPE* a, MAT_VAL_TYPE v)
{
asm volatile("st.global.cs.f64 [%0], %1;" :: "l"(a), "d"(v));
}
__device__ __forceinline__ int load_int_from_global(const int* a)
{
int r;
asm volatile("ld.global.cs.s32 %0, [%1];" : "=r"(r) : "l"(a));
return r;
}
__global__ void longPart_sum(int *dlongA_rpt, MAT_VAL_TYPE *dwarp_val, MAT_VAL_TYPE *dY_val, int row_long)
{
int bid = blockIdx.x;
int tid = threadIdx.x;
int laneid = 31 & tid;
int global_warpid = bid * warpNum_long + (tid >> 5);
if (global_warpid >= row_long) return;
int offset_longA = load_int_from_global(dlongA_rpt + global_warpid);
MAT_VAL_TYPE *cur_temp_val = dwarp_val + offset_longA;
int len = load_int_from_global(dlongA_rpt + global_warpid + 1) - offset_longA;
MAT_VAL_TYPE thread_val = 0;
for (int i = laneid; i < len; i += WARP_SIZE)
{
thread_val += load_double_from_global(cur_temp_val + i);
}
thread_val = warpReduceSum(thread_val);
if (laneid == 0)
store_double_to_global(dY_val + global_warpid, thread_val);
}
template <int rowloop> // this parameter must be 1 or 2 or 4
__global__ void dasp_spmv3(MAT_VAL_TYPE *dX_val, MAT_VAL_TYPE *dY_val, DASPSparseMatrix *A)
{
int bid = blockIdx.x;
int tid = threadIdx.x;
int laneid = 31 & tid;
if (bid < A->offset_reg)
{
// long part
int global_warpid = bid * warpNum_long + (tid >> 5);
int offset = global_warpid * loopNum_long * MMA_M * MMA_K;
MAT_VAL_TYPE *curA_val = A->dlongA_val + offset;
int *curA_cid = A->dlongA_cid + offset;
int groupID = laneid >> 2;
int tID_in_group = 3 & laneid;
MAT_VAL_TYPE fragA, fragB;
MAT_VAL_TYPE fragC[2];
fragC[0] = 0.0, fragC[1] = 0.0;
int idx = tID_in_group + groupID * MMA_K;
#pragma unroll
for (int i = 0; i < loopNum_long; i++)
{
fragA = load_double_from_global(curA_val + idx);
int x_idx = load_int_from_global(curA_cid + idx);
fragB = dX_val[x_idx];
mma_m8n8k4(fragC, fragA, fragB);
idx += 32;
}
fragC[0] += __shfl_down_sync(0xffffffff, fragC[0], 9, 32);
fragC[0] += __shfl_down_sync(0xffffffff, fragC[0], 18, 32);
fragC[1] += __shfl_down_sync(0xffffffff, fragC[1], 9, 32);
fragC[1] += __shfl_down_sync(0xffffffff, fragC[1], 18, 32);
fragC[0] += __shfl_sync(0xffffffff, fragC[1], 4);
if (laneid == 0)
store_double_to_global(A->dwarp_val + global_warpid, fragC[0]);
}
else if (bid >= A->offset_reg && bid < A->offset_short1)
{
// row-block part
int bid_reg = bid - A->offset_reg;
int warp_local = tid >> 5;
int groupID = laneid >> 2;
int tID_in_group = 3 & laneid;
MAT_VAL_TYPE fragA, fragB, fragC[2];
if (rowloop == 1)
{
int block_idx = bid_reg * 4 + warp_local;
// int offset_A = dblockA_ptr[block_idx];
int offset_A = load_int_from_global(A->dblockA_ptr + block_idx);
int blocklen = (load_int_from_global(A->dblockA_ptr + block_idx + 1) - offset_A) >> 3;
if (block_idx >= A->blocknum) return;
MAT_VAL_TYPE *curA_val = A->dregA_val + offset_A;
int *curA_cid = A->dregA_cid + offset_A;
fragC[0] = 0.0, fragC[1] = 0.0;
int idx = tID_in_group + groupID * MMA_K;
for (int i = 0; i < blocklen; i += MMA_K)
{
fragA = load_double_from_global(curA_val + idx);
int x_idx = load_int_from_global(curA_cid + idx);
fragB = dX_val[x_idx];
mma_m8n8k4(fragC, fragA, fragB);
idx += 32;
}
int offset_y = block_idx * BlockSize + groupID;
if (tID_in_group == (groupID >> 1) && offset_y < A->row_block)
{
store_double_to_global(dY_val + A->row_long + offset_y, fragC[1 & groupID]);
}
int cur_row = block_idx * BlockSize + laneid;
if (laneid < BlockSize && cur_row < A->row_block)
{
MAT_VAL_TYPE cur_y = 0.0;
// for (int i = dirregA_rpt[cur_row]; i < dirregA_rpt[cur_row + 1]; i ++)
for (int i = A->dirregA_rpt[cur_row]; i < A->dirregA_rpt[cur_row + 1]; i ++)
{
cur_y += load_double_from_global(A->dirregA_val + i) * dX_val[A->dirregA_cid[i]];
}
cur_y += load_double_from_global(dY_val + A->row_long + cur_row);
store_double_to_global(dY_val + A->row_long + cur_row, cur_y);
}
}
if (rowloop == 2)
{
MAT_VAL_TYPE res;
#pragma unroll
for (int i = 0; i < 2; i ++)
{
int block_idx = bid_reg * 8 + warp_local * 2 + i;
int offset_A = load_int_from_global(A->dblockA_ptr + block_idx);
int blocklen = (load_int_from_global(A->dblockA_ptr + block_idx + 1) - offset_A) >> 3;
// int offset_A = dblockA_ptr[block_idx];
// int blocklen = (dblockA_ptr[block_idx + 1] - offset_A) >> 3;
MAT_VAL_TYPE *curA_val = A->dregA_val + offset_A;
int *curA_cid = A->dregA_cid + offset_A;
fragC[0] = 0.0, fragC[1] = 0.0;
int idx = tID_in_group + groupID * MMA_K;
for (int j = 0; j < blocklen; j += MMA_K)
{
fragA = load_double_from_global(curA_val + idx);
int x_idx = load_int_from_global(curA_cid + idx);
fragB = dX_val[x_idx];
mma_m8n8k4(fragC, fragA, fragB);
idx += 32;
}
int target_id = ((laneid - i * 8) >> 1) * 9;
fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id);
fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4);
if ((laneid >> 3) == i) res = (1 & laneid) == 0 ? fragC[0] : fragC[1];
}
int cur_row = bid_reg * 8 * BlockSize + warp_local * 2 * BlockSize + laneid;
if (laneid < 16 && cur_row < A->row_block)
{
for (int i = A->dirregA_rpt[cur_row]; i < A->dirregA_rpt[cur_row + 1]; i ++)
{
res += load_double_from_global(A->dirregA_val + i) * dX_val[A->dirregA_cid[i]];
}
store_double_to_global(dY_val + A->row_long + cur_row, res);
}
}
if (rowloop == 4)
{
MAT_VAL_TYPE res;
#pragma unroll
for (int i = 0; i < 4; i ++)
{
int block_idx = bid_reg * 16 + warp_local * 4 + i;
// int offset_A = dblockA_ptr[block_idx];
// int blocklen = (dblockA_ptr[block_idx + 1] - offset_A) >> 3;
int offset_A = load_int_from_global(A->dblockA_ptr + block_idx);
int blocklen = (load_int_from_global(A->dblockA_ptr + block_idx + 1) - offset_A) >> 3;
MAT_VAL_TYPE *curA_val = A->dregA_val + offset_A;
int *curA_cid = A->dregA_cid + offset_A;
fragC[0] = 0.0, fragC[1] = 0.0;
int idx = tID_in_group + groupID * MMA_K;
for (int j = 0; j < blocklen; j += MMA_K)
{
fragA = load_double_from_global(curA_val + idx);
int x_idx = load_int_from_global(curA_cid + idx);
fragB = dX_val[x_idx];
mma_m8n8k4(fragC, fragA, fragB);
idx += 32;
}
int target_id = ((laneid - i * 8) >> 1) * 9;
fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id);
fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4);
if ((laneid >> 3) == i) res = (1 & laneid) == 0 ? fragC[0] : fragC[1];
}
int cur_row = bid_reg * 16 * BlockSize + warp_local * 4 * BlockSize + laneid;
if (cur_row < A->row_block)
{
for (int i = A->dirregA_rpt[cur_row]; i < A->dirregA_rpt[cur_row + 1]; i ++)
{
res += load_double_from_global(A->dirregA_val + i) * dX_val[A->dirregA_cid[i]];
}
store_double_to_global(dY_val + A->row_long + cur_row, res);
}
}
}
else if (bid >= A->offset_short1 && bid < A->offset_short13)
{
// short part - 1 nnz/row
int bid1 = bid - A->offset_short1;
int tid1 = bid1 * blockDim.x + tid;
if (tid1 >= A->short_row_1)
{
return;
}
int x_idx = load_int_from_global(A->dshort_cid + tid1);
MAT_VAL_TYPE temp_y = load_double_from_global(A->dshort_val + tid1) * dX_val[x_idx];
store_double_to_global(dY_val + A->row_long + A->row_block + tid1, temp_y);
}
else if (bid >= A->offset_short13 && bid < A->offset_short34)
{
// short part - block 1&3
int warpid_local = tid >> 5;
int bid13 = bid - A->offset_short13;
MAT_VAL_TYPE fragA = 0.0, fragB = 0.0, fragC[2], res;
int groupID = laneid >> 2;
int tID_in_group = 3 & laneid;
#pragma unroll
for (int i = 0; i < groupNum; i ++)
{
int offset = A->short_row_1 + ((bid13 * groupNum + i) * warpNum_short + warpid_local) * MMA_M * MMA_K * 2;
MAT_VAL_TYPE *cur_val = A->dshort_val + offset;
int *cur_cid = A->dshort_cid + offset;
int idx = tID_in_group + groupID * MMA_K;
fragC[0] = 0.0, fragC[1] = 0.0;
fragA = load_double_from_global(cur_val + idx);
int cid = load_int_from_global(cur_cid + idx);
fragB = tID_in_group == 0 ? dX_val[cid] : 0;
mma_m8n8k4(fragC, fragA, fragB);
int target_id = (laneid >> 1) * 9;
fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id);
fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4);
if (laneid < 8) res = (1 & laneid) == 0 ? fragC[0] : fragC[1];
fragC[0] = 0.0, fragC[1] = 0.0;
fragB = tID_in_group == 0 ? 0 : dX_val[cid];
mma_m8n8k4(fragC, fragA, fragB);
target_id = ((laneid - 8) >> 1) * 9;
fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id);
fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4);
if (laneid >> 3 == 1) res = (1 & laneid) == 0 ? fragC[0] : fragC[1];
idx += 32;
fragC[0] = 0.0, fragC[1] = 0.0;
fragA = load_double_from_global(cur_val + idx);
cid = load_int_from_global(cur_cid + idx);
fragB = tID_in_group == 0 ? dX_val[cid] : 0;
mma_m8n8k4(fragC, fragA, fragB);
target_id = ((laneid - 16) >> 1) * 9;
fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id);
fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4);
if (laneid >> 3 == 2) res = (1 & laneid) == 0 ? fragC[0] : fragC[1];
fragC[0] = 0.0, fragC[1] = 0.0;
fragB = tID_in_group == 0 ? 0 : dX_val[cid];
mma_m8n8k4(fragC, fragA, fragB);
target_id = ((laneid - 24) >> 1) * 9;
fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id);
fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4);
if (laneid >= 24) res = (1 & laneid) == 0 ? fragC[0] : fragC[1];
int offset_y = ((bid13 * groupNum + i) * warpNum_short + warpid_local) * WARP_SIZE + laneid;
if (offset_y < A->common_13 * 2)
store_double_to_global(dY_val + A->row_long + A->row_block + A->short_row_1 + offset_y, res);
}
}
else if (bid >= A->offset_short34 && bid < A->offset_short22)
{
// short part - block3 & block4
int warpid_local = tid >> 5;
int bid34 = bid - A->offset_short34;
MAT_VAL_TYPE fragA = 0.0, fragB = 0.0, fragC[2], res;
int groupID = laneid >> 2;
int tID_in_group = 3 & laneid;
#pragma unroll
for (int j = 0; j < groupNum; j ++)
{
int offset = A->short_row_1 + A->fill0_nnz_short13 + ((bid34 * groupNum + j) * warpNum_short + warpid_local) * MMA_M * MMA_K * loopNum_short;
MAT_VAL_TYPE *cur_val = A->dshort_val + offset;
int *cur_cid = A->dshort_cid + offset;
int idx = tID_in_group + groupID * MMA_K;
fragC[0] = 0.0, fragC[1] = 0.0;
fragA = load_double_from_global(cur_val + idx);
int cid = load_int_from_global(cur_cid + idx);
fragB = dX_val[cid];
mma_m8n8k4(fragC, fragA, fragB);
int target_id = (laneid >> 1) * 9;
fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id);
fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4);
if (laneid < 8) res = (1 & laneid) == 0 ? fragC[0] : fragC[1];
idx += 32;
fragC[0] = 0.0, fragC[1] = 0.0;
fragA = load_double_from_global(cur_val + idx);
cid = load_int_from_global(cur_cid + idx);
fragB = dX_val[cid];
mma_m8n8k4(fragC, fragA, fragB);
target_id = ((laneid - 8) >> 1) * 9;
fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id);
fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4);
if (laneid >> 3 == 1) res = (1 & laneid) == 0 ? fragC[0] : fragC[1];
idx += 32;
fragC[0] = 0.0, fragC[1] = 0.0;
fragA = load_double_from_global(cur_val + idx);
cid = load_int_from_global(cur_cid + idx);
fragB = dX_val[cid];
mma_m8n8k4(fragC, fragA, fragB);
target_id = ((laneid - 16) >> 1) * 9;
fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id);
fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4);
if (laneid >> 3 == 2) res = (1 & laneid) == 0 ? fragC[0] : fragC[1];
idx += 32;
fragC[0] = 0.0, fragC[1] = 0.0;
fragA = load_double_from_global(cur_val + idx);
cid = load_int_from_global(cur_cid + idx);
fragB = dX_val[cid];
mma_m8n8k4(fragC, fragA, fragB);
target_id = ((laneid - 24) >> 1) * 9;
fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id);
fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4);
if (laneid >= 24) res = (1 & laneid) == 0 ? fragC[0] : fragC[1];
int offset_y = ((bid34 * groupNum + j) * warpNum_short + warpid_local) * WARP_SIZE + laneid;
if (offset_y < A->short_row_34)
store_double_to_global(dY_val + A->row_long + A->row_block + A->short_row_1 + A->common_13 * 2 + offset_y, res);
}
}
else
{
// short part - blocl 2&2
int warpid_local = tid >> 5;
int bid22 = bid - A->offset_short22;
MAT_VAL_TYPE fragA = 0.0, fragB = 0.0, fragC[2], res;
int groupID = laneid >> 2;
int tID_in_group = 3 & laneid;
#pragma unroll
for (int i = 0; i < groupNum; i ++)
{
int offset = A->short_row_1 + A->fill0_nnz_short13 + A->fill0_nnz_short34 + ((bid22 * groupNum + i) * warpNum_short + warpid_local) * MMA_M * MMA_K * 2;
MAT_VAL_TYPE *cur_val = A->dshort_val + offset;
int *cur_cid = A->dshort_cid + offset;
int idx = tID_in_group + groupID * MMA_K;
fragC[0] = 0.0, fragC[1] = 0.0;
fragA = load_double_from_global(cur_val + idx);
int cid = load_int_from_global(cur_cid + idx);
fragB = tID_in_group < 2 ? dX_val[cid] : 0;
mma_m8n8k4(fragC, fragA, fragB);
int target_id = (laneid >> 1) * 9;
fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id);
fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4);
if (laneid < 8) res = (1 & laneid) == 0 ? fragC[0] : fragC[1];
fragC[0] = 0.0, fragC[1] = 0.0;
fragB = tID_in_group < 2 ? 0 : dX_val[cid];
mma_m8n8k4(fragC, fragA, fragB);
target_id = ((laneid - 8) >> 1) * 9;
fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id);
fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4);
if (laneid >> 3 == 1) res = (1 & laneid) == 0 ? fragC[0] : fragC[1];
idx += 32;
fragC[0] = 0.0, fragC[1] = 0.0;
fragA = load_double_from_global(cur_val + idx);
cid = load_int_from_global(cur_cid + idx);
fragB = tID_in_group < 2 ? dX_val[cid] : 0;
mma_m8n8k4(fragC, fragA, fragB);
target_id = ((laneid - 16) >> 1) * 9;
fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id);
fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4);
if (laneid >> 3 == 2) res = (1 & laneid) == 0 ? fragC[0] : fragC[1];
fragC[0] = 0.0, fragC[1] = 0.0;
fragB = tID_in_group < 2 ? 0 : dX_val[cid];
mma_m8n8k4(fragC, fragA, fragB);
target_id = ((laneid - 24) >> 1) * 9;
fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id);
fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4);
if (laneid >= 24) res = (1 & laneid) == 0 ? fragC[0] : fragC[1];
int offset_y = ((bid22 * groupNum + i) * warpNum_short + warpid_local) * WARP_SIZE + laneid;
if (offset_y < A->short_row_2)
store_double_to_global(dY_val + A->row_long + A->row_block + A->short_row_1 + A->common_13 * 2 + A->short_row_34 + offset_y, res);
}
}
}
__host__ void spmv_all3(char *filename, MAT_VAL_TYPE *csrValA, MAT_PTR_TYPE *csrRowPtrA, int *csrColIdxA,
MAT_VAL_TYPE *X_val, MAT_VAL_TYPE *Y_val, int *order_rid, int rowA, int colA, MAT_PTR_TYPE nnzA, int NUM, double threshold, int block_longest)
{
struct timeval t1, t2;
// three parts: short row (1 & 3 & 2 & 4), long row, row-block (regularorigin & fill0 & irregular)
MAT_PTR_TYPE nnz_short, nnz_long, origin_nnz_reg, fill0_nnz_reg, nnz_irreg;
int row_long = 0, row_block = 0, row_zero = 0;
// get the short part data
// (short_val, short_cid)
int short_row_1 = 0, short_row_3 = 0, short_row_2 = 0, short_row_4 = 0;
for (int i = 0; i < rowA; i ++)
{
int row_len = csrRowPtrA[i + 1] - csrRowPtrA[i];
if (row_len == 1)
{
short_row_1 ++;
}
else if (row_len == 3)
{
short_row_3 ++;
}
else if (row_len == 2)
{
short_row_2 ++;
}
else if (row_len == 0)
{
row_zero ++;
}
else if (row_len == 4)
{
short_row_4 ++;
}
// else if (row_len >= warpNum_long * loopNum_long * MMA_M * MMA_K)
else if (row_len >= block_longest)
{
row_long ++;
}
else
{
row_block ++;
}
}
int rowloop;
if (row_block < 59990) rowloop = 1;
else if (row_block >= 59990 && row_block < 400000) rowloop = 2;
else rowloop = 4;
int *short_rid_1 = (int *)malloc(sizeof(int) * short_row_1);
int *short_rid_2 = (int *)malloc(sizeof(int) * short_row_2);
int *short_rid_3 = (int *)malloc(sizeof(int) * short_row_3);
int *short_rid_4 = (int *)malloc(sizeof(int) * short_row_4);
int *long_rid = (int *)malloc(sizeof(int) * row_long);
int *zero_rid = (int *)malloc(sizeof(int) * row_zero);
int *ridA = (int *)malloc(sizeof(int) * row_block);
MAT_PTR_TYPE *rptA = (MAT_PTR_TYPE *)malloc(sizeof(MAT_PTR_TYPE) * (row_block + 1));
memset(rptA, 0, sizeof(MAT_PTR_TYPE) * (row_block + 1));
MAT_PTR_TYPE *long_rpt = (MAT_PTR_TYPE *)malloc(sizeof(MAT_PTR_TYPE) * (row_long + 1));
memset(long_rpt, 0, sizeof(MAT_PTR_TYPE) * (row_long + 1));
int short_row_flag1 = 0, short_row_flag3 = 0, short_row_flag2 = 0, short_row_flag4 = 0;
int row_long_flag = 0, flag0 = 0, row_block_flag = 0;
for (int i = 0; i < rowA; i ++)
{
int row_len = csrRowPtrA[i + 1] - csrRowPtrA[i];
if (row_len == 1)
{
short_rid_1[short_row_flag1] = i;
short_row_flag1 ++;
}
else if (row_len == 3)
{
short_rid_3[short_row_flag3] = i;
short_row_flag3 ++;
}
else if (row_len == 2)
{
short_rid_2[short_row_flag2] = i;
short_row_flag2 ++;
}
else if (row_len == 0)
{
zero_rid[flag0] = i;
flag0 ++;
}
else if (row_len == 4)
{
short_rid_4[short_row_flag4] = i;
short_row_flag4 ++;
}
// else if (row_len >= warpNum_long * loopNum_long * MMA_M * MMA_K)
else if (row_len >= block_longest)
{
long_rpt[row_long_flag] = row_len;
long_rid[row_long_flag] = i;
row_long_flag ++;
}
else
{
rptA[row_block_flag] = row_len;
ridA[row_block_flag] = i;
row_block_flag ++;
}
}
nnz_short = short_row_1 + short_row_3 * 3 + short_row_2 * 2 + short_row_4 * 4;
int common_13 = short_row_1 < short_row_3 ? short_row_1 : short_row_3;
if (common_13 / BlockSize >= 16)
{
common_13 = BlockSize * (common_13 / BlockSize);
short_row_1 = short_row_1 - common_13;
short_row_3 = short_row_3 - common_13;
}
else
{
common_13 = 0;
}
int short_block13 = (common_13 + BlockSize - 1) / BlockSize;
int half_short_row_2 = (short_row_2 + 1) / 2;
int short_block22 = (half_short_row_2 + BlockSize - 1) / BlockSize;
int short_row_34 = short_row_3 + short_row_4;
int short_block34 = (short_row_34 + BlockSize - 1) / BlockSize;
int block13_per_threadblock = warpNum_short * groupNum * 2;
int block22_per_threadblock = warpNum_short * groupNum * 2;
int block34_per_threadblock = warpNum_short * groupNum * loopNum_short;
int threadblock13 = (short_block13 + block13_per_threadblock - 1) / block13_per_threadblock;
int threadblock22 = (short_block22 + block22_per_threadblock - 1) / block22_per_threadblock;
int threadblock34 = (short_block34 + block34_per_threadblock - 1) / block34_per_threadblock;
MAT_PTR_TYPE fill0_nnz_short13 = threadblock13 * block13_per_threadblock * MMA_M * MMA_K;
MAT_PTR_TYPE fill0_nnz_short34 = threadblock34 * block34_per_threadblock * MMA_M * MMA_K;
MAT_PTR_TYPE fill0_nnz_short22 = threadblock22 * block22_per_threadblock * MMA_M * MMA_K;
MAT_PTR_TYPE fill0_nnz_short = short_row_1 + fill0_nnz_short13 + fill0_nnz_short34 + fill0_nnz_short22;
MAT_VAL_TYPE *short_val = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * fill0_nnz_short);
int *short_cid = (int *)malloc(sizeof(int) * fill0_nnz_short);
memset(short_val, 0.0, sizeof(MAT_VAL_TYPE) * fill0_nnz_short);
memset(short_cid, 0, sizeof(int) * fill0_nnz_short);
int super_group = 1 + threadblock13 + threadblock34 + threadblock22;
MAT_PTR_TYPE *superX_ptr = (MAT_PTR_TYPE *)malloc(sizeof(int) * (super_group + 1));
for (int i = 0; i < short_row_1; i ++)
{
int cur_row = short_rid_1[i];
short_val[i] = csrValA[csrRowPtrA[cur_row]];
short_cid[i] = csrColIdxA[csrRowPtrA[cur_row]];
}
for (int i = 0; i < short_block13; i ++)
{
MAT_VAL_TYPE *cur_short_val = short_val + short_row_1 + i * MMA_M * MMA_K;
int *cur_short_cid = short_cid + short_row_1 + i * MMA_M * MMA_K;
for (int j = 0; j < BlockSize && i * BlockSize + j < common_13; j ++)
{
int cur_row_1 = short_rid_1[short_row_1 + i * BlockSize + j];
int cur_row_3 = short_rid_3[i * BlockSize + j];
cur_short_val[j * MMA_K] = csrValA[csrRowPtrA[cur_row_1]];
cur_short_cid[j * MMA_K] = csrColIdxA[csrRowPtrA[cur_row_1]];
cur_short_val[j * MMA_K + 1] = csrValA[csrRowPtrA[cur_row_3]];
cur_short_val[j * MMA_K + 2] = csrValA[csrRowPtrA[cur_row_3] + 1];
cur_short_val[j * MMA_K + 3] = csrValA[csrRowPtrA[cur_row_3] + 2];
cur_short_cid[j * MMA_K + 1] = csrColIdxA[csrRowPtrA[cur_row_3]];
cur_short_cid[j * MMA_K + 2] = csrColIdxA[csrRowPtrA[cur_row_3] + 1];
cur_short_cid[j * MMA_K + 3] = csrColIdxA[csrRowPtrA[cur_row_3] + 2];
}
}
for (int i = 0; i < short_row_3; i ++)
{
MAT_VAL_TYPE *cur_short_val = short_val + short_row_1 + fill0_nnz_short13 + i * MMA_K;
int *cur_short_cid = short_cid + short_row_1 + fill0_nnz_short13 + i * MMA_K;
int cur_row = short_rid_3[common_13 + i];
cur_short_val[0] = csrValA[csrRowPtrA[cur_row]];
cur_short_val[1] = csrValA[csrRowPtrA[cur_row] + 1];
cur_short_val[2] = csrValA[csrRowPtrA[cur_row] + 2];
cur_short_cid[0] = csrColIdxA[csrRowPtrA[cur_row]];
cur_short_cid[1] = csrColIdxA[csrRowPtrA[cur_row] + 1];
cur_short_cid[2] = csrColIdxA[csrRowPtrA[cur_row] + 2];
}
for (int i = 0; i < short_row_4; i ++)
{
MAT_VAL_TYPE *cur_short_val = short_val + short_row_1 + fill0_nnz_short13 + (short_row_3 + i) * MMA_K;
int *cur_short_cid = short_cid + short_row_1 + fill0_nnz_short13 + (short_row_3 + i) * MMA_K;
int cur_row = short_rid_4[i];
cur_short_val[0] = csrValA[csrRowPtrA[cur_row]];
cur_short_val[1] = csrValA[csrRowPtrA[cur_row] + 1];
cur_short_val[2] = csrValA[csrRowPtrA[cur_row] + 2];
cur_short_val[3] = csrValA[csrRowPtrA[cur_row] + 3];
cur_short_cid[0] = csrColIdxA[csrRowPtrA[cur_row]];
cur_short_cid[1] = csrColIdxA[csrRowPtrA[cur_row] + 1];
cur_short_cid[2] = csrColIdxA[csrRowPtrA[cur_row] + 2];
cur_short_cid[3] = csrColIdxA[csrRowPtrA[cur_row] + 3];
}
for (int i = 0; i < short_block22; i ++)
{
MAT_VAL_TYPE *cur_short_val = short_val + short_row_1 + fill0_nnz_short13 + fill0_nnz_short34 + i * MMA_M * MMA_K;
int *cur_short_cid = short_cid + short_row_1 + fill0_nnz_short13 + fill0_nnz_short34 + i * MMA_M * MMA_K;
for (int j = 0; j < BlockSize * 2 && (i * BlockSize * 2 + j) < short_row_2; j ++)
{
int cur_row = short_rid_2[i * BlockSize * 2 + j];
cur_short_val[j % BlockSize * MMA_K + (j / BlockSize) * 2] = csrValA[csrRowPtrA[cur_row]];
cur_short_val[j % BlockSize * MMA_K + (j / BlockSize) * 2 + 1] = csrValA[csrRowPtrA[cur_row] + 1];
cur_short_cid[j % BlockSize * MMA_K + (j / BlockSize) * 2] = csrColIdxA[csrRowPtrA[cur_row]];
cur_short_cid[j % BlockSize * MMA_K + (j / BlockSize) * 2 + 1] = csrColIdxA[csrRowPtrA[cur_row] + 1];
}
}
int *short_cid_temp = (int *)malloc(sizeof(int) * fill0_nnz_short);
memcpy(short_cid_temp, short_cid, sizeof(int) * fill0_nnz_short);
quick_sort_key(short_cid_temp, short_row_1);
int nnzr = short_row_1 > 0 ? 1 : 0;
for (int i = 1; i < short_row_1; i ++)
{
nnzr += short_cid_temp[i] != short_cid_temp[i - 1] ? 1 : 0;
}
superX_ptr[0] = nnzr;
MAT_PTR_TYPE *cur_superX_ptr = superX_ptr + 1;
for (int i = 0; i < threadblock13; i++)
{
int *cur_short_cid_temp = short_cid_temp + short_row_1 + i * block13_per_threadblock * MMA_M * MMA_K;
int len = block13_per_threadblock * MMA_M * MMA_K;
quick_sort_key(cur_short_cid_temp, len);
int nnzcid = len > 0 ? 1 : 0;
for (int j = 1; j < len; j ++)
{
nnzcid += cur_short_cid_temp[j] != cur_short_cid_temp[j - 1] ? 1 : 0;
}
cur_superX_ptr[i] = nnzcid;
}
cur_superX_ptr = superX_ptr + 1 + threadblock13;
for (int i = 0; i < threadblock34; i++)
{
int *cur_short_cid_temp = short_cid_temp + short_row_1 + fill0_nnz_short13 + i * block34_per_threadblock * MMA_M * MMA_K;
int len = block34_per_threadblock * MMA_M * MMA_K;
quick_sort_key(cur_short_cid_temp, len);
int nnzcid = len > 0 ? 1 : 0;
for (int j = 1; j < len; j ++)
{
nnzcid += cur_short_cid_temp[j] != cur_short_cid_temp[j - 1] ? 1 : 0;
}
cur_superX_ptr[i] = nnzcid;
}
cur_superX_ptr = superX_ptr + 1 + threadblock13 + threadblock34;
for (int i = 0; i < threadblock22; i++)
{
int *cur_short_cid_temp = short_cid_temp + short_row_1 + fill0_nnz_short13 + fill0_nnz_short34 + i * block22_per_threadblock * MMA_M * MMA_K;
int len = block22_per_threadblock * MMA_M * MMA_K;
quick_sort_key(cur_short_cid_temp, len);
int nnzcid = len > 0 ? 1 : 0;
for (int j = 1; j < len; j ++)
{
nnzcid += cur_short_cid_temp[j] != cur_short_cid_temp[j - 1] ? 1 : 0;
}
cur_superX_ptr[i] = nnzcid;
}
exclusive_scan(superX_ptr, super_group + 1);
MAT_PTR_TYPE nnz_superX = superX_ptr[super_group];
int new_cid_len = short_row_1 + threadblock13 * block13_per_threadblock * MMA_M * MMA_K / 4 +
threadblock34 * block34_per_threadblock * MMA_M * MMA_K / 4 +
threadblock22 * block22_per_threadblock * MMA_M * MMA_K / 4;
int *short_cid_new = (int *)malloc(sizeof(int) * new_cid_len);
int *superX_cid = (int *)malloc(sizeof(int) * nnz_superX);
int flag = 0;
if (short_row_1)
{
superX_cid[0] = short_cid_temp[0];
flag ++;
}
for (int j = 1; j < short_row_1; j ++)
{
if (short_cid_temp[j] != short_cid_temp[j - 1])
{
superX_cid[flag] = short_cid_temp[j];
flag ++;
}
}
if (flag != superX_ptr[1]) printf("flag1 = %d, len = %d\n", flag, superX_ptr[1]);
for (int i = 0; i < short_row_1; i ++)
{
short_cid_new[i] = BinarySearch(superX_cid, superX_ptr[1], short_cid[i]);
}
cur_superX_ptr = superX_ptr + 1;
for (int i = 0; i < threadblock13; i ++)
{
int *cur_short_cid_temp = short_cid_temp + short_row_1 + i * block13_per_threadblock * MMA_M * MMA_K;
int len = block13_per_threadblock * MMA_M * MMA_K;
int *cur_superX_cid = superX_cid + cur_superX_ptr[i];
int xlen = cur_superX_ptr[i + 1] - cur_superX_ptr[i];
int flag_cid = 0;
if (len)
{
cur_superX_cid[0] = cur_short_cid_temp[0];
flag_cid ++;
}
else
{
continue;
}
for (int j = 1; j < len; j ++)
{
if (cur_short_cid_temp[j] != cur_short_cid_temp[j - 1])
{
cur_superX_cid[flag_cid] = cur_short_cid_temp[j];
flag_cid ++;
}
}
if (flag_cid != xlen) printf("flag13 = %d, len = %d\n", flag_cid, xlen);
int *cur_short_cid_new = short_cid_new + short_row_1 + i * (block13_per_threadblock * MMA_M * MMA_K / 4);
int *cur_short_cid = short_cid + short_row_1 + i * block13_per_threadblock * MMA_M * MMA_K;
for (int j = 0; j < len; j ++)
{
// cur_short_cid_new[j] = BinarySearch(cur_superX_cid, xlen, cur_short_cid[j]);
SET_8_BIT(cur_short_cid_new[j / 4], BinarySearch(cur_superX_cid, xlen, cur_short_cid[j]), j % 4);
}
}
cur_superX_ptr = superX_ptr + 1 + threadblock13;
for (int i = 0; i < threadblock34; i ++)
{
int *cur_short_cid_temp = short_cid_temp + short_row_1 + fill0_nnz_short13 + i * block34_per_threadblock * MMA_M * MMA_K;
int len = block34_per_threadblock * MMA_M * MMA_K;
int *cur_superX_cid = superX_cid + cur_superX_ptr[i];
int xlen = cur_superX_ptr[i + 1] - cur_superX_ptr[i];
int flag_cid = 0;
if (len)
{
cur_superX_cid[0] = cur_short_cid_temp[0];
flag_cid ++;
}
else
{
continue;
}
for (int j = 1; j < len; j ++)
{
if (cur_short_cid_temp[j] != cur_short_cid_temp[j - 1])
{
cur_superX_cid[flag_cid] = cur_short_cid_temp[j];
flag_cid ++;
}
}
if (flag_cid != xlen) printf("flag34 = %d, len = %d\n", flag_cid, xlen);
int *cur_short_cid_new = short_cid_new + short_row_1 + fill0_nnz_short13 / 4 + i * (block34_per_threadblock * MMA_M * MMA_K / 4);
int *cur_short_cid = short_cid + short_row_1 + fill0_nnz_short13 + i * block34_per_threadblock * MMA_M * MMA_K;
for (int j = 0; j < len; j ++)
{
// cur_short_cid_new[j] = BinarySearch(cur_superX_cid, xlen, cur_short_cid[j]);
SET_8_BIT(cur_short_cid_new[j / 4], BinarySearch(cur_superX_cid, xlen, cur_short_cid[j]), j % 4);
}
}
cur_superX_ptr = superX_ptr + 1 + threadblock13 + threadblock34;
for (int i = 0; i < threadblock22; i ++)
{
int *cur_short_cid_temp = short_cid_temp + short_row_1 + fill0_nnz_short13 + fill0_nnz_short34 + i * block22_per_threadblock * MMA_M * MMA_K;
int len = block22_per_threadblock * MMA_M * MMA_K;
int *cur_superX_cid = superX_cid + cur_superX_ptr[i];
int xlen = cur_superX_ptr[i + 1] - cur_superX_ptr[i];
int flag_cid = 0;
if (len)
{
cur_superX_cid[0] = cur_short_cid_temp[0];
flag_cid ++;
}
else
{
continue;
}
for (int j = 1; j < len; j ++)
{
if (cur_short_cid_temp[j] != cur_short_cid_temp[j - 1])
{
cur_superX_cid[flag_cid] = cur_short_cid_temp[j];
flag_cid ++;
}
}
if (flag_cid != xlen) printf("flag22 = %d, len = %d\n", flag_cid, xlen);
int *cur_short_cid_new = short_cid_new + short_row_1 + (fill0_nnz_short13 + fill0_nnz_short34) / 4 + i * (block22_per_threadblock * MMA_M * MMA_K / 4);
int *cur_short_cid = short_cid + short_row_1 + fill0_nnz_short13 + fill0_nnz_short34 + i * block22_per_threadblock * MMA_M * MMA_K;
for (int j = 0; j < len; j ++)
{
// cur_short_cid_new[j] = BinarySearch(cur_superX_cid, xlen, cur_short_cid[j]);
SET_8_BIT(cur_short_cid_new[j / 4], BinarySearch(cur_superX_cid, xlen, cur_short_cid[j]), j % 4);
}
}
MAT_VAL_TYPE *superX_val = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * nnz_superX);
for (int i = 0; i < nnz_superX; i ++)
{
superX_val[i] = X_val[superX_cid[i]];
}
radix_sort(rptA, ridA, row_block);
exclusive_scan(rptA, row_block + 1);
exclusive_scan(long_rpt, row_long + 1);
nnz_long = long_rpt[row_long];
memcpy(order_rid, long_rid, sizeof(int) * row_long);
memcpy(order_rid + row_long, ridA, sizeof(int) * row_block);
memcpy(order_rid + row_long + row_block, short_rid_1, sizeof(int) * short_row_1);
for (int i = 0; i < short_block13; i ++)
{
int *cur_order_rid = order_rid + row_long + row_block + short_row_1 + i * BlockSize * 2;
for (int j = 0; j < BlockSize; j ++)
{
cur_order_rid[j] = short_rid_1[short_row_1 + i * BlockSize + j];
cur_order_rid[BlockSize + j] = short_rid_3[i * BlockSize + j];
}
}
memcpy(order_rid + row_long + row_block + short_row_1 + common_13 * 2, short_rid_3 + common_13, sizeof(int) * short_row_3);
memcpy(order_rid + row_long + row_block + short_row_1 + common_13 * 2 + short_row_3, short_rid_4, sizeof(int) * short_row_4);
memcpy(order_rid + row_long + row_block + short_row_1 + common_13 * 2 + short_row_3 + short_row_4, short_rid_2, sizeof(int) * short_row_2);
memcpy(order_rid + row_long + row_block + short_row_1 + common_13 * 2 + short_row_3 + short_row_4 + short_row_2, zero_rid, sizeof(int) * row_zero);
int short_row = short_row_1 + common_13 * 2 + short_row_34 + short_row_2;
int offset_short_row = row_long + row_block;
MAT_VAL_TYPE *short3_val = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * nnz_short);
int *short3_cid = (int *)malloc(sizeof(int) * nnz_short);
MAT_PTR_TYPE *short3_rpt = (MAT_PTR_TYPE *)malloc(sizeof(MAT_PTR_TYPE) * (short_row + 1));
for (int i = 0; i < short_row; i ++)
{
int idx = order_rid[offset_short_row + i];
short3_rpt[i] = csrRowPtrA[idx + 1] - csrRowPtrA[idx];
}
exclusive_scan(short3_rpt, short_row + 1);
for (int i = 0; i < short_row; i ++)
{
int idx = order_rid[offset_short_row + i];
memcpy(short3_val + short3_rpt[i], csrValA + csrRowPtrA[idx], sizeof(MAT_VAL_TYPE) * (csrRowPtrA[idx + 1] - csrRowPtrA[idx]));
memcpy(short3_cid + short3_rpt[i], csrColIdxA + csrRowPtrA[idx], sizeof(int) * (csrRowPtrA[idx + 1] - csrRowPtrA[idx]));
}
MAT_PTR_TYPE *long_rpt_new = (MAT_PTR_TYPE *)malloc(sizeof(MAT_PTR_TYPE) * (row_long + 1));
memset(long_rpt_new, 0, sizeof(MAT_PTR_TYPE) * (row_long + 1));
int warp_number = 0;
for (int i = 0; i < row_long; i ++)
{
int nnz_num = long_rpt[i + 1] - long_rpt[i];
int cur_warp_num = (nnz_num + MMA_M * MMA_K * loopNum_long - 1) / (MMA_M * MMA_K * loopNum_long);
long_rpt_new[i] = cur_warp_num;
}
exclusive_scan(long_rpt_new, row_long + 1);
warp_number = long_rpt_new[row_long];
int BlockNum_long = (warp_number + warpNum_long - 1) / warpNum_long;
int fill0_nnz_long = BlockNum_long * warpNum_long * loopNum_long * MMA_M * MMA_K;
warp_number = BlockNum_long * warpNum_long;
MAT_VAL_TYPE *val_by_warp = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * warp_number);
int *rid_by_warp = (int *)malloc(sizeof(int) * warp_number);
MAT_VAL_TYPE *long_val = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * fill0_nnz_long);
memset(long_val, 0.0, sizeof(MAT_VAL_TYPE) * fill0_nnz_long);
int *long_cid = (int *)malloc(sizeof(int) * fill0_nnz_long);
memset(long_cid, 0, sizeof(int) * fill0_nnz_long);
for (int i = 0; i < row_long; i ++)
{
MAT_VAL_TYPE *cur_val = long_val + long_rpt_new[i] * loopNum_long * MMA_M * MMA_K;
int *cur_cid = long_cid + long_rpt_new[i] * loopNum_long * MMA_M * MMA_K;
int real_rid = long_rid[i];
for (int j = 0; j < long_rpt[i + 1] - long_rpt[i]; j ++)
{
cur_val[j] = csrValA[csrRowPtrA[real_rid] + j];
cur_cid[j] = csrColIdxA[csrRowPtrA[real_rid] + j];
}
for (int j = long_rpt_new[i]; j < long_rpt_new[i + 1]; j ++)
{
rid_by_warp[j] = i;
}
}
int blocknum = (row_block + BlockSize - 1) / BlockSize;
blocknum = ((blocknum + rowloop * 4 - 1) / (rowloop * 4)) * rowloop * 4;
MAT_PTR_TYPE *blockPtr = (MAT_PTR_TYPE *)malloc(sizeof(MAT_PTR_TYPE) * (blocknum + 1));
memset(blockPtr, 0, sizeof(MAT_PTR_TYPE) * (blocknum + 1));
MAT_PTR_TYPE *irreg_rpt = (MAT_PTR_TYPE *)malloc(sizeof(MAT_PTR_TYPE) * (row_block + 1));
memset(irreg_rpt, 0, sizeof(MAT_PTR_TYPE) * (row_block + 1));
#pragma omp parallel for
for (int i = 0; i < blocknum; i++)
{
int row_start = i * BlockSize;
int row_end = (i + 1) * BlockSize >= row_block ? row_block : (i + 1) * BlockSize;
int k = 1;
while(1)
{
int block_nnz = 0;
for (int cur_row = row_start; cur_row < row_end; cur_row++)
{
int row_len = rptA[cur_row + 1] - rptA[cur_row];
if (row_len / MMA_K >= k) block_nnz += MMA_K;
else if(row_len / MMA_K == k - 1) block_nnz += row_len % MMA_K;
}
if (block_nnz >= threshold * MMA_K * MMA_M)
{
blockPtr[i] += MMA_K * MMA_M;
}
else
{
for (int cur_row = row_start; cur_row < row_end; cur_row++ )
{
int row_len = rptA[cur_row + 1] - rptA[cur_row];
irreg_rpt[cur_row] = row_len - (k - 1) * MMA_K > 0 ? row_len - (k - 1) * MMA_K : 0;
}
break;
}
k++;
}
}
exclusive_scan(blockPtr, blocknum + 1);
exclusive_scan(irreg_rpt, row_block + 1);
fill0_nnz_reg = blockPtr[blocknum];
nnz_irreg = irreg_rpt[row_block];
origin_nnz_reg = nnzA - nnz_irreg - nnz_long - nnz_short;
MAT_VAL_TYPE *irreg_val = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * nnz_irreg);
int *irreg_cid = (int *)malloc(sizeof(int) * nnz_irreg);
for (int i = 0; i < row_block; i ++)
{
int cur_rid = ridA[i];
int irreg_offset = irreg_rpt[i];
int irreg_len = irreg_rpt[i + 1] - irreg_offset;
for (int j = 0; j < irreg_len; j ++)
{
irreg_val[irreg_offset + j] = csrValA[csrRowPtrA[cur_rid + 1] - irreg_len + j];
irreg_cid[irreg_offset + j] = csrColIdxA[csrRowPtrA[cur_rid + 1] - irreg_len + j];
}
}
MAT_VAL_TYPE *reg_val = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * fill0_nnz_reg);
int *reg_cid = (int *)malloc(sizeof(int) * fill0_nnz_reg);
for (int bid = 0; bid < blocknum; bid ++)
{
int nnz_block = (blockPtr[bid + 1] - blockPtr[bid]);
int blocklen = nnz_block / BlockSize;
for (int rowid = bid * BlockSize; rowid < (bid + 1) * BlockSize; rowid ++)
{
int regA_start = blockPtr[bid] + blocklen * (rowid - bid * BlockSize);
if (rowid < row_block)
{
int real_id = ridA[rowid];
int A_start = csrRowPtrA[real_id];
int row_len = csrRowPtrA[real_id + 1] - A_start;
for (int i = 0; i < blocklen; i ++)
{
reg_val[regA_start + i] = i < row_len ? csrValA[A_start + i] : 0.0;
reg_cid[regA_start + i] = i < row_len ? csrColIdxA[A_start + i] : 0;
}
}
else
{
for (int i = 0; i < blocklen; i ++)
{
reg_val[regA_start + i] = 0.0;
reg_cid[regA_start + i] = 0;
}
}
}
MAT_VAL_TYPE *temp_val = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * nnz_block);
int *temp_cid = (int *)malloc(sizeof(int) * nnz_block);
MAT_VAL_TYPE *cur_val = reg_val + blockPtr[bid];
int *cur_cid = reg_cid + blockPtr[bid];
for (int i = 0; i < nnz_block; i ++)
{
int new_id = ((i % blocklen) / MMA_K) * BlockSize * MMA_K + (i / blocklen) * MMA_K + i % MMA_K;
temp_val[new_id] = cur_val[i];
temp_cid[new_id] = cur_cid[i];
}
memcpy(cur_val, temp_val, sizeof(MAT_VAL_TYPE) * nnz_block);
memcpy(cur_cid, temp_cid, sizeof(int) * nnz_block);
free(temp_val);
free(temp_cid);
}
long fill0_nnz = fill0_nnz_short + fill0_nnz_long + nnz_irreg + fill0_nnz_reg;
double rate_fill0 = (double)(fill0_nnz - nnzA) / nnzA;
long long int data_X = (rowA + colA) * sizeof(MAT_VAL_TYPE) + \
fill0_nnz_long * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + warp_number * sizeof(MAT_VAL_TYPE) + (row_long + 1) * sizeof(int) + \
fill0_nnz_short * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + \
fill0_nnz_reg * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + (blocknum + 1) * sizeof(MAT_PTR_TYPE) + \
nnz_irreg * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + (row_block + 1) * sizeof(MAT_PTR_TYPE);
long long int data_X2 = (rowA + nnzA) * sizeof(MAT_VAL_TYPE) + \
fill0_nnz_long * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + warp_number * sizeof(MAT_VAL_TYPE) + (row_long + 1) * sizeof(int) + \
fill0_nnz_short * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + \
fill0_nnz_reg * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + (blocknum + 1) * sizeof(MAT_PTR_TYPE) + \
nnz_irreg * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + (row_block + 1) * sizeof(MAT_PTR_TYPE);
int BlockNum = (blocknum + rowloop * 4 - 1) / (rowloop * 4);
int ThreadNum_short = warpNum_short * WARP_SIZE;
int BlockNum_short_1 = (short_row_1 + ThreadNum_short - 1) / ThreadNum_short;
int BlockNum_short = BlockNum_short_1 + threadblock13 + threadblock34 + threadblock22;
int offset_reg = BlockNum_long;
int offset_short1 = offset_reg + BlockNum;
int offset_short13 = offset_short1 + BlockNum_short_1;
int offset_short34 = offset_short13 + threadblock13;
int offset_short22 = offset_short34 + threadblock34;
int BlockNum_all = BlockNum_long + BlockNum + BlockNum_short;
int ThreadNum_all = 4 * WARP_SIZE;
int sumBlockNum = (row_long + 3) / 4;
MAT_VAL_TYPE *dX_val, *dY_val;
// Allocate struct on device
DASPSparseMatrix hA;
hA.row_long = row_long;
hA.row_block = row_block;
hA.blocknum = blocknum;
hA.short_row_1 = short_row_1;
hA.common_13 = common_13;
hA.short_row_34 = short_row_34;
hA.short_row_2 = short_row_2;
hA.fill0_nnz_short13 = fill0_nnz_short13;
hA.fill0_nnz_short34 = fill0_nnz_short34;
hA.offset_reg = offset_reg;
hA.offset_short1 = offset_short1;
hA.offset_short13 = offset_short13;
hA.offset_short34 = offset_short34;
hA.offset_short22 = offset_short22;
cudaMalloc((void **)&dX_val, sizeof(MAT_VAL_TYPE) * colA);
cudaMalloc((void **)&dY_val, sizeof(MAT_VAL_TYPE) * rowA);
cudaMemcpy(dX_val, X_val, sizeof(MAT_VAL_TYPE) * colA, cudaMemcpyHostToDevice);
cudaMemset(dY_val, 0.0, sizeof(MAT_VAL_TYPE) * rowA);
cudaMalloc((void **)&hA.dlongA_val, sizeof(MAT_VAL_TYPE) * fill0_nnz_long);
cudaMalloc((void **)&hA.dlongA_cid, sizeof(int) * fill0_nnz_long);
// hA.dwarp_val is dval_by_warp in original
cudaMalloc((void **)&hA.dwarp_val, sizeof(MAT_VAL_TYPE) * warp_number);
cudaMalloc((void **)&hA.dlongA_rpt, sizeof(MAT_PTR_TYPE) * (row_long + 1));
cudaMemcpy(hA.dlongA_val, long_val, sizeof(MAT_VAL_TYPE) * fill0_nnz_long, cudaMemcpyHostToDevice);
cudaMemcpy(hA.dlongA_cid, long_cid, sizeof(int) * fill0_nnz_long, cudaMemcpyHostToDevice);
// drid_by_warp is not in the struct but allocated in original, seems unused in kernel?
// checking kernel... it is not used in kernel. Only dval_by_warp (dwarp_val) is used.
cudaMemcpy(hA.dlongA_rpt, long_rpt_new, sizeof(MAT_PTR_TYPE) * (row_long + 1), cudaMemcpyHostToDevice);
cudaMalloc((void **)&hA.dshort_val, sizeof(MAT_VAL_TYPE) * fill0_nnz_short);
cudaMalloc((void **)&hA.dshort_cid, sizeof(int) * fill0_nnz_short);
cudaMemcpy(hA.dshort_val, short_val, sizeof(MAT_VAL_TYPE) * fill0_nnz_short, cudaMemcpyHostToDevice);
cudaMemcpy(hA.dshort_cid, short_cid, sizeof(int) * fill0_nnz_short, cudaMemcpyHostToDevice);
cudaMalloc((void **)&hA.dregA_val, sizeof(MAT_VAL_TYPE) * fill0_nnz_reg);
cudaMalloc((void **)&hA.dregA_cid, sizeof(int) * fill0_nnz_reg);
cudaMalloc((void **)&hA.dblockA_ptr, sizeof(MAT_PTR_TYPE) * (blocknum + 1));
cudaMemcpy(hA.dregA_val, reg_val, sizeof(MAT_VAL_TYPE) * fill0_nnz_reg, cudaMemcpyHostToDevice);
cudaMemcpy(hA.dregA_cid, reg_cid, sizeof(int) * fill0_nnz_reg, cudaMemcpyHostToDevice);
cudaMemcpy(hA.dblockA_ptr, blockPtr, sizeof(MAT_PTR_TYPE) * (blocknum + 1), cudaMemcpyHostToDevice);
cudaMalloc((void **)&hA.dirregA_val, sizeof(MAT_VAL_TYPE) * nnz_irreg);
cudaMalloc((void **)&hA.dirregA_rpt, sizeof(MAT_PTR_TYPE) * (row_block + 1));
cudaMalloc((void **)&hA.dirregA_cid, sizeof(int) * nnz_irreg);
cudaMemcpy(hA.dirregA_val, irreg_val, sizeof(MAT_VAL_TYPE) * nnz_irreg, cudaMemcpyHostToDevice);
cudaMemcpy(hA.dirregA_rpt, irreg_rpt, sizeof(MAT_PTR_TYPE) * (row_block + 1), cudaMemcpyHostToDevice);
cudaMemcpy(hA.dirregA_cid, irreg_cid, sizeof(int) * nnz_irreg, cudaMemcpyHostToDevice);
DASPSparseMatrix *dA;
cudaMalloc((void **)&dA, sizeof(DASPSparseMatrix));
cudaMemcpy(dA, &hA, sizeof(DASPSparseMatrix), cudaMemcpyHostToDevice);
int carveout = 0;
cudaFuncSetAttribute(dasp_spmv3<1>, cudaFuncAttributePreferredSharedMemoryCarveout, carveout);
cudaFuncSetAttribute(dasp_spmv3<2>, cudaFuncAttributePreferredSharedMemoryCarveout, carveout);
cudaFuncSetAttribute(dasp_spmv3<4>, cudaFuncAttributePreferredSharedMemoryCarveout, carveout);
int warmup_time = 100;
int execute_time = 1000;
if (rowloop == 1)
{
for (int i = 0; i < warmup_time; ++i)
{
dasp_spmv3<1><<<BlockNum_all, ThreadNum_all>>>(dX_val, dY_val, dA);
}
cudaDeviceSynchronize();
gettimeofday(&t1, NULL);
for (int i = 0; i < execute_time; ++i)
{
dasp_spmv3<1><<<BlockNum_all, ThreadNum_all>>>(dX_val, dY_val, dA);
}
cudaDeviceSynchronize();
if (row_long)
{
for (int i = 0; i < execute_time; ++i)
{
longPart_sum<<<sumBlockNum, ThreadNum_all>>>(hA.dlongA_rpt, hA.dwarp_val, dY_val, row_long);
}
cudaDeviceSynchronize();
}
gettimeofday(&t2, NULL);
}
else if (rowloop == 2)
{
for (int i = 0; i < warmup_time; ++i)
{
dasp_spmv3<2><<<BlockNum_all, ThreadNum_all>>>(dX_val, dY_val, dA);
}
cudaDeviceSynchronize();
gettimeofday(&t1, NULL);
for (int i = 0; i < execute_time; ++i)
{
dasp_spmv3<2><<<BlockNum_all, ThreadNum_all>>>(dX_val, dY_val, dA);
}
cudaDeviceSynchronize();
if (row_long)
{
for (int i = 0; i < execute_time; ++i)
{
longPart_sum<<<sumBlockNum, ThreadNum_all>>>(hA.dlongA_rpt, hA.dwarp_val, dY_val, row_long);
}
cudaDeviceSynchronize();
}
gettimeofday(&t2, NULL);
}
else
{
for (int i = 0; i < warmup_time; ++i)
{
dasp_spmv3<4><<<BlockNum_all, ThreadNum_all>>>(dX_val, dY_val, dA);
}
cudaDeviceSynchronize();
gettimeofday(&t1, NULL);
for (int i = 0; i < execute_time; ++i)
{
dasp_spmv3<4><<<BlockNum_all, ThreadNum_all>>>(dX_val, dY_val, dA);
}
cudaDeviceSynchronize();
if (row_long)
{
for (int i = 0; i < execute_time; ++i)
{
longPart_sum<<<sumBlockNum, ThreadNum_all>>>(hA.dlongA_rpt, hA.dwarp_val, dY_val, row_long);
}
cudaDeviceSynchronize();
}
gettimeofday(&t2, NULL);
}
double dasp_time = ((t2.tv_sec - t1.tv_sec) * 1000.0 + (t2.tv_usec - t1.tv_usec) / 1000.0) / execute_time;
double dasp_gflops = (double)((long)nnzA * 2) / (dasp_time * 1e6);
double dasp_bandwidth1 = (double)data_X / (dasp_time * 1e6);
double dasp_bandwidth2 = (double)data_X2 / (dasp_time * 1e6);
printf("SpMV_3: %8.4lf ms, %8.4lf GFlop/s, %9.4lf GB/s, %9.4lf GB/s\n", dasp_time, dasp_gflops, dasp_bandwidth1, dasp_bandwidth2);
cudaMemcpy(Y_val, dY_val, sizeof(MAT_VAL_TYPE) * rowA, cudaMemcpyDeviceToHost);
cudaFree(dX_val);
cudaFree(dY_val);
cudaFree(hA.dlongA_val);
cudaFree(hA.dlongA_cid);
cudaFree(hA.dwarp_val);
// cudaFree(drid_by_warp); // not in struct
cudaFree(hA.dlongA_rpt);
cudaFree(hA.dshort_cid);
cudaFree(hA.dshort_val);
cudaFree(dshort3_cid);
cudaFree(dshort3_rpt);
cudaFree(dshort3_val);
cudaFree(hA.dregA_val);
cudaFree(hA.dregA_cid);
cudaFree(hA.dblockA_ptr);
cudaFree(hA.dirregA_cid);
cudaFree(hA.dirregA_rpt);
cudaFree(hA.dirregA_val);
cudaFree(dA);
FILE* fout;
fout = fopen("data/spmv_f64_record.csv", "a");
fprintf(fout, "%s,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,", filename, rowA, colA, nnzA, short_row_1, common_13, short_row_3, short_row_4, short_row_2, row_long, row_block, nnz_short, fill0_nnz_short, nnz_long, fill0_nnz_long, origin_nnz_reg, fill0_nnz_reg, nnz_irreg);
fprintf(fout, "%lf,%d,%lld,%lf,%lf,%lf,%lf,", rate_fill0, block_longest, data_X, dasp_time, dasp_gflops, dasp_bandwidth1, dasp_bandwidth2);
fclose(fout);
printf("\n");
free(short_rid_1);
free(short_rid_2);
free(short_rid_3);
free(short_rid_4);
free(long_rid);
free(zero_rid);
free(ridA);
free(superX_ptr);
free(superX_cid);
free(superX_val);
free(short_cid_temp);
free(short_cid_new);
free(rptA);
free(long_rpt);
free(short_val);
free(short_cid);
free(short3_cid);
free(short3_rpt);
free(short3_val);
free(long_cid);
free(long_val);
free(long_rpt_new);
free(val_by_warp);
free(rid_by_warp);
free(reg_val);
free(reg_cid);
free(blockPtr);
free(irreg_rpt);
free(irreg_cid);
free(irreg_val);
}