多机调整

This commit is contained in:
qihuanye
2026-05-17 20:49:33 +08:00
parent 0164e21f48
commit 113e591899
6 changed files with 84 additions and 9 deletions

1
.gitignore vendored
View File

@@ -31,3 +31,4 @@ eval_tmp_*.npy
.DS_Store
.idea/
.vscode/
*.log

View File

@@ -67,6 +67,10 @@ multi_node:
rank_env: RANK
world_size_env: WORLD_SIZE
local_rank_env: LOCAL_RANK
aggregate_results: true
sync_before_return: false
destroy_process_group: true
shard_strategy: round_robin
preload_wait:
enabled: false

View File

@@ -55,6 +55,10 @@ multi_node:
rank_env: RANK
world_size_env: WORLD_SIZE
local_rank_env: LOCAL_RANK
aggregate_results: true
sync_before_return: false
destroy_process_group: true
shard_strategy: round_robin
preload_wait:
enabled: false

View File

@@ -56,6 +56,10 @@ multi_node:
rank_env: RANK
world_size_env: WORLD_SIZE
local_rank_env: LOCAL_RANK
aggregate_results: true
sync_before_return: false
destroy_process_group: true
shard_strategy: round_robin
preload_wait:
enabled: false

View File

@@ -54,6 +54,10 @@ multi_node:
rank_env: RANK
world_size_env: WORLD_SIZE
local_rank_env: LOCAL_RANK
aggregate_results: true
sync_before_return: false
destroy_process_group: true
shard_strategy: round_robin
preload_wait:
enabled: false

76
eval.py
View File

@@ -277,6 +277,10 @@ def get_multi_node_cfg(cfg):
"world_size_env": "WORLD_SIZE",
"local_rank_env": "LOCAL_RANK",
"output_mode": "single",
"aggregate_results": True,
"sync_before_return": False,
"destroy_process_group": True,
"shard_strategy": "round_robin",
}
cfg_multi_node = cfg.get("multi_node")
if cfg_multi_node is not None:
@@ -318,6 +322,28 @@ def all_gather_eval_result(result):
return payload
def finalize_multi_node_process_group(cfg):
multi_node_cfg = get_multi_node_cfg(cfg)
if not multi_node_cfg["destroy_process_group"]:
return
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
def get_rank_result_path(output_dir: Path, cfg: DictConfig, rank: int) -> Path:
filename = str(cfg.output.filename)
if rank == 0:
return output_dir / filename
suffix = Path(filename).suffix
stem = Path(filename).stem
if suffix:
ranked_filename = f"{stem}.rank{rank}{suffix}"
else:
ranked_filename = f"{filename}.rank{rank}"
return output_dir / ranked_filename
def build_process(cfg, dataset):
process = {}
for col in cfg.dataset.keys_to_cache:
@@ -400,12 +426,26 @@ def shard_eval_cases(eval_episodes, eval_start_idx, num_shards):
return shards
def get_rank_eval_subset(eval_episodes, eval_start_idx, rank, world_size):
def get_rank_eval_subset(
eval_episodes,
eval_start_idx,
rank,
world_size,
*,
strategy="contiguous",
):
if world_size < 1:
raise ValueError("world_size must be >= 1")
if rank < 0 or rank >= world_size:
raise ValueError("rank must be in [0, world_size)")
if strategy == "round_robin":
episode_subset = eval_episodes[rank::world_size]
start_subset = eval_start_idx[rank::world_size]
return episode_subset, start_subset
if strategy != "contiguous":
raise ValueError("strategy must be one of: contiguous, round_robin")
total = len(eval_episodes)
shard_sizes = [total // world_size] * world_size
for idx in range(total % world_size):
@@ -697,8 +737,13 @@ def combine_eval_results(ordered_results):
def run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
rank, world_size, local_rank = get_rank_context(cfg)
multi_node_cfg = get_multi_node_cfg(cfg)
shard_episodes, shard_start_idx = get_rank_eval_subset(
eval_episodes, eval_start_idx, rank, world_size
eval_episodes,
eval_start_idx,
rank,
world_size,
strategy=multi_node_cfg["shard_strategy"],
)
if len(shard_episodes) == 0:
raise ValueError("No evaluation episodes assigned to this rank")
@@ -716,21 +761,27 @@ def run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
if not torch.distributed.is_available():
raise RuntimeError("torch.distributed is required for preload_wait")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend=get_multi_node_cfg(cfg)["backend"])
torch.distributed.init_process_group(backend=multi_node_cfg["backend"])
rank_output_path = get_rank_result_path(output_dir, cfg, rank)
result = run_eval_subset(
local_cfg,
list(shard_episodes),
list(shard_start_idx),
output_dir,
rank_output_path.parent,
device_override=device,
enable_profile=False,
before_evaluate=lambda: wait_for_preload_signal(cfg, rank=rank),
)
if not multi_node_cfg["aggregate_results"]:
result["output_filename"] = rank_output_path.name
finalize_multi_node_process_group(cfg)
return result
if not torch.distributed.is_available():
raise RuntimeError("torch.distributed is required for multi-node evaluation")
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend=get_multi_node_cfg(cfg)["backend"])
torch.distributed.init_process_group(backend=multi_node_cfg["backend"])
gathered = all_gather_eval_result(result)
metrics, reference = combine_eval_results(gathered)
@@ -742,8 +793,11 @@ def run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
"compile_mode": reference["compile_mode"],
"profile_dir": None,
"profile_summary_path": None,
"output_filename": cfg.output.filename,
}
torch.distributed.barrier()
if multi_node_cfg["sync_before_return"]:
torch.distributed.barrier()
finalize_multi_node_process_group(cfg)
if rank != 0:
return None
return combined
@@ -761,6 +815,7 @@ def run(cfg: DictConfig):
profile_cfg = get_profile_cfg(cfg)
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
eval_wall_start = time.time()
if get_multi_node_cfg(cfg)["enabled"] and get_multi_gpu_cfg(cfg)["enabled"]:
raise ValueError("multi_node.enabled and multi_gpu.enabled are mutually exclusive")
@@ -788,10 +843,11 @@ def run(cfg: DictConfig):
compile_mode = eval_result["compile_mode"]
profile_dir = eval_result["profile_dir"]
profile_summary_path = eval_result["profile_summary_path"]
output_filename = eval_result.get("output_filename", cfg.output.filename)
print(metrics)
results_path = output_dir / cfg.output.filename
results_path = output_dir / output_filename
results_path.parent.mkdir(parents=True, exist_ok=True)
with results_path.open("a") as f:
@@ -810,8 +866,10 @@ def run(cfg: DictConfig):
f.write(f"inference_compile_mode: {compile_mode}\n")
if profile_cfg["enabled"]:
f.write(f"profile_dir: {profile_dir}\n")
if profile_summary_path is not None:
f.write(f"profile_summary: {profile_summary_path}\n")
if profile_summary_path is not None:
f.write(f"profile_summary: {profile_summary_path}\n")
f.write(f"total_wall_time: {time.time() - eval_wall_start} seconds\n")
if __name__ == "__main__":