Optimize 6th-order CUDA AMR stencils

This commit is contained in:
2026-05-07 19:22:37 +08:00
parent 9ff2f065be
commit 0076b3ca18
2 changed files with 90 additions and 26 deletions

View File

@@ -234,7 +234,12 @@ bool cuda_cell_gw3_restrict_params(const Parallel::gridseg *src,
const Parallel::gridseg *dst,
int first_fine[3])
{
#if USE_CUDA_BSSN && defined(Cell) && (ghost_width == 3)
#if USE_CUDA_BSSN && defined(Cell) && ((ghost_width == 3) || (ghost_width == 4))
#if ghost_width == 4
const int stencil_hi = 4;
#else
const int stencil_hi = 3;
#endif
if (!src || !dst || !src->Bg || !dst->Bg)
return false;
for (int d = 0; d < dim; ++d)
@@ -260,7 +265,7 @@ bool cuda_cell_gw3_restrict_params(const Parallel::gridseg *src,
first_fine[d] = 2 * lbc - lbf - 1;
if (first_fine[d] < 0)
return false;
if (first_fine[d] + 2 * (dst->shape[d] - 1) + 3 >= src->Bg->shape[d])
if (first_fine[d] + 2 * (dst->shape[d] - 1) + stencil_hi >= src->Bg->shape[d])
return false;
}
return true;
@@ -275,7 +280,12 @@ bool cuda_cell_gw3_prolong_params(const Parallel::gridseg *src,
int first_fine_ii[3],
int coarse_lb[3])
{
#if USE_CUDA_BSSN && defined(Cell) && (ghost_width == 3)
#if USE_CUDA_BSSN && defined(Cell) && ((ghost_width == 3) || (ghost_width == 4))
#if ghost_width == 4
const int stencil_hi = 4;
#else
const int stencil_hi = 3;
#endif
if (!src || !dst || !src->Bg || !dst->Bg)
return false;
for (int d = 0; d < dim; ++d)
@@ -305,7 +315,7 @@ bool cuda_cell_gw3_prolong_params(const Parallel::gridseg *src,
const int last_coarse = last_fine_ii / 2 - coarse_lb[d];
if (first_coarse < -1)
return false;
if (last_coarse + 3 >= src->Bg->shape[d])
if (last_coarse + stencil_hi >= src->Bg->shape[d])
return false;
}
return true;

View File

@@ -7622,11 +7622,22 @@ __global__ void kern_restrict_state_region_batch(const double * __restrict__ src
{
const int state_index = blockIdx.y;
if (state_index >= state_count) return;
#if ghost_width == 4
const double c1 = -5.0 / 2048.0;
const double c2 = 49.0 / 2048.0;
const double c3 = -245.0 / 2048.0;
const double c4 = 1225.0 / 2048.0;
const int offs[8] = {-3, -2, -1, 0, 1, 2, 3, 4};
const double w[8] = {c1, c2, c3, c4, c4, c3, c2, c1};
const int nst = 8;
#else
const double c1 = 3.0 / 256.0;
const double c2 = -25.0 / 256.0;
const double c3 = 75.0 / 128.0;
const int offs[6] = {-2, -1, 0, 1, 2, 3};
const double w[6] = {c1, c2, c3, c3, c2, c1};
const int nst = 6;
#endif
for (int local = blockIdx.x * blockDim.x + threadIdx.x;
local < region_all;
@@ -7639,20 +7650,20 @@ __global__ void kern_restrict_state_region_batch(const double * __restrict__ src
const int fc_j = fj0 + 2 * jj;
const int fc_k = fk0 + 2 * kk;
double sum = 0.0;
for (int oz = 0; oz < 6; ++oz)
for (int oz = 0; oz < nst; ++oz)
{
const int z = fc_k + offs[oz];
const double wz = w[oz];
for (int oy = 0; oy < 6; ++oy)
for (int oy = 0; oy < nst; ++oy)
{
const int y = fc_j + offs[oy];
const double wyz = wz * w[oy];
for (int ox = 0; ox < 6; ++ox)
{
const int x = fc_i + offs[ox];
sum += wyz * w[ox] *
load_comm_state_cell_sym(src_mem, state_index, x, y, z, nx, ny, all);
}
for (int ox = 0; ox < nst; ++ox)
{
const int x = fc_i + offs[ox];
sum += wyz * w[ox] *
load_comm_state_cell_sym(src_mem, state_index, x, y, z, nx, ny, all);
}
}
}
dst[(size_t)state_index * region_all + local] = sum;
@@ -7697,11 +7708,22 @@ __global__ void kern_restrict_state_segments_batch(const double * __restrict__ s
const int offset = m[4];
const int fi0 = m[5], fj0 = m[6], fk0 = m[7];
if (state_index >= state_count) return;
#if ghost_width == 4
const double c1 = -5.0 / 2048.0;
const double c2 = 49.0 / 2048.0;
const double c3 = -245.0 / 2048.0;
const double c4 = 1225.0 / 2048.0;
const int offs[8] = {-3, -2, -1, 0, 1, 2, 3, 4};
const double w[8] = {c1, c2, c3, c4, c4, c3, c2, c1};
const int nst = 8;
#else
const double c1 = 3.0 / 256.0;
const double c2 = -25.0 / 256.0;
const double c3 = 75.0 / 128.0;
const int offs[6] = {-2, -1, 0, 1, 2, 3};
const double w[6] = {c1, c2, c3, c3, c2, c1};
const int nst = 6;
#endif
for (int local = blockIdx.x * blockDim.x + threadIdx.x;
local < region_all;
@@ -7714,15 +7736,15 @@ __global__ void kern_restrict_state_segments_batch(const double * __restrict__ s
const int fc_j = fj0 + 2 * jj;
const int fc_k = fk0 + 2 * kk;
double sum = 0.0;
for (int oz = 0; oz < 6; ++oz)
for (int oz = 0; oz < nst; ++oz)
{
const int z = fc_k + offs[oz];
const double wz = w[oz];
for (int oy = 0; oy < 6; ++oy)
for (int oy = 0; oy < nst; ++oy)
{
const int y = fc_j + offs[oy];
const double wyz = wz * w[oy];
for (int ox = 0; ox < 6; ++ox)
for (int ox = 0; ox < nst; ++ox)
{
const int x = fc_i + offs[ox];
sum += wyz * w[ox] *
@@ -7746,6 +7768,20 @@ __global__ void kern_prolong_state_region_batch(const double * __restrict__ src_
{
const int state_index = blockIdx.y;
if (state_index >= state_count) return;
#if ghost_width == 4
const double c1 = -495.0 / 262144.0;
const double c2 = 5005.0 / 262144.0;
const double c3 = -27027.0 / 262144.0;
const double c4 = 225225.0 / 262144.0;
const double c5 = 75075.0 / 262144.0;
const double c6 = -19305.0 / 262144.0;
const double c7 = 4095.0 / 262144.0;
const double c8 = -429.0 / 262144.0;
const int offs[8] = {-3, -2, -1, 0, 1, 2, 3, 4};
const double wl[8] = {c1, c2, c3, c4, c5, c6, c7, c8};
const double wr[8] = {c8, c7, c6, c5, c4, c3, c2, c1};
const int nst = 8;
#else
const double c1 = 77.0 / 8192.0;
const double c2 = -693.0 / 8192.0;
const double c3 = 3465.0 / 4096.0;
@@ -7755,6 +7791,8 @@ __global__ void kern_prolong_state_region_batch(const double * __restrict__ src_
const int offs[6] = {-2, -1, 0, 1, 2, 3};
const double wl[6] = {c1, c2, c3, c4, c5, c6};
const double wr[6] = {c6, c5, c4, c3, c2, c1};
const int nst = 6;
#endif
for (int local = blockIdx.x * blockDim.x + threadIdx.x;
local < region_all;
@@ -7773,15 +7811,15 @@ __global__ void kern_prolong_state_region_batch(const double * __restrict__ src_
const double *wy = ((fine_j / 2) * 2 == fine_j) ? wl : wr;
const double *wz = ((fine_k / 2) * 2 == fine_k) ? wl : wr;
double sum = 0.0;
for (int oz = 0; oz < 6; ++oz)
for (int oz = 0; oz < nst; ++oz)
{
const int z = ck + offs[oz];
const double wzv = wz[oz];
for (int oy = 0; oy < 6; ++oy)
for (int oy = 0; oy < nst; ++oy)
{
const int y = cj + offs[oy];
const double wyz = wzv * wy[oy];
for (int ox = 0; ox < 6; ++ox)
for (int ox = 0; ox < nst; ++ox)
{
const int x = ci + offs[ox];
sum += wyz * wx[ox] *
@@ -7809,6 +7847,20 @@ __global__ void kern_prolong_state_segments_batch(const double * __restrict__ sr
const int ii0 = m[5], jj0 = m[6], kk0 = m[7];
const int lbc_i = m[8], lbc_j = m[9], lbc_k = m[10];
if (state_index >= state_count) return;
#if ghost_width == 4
const double c1 = -495.0 / 262144.0;
const double c2 = 5005.0 / 262144.0;
const double c3 = -27027.0 / 262144.0;
const double c4 = 225225.0 / 262144.0;
const double c5 = 75075.0 / 262144.0;
const double c6 = -19305.0 / 262144.0;
const double c7 = 4095.0 / 262144.0;
const double c8 = -429.0 / 262144.0;
const int offs[8] = {-3, -2, -1, 0, 1, 2, 3, 4};
const double wl[8] = {c1, c2, c3, c4, c5, c6, c7, c8};
const double wr[8] = {c8, c7, c6, c5, c4, c3, c2, c1};
const int nst = 8;
#else
const double c1 = 77.0 / 8192.0;
const double c2 = -693.0 / 8192.0;
const double c3 = 3465.0 / 4096.0;
@@ -7818,6 +7870,8 @@ __global__ void kern_prolong_state_segments_batch(const double * __restrict__ sr
const int offs[6] = {-2, -1, 0, 1, 2, 3};
const double wl[6] = {c1, c2, c3, c4, c5, c6};
const double wr[6] = {c6, c5, c4, c3, c2, c1};
const int nst = 6;
#endif
for (int local = blockIdx.x * blockDim.x + threadIdx.x;
local < region_all;
@@ -7836,20 +7890,20 @@ __global__ void kern_prolong_state_segments_batch(const double * __restrict__ sr
const double *wy = ((fine_j / 2) * 2 == fine_j) ? wl : wr;
const double *wz = ((fine_k / 2) * 2 == fine_k) ? wl : wr;
double sum = 0.0;
for (int oz = 0; oz < 6; ++oz)
for (int oz = 0; oz < nst; ++oz)
{
const int z = ck + offs[oz];
const double wzv = wz[oz];
for (int oy = 0; oy < 6; ++oy)
for (int oy = 0; oy < nst; ++oy)
{
const int y = cj + offs[oy];
const double wyz = wzv * wy[oy];
for (int ox = 0; ox < 6; ++ox)
{
const int x = ci + offs[ox];
sum += wyz * wx[ox] *
load_comm_state_cell_sym(src_mem, state_index, x, y, z, nx, ny, all);
}
for (int ox = 0; ox < nst; ++ox)
{
const int x = ci + offs[ox];
sum += wyz * wx[ox] *
load_comm_state_cell_sym(src_mem, state_index, x, y, z, nx, ny, all);
}
}
}
dst[(size_t)offset + (size_t)state_index * region_all + local] = sum;