877 lines
29 KiB
Python
877 lines
29 KiB
Python
import os
|
|
|
|
os.environ["MUJOCO_GL"] = "egl"
|
|
|
|
import multiprocessing as mp
|
|
import time
|
|
import traceback
|
|
from contextlib import nullcontext
|
|
from pathlib import Path
|
|
import tempfile
|
|
|
|
import hydra
|
|
import numpy as np
|
|
import stable_pretraining as spt
|
|
import torch
|
|
from omegaconf import DictConfig, OmegaConf
|
|
from sklearn import preprocessing
|
|
from torchvision.transforms import v2 as transforms
|
|
import stable_worldmodel as swm
|
|
|
|
def img_transform(cfg):
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.ToImage(),
|
|
transforms.ToDtype(torch.float32, scale=True),
|
|
transforms.Normalize(**spt.data.dataset_stats.ImageNet),
|
|
transforms.Resize(size=cfg.eval.img_size),
|
|
]
|
|
)
|
|
return transform
|
|
|
|
|
|
def get_episodes_length(dataset, episodes):
|
|
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
|
|
|
|
episode_idx = dataset.get_col_data(col_name)
|
|
step_idx = dataset.get_col_data("step_idx")
|
|
lengths = []
|
|
for ep_id in episodes:
|
|
lengths.append(np.max(step_idx[episode_idx == ep_id]) + 1)
|
|
return np.array(lengths)
|
|
|
|
|
|
def get_dataset(cfg, dataset_name):
|
|
dataset_path = Path(cfg.cache_dir or swm.data.utils.get_cache_dir())
|
|
dataset = swm.data.HDF5Dataset(
|
|
dataset_name,
|
|
keys_to_cache=cfg.dataset.keys_to_cache,
|
|
cache_dir=dataset_path,
|
|
)
|
|
return dataset
|
|
|
|
|
|
def get_profile_cfg(cfg):
|
|
profile_cfg = {
|
|
"enabled": False,
|
|
"trace_dirname": "torch_profile",
|
|
"record_shapes": True,
|
|
"profile_memory": True,
|
|
"with_stack": False,
|
|
"with_flops": False,
|
|
"row_limit": 40,
|
|
"worker_name": "eval",
|
|
"export_chrome_trace": True,
|
|
"export_tensorboard": True,
|
|
}
|
|
cfg_profile = cfg.get("profile")
|
|
if cfg_profile is not None:
|
|
profile_cfg.update(OmegaConf.to_container(cfg_profile, resolve=True))
|
|
return profile_cfg
|
|
|
|
|
|
def get_compile_cfg(cfg):
|
|
compile_cfg = {
|
|
"enabled": True,
|
|
"target": "predictor",
|
|
"mode": "reduce-overhead",
|
|
"fullgraph": False,
|
|
"dynamic": False,
|
|
"cuda_only": True,
|
|
}
|
|
cfg_compile = cfg.get("compile")
|
|
if cfg_compile is not None:
|
|
compile_cfg.update(OmegaConf.to_container(cfg_compile, resolve=True))
|
|
return compile_cfg
|
|
|
|
|
|
def get_compile_warmup_cfg(cfg):
|
|
warmup_cfg = {
|
|
"enabled": True,
|
|
"num_eval": 1,
|
|
}
|
|
cfg_warmup = cfg.get("compile_warmup")
|
|
if cfg_warmup is not None:
|
|
warmup_cfg.update(OmegaConf.to_container(cfg_warmup, resolve=True))
|
|
return warmup_cfg
|
|
|
|
|
|
def get_preload_wait_cfg(cfg):
|
|
preload_cfg = {
|
|
"enabled": False,
|
|
"file": "/tmp/lewm_preload_start",
|
|
"poll_interval": 1.0,
|
|
}
|
|
cfg_preload = cfg.get("preload_wait")
|
|
if cfg_preload is not None:
|
|
preload_cfg.update(OmegaConf.to_container(cfg_preload, resolve=True))
|
|
return preload_cfg
|
|
|
|
|
|
def wait_for_preload_signal(cfg, rank=0):
|
|
preload_cfg = get_preload_wait_cfg(cfg)
|
|
if not preload_cfg["enabled"]:
|
|
return
|
|
|
|
dist_ready = (
|
|
torch.distributed.is_available()
|
|
and torch.distributed.is_initialized()
|
|
)
|
|
if dist_ready:
|
|
torch.distributed.barrier()
|
|
|
|
signal_path = Path(str(preload_cfg["file"])).expanduser()
|
|
poll_interval = float(preload_cfg["poll_interval"])
|
|
if rank == 0:
|
|
print(
|
|
"Preload ready. Create this file to start evaluation: "
|
|
f"{signal_path}",
|
|
flush=True,
|
|
)
|
|
while not signal_path.exists():
|
|
time.sleep(poll_interval)
|
|
print("Preload start signal received. Starting evaluation.", flush=True)
|
|
|
|
if dist_ready:
|
|
torch.distributed.barrier()
|
|
|
|
|
|
def maybe_compile_inference_target(model, cfg, device):
|
|
compile_cfg = get_compile_cfg(cfg)
|
|
compile_target = "disabled"
|
|
|
|
if not compile_cfg["enabled"]:
|
|
return model, compile_cfg, compile_target
|
|
|
|
if not hasattr(torch, "compile"):
|
|
print("torch.compile is unavailable, skipping inference compilation.")
|
|
return model, compile_cfg, compile_target
|
|
|
|
if compile_cfg["cuda_only"] and not str(device).startswith("cuda"):
|
|
print("Skipping torch.compile because compile.cuda_only=true and device is not CUDA.")
|
|
return model, compile_cfg, compile_target
|
|
|
|
target = str(compile_cfg["target"]).lower()
|
|
compile_kwargs = {
|
|
"mode": compile_cfg["mode"],
|
|
"fullgraph": compile_cfg["fullgraph"],
|
|
"dynamic": compile_cfg["dynamic"],
|
|
}
|
|
|
|
if target == "predictor":
|
|
if not hasattr(model, "predictor"):
|
|
print("Requested compile target 'predictor' is unavailable on the model.")
|
|
return model, compile_cfg, compile_target
|
|
model.predictor = torch.compile(model.predictor, **compile_kwargs)
|
|
compile_target = "predictor"
|
|
elif target == "predict":
|
|
if not hasattr(model, "predict"):
|
|
print("Requested compile target 'predict' is unavailable on the model.")
|
|
return model, compile_cfg, compile_target
|
|
model.predict = torch.compile(model.predict, **compile_kwargs)
|
|
compile_target = "predict"
|
|
else:
|
|
print(
|
|
f"Unsupported compile.target={target}. Expected one of: predictor, predict."
|
|
)
|
|
|
|
return model, compile_cfg, compile_target
|
|
|
|
|
|
def get_inference_context(cfg, device):
|
|
precision = str(cfg.get("inference_precision", "fp32")).lower()
|
|
device_type = "cuda" if device.startswith("cuda") else "cpu"
|
|
|
|
if precision == "fp32":
|
|
return nullcontext(), "fp32"
|
|
|
|
if precision in {"bf16", "bfloat16"}:
|
|
return (
|
|
torch.autocast(device_type=device_type, dtype=torch.bfloat16),
|
|
"bf16",
|
|
)
|
|
|
|
if precision in {"fp16", "float16"}:
|
|
if device_type != "cuda":
|
|
print("fp16 inference is only supported on CUDA, falling back to fp32.")
|
|
return nullcontext(), "fp32"
|
|
return (
|
|
torch.autocast(device_type=device_type, dtype=torch.float16),
|
|
"fp16",
|
|
)
|
|
|
|
raise ValueError(
|
|
f"Unsupported inference_precision={precision}. Expected one of: fp32, bf16, fp16."
|
|
)
|
|
|
|
|
|
def get_eval_grad_context(solver=None):
|
|
if isinstance(solver, swm.solver.GradientSolver):
|
|
return torch.enable_grad()
|
|
return torch.inference_mode()
|
|
|
|
|
|
def make_profiler(cfg, results_path):
|
|
profile_cfg = get_profile_cfg(cfg)
|
|
if not profile_cfg["enabled"]:
|
|
return nullcontext(), None, profile_cfg
|
|
|
|
activities = [torch.profiler.ProfilerActivity.CPU]
|
|
if torch.cuda.is_available():
|
|
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
|
|
|
profile_dir = results_path / profile_cfg["trace_dirname"]
|
|
profile_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
profiler = torch.profiler.profile(
|
|
activities=activities,
|
|
record_shapes=profile_cfg["record_shapes"],
|
|
profile_memory=profile_cfg["profile_memory"],
|
|
with_stack=profile_cfg["with_stack"],
|
|
with_flops=profile_cfg["with_flops"],
|
|
)
|
|
return profiler, profile_dir, profile_cfg
|
|
|
|
|
|
def dump_profiler_results(profiler, profile_dir, profile_cfg):
|
|
if profiler is None or profile_dir is None:
|
|
return None
|
|
|
|
has_cuda = torch.cuda.is_available()
|
|
table = profiler.key_averages().table(
|
|
sort_by="self_cuda_time_total" if has_cuda else "self_cpu_time_total",
|
|
row_limit=profile_cfg["row_limit"],
|
|
)
|
|
|
|
summary_path = profile_dir / "key_averages.txt"
|
|
summary_path.write_text(table)
|
|
|
|
if profile_cfg["export_tensorboard"]:
|
|
trace_handler = torch.profiler.tensorboard_trace_handler(
|
|
str(profile_dir), worker_name=profile_cfg["worker_name"]
|
|
)
|
|
trace_handler(profiler)
|
|
elif profile_cfg["export_chrome_trace"]:
|
|
profiler.export_chrome_trace(str(profile_dir / "trace.json"))
|
|
|
|
return summary_path
|
|
|
|
|
|
def get_multi_gpu_cfg(cfg):
|
|
multi_gpu_cfg = {
|
|
"enabled": False,
|
|
"devices": None,
|
|
"start_method": "spawn",
|
|
}
|
|
cfg_multi_gpu = cfg.get("multi_gpu")
|
|
if cfg_multi_gpu is not None:
|
|
multi_gpu_cfg.update(OmegaConf.to_container(cfg_multi_gpu, resolve=True))
|
|
return multi_gpu_cfg
|
|
|
|
|
|
def get_multi_node_cfg(cfg):
|
|
multi_node_cfg = {
|
|
"enabled": False,
|
|
"backend": "gloo",
|
|
"rank_env": "RANK",
|
|
"world_size_env": "WORLD_SIZE",
|
|
"local_rank_env": "LOCAL_RANK",
|
|
"output_mode": "single",
|
|
"aggregate_results": True,
|
|
"sync_before_return": False,
|
|
"destroy_process_group": True,
|
|
"shard_strategy": "round_robin",
|
|
}
|
|
cfg_multi_node = cfg.get("multi_node")
|
|
if cfg_multi_node is not None:
|
|
multi_node_cfg.update(OmegaConf.to_container(cfg_multi_node, resolve=True))
|
|
return multi_node_cfg
|
|
|
|
|
|
def get_dist_env(name, default=None):
|
|
value = os.environ.get(name, default)
|
|
if value is None:
|
|
return None
|
|
return int(value)
|
|
|
|
|
|
def get_rank_context(cfg):
|
|
multi_node_cfg = get_multi_node_cfg(cfg)
|
|
if not multi_node_cfg["enabled"]:
|
|
return 0, 1, 0
|
|
|
|
rank = get_dist_env(multi_node_cfg["rank_env"])
|
|
world_size = get_dist_env(multi_node_cfg["world_size_env"])
|
|
local_rank = get_dist_env(multi_node_cfg["local_rank_env"], 0)
|
|
|
|
if rank is None or world_size is None:
|
|
raise ValueError(
|
|
"multi_node.enabled=true requires torchrun env vars RANK and WORLD_SIZE"
|
|
)
|
|
if world_size < 1:
|
|
raise ValueError("WORLD_SIZE must be >= 1")
|
|
if rank < 0 or rank >= world_size:
|
|
raise ValueError("RANK must be in [0, WORLD_SIZE)")
|
|
return rank, world_size, local_rank
|
|
|
|
|
|
def all_gather_eval_result(result):
|
|
world_size = torch.distributed.get_world_size()
|
|
payload = [None for _ in range(world_size)]
|
|
torch.distributed.all_gather_object(payload, result)
|
|
return payload
|
|
|
|
|
|
def finalize_multi_node_process_group(cfg):
|
|
multi_node_cfg = get_multi_node_cfg(cfg)
|
|
if not multi_node_cfg["destroy_process_group"]:
|
|
return
|
|
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
|
torch.distributed.destroy_process_group()
|
|
|
|
|
|
def get_rank_result_path(output_dir: Path, cfg: DictConfig, rank: int) -> Path:
|
|
filename = str(cfg.output.filename)
|
|
if rank == 0:
|
|
return output_dir / filename
|
|
|
|
suffix = Path(filename).suffix
|
|
stem = Path(filename).stem
|
|
if suffix:
|
|
ranked_filename = f"{stem}.rank{rank}{suffix}"
|
|
else:
|
|
ranked_filename = f"{filename}.rank{rank}"
|
|
return output_dir / ranked_filename
|
|
|
|
|
|
def build_process(cfg, dataset):
|
|
process = {}
|
|
for col in cfg.dataset.keys_to_cache:
|
|
if col in ["pixels"]:
|
|
continue
|
|
processor = preprocessing.StandardScaler()
|
|
col_data = dataset.get_col_data(col)
|
|
col_data = col_data[~np.isnan(col_data).any(axis=1)]
|
|
processor.fit(col_data)
|
|
process[col] = processor
|
|
|
|
if col != "action":
|
|
process[f"goal_{col}"] = process[col]
|
|
return process
|
|
|
|
|
|
def sample_eval_cases(cfg, dataset):
|
|
stats_dataset = dataset
|
|
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
|
|
ep_indices, _ = np.unique(stats_dataset.get_col_data(col_name), return_index=True)
|
|
episode_len = get_episodes_length(dataset, ep_indices)
|
|
max_start_idx = episode_len - cfg.eval.goal_offset_steps - 1
|
|
max_start_idx_dict = {ep_id: max_start_idx[i] for i, ep_id in enumerate(ep_indices)}
|
|
max_start_per_row = np.array(
|
|
[max_start_idx_dict[ep_id] for ep_id in dataset.get_col_data(col_name)]
|
|
)
|
|
|
|
valid_mask = dataset.get_col_data("step_idx") <= max_start_per_row
|
|
valid_indices = np.nonzero(valid_mask)[0]
|
|
print(valid_mask.sum(), "valid starting points found for evaluation.")
|
|
|
|
g = np.random.default_rng(cfg.seed)
|
|
random_episode_indices = g.choice(
|
|
len(valid_indices) - 1, size=cfg.eval.num_eval, replace=False
|
|
)
|
|
random_episode_indices = np.sort(valid_indices[random_episode_indices])
|
|
print(random_episode_indices)
|
|
|
|
rows = dataset.get_row_data(random_episode_indices)
|
|
eval_episodes = rows[col_name]
|
|
eval_start_idx = rows["step_idx"]
|
|
|
|
if len(eval_episodes) < cfg.eval.num_eval:
|
|
raise ValueError("Not enough episodes with sufficient length for evaluation.")
|
|
|
|
return eval_episodes, eval_start_idx
|
|
|
|
|
|
def normalize_multi_gpu_devices(devices):
|
|
if devices is None:
|
|
return [f"cuda:{idx}" for idx in range(torch.cuda.device_count())]
|
|
|
|
normalized = []
|
|
for device in devices:
|
|
if isinstance(device, int):
|
|
normalized.append(f"cuda:{device}")
|
|
elif isinstance(device, str) and device.isdigit():
|
|
normalized.append(f"cuda:{int(device)}")
|
|
else:
|
|
normalized.append(str(device))
|
|
return normalized
|
|
|
|
|
|
def shard_eval_cases(eval_episodes, eval_start_idx, num_shards):
|
|
if num_shards < 1:
|
|
raise ValueError("num_shards must be >= 1")
|
|
|
|
total = len(eval_episodes)
|
|
shard_sizes = [total // num_shards] * num_shards
|
|
for idx in range(total % num_shards):
|
|
shard_sizes[idx] += 1
|
|
|
|
shards = []
|
|
start = 0
|
|
for size in shard_sizes:
|
|
end = start + size
|
|
if size > 0:
|
|
shards.append((eval_episodes[start:end], eval_start_idx[start:end]))
|
|
start = end
|
|
return shards
|
|
|
|
|
|
def get_rank_eval_subset(
|
|
eval_episodes,
|
|
eval_start_idx,
|
|
rank,
|
|
world_size,
|
|
*,
|
|
strategy="contiguous",
|
|
):
|
|
if world_size < 1:
|
|
raise ValueError("world_size must be >= 1")
|
|
if rank < 0 or rank >= world_size:
|
|
raise ValueError("rank must be in [0, world_size)")
|
|
|
|
if strategy == "round_robin":
|
|
episode_subset = eval_episodes[rank::world_size]
|
|
start_subset = eval_start_idx[rank::world_size]
|
|
return episode_subset, start_subset
|
|
if strategy != "contiguous":
|
|
raise ValueError("strategy must be one of: contiguous, round_robin")
|
|
|
|
total = len(eval_episodes)
|
|
shard_sizes = [total // world_size] * world_size
|
|
for idx in range(total % world_size):
|
|
shard_sizes[idx] += 1
|
|
|
|
start = sum(shard_sizes[:rank])
|
|
end = start + shard_sizes[rank]
|
|
return eval_episodes[start:end], eval_start_idx[start:end]
|
|
|
|
|
|
def run_eval_subset(
|
|
cfg: DictConfig,
|
|
eval_episodes,
|
|
eval_start_idx,
|
|
output_dir: Path,
|
|
*,
|
|
device_override: str | None = None,
|
|
enable_profile: bool = True,
|
|
before_evaluate=None,
|
|
):
|
|
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
|
local_cfg.eval.num_eval = len(eval_episodes)
|
|
local_cfg.world.num_envs = len(eval_episodes)
|
|
local_cfg.world.max_episode_steps = 2 * local_cfg.eval.eval_budget
|
|
|
|
if device_override is not None:
|
|
local_cfg.solver.device = device_override
|
|
if torch.cuda.is_available() and str(device_override).startswith("cuda"):
|
|
torch.cuda.set_device(torch.device(device_override))
|
|
|
|
if not enable_profile:
|
|
if local_cfg.get("profile") is None:
|
|
local_cfg.profile = OmegaConf.create({"enabled": False})
|
|
else:
|
|
local_cfg.profile.enabled = False
|
|
|
|
world = swm.World(**local_cfg.world, image_shape=(224, 224))
|
|
transform = {
|
|
"pixels": img_transform(local_cfg),
|
|
"goal": img_transform(local_cfg),
|
|
}
|
|
dataset = get_dataset(local_cfg, local_cfg.eval.dataset_name)
|
|
process = build_process(local_cfg, dataset)
|
|
|
|
policy_name = local_cfg.get("policy", "random")
|
|
if policy_name != "random":
|
|
model = swm.policy.AutoCostModel(local_cfg.policy)
|
|
device = device_override or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = model.to(device)
|
|
model = model.eval()
|
|
model.requires_grad_(False)
|
|
model, compile_cfg, compile_target = maybe_compile_inference_target(
|
|
model, local_cfg, device
|
|
)
|
|
inference_ctx, inference_precision = get_inference_context(local_cfg, device)
|
|
model.interpolate_pos_encoding = True
|
|
config = swm.PlanConfig(**local_cfg.plan_config)
|
|
solver = hydra.utils.instantiate(local_cfg.solver, model=model)
|
|
policy = swm.policy.WorldModelPolicy(
|
|
solver=solver, config=config, process=process, transform=transform
|
|
)
|
|
else:
|
|
policy = swm.policy.RandomPolicy()
|
|
solver = None
|
|
inference_ctx = nullcontext()
|
|
inference_precision = "fp32"
|
|
compile_cfg = get_compile_cfg(local_cfg)
|
|
compile_target = "disabled"
|
|
device = device_override or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
profiler_ctx, profile_dir, profile_cfg = make_profiler(local_cfg, output_dir)
|
|
world.set_policy(policy)
|
|
|
|
if str(device).startswith("cuda") and torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
|
|
if before_evaluate is not None:
|
|
before_evaluate()
|
|
if str(device).startswith("cuda") and torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
|
|
def evaluate_subset(episodes, start_indices, *, eval_cfg=local_cfg):
|
|
return world.evaluate_from_dataset(
|
|
dataset,
|
|
start_steps=list(start_indices),
|
|
goal_offset_steps=eval_cfg.eval.goal_offset_steps,
|
|
eval_budget=eval_cfg.eval.eval_budget,
|
|
episodes_idx=list(episodes),
|
|
callables=OmegaConf.to_container(
|
|
eval_cfg.eval.get("callables"), resolve=True
|
|
),
|
|
save_video=bool(eval_cfg.eval.get("save_video", False)),
|
|
video_path=output_dir,
|
|
)
|
|
|
|
start_time = time.time()
|
|
with get_eval_grad_context(solver):
|
|
with profiler_ctx as profiler:
|
|
with inference_ctx:
|
|
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
|
|
metrics = evaluate_subset(eval_episodes, eval_start_idx)
|
|
if str(device).startswith("cuda") and torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
evaluation_time = time.time() - start_time
|
|
profile_summary_path = dump_profiler_results(profiler, profile_dir, profile_cfg)
|
|
|
|
return {
|
|
"metrics": metrics,
|
|
"evaluation_time": evaluation_time,
|
|
"inference_precision": inference_precision,
|
|
"compile_target": compile_target,
|
|
"compile_mode": compile_cfg["mode"] if compile_target != "disabled" else None,
|
|
"profile_dir": profile_dir,
|
|
"profile_summary_path": profile_summary_path,
|
|
}
|
|
|
|
|
|
def maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx):
|
|
warmup_cfg = get_compile_warmup_cfg(cfg)
|
|
if not warmup_cfg["enabled"]:
|
|
return
|
|
|
|
if get_multi_gpu_cfg(cfg)["enabled"]:
|
|
print("Skipping compile warmup because multi_gpu.enabled=true uses spawned workers.")
|
|
return
|
|
|
|
if get_multi_node_cfg(cfg)["enabled"]:
|
|
rank, world_size, local_rank = get_rank_context(cfg)
|
|
eval_episodes, eval_start_idx = get_rank_eval_subset(
|
|
eval_episodes, eval_start_idx, rank, world_size
|
|
)
|
|
device_override = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
|
|
else:
|
|
device_override = None
|
|
|
|
warmup_count = min(int(warmup_cfg["num_eval"]), len(eval_episodes))
|
|
if warmup_count < 1:
|
|
return
|
|
|
|
warmup_eval_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
|
warmup_eval_cfg.eval.num_eval = warmup_count
|
|
warmup_eval_cfg.eval.save_video = False
|
|
if warmup_eval_cfg.get("profile") is None:
|
|
warmup_eval_cfg.profile = OmegaConf.create({"enabled": False})
|
|
else:
|
|
warmup_eval_cfg.profile.enabled = False
|
|
|
|
with tempfile.TemporaryDirectory(prefix="lewm_compile_warmup_") as tmpdir:
|
|
run_eval_subset(
|
|
warmup_eval_cfg,
|
|
eval_episodes[:warmup_count].tolist(),
|
|
eval_start_idx[:warmup_count].tolist(),
|
|
Path(tmpdir),
|
|
device_override=device_override,
|
|
enable_profile=False,
|
|
)
|
|
|
|
|
|
|
|
def _multi_gpu_eval_worker(
|
|
cfg_container,
|
|
eval_episodes,
|
|
eval_start_idx,
|
|
output_dir,
|
|
device,
|
|
shard_idx,
|
|
queue,
|
|
):
|
|
try:
|
|
cfg = OmegaConf.create(cfg_container)
|
|
result = run_eval_subset(
|
|
cfg,
|
|
eval_episodes,
|
|
eval_start_idx,
|
|
Path(output_dir),
|
|
device_override=device,
|
|
enable_profile=False,
|
|
)
|
|
queue.put({"ok": True, "shard_idx": shard_idx, "result": result})
|
|
except Exception:
|
|
queue.put(
|
|
{
|
|
"ok": False,
|
|
"shard_idx": shard_idx,
|
|
"error": traceback.format_exc(),
|
|
}
|
|
)
|
|
|
|
|
|
def run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
|
|
multi_gpu_cfg = get_multi_gpu_cfg(cfg)
|
|
devices = normalize_multi_gpu_devices(multi_gpu_cfg["devices"])
|
|
if len(devices) < 2:
|
|
raise ValueError("multi_gpu.enabled=true requires at least 2 CUDA devices")
|
|
|
|
shards = shard_eval_cases(eval_episodes, eval_start_idx, min(len(devices), len(eval_episodes)))
|
|
devices = devices[: len(shards)]
|
|
|
|
ctx = mp.get_context(multi_gpu_cfg["start_method"])
|
|
queue = ctx.Queue()
|
|
cfg_container = OmegaConf.to_container(cfg, resolve=False)
|
|
processes = []
|
|
|
|
start_time = time.time()
|
|
for shard_idx, ((shard_episodes, shard_start_idx), device) in enumerate(
|
|
zip(shards, devices, strict=True)
|
|
):
|
|
process = ctx.Process(
|
|
target=_multi_gpu_eval_worker,
|
|
args=(
|
|
cfg_container,
|
|
list(shard_episodes),
|
|
list(shard_start_idx),
|
|
str(output_dir),
|
|
device,
|
|
shard_idx,
|
|
queue,
|
|
),
|
|
)
|
|
process.start()
|
|
processes.append(process)
|
|
|
|
shard_results = {}
|
|
errors = []
|
|
for _ in processes:
|
|
message = queue.get()
|
|
if message["ok"]:
|
|
shard_results[message["shard_idx"]] = message["result"]
|
|
else:
|
|
errors.append(message["error"])
|
|
|
|
for process in processes:
|
|
process.join()
|
|
|
|
if errors:
|
|
raise RuntimeError(errors[0])
|
|
|
|
ordered_results = [shard_results[idx] for idx in range(len(processes))]
|
|
episode_successes = np.concatenate(
|
|
[
|
|
np.asarray(result["metrics"]["episode_successes"], dtype=np.bool_)
|
|
for result in ordered_results
|
|
]
|
|
)
|
|
|
|
seeds = None
|
|
shard_seeds = [result["metrics"].get("seeds") for result in ordered_results]
|
|
if all(seed is not None for seed in shard_seeds):
|
|
seeds = np.concatenate(shard_seeds)
|
|
|
|
metrics = {
|
|
"success_rate": float(np.sum(episode_successes)) / len(episode_successes) * 100.0,
|
|
"episode_successes": episode_successes,
|
|
"seeds": seeds,
|
|
}
|
|
reference = ordered_results[0]
|
|
return {
|
|
"metrics": metrics,
|
|
"evaluation_time": time.time() - start_time,
|
|
"inference_precision": reference["inference_precision"],
|
|
"compile_target": reference["compile_target"],
|
|
"compile_mode": reference["compile_mode"],
|
|
"profile_dir": None,
|
|
"profile_summary_path": None,
|
|
}
|
|
|
|
|
|
def combine_eval_results(ordered_results):
|
|
episode_successes = np.concatenate(
|
|
[
|
|
np.asarray(result["metrics"]["episode_successes"], dtype=np.bool_)
|
|
for result in ordered_results
|
|
]
|
|
)
|
|
|
|
seeds = None
|
|
shard_seeds = [result["metrics"].get("seeds") for result in ordered_results]
|
|
if all(seed is not None for seed in shard_seeds):
|
|
seeds = np.concatenate(shard_seeds)
|
|
|
|
metrics = {
|
|
"success_rate": float(np.sum(episode_successes)) / len(episode_successes) * 100.0,
|
|
"episode_successes": episode_successes,
|
|
"seeds": seeds,
|
|
}
|
|
reference = ordered_results[0]
|
|
return metrics, reference
|
|
|
|
|
|
def run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir: Path):
|
|
rank, world_size, local_rank = get_rank_context(cfg)
|
|
multi_node_cfg = get_multi_node_cfg(cfg)
|
|
shard_episodes, shard_start_idx = get_rank_eval_subset(
|
|
eval_episodes,
|
|
eval_start_idx,
|
|
rank,
|
|
world_size,
|
|
strategy=multi_node_cfg["shard_strategy"],
|
|
)
|
|
if len(shard_episodes) == 0:
|
|
raise ValueError("No evaluation episodes assigned to this rank")
|
|
|
|
local_cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=False))
|
|
local_cfg.multi_node.enabled = False
|
|
if local_cfg.get("multi_gpu") is None:
|
|
local_cfg.multi_gpu = OmegaConf.create({"enabled": False})
|
|
else:
|
|
local_cfg.multi_gpu.enabled = False
|
|
|
|
device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
|
|
preload_cfg = get_preload_wait_cfg(cfg)
|
|
if preload_cfg["enabled"]:
|
|
if not torch.distributed.is_available():
|
|
raise RuntimeError("torch.distributed is required for preload_wait")
|
|
if not torch.distributed.is_initialized():
|
|
torch.distributed.init_process_group(backend=multi_node_cfg["backend"])
|
|
|
|
rank_output_path = get_rank_result_path(output_dir, cfg, rank)
|
|
result = run_eval_subset(
|
|
local_cfg,
|
|
list(shard_episodes),
|
|
list(shard_start_idx),
|
|
rank_output_path.parent,
|
|
device_override=device,
|
|
enable_profile=False,
|
|
before_evaluate=lambda: wait_for_preload_signal(cfg, rank=rank),
|
|
)
|
|
if not multi_node_cfg["aggregate_results"]:
|
|
result["output_filename"] = rank_output_path.name
|
|
finalize_multi_node_process_group(cfg)
|
|
return result
|
|
|
|
if not torch.distributed.is_available():
|
|
raise RuntimeError("torch.distributed is required for multi-node evaluation")
|
|
if not torch.distributed.is_initialized():
|
|
torch.distributed.init_process_group(backend=multi_node_cfg["backend"])
|
|
|
|
gathered = all_gather_eval_result(result)
|
|
metrics, reference = combine_eval_results(gathered)
|
|
combined = {
|
|
"metrics": metrics,
|
|
"evaluation_time": max(item["evaluation_time"] for item in gathered),
|
|
"inference_precision": reference["inference_precision"],
|
|
"compile_target": reference["compile_target"],
|
|
"compile_mode": reference["compile_mode"],
|
|
"profile_dir": None,
|
|
"profile_summary_path": None,
|
|
"output_filename": cfg.output.filename,
|
|
}
|
|
if multi_node_cfg["sync_before_return"]:
|
|
torch.distributed.barrier()
|
|
finalize_multi_node_process_group(cfg)
|
|
if rank != 0:
|
|
return None
|
|
return combined
|
|
|
|
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
|
|
def run(cfg: DictConfig):
|
|
"""Run evaluation of dinowm vs random policy."""
|
|
assert (
|
|
cfg.plan_config.horizon * cfg.plan_config.action_block <= cfg.eval.eval_budget
|
|
), "Planning horizon must be smaller than or equal to eval_budget"
|
|
|
|
dataset = get_dataset(cfg, cfg.eval.dataset_name)
|
|
eval_episodes, eval_start_idx = sample_eval_cases(cfg, dataset)
|
|
output_dir = Path.cwd().resolve()
|
|
profile_cfg = get_profile_cfg(cfg)
|
|
|
|
maybe_run_compile_warmup(cfg, eval_episodes, eval_start_idx)
|
|
eval_wall_start = time.time()
|
|
|
|
if get_multi_node_cfg(cfg)["enabled"] and get_multi_gpu_cfg(cfg)["enabled"]:
|
|
raise ValueError("multi_node.enabled and multi_gpu.enabled are mutually exclusive")
|
|
|
|
if get_multi_node_cfg(cfg)["enabled"]:
|
|
eval_result = run_multi_node_eval(cfg, eval_episodes, eval_start_idx, output_dir)
|
|
if eval_result is None:
|
|
return
|
|
elif get_multi_gpu_cfg(cfg)["enabled"]:
|
|
if profile_cfg["enabled"]:
|
|
raise ValueError("Profiling is not supported together with multi_gpu.enabled=true")
|
|
eval_result = run_multi_gpu_eval(cfg, eval_episodes, eval_start_idx, output_dir)
|
|
else:
|
|
eval_result = run_eval_subset(
|
|
cfg,
|
|
eval_episodes.tolist(),
|
|
eval_start_idx.tolist(),
|
|
output_dir,
|
|
)
|
|
|
|
metrics = eval_result["metrics"]
|
|
evaluation_time = eval_result["evaluation_time"]
|
|
inference_precision = eval_result["inference_precision"]
|
|
compile_target = eval_result["compile_target"]
|
|
compile_mode = eval_result["compile_mode"]
|
|
profile_dir = eval_result["profile_dir"]
|
|
profile_summary_path = eval_result["profile_summary_path"]
|
|
output_filename = eval_result.get("output_filename", cfg.output.filename)
|
|
|
|
print(metrics)
|
|
|
|
results_path = output_dir / output_filename
|
|
results_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
with results_path.open("a") as f:
|
|
f.write("\n") # separate from previous runs
|
|
|
|
f.write("==== CONFIG ====\n")
|
|
f.write(OmegaConf.to_yaml(cfg))
|
|
f.write("\n")
|
|
|
|
f.write("==== RESULTS ====\n")
|
|
f.write(f"metrics: {metrics}\n")
|
|
f.write(f"evaluation_time: {evaluation_time} seconds\n")
|
|
f.write(f"inference_precision: {inference_precision}\n")
|
|
f.write(f"inference_compile_target: {compile_target}\n")
|
|
if compile_target != "disabled":
|
|
f.write(f"inference_compile_mode: {compile_mode}\n")
|
|
if profile_cfg["enabled"]:
|
|
f.write(f"profile_dir: {profile_dir}\n")
|
|
if profile_summary_path is not None:
|
|
f.write(f"profile_summary: {profile_summary_path}\n")
|
|
|
|
f.write(f"total_wall_time: {time.time() - eval_wall_start} seconds\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run()
|