diff --git a/DASP/Makefile b/DASP/Makefile index 64d2581..67e1f58 100644 --- a/DASP/Makefile +++ b/DASP/Makefile @@ -16,7 +16,7 @@ double: $(CC) $(NVCC_FLAGS) src/main_f64.cu -o spmv_double -D f64 $(OPTIONS) $(LIBS) double3: - $(CC) $(NVCC_FLAGS) src/main_spmv3_f64.cu src/dasp_spmv3.cu -o spmv_double3 -Isrc -D f64 $(OPTIONS) $(LIBS) + $(CC) $(NVCC_FLAGS) src/main_spmv3_f64.cu -o spmv_double3 -Isrc -D f64 $(OPTIONS) $(LIBS) half: $(CC) $(NVCC_FLAGS) src/main_f16.cu -o spmv_half $(OPTIONS) $(LIBS) diff --git a/DASP/src/dasp_spmv3.h b/DASP/src/dasp_spmv3.h index dabbc5b..2378ffc 100644 --- a/DASP/src/dasp_spmv3.h +++ b/DASP/src/dasp_spmv3.h @@ -1,9 +1,12 @@ -#ifndef DASP_SPMV3_H -#define DASP_SPMV3_H - #include "common.h" #include "utils.h" +#define groupNum 1 +#define warpNum_short 4 +#define loopNum_short 4 +#define warpNum_long 4 +#define loopNum_long 2 + struct DASPSparseMatrix { // === 1. Long Part === MAT_VAL_TYPE *dlongA_val; @@ -42,10 +45,1299 @@ struct DASPSparseMatrix { int offset_short22; }; -template -__global__ void dasp_spmv3(MAT_VAL_TYPE *dX_val, MAT_VAL_TYPE *dY_val, DASPSparseMatrix *A); +__device__ __forceinline__ MAT_VAL_TYPE warpReduceSum(MAT_VAL_TYPE sum){ + sum += __shfl_down_sync(0xffffffff,sum,16); + sum += __shfl_down_sync(0xffffffff,sum,8); + sum += __shfl_down_sync(0xffffffff,sum,4); + sum += __shfl_down_sync(0xffffffff,sum,2); + sum += __shfl_down_sync(0xffffffff,sum,1); + return sum; +} + +__device__ __forceinline__ MAT_VAL_TYPE load_double_from_global(const MAT_VAL_TYPE* a) +{ + MAT_VAL_TYPE r; + asm volatile("ld.global.cs.f64 %0, [%1];" : "=d"(r) : "l"(a)); + return r; +} + +__device__ __forceinline__ void store_double_to_global(const MAT_VAL_TYPE* a, MAT_VAL_TYPE v) +{ + asm volatile("st.global.cs.f64 [%0], %1;" :: "l"(a), "d"(v)); +} + +__device__ __forceinline__ int load_int_from_global(const int* a) +{ + int r; + asm volatile("ld.global.cs.s32 %0, [%1];" : "=r"(r) : "l"(a)); + return r; +} + +__global__ void longPart_sum(int *dlongA_rpt, MAT_VAL_TYPE *dwarp_val, MAT_VAL_TYPE *dY_val, int row_long) +{ + int bid = blockIdx.x; + int tid = threadIdx.x; + int laneid = 31 & tid; + int global_warpid = bid * warpNum_long + (tid >> 5); + + if (global_warpid >= row_long) return; + + int offset_longA = load_int_from_global(dlongA_rpt + global_warpid); + MAT_VAL_TYPE *cur_temp_val = dwarp_val + offset_longA; + int len = load_int_from_global(dlongA_rpt + global_warpid + 1) - offset_longA; + + MAT_VAL_TYPE thread_val = 0; + for (int i = laneid; i < len; i += WARP_SIZE) + { + thread_val += load_double_from_global(cur_temp_val + i); + } + thread_val = warpReduceSum(thread_val); + + if (laneid == 0) + store_double_to_global(dY_val + global_warpid, thread_val); +} + +template // this parameter must be 1 or 2 or 4 +__global__ void dasp_spmv3(MAT_VAL_TYPE *dX_val, MAT_VAL_TYPE *dY_val, DASPSparseMatrix *A) +{ + int bid = blockIdx.x; + int tid = threadIdx.x; + int laneid = 31 & tid; + + if (bid < A->offset_reg) + { + // long part + int global_warpid = bid * warpNum_long + (tid >> 5); + int offset = global_warpid * loopNum_long * MMA_M * MMA_K; + MAT_VAL_TYPE *curA_val = A->dlongA_val + offset; + int *curA_cid = A->dlongA_cid + offset; + + int groupID = laneid >> 2; + int tID_in_group = 3 & laneid; + + MAT_VAL_TYPE fragA, fragB; + MAT_VAL_TYPE fragC[2]; + fragC[0] = 0.0, fragC[1] = 0.0; + + int idx = tID_in_group + groupID * MMA_K; + + #pragma unroll + for (int i = 0; i < loopNum_long; i++) + { + fragA = load_double_from_global(curA_val + idx); + int x_idx = load_int_from_global(curA_cid + idx); + fragB = dX_val[x_idx]; + + mma_m8n8k4(fragC, fragA, fragB); + idx += 32; + } + + fragC[0] += __shfl_down_sync(0xffffffff, fragC[0], 9, 32); + fragC[0] += __shfl_down_sync(0xffffffff, fragC[0], 18, 32); + fragC[1] += __shfl_down_sync(0xffffffff, fragC[1], 9, 32); + fragC[1] += __shfl_down_sync(0xffffffff, fragC[1], 18, 32); + fragC[0] += __shfl_sync(0xffffffff, fragC[1], 4); + + if (laneid == 0) + store_double_to_global(A->dwarp_val + global_warpid, fragC[0]); + + } + else if (bid >= A->offset_reg && bid < A->offset_short1) + { + // row-block part + int bid_reg = bid - A->offset_reg; + int warp_local = tid >> 5; + + int groupID = laneid >> 2; + int tID_in_group = 3 & laneid; + MAT_VAL_TYPE fragA, fragB, fragC[2]; + + if (rowloop == 1) + { + int block_idx = bid_reg * 4 + warp_local; + // int offset_A = dblockA_ptr[block_idx]; + int offset_A = load_int_from_global(A->dblockA_ptr + block_idx); + int blocklen = (load_int_from_global(A->dblockA_ptr + block_idx + 1) - offset_A) >> 3; + + if (block_idx >= A->blocknum) return; + + MAT_VAL_TYPE *curA_val = A->dregA_val + offset_A; + int *curA_cid = A->dregA_cid + offset_A; + + fragC[0] = 0.0, fragC[1] = 0.0; + int idx = tID_in_group + groupID * MMA_K; + for (int i = 0; i < blocklen; i += MMA_K) + { + fragA = load_double_from_global(curA_val + idx); + int x_idx = load_int_from_global(curA_cid + idx); + fragB = dX_val[x_idx]; + mma_m8n8k4(fragC, fragA, fragB); + idx += 32; + } + + int offset_y = block_idx * BlockSize + groupID; + if (tID_in_group == (groupID >> 1) && offset_y < A->row_block) + { + store_double_to_global(dY_val + A->row_long + offset_y, fragC[1 & groupID]); + } + + int cur_row = block_idx * BlockSize + laneid; + if (laneid < BlockSize && cur_row < A->row_block) + { + MAT_VAL_TYPE cur_y = 0.0; + // for (int i = dirregA_rpt[cur_row]; i < dirregA_rpt[cur_row + 1]; i ++) + for (int i = A->dirregA_rpt[cur_row]; i < A->dirregA_rpt[cur_row + 1]; i ++) + { + cur_y += load_double_from_global(A->dirregA_val + i) * dX_val[A->dirregA_cid[i]]; + } + cur_y += load_double_from_global(dY_val + A->row_long + cur_row); + store_double_to_global(dY_val + A->row_long + cur_row, cur_y); + } + } + + if (rowloop == 2) + { + MAT_VAL_TYPE res; + #pragma unroll + for (int i = 0; i < 2; i ++) + { + int block_idx = bid_reg * 8 + warp_local * 2 + i; + int offset_A = load_int_from_global(A->dblockA_ptr + block_idx); + int blocklen = (load_int_from_global(A->dblockA_ptr + block_idx + 1) - offset_A) >> 3; + // int offset_A = dblockA_ptr[block_idx]; + // int blocklen = (dblockA_ptr[block_idx + 1] - offset_A) >> 3; + + MAT_VAL_TYPE *curA_val = A->dregA_val + offset_A; + int *curA_cid = A->dregA_cid + offset_A; + + fragC[0] = 0.0, fragC[1] = 0.0; + int idx = tID_in_group + groupID * MMA_K; + for (int j = 0; j < blocklen; j += MMA_K) + { + fragA = load_double_from_global(curA_val + idx); + int x_idx = load_int_from_global(curA_cid + idx); + fragB = dX_val[x_idx]; + mma_m8n8k4(fragC, fragA, fragB); + idx += 32; + } + int target_id = ((laneid - i * 8) >> 1) * 9; + fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id); + fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4); + if ((laneid >> 3) == i) res = (1 & laneid) == 0 ? fragC[0] : fragC[1]; + } + + int cur_row = bid_reg * 8 * BlockSize + warp_local * 2 * BlockSize + laneid; + if (laneid < 16 && cur_row < A->row_block) + { + for (int i = A->dirregA_rpt[cur_row]; i < A->dirregA_rpt[cur_row + 1]; i ++) + { + res += load_double_from_global(A->dirregA_val + i) * dX_val[A->dirregA_cid[i]]; + } + store_double_to_global(dY_val + A->row_long + cur_row, res); + } + } + + if (rowloop == 4) + { + MAT_VAL_TYPE res; + #pragma unroll + for (int i = 0; i < 4; i ++) + { + int block_idx = bid_reg * 16 + warp_local * 4 + i; + // int offset_A = dblockA_ptr[block_idx]; + // int blocklen = (dblockA_ptr[block_idx + 1] - offset_A) >> 3; + int offset_A = load_int_from_global(A->dblockA_ptr + block_idx); + int blocklen = (load_int_from_global(A->dblockA_ptr + block_idx + 1) - offset_A) >> 3; + + MAT_VAL_TYPE *curA_val = A->dregA_val + offset_A; + int *curA_cid = A->dregA_cid + offset_A; + + fragC[0] = 0.0, fragC[1] = 0.0; + int idx = tID_in_group + groupID * MMA_K; + for (int j = 0; j < blocklen; j += MMA_K) + { + fragA = load_double_from_global(curA_val + idx); + int x_idx = load_int_from_global(curA_cid + idx); + fragB = dX_val[x_idx]; + mma_m8n8k4(fragC, fragA, fragB); + idx += 32; + } + int target_id = ((laneid - i * 8) >> 1) * 9; + fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id); + fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4); + if ((laneid >> 3) == i) res = (1 & laneid) == 0 ? fragC[0] : fragC[1]; + } + int cur_row = bid_reg * 16 * BlockSize + warp_local * 4 * BlockSize + laneid; + if (cur_row < A->row_block) + { + for (int i = A->dirregA_rpt[cur_row]; i < A->dirregA_rpt[cur_row + 1]; i ++) + { + res += load_double_from_global(A->dirregA_val + i) * dX_val[A->dirregA_cid[i]]; + } + store_double_to_global(dY_val + A->row_long + cur_row, res); + } + } + } + else if (bid >= A->offset_short1 && bid < A->offset_short13) + { + // short part - 1 nnz/row + int bid1 = bid - A->offset_short1; + int tid1 = bid1 * blockDim.x + tid; + if (tid1 >= A->short_row_1) + { + return; + } + + int x_idx = load_int_from_global(A->dshort_cid + tid1); + MAT_VAL_TYPE temp_y = load_double_from_global(A->dshort_val + tid1) * dX_val[x_idx]; + store_double_to_global(dY_val + A->row_long + A->row_block + tid1, temp_y); + + } + else if (bid >= A->offset_short13 && bid < A->offset_short34) + { + // short part - block 1&3 + int warpid_local = tid >> 5; + int bid13 = bid - A->offset_short13; + + MAT_VAL_TYPE fragA = 0.0, fragB = 0.0, fragC[2], res; + int groupID = laneid >> 2; + int tID_in_group = 3 & laneid; + + #pragma unroll + for (int i = 0; i < groupNum; i ++) + { + int offset = A->short_row_1 + ((bid13 * groupNum + i) * warpNum_short + warpid_local) * MMA_M * MMA_K * 2; + MAT_VAL_TYPE *cur_val = A->dshort_val + offset; + int *cur_cid = A->dshort_cid + offset; + int idx = tID_in_group + groupID * MMA_K; + + fragC[0] = 0.0, fragC[1] = 0.0; + fragA = load_double_from_global(cur_val + idx); + int cid = load_int_from_global(cur_cid + idx); + fragB = tID_in_group == 0 ? dX_val[cid] : 0; + mma_m8n8k4(fragC, fragA, fragB); + int target_id = (laneid >> 1) * 9; + fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id); + fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4); + if (laneid < 8) res = (1 & laneid) == 0 ? fragC[0] : fragC[1]; + + fragC[0] = 0.0, fragC[1] = 0.0; + fragB = tID_in_group == 0 ? 0 : dX_val[cid]; + mma_m8n8k4(fragC, fragA, fragB); + target_id = ((laneid - 8) >> 1) * 9; + fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id); + fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4); + if (laneid >> 3 == 1) res = (1 & laneid) == 0 ? fragC[0] : fragC[1]; + idx += 32; + + fragC[0] = 0.0, fragC[1] = 0.0; + fragA = load_double_from_global(cur_val + idx); + cid = load_int_from_global(cur_cid + idx); + fragB = tID_in_group == 0 ? dX_val[cid] : 0; + mma_m8n8k4(fragC, fragA, fragB); + target_id = ((laneid - 16) >> 1) * 9; + fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id); + fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4); + if (laneid >> 3 == 2) res = (1 & laneid) == 0 ? fragC[0] : fragC[1]; + + fragC[0] = 0.0, fragC[1] = 0.0; + fragB = tID_in_group == 0 ? 0 : dX_val[cid]; + mma_m8n8k4(fragC, fragA, fragB); + target_id = ((laneid - 24) >> 1) * 9; + fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id); + fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4); + if (laneid >= 24) res = (1 & laneid) == 0 ? fragC[0] : fragC[1]; + + int offset_y = ((bid13 * groupNum + i) * warpNum_short + warpid_local) * WARP_SIZE + laneid; + if (offset_y < A->common_13 * 2) + store_double_to_global(dY_val + A->row_long + A->row_block + A->short_row_1 + offset_y, res); + + } + } + else if (bid >= A->offset_short34 && bid < A->offset_short22) + { + // short part - block3 & block4 + int warpid_local = tid >> 5; + int bid34 = bid - A->offset_short34; + + MAT_VAL_TYPE fragA = 0.0, fragB = 0.0, fragC[2], res; + int groupID = laneid >> 2; + int tID_in_group = 3 & laneid; + + #pragma unroll + for (int j = 0; j < groupNum; j ++) + { + int offset = A->short_row_1 + A->fill0_nnz_short13 + ((bid34 * groupNum + j) * warpNum_short + warpid_local) * MMA_M * MMA_K * loopNum_short; + MAT_VAL_TYPE *cur_val = A->dshort_val + offset; + int *cur_cid = A->dshort_cid + offset; + int idx = tID_in_group + groupID * MMA_K; + + fragC[0] = 0.0, fragC[1] = 0.0; + fragA = load_double_from_global(cur_val + idx); + int cid = load_int_from_global(cur_cid + idx); + fragB = dX_val[cid]; + mma_m8n8k4(fragC, fragA, fragB); + int target_id = (laneid >> 1) * 9; + fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id); + fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4); + if (laneid < 8) res = (1 & laneid) == 0 ? fragC[0] : fragC[1]; + idx += 32; + + fragC[0] = 0.0, fragC[1] = 0.0; + fragA = load_double_from_global(cur_val + idx); + cid = load_int_from_global(cur_cid + idx); + fragB = dX_val[cid]; + mma_m8n8k4(fragC, fragA, fragB); + target_id = ((laneid - 8) >> 1) * 9; + fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id); + fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4); + if (laneid >> 3 == 1) res = (1 & laneid) == 0 ? fragC[0] : fragC[1]; + idx += 32; + + fragC[0] = 0.0, fragC[1] = 0.0; + fragA = load_double_from_global(cur_val + idx); + cid = load_int_from_global(cur_cid + idx); + fragB = dX_val[cid]; + mma_m8n8k4(fragC, fragA, fragB); + target_id = ((laneid - 16) >> 1) * 9; + fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id); + fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4); + if (laneid >> 3 == 2) res = (1 & laneid) == 0 ? fragC[0] : fragC[1]; + idx += 32; + + fragC[0] = 0.0, fragC[1] = 0.0; + fragA = load_double_from_global(cur_val + idx); + cid = load_int_from_global(cur_cid + idx); + fragB = dX_val[cid]; + mma_m8n8k4(fragC, fragA, fragB); + target_id = ((laneid - 24) >> 1) * 9; + fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id); + fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4); + if (laneid >= 24) res = (1 & laneid) == 0 ? fragC[0] : fragC[1]; + + int offset_y = ((bid34 * groupNum + j) * warpNum_short + warpid_local) * WARP_SIZE + laneid; + if (offset_y < A->short_row_34) + store_double_to_global(dY_val + A->row_long + A->row_block + A->short_row_1 + A->common_13 * 2 + offset_y, res); + + } + } + else + { + // short part - blocl 2&2 + int warpid_local = tid >> 5; + int bid22 = bid - A->offset_short22; + + MAT_VAL_TYPE fragA = 0.0, fragB = 0.0, fragC[2], res; + int groupID = laneid >> 2; + int tID_in_group = 3 & laneid; + + #pragma unroll + for (int i = 0; i < groupNum; i ++) + { + int offset = A->short_row_1 + A->fill0_nnz_short13 + A->fill0_nnz_short34 + ((bid22 * groupNum + i) * warpNum_short + warpid_local) * MMA_M * MMA_K * 2; + MAT_VAL_TYPE *cur_val = A->dshort_val + offset; + int *cur_cid = A->dshort_cid + offset; + int idx = tID_in_group + groupID * MMA_K; + + fragC[0] = 0.0, fragC[1] = 0.0; + fragA = load_double_from_global(cur_val + idx); + int cid = load_int_from_global(cur_cid + idx); + fragB = tID_in_group < 2 ? dX_val[cid] : 0; + mma_m8n8k4(fragC, fragA, fragB); + int target_id = (laneid >> 1) * 9; + fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id); + fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4); + if (laneid < 8) res = (1 & laneid) == 0 ? fragC[0] : fragC[1]; + + fragC[0] = 0.0, fragC[1] = 0.0; + fragB = tID_in_group < 2 ? 0 : dX_val[cid]; + mma_m8n8k4(fragC, fragA, fragB); + target_id = ((laneid - 8) >> 1) * 9; + fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id); + fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4); + if (laneid >> 3 == 1) res = (1 & laneid) == 0 ? fragC[0] : fragC[1]; + idx += 32; + + fragC[0] = 0.0, fragC[1] = 0.0; + fragA = load_double_from_global(cur_val + idx); + cid = load_int_from_global(cur_cid + idx); + fragB = tID_in_group < 2 ? dX_val[cid] : 0; + mma_m8n8k4(fragC, fragA, fragB); + target_id = ((laneid - 16) >> 1) * 9; + fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id); + fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4); + if (laneid >> 3 == 2) res = (1 & laneid) == 0 ? fragC[0] : fragC[1]; + + fragC[0] = 0.0, fragC[1] = 0.0; + fragB = tID_in_group < 2 ? 0 : dX_val[cid]; + mma_m8n8k4(fragC, fragA, fragB); + target_id = ((laneid - 24) >> 1) * 9; + fragC[0] = __shfl_sync(0xffffffff, fragC[0], target_id); + fragC[1] = __shfl_sync(0xffffffff, fragC[1], target_id + 4); + if (laneid >= 24) res = (1 & laneid) == 0 ? fragC[0] : fragC[1]; + + int offset_y = ((bid22 * groupNum + i) * warpNum_short + warpid_local) * WARP_SIZE + laneid; + if (offset_y < A->short_row_2) + store_double_to_global(dY_val + A->row_long + A->row_block + A->short_row_1 + A->common_13 * 2 + A->short_row_34 + offset_y, res); + } + } +} __host__ void spmv_all3(char *filename, MAT_VAL_TYPE *csrValA, MAT_PTR_TYPE *csrRowPtrA, int *csrColIdxA, - MAT_VAL_TYPE *X_val, MAT_VAL_TYPE *Y_val, int *order_rid, int rowA, int colA, MAT_PTR_TYPE nnzA, int NUM, double threshold, int block_longest); + MAT_VAL_TYPE *X_val, MAT_VAL_TYPE *Y_val, int *order_rid, int rowA, int colA, MAT_PTR_TYPE nnzA, int NUM, double threshold, int block_longest) +{ + struct timeval t1, t2; -#endif \ No newline at end of file + // three parts: short row (1 & 3 & 2 & 4), long row, row-block (regular(origin & fill0) & irregular) + MAT_PTR_TYPE nnz_short, nnz_long, origin_nnz_reg, fill0_nnz_reg, nnz_irreg; + int row_long = 0, row_block = 0, row_zero = 0; + + // get the short part data + // (short_val, short_cid) + int short_row_1 = 0, short_row_3 = 0, short_row_2 = 0, short_row_4 = 0; + + for (int i = 0; i < rowA; i ++) + { + int row_len = csrRowPtrA[i + 1] - csrRowPtrA[i]; + if (row_len == 1) + { + short_row_1 ++; + } + else if (row_len == 3) + { + short_row_3 ++; + } + else if (row_len == 2) + { + short_row_2 ++; + } + else if (row_len == 0) + { + row_zero ++; + } + else if (row_len == 4) + { + short_row_4 ++; + } + // else if (row_len >= warpNum_long * loopNum_long * MMA_M * MMA_K) + else if (row_len >= block_longest) + { + row_long ++; + } + else + { + row_block ++; + } + } + + int rowloop; + if (row_block < 59990) rowloop = 1; + else if (row_block >= 59990 && row_block < 400000) rowloop = 2; + else rowloop = 4; + + int *short_rid_1 = (int *)malloc(sizeof(int) * short_row_1); + int *short_rid_2 = (int *)malloc(sizeof(int) * short_row_2); + int *short_rid_3 = (int *)malloc(sizeof(int) * short_row_3); + int *short_rid_4 = (int *)malloc(sizeof(int) * short_row_4); + int *long_rid = (int *)malloc(sizeof(int) * row_long); + int *zero_rid = (int *)malloc(sizeof(int) * row_zero); + int *ridA = (int *)malloc(sizeof(int) * row_block); + + MAT_PTR_TYPE *rptA = (MAT_PTR_TYPE *)malloc(sizeof(MAT_PTR_TYPE) * (row_block + 1)); + memset(rptA, 0, sizeof(MAT_PTR_TYPE) * (row_block + 1)); + MAT_PTR_TYPE *long_rpt = (MAT_PTR_TYPE *)malloc(sizeof(MAT_PTR_TYPE) * (row_long + 1)); + memset(long_rpt, 0, sizeof(MAT_PTR_TYPE) * (row_long + 1)); + + int short_row_flag1 = 0, short_row_flag3 = 0, short_row_flag2 = 0, short_row_flag4 = 0; + int row_long_flag = 0, flag0 = 0, row_block_flag = 0; + for (int i = 0; i < rowA; i ++) + { + int row_len = csrRowPtrA[i + 1] - csrRowPtrA[i]; + if (row_len == 1) + { + short_rid_1[short_row_flag1] = i; + short_row_flag1 ++; + } + else if (row_len == 3) + { + short_rid_3[short_row_flag3] = i; + short_row_flag3 ++; + } + else if (row_len == 2) + { + short_rid_2[short_row_flag2] = i; + short_row_flag2 ++; + } + else if (row_len == 0) + { + zero_rid[flag0] = i; + flag0 ++; + } + else if (row_len == 4) + { + short_rid_4[short_row_flag4] = i; + short_row_flag4 ++; + } + // else if (row_len >= warpNum_long * loopNum_long * MMA_M * MMA_K) + else if (row_len >= block_longest) + { + long_rpt[row_long_flag] = row_len; + long_rid[row_long_flag] = i; + row_long_flag ++; + } + else + { + rptA[row_block_flag] = row_len; + ridA[row_block_flag] = i; + row_block_flag ++; + } + } + nnz_short = short_row_1 + short_row_3 * 3 + short_row_2 * 2 + short_row_4 * 4; + + int common_13 = short_row_1 < short_row_3 ? short_row_1 : short_row_3; + if (common_13 / BlockSize >= 16) + { + common_13 = BlockSize * (common_13 / BlockSize); + short_row_1 = short_row_1 - common_13; + short_row_3 = short_row_3 - common_13; + } + else + { + common_13 = 0; + } + + int short_block13 = (common_13 + BlockSize - 1) / BlockSize; + int half_short_row_2 = (short_row_2 + 1) / 2; + int short_block22 = (half_short_row_2 + BlockSize - 1) / BlockSize; + int short_row_34 = short_row_3 + short_row_4; + int short_block34 = (short_row_34 + BlockSize - 1) / BlockSize; + + int block13_per_threadblock = warpNum_short * groupNum * 2; + int block22_per_threadblock = warpNum_short * groupNum * 2; + int block34_per_threadblock = warpNum_short * groupNum * loopNum_short; + + int threadblock13 = (short_block13 + block13_per_threadblock - 1) / block13_per_threadblock; + int threadblock22 = (short_block22 + block22_per_threadblock - 1) / block22_per_threadblock; + int threadblock34 = (short_block34 + block34_per_threadblock - 1) / block34_per_threadblock; + + MAT_PTR_TYPE fill0_nnz_short13 = threadblock13 * block13_per_threadblock * MMA_M * MMA_K; + MAT_PTR_TYPE fill0_nnz_short34 = threadblock34 * block34_per_threadblock * MMA_M * MMA_K; + MAT_PTR_TYPE fill0_nnz_short22 = threadblock22 * block22_per_threadblock * MMA_M * MMA_K; + MAT_PTR_TYPE fill0_nnz_short = short_row_1 + fill0_nnz_short13 + fill0_nnz_short34 + fill0_nnz_short22; + MAT_VAL_TYPE *short_val = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * fill0_nnz_short); + int *short_cid = (int *)malloc(sizeof(int) * fill0_nnz_short); + memset(short_val, 0.0, sizeof(MAT_VAL_TYPE) * fill0_nnz_short); + memset(short_cid, 0, sizeof(int) * fill0_nnz_short); + + int super_group = 1 + threadblock13 + threadblock34 + threadblock22; + MAT_PTR_TYPE *superX_ptr = (MAT_PTR_TYPE *)malloc(sizeof(int) * (super_group + 1)); + + for (int i = 0; i < short_row_1; i ++) + { + int cur_row = short_rid_1[i]; + short_val[i] = csrValA[csrRowPtrA[cur_row]]; + short_cid[i] = csrColIdxA[csrRowPtrA[cur_row]]; + } + + for (int i = 0; i < short_block13; i ++) + { + MAT_VAL_TYPE *cur_short_val = short_val + short_row_1 + i * MMA_M * MMA_K; + int *cur_short_cid = short_cid + short_row_1 + i * MMA_M * MMA_K; + + for (int j = 0; j < BlockSize && i * BlockSize + j < common_13; j ++) + { + int cur_row_1 = short_rid_1[short_row_1 + i * BlockSize + j]; + int cur_row_3 = short_rid_3[i * BlockSize + j]; + cur_short_val[j * MMA_K] = csrValA[csrRowPtrA[cur_row_1]]; + cur_short_cid[j * MMA_K] = csrColIdxA[csrRowPtrA[cur_row_1]]; + cur_short_val[j * MMA_K + 1] = csrValA[csrRowPtrA[cur_row_3]]; + cur_short_val[j * MMA_K + 2] = csrValA[csrRowPtrA[cur_row_3] + 1]; + cur_short_val[j * MMA_K + 3] = csrValA[csrRowPtrA[cur_row_3] + 2]; + cur_short_cid[j * MMA_K + 1] = csrColIdxA[csrRowPtrA[cur_row_3]]; + cur_short_cid[j * MMA_K + 2] = csrColIdxA[csrRowPtrA[cur_row_3] + 1]; + cur_short_cid[j * MMA_K + 3] = csrColIdxA[csrRowPtrA[cur_row_3] + 2]; + } + } + + for (int i = 0; i < short_row_3; i ++) + { + MAT_VAL_TYPE *cur_short_val = short_val + short_row_1 + fill0_nnz_short13 + i * MMA_K; + int *cur_short_cid = short_cid + short_row_1 + fill0_nnz_short13 + i * MMA_K; + + int cur_row = short_rid_3[common_13 + i]; + + cur_short_val[0] = csrValA[csrRowPtrA[cur_row]]; + cur_short_val[1] = csrValA[csrRowPtrA[cur_row] + 1]; + cur_short_val[2] = csrValA[csrRowPtrA[cur_row] + 2]; + cur_short_cid[0] = csrColIdxA[csrRowPtrA[cur_row]]; + cur_short_cid[1] = csrColIdxA[csrRowPtrA[cur_row] + 1]; + cur_short_cid[2] = csrColIdxA[csrRowPtrA[cur_row] + 2]; + } + + for (int i = 0; i < short_row_4; i ++) + { + MAT_VAL_TYPE *cur_short_val = short_val + short_row_1 + fill0_nnz_short13 + (short_row_3 + i) * MMA_K; + int *cur_short_cid = short_cid + short_row_1 + fill0_nnz_short13 + (short_row_3 + i) * MMA_K; + + int cur_row = short_rid_4[i]; + + cur_short_val[0] = csrValA[csrRowPtrA[cur_row]]; + cur_short_val[1] = csrValA[csrRowPtrA[cur_row] + 1]; + cur_short_val[2] = csrValA[csrRowPtrA[cur_row] + 2]; + cur_short_val[3] = csrValA[csrRowPtrA[cur_row] + 3]; + cur_short_cid[0] = csrColIdxA[csrRowPtrA[cur_row]]; + cur_short_cid[1] = csrColIdxA[csrRowPtrA[cur_row] + 1]; + cur_short_cid[2] = csrColIdxA[csrRowPtrA[cur_row] + 2]; + cur_short_cid[3] = csrColIdxA[csrRowPtrA[cur_row] + 3]; + } + + for (int i = 0; i < short_block22; i ++) + { + MAT_VAL_TYPE *cur_short_val = short_val + short_row_1 + fill0_nnz_short13 + fill0_nnz_short34 + i * MMA_M * MMA_K; + int *cur_short_cid = short_cid + short_row_1 + fill0_nnz_short13 + fill0_nnz_short34 + i * MMA_M * MMA_K; + + for (int j = 0; j < BlockSize * 2 && (i * BlockSize * 2 + j) < short_row_2; j ++) + { + int cur_row = short_rid_2[i * BlockSize * 2 + j]; + cur_short_val[j % BlockSize * MMA_K + (j / BlockSize) * 2] = csrValA[csrRowPtrA[cur_row]]; + cur_short_val[j % BlockSize * MMA_K + (j / BlockSize) * 2 + 1] = csrValA[csrRowPtrA[cur_row] + 1]; + cur_short_cid[j % BlockSize * MMA_K + (j / BlockSize) * 2] = csrColIdxA[csrRowPtrA[cur_row]]; + cur_short_cid[j % BlockSize * MMA_K + (j / BlockSize) * 2 + 1] = csrColIdxA[csrRowPtrA[cur_row] + 1]; + } + } + + int *short_cid_temp = (int *)malloc(sizeof(int) * fill0_nnz_short); + memcpy(short_cid_temp, short_cid, sizeof(int) * fill0_nnz_short); + + quick_sort_key(short_cid_temp, short_row_1); + int nnzr = short_row_1 > 0 ? 1 : 0; + for (int i = 1; i < short_row_1; i ++) + { + nnzr += short_cid_temp[i] != short_cid_temp[i - 1] ? 1 : 0; + } + superX_ptr[0] = nnzr; + + MAT_PTR_TYPE *cur_superX_ptr = superX_ptr + 1; + for (int i = 0; i < threadblock13; i++) + { + int *cur_short_cid_temp = short_cid_temp + short_row_1 + i * block13_per_threadblock * MMA_M * MMA_K; + int len = block13_per_threadblock * MMA_M * MMA_K; + quick_sort_key(cur_short_cid_temp, len); + int nnzcid = len > 0 ? 1 : 0; + for (int j = 1; j < len; j ++) + { + nnzcid += cur_short_cid_temp[j] != cur_short_cid_temp[j - 1] ? 1 : 0; + } + cur_superX_ptr[i] = nnzcid; + } + + cur_superX_ptr = superX_ptr + 1 + threadblock13; + for (int i = 0; i < threadblock34; i++) + { + int *cur_short_cid_temp = short_cid_temp + short_row_1 + fill0_nnz_short13 + i * block34_per_threadblock * MMA_M * MMA_K; + int len = block34_per_threadblock * MMA_M * MMA_K; + quick_sort_key(cur_short_cid_temp, len); + int nnzcid = len > 0 ? 1 : 0; + for (int j = 1; j < len; j ++) + { + nnzcid += cur_short_cid_temp[j] != cur_short_cid_temp[j - 1] ? 1 : 0; + } + cur_superX_ptr[i] = nnzcid; + } + + cur_superX_ptr = superX_ptr + 1 + threadblock13 + threadblock34; + for (int i = 0; i < threadblock22; i++) + { + int *cur_short_cid_temp = short_cid_temp + short_row_1 + fill0_nnz_short13 + fill0_nnz_short34 + i * block22_per_threadblock * MMA_M * MMA_K; + int len = block22_per_threadblock * MMA_M * MMA_K; + quick_sort_key(cur_short_cid_temp, len); + int nnzcid = len > 0 ? 1 : 0; + for (int j = 1; j < len; j ++) + { + nnzcid += cur_short_cid_temp[j] != cur_short_cid_temp[j - 1] ? 1 : 0; + } + cur_superX_ptr[i] = nnzcid; + } + exclusive_scan(superX_ptr, super_group + 1); + MAT_PTR_TYPE nnz_superX = superX_ptr[super_group]; + + int new_cid_len = short_row_1 + threadblock13 * block13_per_threadblock * MMA_M * MMA_K / 4 + + threadblock34 * block34_per_threadblock * MMA_M * MMA_K / 4 + + threadblock22 * block22_per_threadblock * MMA_M * MMA_K / 4; + + int *short_cid_new = (int *)malloc(sizeof(int) * new_cid_len); + + int *superX_cid = (int *)malloc(sizeof(int) * nnz_superX); + int flag = 0; + if (short_row_1) + { + superX_cid[0] = short_cid_temp[0]; + flag ++; + } + for (int j = 1; j < short_row_1; j ++) + { + if (short_cid_temp[j] != short_cid_temp[j - 1]) + { + superX_cid[flag] = short_cid_temp[j]; + flag ++; + } + } + if (flag != superX_ptr[1]) printf("flag1 = %d, len = %d\n", flag, superX_ptr[1]); + for (int i = 0; i < short_row_1; i ++) + { + short_cid_new[i] = BinarySearch(superX_cid, superX_ptr[1], short_cid[i]); + } + + cur_superX_ptr = superX_ptr + 1; + for (int i = 0; i < threadblock13; i ++) + { + int *cur_short_cid_temp = short_cid_temp + short_row_1 + i * block13_per_threadblock * MMA_M * MMA_K; + int len = block13_per_threadblock * MMA_M * MMA_K; + int *cur_superX_cid = superX_cid + cur_superX_ptr[i]; + int xlen = cur_superX_ptr[i + 1] - cur_superX_ptr[i]; + int flag_cid = 0; + if (len) + { + cur_superX_cid[0] = cur_short_cid_temp[0]; + flag_cid ++; + } + else + { + continue; + } + for (int j = 1; j < len; j ++) + { + if (cur_short_cid_temp[j] != cur_short_cid_temp[j - 1]) + { + cur_superX_cid[flag_cid] = cur_short_cid_temp[j]; + flag_cid ++; + } + } + if (flag_cid != xlen) printf("flag13 = %d, len = %d\n", flag_cid, xlen); + + int *cur_short_cid_new = short_cid_new + short_row_1 + i * (block13_per_threadblock * MMA_M * MMA_K / 4); + int *cur_short_cid = short_cid + short_row_1 + i * block13_per_threadblock * MMA_M * MMA_K; + for (int j = 0; j < len; j ++) + { + // cur_short_cid_new[j] = BinarySearch(cur_superX_cid, xlen, cur_short_cid[j]); + SET_8_BIT(cur_short_cid_new[j / 4], BinarySearch(cur_superX_cid, xlen, cur_short_cid[j]), j % 4); + } + } + + cur_superX_ptr = superX_ptr + 1 + threadblock13; + for (int i = 0; i < threadblock34; i ++) + { + int *cur_short_cid_temp = short_cid_temp + short_row_1 + fill0_nnz_short13 + i * block34_per_threadblock * MMA_M * MMA_K; + int len = block34_per_threadblock * MMA_M * MMA_K; + int *cur_superX_cid = superX_cid + cur_superX_ptr[i]; + int xlen = cur_superX_ptr[i + 1] - cur_superX_ptr[i]; + int flag_cid = 0; + if (len) + { + cur_superX_cid[0] = cur_short_cid_temp[0]; + flag_cid ++; + } + else + { + continue; + } + for (int j = 1; j < len; j ++) + { + if (cur_short_cid_temp[j] != cur_short_cid_temp[j - 1]) + { + cur_superX_cid[flag_cid] = cur_short_cid_temp[j]; + flag_cid ++; + } + } + if (flag_cid != xlen) printf("flag34 = %d, len = %d\n", flag_cid, xlen); + + int *cur_short_cid_new = short_cid_new + short_row_1 + fill0_nnz_short13 / 4 + i * (block34_per_threadblock * MMA_M * MMA_K / 4); + int *cur_short_cid = short_cid + short_row_1 + fill0_nnz_short13 + i * block34_per_threadblock * MMA_M * MMA_K; + for (int j = 0; j < len; j ++) + { + // cur_short_cid_new[j] = BinarySearch(cur_superX_cid, xlen, cur_short_cid[j]); + SET_8_BIT(cur_short_cid_new[j / 4], BinarySearch(cur_superX_cid, xlen, cur_short_cid[j]), j % 4); + } + } + + cur_superX_ptr = superX_ptr + 1 + threadblock13 + threadblock34; + for (int i = 0; i < threadblock22; i ++) + { + int *cur_short_cid_temp = short_cid_temp + short_row_1 + fill0_nnz_short13 + fill0_nnz_short34 + i * block22_per_threadblock * MMA_M * MMA_K; + int len = block22_per_threadblock * MMA_M * MMA_K; + int *cur_superX_cid = superX_cid + cur_superX_ptr[i]; + int xlen = cur_superX_ptr[i + 1] - cur_superX_ptr[i]; + int flag_cid = 0; + if (len) + { + cur_superX_cid[0] = cur_short_cid_temp[0]; + flag_cid ++; + } + else + { + continue; + } + for (int j = 1; j < len; j ++) + { + if (cur_short_cid_temp[j] != cur_short_cid_temp[j - 1]) + { + cur_superX_cid[flag_cid] = cur_short_cid_temp[j]; + flag_cid ++; + } + } + if (flag_cid != xlen) printf("flag22 = %d, len = %d\n", flag_cid, xlen); + + int *cur_short_cid_new = short_cid_new + short_row_1 + (fill0_nnz_short13 + fill0_nnz_short34) / 4 + i * (block22_per_threadblock * MMA_M * MMA_K / 4); + int *cur_short_cid = short_cid + short_row_1 + fill0_nnz_short13 + fill0_nnz_short34 + i * block22_per_threadblock * MMA_M * MMA_K; + for (int j = 0; j < len; j ++) + { + // cur_short_cid_new[j] = BinarySearch(cur_superX_cid, xlen, cur_short_cid[j]); + SET_8_BIT(cur_short_cid_new[j / 4], BinarySearch(cur_superX_cid, xlen, cur_short_cid[j]), j % 4); + } + } + + MAT_VAL_TYPE *superX_val = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * nnz_superX); + for (int i = 0; i < nnz_superX; i ++) + { + superX_val[i] = X_val[superX_cid[i]]; + } + + radix_sort(rptA, ridA, row_block); + + exclusive_scan(rptA, row_block + 1); + exclusive_scan(long_rpt, row_long + 1); + nnz_long = long_rpt[row_long]; + + memcpy(order_rid, long_rid, sizeof(int) * row_long); + memcpy(order_rid + row_long, ridA, sizeof(int) * row_block); + memcpy(order_rid + row_long + row_block, short_rid_1, sizeof(int) * short_row_1); + for (int i = 0; i < short_block13; i ++) + { + int *cur_order_rid = order_rid + row_long + row_block + short_row_1 + i * BlockSize * 2; + + for (int j = 0; j < BlockSize; j ++) + { + cur_order_rid[j] = short_rid_1[short_row_1 + i * BlockSize + j]; + cur_order_rid[BlockSize + j] = short_rid_3[i * BlockSize + j]; + } + } + memcpy(order_rid + row_long + row_block + short_row_1 + common_13 * 2, short_rid_3 + common_13, sizeof(int) * short_row_3); + memcpy(order_rid + row_long + row_block + short_row_1 + common_13 * 2 + short_row_3, short_rid_4, sizeof(int) * short_row_4); + memcpy(order_rid + row_long + row_block + short_row_1 + common_13 * 2 + short_row_3 + short_row_4, short_rid_2, sizeof(int) * short_row_2); + memcpy(order_rid + row_long + row_block + short_row_1 + common_13 * 2 + short_row_3 + short_row_4 + short_row_2, zero_rid, sizeof(int) * row_zero); + + int short_row = short_row_1 + common_13 * 2 + short_row_34 + short_row_2; + int offset_short_row = row_long + row_block; + + MAT_VAL_TYPE *short3_val = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * nnz_short); + int *short3_cid = (int *)malloc(sizeof(int) * nnz_short); + MAT_PTR_TYPE *short3_rpt = (MAT_PTR_TYPE *)malloc(sizeof(MAT_PTR_TYPE) * (short_row + 1)); + + for (int i = 0; i < short_row; i ++) + { + int idx = order_rid[offset_short_row + i]; + short3_rpt[i] = csrRowPtrA[idx + 1] - csrRowPtrA[idx]; + } + exclusive_scan(short3_rpt, short_row + 1); + + for (int i = 0; i < short_row; i ++) + { + int idx = order_rid[offset_short_row + i]; + memcpy(short3_val + short3_rpt[i], csrValA + csrRowPtrA[idx], sizeof(MAT_VAL_TYPE) * (csrRowPtrA[idx + 1] - csrRowPtrA[idx])); + memcpy(short3_cid + short3_rpt[i], csrColIdxA + csrRowPtrA[idx], sizeof(int) * (csrRowPtrA[idx + 1] - csrRowPtrA[idx])); + } + + MAT_PTR_TYPE *long_rpt_new = (MAT_PTR_TYPE *)malloc(sizeof(MAT_PTR_TYPE) * (row_long + 1)); + memset(long_rpt_new, 0, sizeof(MAT_PTR_TYPE) * (row_long + 1)); + int warp_number = 0; + for (int i = 0; i < row_long; i ++) + { + int nnz_num = long_rpt[i + 1] - long_rpt[i]; + int cur_warp_num = (nnz_num + MMA_M * MMA_K * loopNum_long - 1) / (MMA_M * MMA_K * loopNum_long); + long_rpt_new[i] = cur_warp_num; + } + exclusive_scan(long_rpt_new, row_long + 1); + warp_number = long_rpt_new[row_long]; + + int BlockNum_long = (warp_number + warpNum_long - 1) / warpNum_long; + int fill0_nnz_long = BlockNum_long * warpNum_long * loopNum_long * MMA_M * MMA_K; + warp_number = BlockNum_long * warpNum_long; + MAT_VAL_TYPE *val_by_warp = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * warp_number); + int *rid_by_warp = (int *)malloc(sizeof(int) * warp_number); + MAT_VAL_TYPE *long_val = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * fill0_nnz_long); + memset(long_val, 0.0, sizeof(MAT_VAL_TYPE) * fill0_nnz_long); + int *long_cid = (int *)malloc(sizeof(int) * fill0_nnz_long); + memset(long_cid, 0, sizeof(int) * fill0_nnz_long); + + for (int i = 0; i < row_long; i ++) + { + MAT_VAL_TYPE *cur_val = long_val + long_rpt_new[i] * loopNum_long * MMA_M * MMA_K; + int *cur_cid = long_cid + long_rpt_new[i] * loopNum_long * MMA_M * MMA_K; + int real_rid = long_rid[i]; + for (int j = 0; j < long_rpt[i + 1] - long_rpt[i]; j ++) + { + cur_val[j] = csrValA[csrRowPtrA[real_rid] + j]; + cur_cid[j] = csrColIdxA[csrRowPtrA[real_rid] + j]; + } + + for (int j = long_rpt_new[i]; j < long_rpt_new[i + 1]; j ++) + { + rid_by_warp[j] = i; + } + } + + int blocknum = (row_block + BlockSize - 1) / BlockSize; + blocknum = ((blocknum + rowloop * 4 - 1) / (rowloop * 4)) * rowloop * 4; + MAT_PTR_TYPE *blockPtr = (MAT_PTR_TYPE *)malloc(sizeof(MAT_PTR_TYPE) * (blocknum + 1)); + memset(blockPtr, 0, sizeof(MAT_PTR_TYPE) * (blocknum + 1)); + + MAT_PTR_TYPE *irreg_rpt = (MAT_PTR_TYPE *)malloc(sizeof(MAT_PTR_TYPE) * (row_block + 1)); + memset(irreg_rpt, 0, sizeof(MAT_PTR_TYPE) * (row_block + 1)); + + #pragma omp parallel for + for (int i = 0; i < blocknum; i++) + { + int row_start = i * BlockSize; + int row_end = (i + 1) * BlockSize >= row_block ? row_block : (i + 1) * BlockSize; + int k = 1; + while(1) + { + int block_nnz = 0; + for (int cur_row = row_start; cur_row < row_end; cur_row++) + { + int row_len = rptA[cur_row + 1] - rptA[cur_row]; + if (row_len / MMA_K >= k) block_nnz += MMA_K; + else if(row_len / MMA_K == k - 1) block_nnz += row_len % MMA_K; + } + + if (block_nnz >= threshold * MMA_K * MMA_M) + { + blockPtr[i] += MMA_K * MMA_M; + } + else + { + for (int cur_row = row_start; cur_row < row_end; cur_row++ ) + { + int row_len = rptA[cur_row + 1] - rptA[cur_row]; + irreg_rpt[cur_row] = row_len - (k - 1) * MMA_K > 0 ? row_len - (k - 1) * MMA_K : 0; + } + break; + } + k++; + } + } + + exclusive_scan(blockPtr, blocknum + 1); + exclusive_scan(irreg_rpt, row_block + 1); + + fill0_nnz_reg = blockPtr[blocknum]; + nnz_irreg = irreg_rpt[row_block]; + origin_nnz_reg = nnzA - nnz_irreg - nnz_long - nnz_short; + + MAT_VAL_TYPE *irreg_val = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * nnz_irreg); + int *irreg_cid = (int *)malloc(sizeof(int) * nnz_irreg); + for (int i = 0; i < row_block; i ++) + { + int cur_rid = ridA[i]; + int irreg_offset = irreg_rpt[i]; + int irreg_len = irreg_rpt[i + 1] - irreg_offset; + for (int j = 0; j < irreg_len; j ++) + { + irreg_val[irreg_offset + j] = csrValA[csrRowPtrA[cur_rid + 1] - irreg_len + j]; + irreg_cid[irreg_offset + j] = csrColIdxA[csrRowPtrA[cur_rid + 1] - irreg_len + j]; + } + } + + MAT_VAL_TYPE *reg_val = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * fill0_nnz_reg); + int *reg_cid = (int *)malloc(sizeof(int) * fill0_nnz_reg); + + for (int bid = 0; bid < blocknum; bid ++) + { + int nnz_block = (blockPtr[bid + 1] - blockPtr[bid]); + int blocklen = nnz_block / BlockSize; + + for (int rowid = bid * BlockSize; rowid < (bid + 1) * BlockSize; rowid ++) + { + int regA_start = blockPtr[bid] + blocklen * (rowid - bid * BlockSize); + if (rowid < row_block) + { + int real_id = ridA[rowid]; + int A_start = csrRowPtrA[real_id]; + int row_len = csrRowPtrA[real_id + 1] - A_start; + for (int i = 0; i < blocklen; i ++) + { + reg_val[regA_start + i] = i < row_len ? csrValA[A_start + i] : 0.0; + reg_cid[regA_start + i] = i < row_len ? csrColIdxA[A_start + i] : 0; + } + } + else + { + for (int i = 0; i < blocklen; i ++) + { + reg_val[regA_start + i] = 0.0; + reg_cid[regA_start + i] = 0; + } + } + + } + + MAT_VAL_TYPE *temp_val = (MAT_VAL_TYPE *)malloc(sizeof(MAT_VAL_TYPE) * nnz_block); + int *temp_cid = (int *)malloc(sizeof(int) * nnz_block); + MAT_VAL_TYPE *cur_val = reg_val + blockPtr[bid]; + int *cur_cid = reg_cid + blockPtr[bid]; + + for (int i = 0; i < nnz_block; i ++) + { + int new_id = ((i % blocklen) / MMA_K) * BlockSize * MMA_K + (i / blocklen) * MMA_K + i % MMA_K; + temp_val[new_id] = cur_val[i]; + temp_cid[new_id] = cur_cid[i]; + } + memcpy(cur_val, temp_val, sizeof(MAT_VAL_TYPE) * nnz_block); + memcpy(cur_cid, temp_cid, sizeof(int) * nnz_block); + free(temp_val); + free(temp_cid); + } + + long fill0_nnz = fill0_nnz_short + fill0_nnz_long + nnz_irreg + fill0_nnz_reg; + double rate_fill0 = (double)(fill0_nnz - nnzA) / nnzA; + + long long int data_X = (rowA + colA) * sizeof(MAT_VAL_TYPE) + \ + fill0_nnz_long * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + warp_number * sizeof(MAT_VAL_TYPE) + (row_long + 1) * sizeof(int) + \ + fill0_nnz_short * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + \ + fill0_nnz_reg * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + (blocknum + 1) * sizeof(MAT_PTR_TYPE) + \ + nnz_irreg * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + (row_block + 1) * sizeof(MAT_PTR_TYPE); + + long long int data_X2 = (rowA + nnzA) * sizeof(MAT_VAL_TYPE) + \ + fill0_nnz_long * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + warp_number * sizeof(MAT_VAL_TYPE) + (row_long + 1) * sizeof(int) + \ + fill0_nnz_short * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + \ + fill0_nnz_reg * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + (blocknum + 1) * sizeof(MAT_PTR_TYPE) + \ + nnz_irreg * (sizeof(MAT_VAL_TYPE) + sizeof(int)) + (row_block + 1) * sizeof(MAT_PTR_TYPE); + + int BlockNum = (blocknum + rowloop * 4 - 1) / (rowloop * 4); + + int ThreadNum_short = warpNum_short * WARP_SIZE; + int BlockNum_short_1 = (short_row_1 + ThreadNum_short - 1) / ThreadNum_short; + int BlockNum_short = BlockNum_short_1 + threadblock13 + threadblock34 + threadblock22; + + int offset_reg = BlockNum_long; + int offset_short1 = offset_reg + BlockNum; + int offset_short13 = offset_short1 + BlockNum_short_1; + int offset_short34 = offset_short13 + threadblock13; + int offset_short22 = offset_short34 + threadblock34; + + int BlockNum_all = BlockNum_long + BlockNum + BlockNum_short; + int ThreadNum_all = 4 * WARP_SIZE; + + int sumBlockNum = (row_long + 3) / 4; + + MAT_VAL_TYPE *dX_val, *dY_val; + + // Allocate struct on device + DASPSparseMatrix hA; + hA.row_long = row_long; + hA.row_block = row_block; + hA.blocknum = blocknum; + hA.short_row_1 = short_row_1; + hA.common_13 = common_13; + hA.short_row_34 = short_row_34; + hA.short_row_2 = short_row_2; + hA.fill0_nnz_short13 = fill0_nnz_short13; + hA.fill0_nnz_short34 = fill0_nnz_short34; + hA.offset_reg = offset_reg; + hA.offset_short1 = offset_short1; + hA.offset_short13 = offset_short13; + hA.offset_short34 = offset_short34; + hA.offset_short22 = offset_short22; + + cudaMalloc((void **)&dX_val, sizeof(MAT_VAL_TYPE) * colA); + cudaMalloc((void **)&dY_val, sizeof(MAT_VAL_TYPE) * rowA); + cudaMemcpy(dX_val, X_val, sizeof(MAT_VAL_TYPE) * colA, cudaMemcpyHostToDevice); + cudaMemset(dY_val, 0.0, sizeof(MAT_VAL_TYPE) * rowA); + + cudaMalloc((void **)&hA.dlongA_val, sizeof(MAT_VAL_TYPE) * fill0_nnz_long); + cudaMalloc((void **)&hA.dlongA_cid, sizeof(int) * fill0_nnz_long); + // hA.dwarp_val is dval_by_warp in original + cudaMalloc((void **)&hA.dwarp_val, sizeof(MAT_VAL_TYPE) * warp_number); + cudaMalloc((void **)&hA.dlongA_rpt, sizeof(MAT_PTR_TYPE) * (row_long + 1)); + cudaMemcpy(hA.dlongA_val, long_val, sizeof(MAT_VAL_TYPE) * fill0_nnz_long, cudaMemcpyHostToDevice); + cudaMemcpy(hA.dlongA_cid, long_cid, sizeof(int) * fill0_nnz_long, cudaMemcpyHostToDevice); + // drid_by_warp is not in the struct but allocated in original, seems unused in kernel? + // checking kernel... it is not used in kernel. Only dval_by_warp (dwarp_val) is used. + cudaMemcpy(hA.dlongA_rpt, long_rpt_new, sizeof(MAT_PTR_TYPE) * (row_long + 1), cudaMemcpyHostToDevice); + + cudaMalloc((void **)&hA.dshort_val, sizeof(MAT_VAL_TYPE) * fill0_nnz_short); + cudaMalloc((void **)&hA.dshort_cid, sizeof(int) * fill0_nnz_short); + cudaMemcpy(hA.dshort_val, short_val, sizeof(MAT_VAL_TYPE) * fill0_nnz_short, cudaMemcpyHostToDevice); + cudaMemcpy(hA.dshort_cid, short_cid, sizeof(int) * fill0_nnz_short, cudaMemcpyHostToDevice); + + cudaMalloc((void **)&hA.dregA_val, sizeof(MAT_VAL_TYPE) * fill0_nnz_reg); + cudaMalloc((void **)&hA.dregA_cid, sizeof(int) * fill0_nnz_reg); + cudaMalloc((void **)&hA.dblockA_ptr, sizeof(MAT_PTR_TYPE) * (blocknum + 1)); + cudaMemcpy(hA.dregA_val, reg_val, sizeof(MAT_VAL_TYPE) * fill0_nnz_reg, cudaMemcpyHostToDevice); + cudaMemcpy(hA.dregA_cid, reg_cid, sizeof(int) * fill0_nnz_reg, cudaMemcpyHostToDevice); + cudaMemcpy(hA.dblockA_ptr, blockPtr, sizeof(MAT_PTR_TYPE) * (blocknum + 1), cudaMemcpyHostToDevice); + + cudaMalloc((void **)&hA.dirregA_val, sizeof(MAT_VAL_TYPE) * nnz_irreg); + cudaMalloc((void **)&hA.dirregA_rpt, sizeof(MAT_PTR_TYPE) * (row_block + 1)); + cudaMalloc((void **)&hA.dirregA_cid, sizeof(int) * nnz_irreg); + cudaMemcpy(hA.dirregA_val, irreg_val, sizeof(MAT_VAL_TYPE) * nnz_irreg, cudaMemcpyHostToDevice); + cudaMemcpy(hA.dirregA_rpt, irreg_rpt, sizeof(MAT_PTR_TYPE) * (row_block + 1), cudaMemcpyHostToDevice); + cudaMemcpy(hA.dirregA_cid, irreg_cid, sizeof(int) * nnz_irreg, cudaMemcpyHostToDevice); + + DASPSparseMatrix *dA; + cudaMalloc((void **)&dA, sizeof(DASPSparseMatrix)); + cudaMemcpy(dA, &hA, sizeof(DASPSparseMatrix), cudaMemcpyHostToDevice); + + int carveout = 0; + cudaFuncSetAttribute(dasp_spmv3<1>, cudaFuncAttributePreferredSharedMemoryCarveout, carveout); + cudaFuncSetAttribute(dasp_spmv3<2>, cudaFuncAttributePreferredSharedMemoryCarveout, carveout); + cudaFuncSetAttribute(dasp_spmv3<4>, cudaFuncAttributePreferredSharedMemoryCarveout, carveout); + + int warmup_time = 100; + int execute_time = 1000; + if (rowloop == 1) + { + for (int i = 0; i < warmup_time; ++i) + { + dasp_spmv3<1><<>>(dX_val, dY_val, dA); + } + cudaDeviceSynchronize(); + gettimeofday(&t1, NULL); + for (int i = 0; i < execute_time; ++i) + { + dasp_spmv3<1><<>>(dX_val, dY_val, dA); + } + cudaDeviceSynchronize(); + if (row_long) + { + for (int i = 0; i < execute_time; ++i) + { + longPart_sum<<>>(hA.dlongA_rpt, hA.dwarp_val, dY_val, row_long); + } + cudaDeviceSynchronize(); + } + gettimeofday(&t2, NULL); + } + else if (rowloop == 2) + { + for (int i = 0; i < warmup_time; ++i) + { + dasp_spmv3<2><<>>(dX_val, dY_val, dA); + } + cudaDeviceSynchronize(); + gettimeofday(&t1, NULL); + for (int i = 0; i < execute_time; ++i) + { + dasp_spmv3<2><<>>(dX_val, dY_val, dA); + } + cudaDeviceSynchronize(); + if (row_long) + { + for (int i = 0; i < execute_time; ++i) + { + longPart_sum<<>>(hA.dlongA_rpt, hA.dwarp_val, dY_val, row_long); + } + cudaDeviceSynchronize(); + } + gettimeofday(&t2, NULL); + } + else + { + for (int i = 0; i < warmup_time; ++i) + { + dasp_spmv3<4><<>>(dX_val, dY_val, dA); + } + cudaDeviceSynchronize(); + gettimeofday(&t1, NULL); + for (int i = 0; i < execute_time; ++i) + { + dasp_spmv3<4><<>>(dX_val, dY_val, dA); + } + cudaDeviceSynchronize(); + if (row_long) + { + for (int i = 0; i < execute_time; ++i) + { + longPart_sum<<>>(hA.dlongA_rpt, hA.dwarp_val, dY_val, row_long); + } + cudaDeviceSynchronize(); + } + gettimeofday(&t2, NULL); + } + + + double dasp_time = ((t2.tv_sec - t1.tv_sec) * 1000.0 + (t2.tv_usec - t1.tv_usec) / 1000.0) / execute_time; + double dasp_gflops = (double)((long)nnzA * 2) / (dasp_time * 1e6); + double dasp_bandwidth1 = (double)data_X / (dasp_time * 1e6); + double dasp_bandwidth2 = (double)data_X2 / (dasp_time * 1e6); + printf("SpMV_3: %8.4lf ms, %8.4lf GFlop/s, %9.4lf GB/s, %9.4lf GB/s\n", dasp_time, dasp_gflops, dasp_bandwidth1, dasp_bandwidth2); + + cudaMemcpy(Y_val, dY_val, sizeof(MAT_VAL_TYPE) * rowA, cudaMemcpyDeviceToHost); + + cudaFree(dX_val); + cudaFree(dY_val); + + cudaFree(hA.dlongA_val); + cudaFree(hA.dlongA_cid); + cudaFree(hA.dwarp_val); + // cudaFree(drid_by_warp); // not in struct + cudaFree(hA.dlongA_rpt); + + cudaFree(hA.dshort_cid); + cudaFree(hA.dshort_val); + + cudaFree(hA.dregA_val); + cudaFree(hA.dregA_cid); + cudaFree(hA.dblockA_ptr); + cudaFree(hA.dirregA_cid); + cudaFree(hA.dirregA_rpt); + cudaFree(hA.dirregA_val); + + cudaFree(dA); + + FILE* fout; + fout = fopen("data/spmv_f64_record.csv", "a"); + fprintf(fout, "%s,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,", filename, rowA, colA, nnzA, short_row_1, common_13, short_row_3, short_row_4, short_row_2, row_long, row_block, nnz_short, fill0_nnz_short, nnz_long, fill0_nnz_long, origin_nnz_reg, fill0_nnz_reg, nnz_irreg); + fprintf(fout, "%lf,%d,%lld,%lf,%lf,%lf,%lf,", rate_fill0, block_longest, data_X, dasp_time, dasp_gflops, dasp_bandwidth1, dasp_bandwidth2); + fclose(fout); + + printf("\n"); + + free(short_rid_1); + free(short_rid_2); + free(short_rid_3); + free(short_rid_4); + free(long_rid); + free(zero_rid); + free(ridA); + + free(superX_ptr); + free(superX_cid); + free(superX_val); + free(short_cid_temp); + free(short_cid_new); + + free(rptA); + free(long_rpt); + + free(short_val); + free(short_cid); + + free(short3_cid); + free(short3_rpt); + free(short3_val); + + free(long_cid); + free(long_val); + free(long_rpt_new); + free(val_by_warp); + free(rid_by_warp); + + free(reg_val); + free(reg_cid); + free(blockPtr); + + free(irreg_rpt); + free(irreg_cid); + free(irreg_val); +} \ No newline at end of file