Move resident sync comm buffers into StepAllocation pool

This commit is contained in:
2026-04-13 21:04:44 +08:00
parent b3ec244cf9
commit 7191fc0b96

View File

@@ -459,7 +459,12 @@ struct StepAllocation {
double *d_state_curr_mem;
double *d_state_next_mem;
double *d_matter_mem;
double *d_comm_mem;
double *h_comm_mem;
size_t cap_all;
size_t cap_comm;
bool h_comm_pinned;
size_t cap_h_comm;
};
static std::unordered_map<void *, StepContext> g_step_ctx;
@@ -467,7 +472,11 @@ static std::vector<StepAllocation> g_step_pool;
static StepAllocation empty_step_allocation()
{
StepAllocation alloc = {nullptr, nullptr, nullptr, nullptr, nullptr, 0};
StepAllocation alloc = {
nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr,
0, 0, false, 0
};
return alloc;
}
@@ -480,14 +489,21 @@ static StepAllocation detach_step_allocation(StepContext &ctx)
{
StepAllocation alloc = {
ctx.d_state0_mem, ctx.d_accum_mem, ctx.d_state_curr_mem,
ctx.d_state_next_mem, ctx.d_matter_mem, ctx.cap_all
ctx.d_state_next_mem, ctx.d_matter_mem,
ctx.d_comm_mem, ctx.h_comm_mem,
ctx.cap_all, ctx.cap_comm, ctx.h_comm_pinned, ctx.cap_h_comm
};
ctx.d_state0_mem = nullptr;
ctx.d_accum_mem = nullptr;
ctx.d_state_curr_mem = nullptr;
ctx.d_state_next_mem = nullptr;
ctx.d_matter_mem = nullptr;
ctx.d_comm_mem = nullptr;
ctx.h_comm_mem = nullptr;
ctx.cap_all = 0;
ctx.cap_comm = 0;
ctx.h_comm_pinned = false;
ctx.cap_h_comm = 0;
ctx.matter_ready = false;
ctx.state_ready = false;
ctx.d_state0.fill(nullptr);
@@ -505,7 +521,12 @@ static void attach_step_allocation(StepContext &ctx, const StepAllocation &alloc
ctx.d_state_curr_mem = alloc.d_state_curr_mem;
ctx.d_state_next_mem = alloc.d_state_next_mem;
ctx.d_matter_mem = alloc.d_matter_mem;
ctx.d_comm_mem = alloc.d_comm_mem;
ctx.h_comm_mem = alloc.h_comm_mem;
ctx.cap_all = alloc.cap_all;
ctx.cap_comm = alloc.cap_comm;
ctx.h_comm_pinned = alloc.h_comm_pinned;
ctx.cap_h_comm = alloc.cap_h_comm;
ctx.matter_ready = false;
ctx.state_ready = false;
}
@@ -619,18 +640,6 @@ static void release_step_ctx(void *block_tag)
{
auto it = g_step_ctx.find(block_tag);
if (it == g_step_ctx.end()) return;
if (it->second.d_comm_mem) {
cudaFree(it->second.d_comm_mem);
it->second.d_comm_mem = nullptr;
it->second.cap_comm = 0;
}
if (it->second.h_comm_mem) {
if (it->second.h_comm_pinned) cudaFreeHost(it->second.h_comm_mem);
else free(it->second.h_comm_mem);
it->second.h_comm_mem = nullptr;
it->second.h_comm_pinned = false;
it->second.cap_h_comm = 0;
}
StepAllocation alloc = detach_step_allocation(it->second);
recycle_step_allocation(alloc);
g_step_ctx.erase(it);