284 lines
9.8 KiB
Python
284 lines
9.8 KiB
Python
import os
|
||
|
||
os.environ["MUJOCO_GL"] = "egl"
|
||
|
||
import time
|
||
from contextlib import nullcontext
|
||
from pathlib import Path
|
||
|
||
import hydra
|
||
import numpy as np
|
||
import stable_pretraining as spt
|
||
import torch
|
||
from omegaconf import DictConfig, OmegaConf
|
||
from sklearn import preprocessing
|
||
from torchvision.transforms import v2 as transforms
|
||
import stable_worldmodel as swm
|
||
|
||
def img_transform(cfg):
|
||
transform = transforms.Compose(
|
||
[
|
||
transforms.ToImage(),
|
||
transforms.ToDtype(torch.float32, scale=True),
|
||
transforms.Normalize(**spt.data.dataset_stats.ImageNet),
|
||
transforms.Resize(size=cfg.eval.img_size),
|
||
]
|
||
)
|
||
return transform
|
||
|
||
|
||
def get_episodes_length(dataset, episodes):
|
||
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
|
||
|
||
episode_idx = dataset.get_col_data(col_name)
|
||
step_idx = dataset.get_col_data("step_idx")
|
||
lengths = []
|
||
for ep_id in episodes:
|
||
lengths.append(np.max(step_idx[episode_idx == ep_id]) + 1)
|
||
return np.array(lengths)
|
||
|
||
|
||
def get_dataset(cfg, dataset_name):
|
||
dataset_path = Path(cfg.cache_dir or swm.data.utils.get_cache_dir())
|
||
dataset = swm.data.HDF5Dataset(
|
||
dataset_name,
|
||
keys_to_cache=cfg.dataset.keys_to_cache,
|
||
cache_dir=dataset_path,
|
||
)
|
||
return dataset
|
||
|
||
|
||
def get_profile_cfg(cfg):
|
||
profile_cfg = {
|
||
"enabled": False,
|
||
"trace_dirname": "torch_profile",
|
||
"record_shapes": True,
|
||
"profile_memory": True,
|
||
"with_stack": False,
|
||
"with_flops": False,
|
||
"row_limit": 40,
|
||
"worker_name": "eval",
|
||
"export_chrome_trace": True,
|
||
"export_tensorboard": True,
|
||
}
|
||
cfg_profile = cfg.get("profile")
|
||
if cfg_profile is not None:
|
||
profile_cfg.update(OmegaConf.to_container(cfg_profile, resolve=True))
|
||
return profile_cfg
|
||
|
||
|
||
def get_inference_context(cfg, device):
|
||
precision = str(cfg.get("inference_precision", "fp32")).lower()
|
||
device_type = "cuda" if device.startswith("cuda") else "cpu"
|
||
|
||
if precision == "fp32":
|
||
return nullcontext(), "fp32"
|
||
|
||
if precision in {"bf16", "bfloat16"}:
|
||
return (
|
||
torch.autocast(device_type=device_type, dtype=torch.bfloat16),
|
||
"bf16",
|
||
)
|
||
|
||
if precision in {"fp16", "float16"}:
|
||
if device_type != "cuda":
|
||
print("fp16 inference is only supported on CUDA, falling back to fp32.")
|
||
return nullcontext(), "fp32"
|
||
return (
|
||
torch.autocast(device_type=device_type, dtype=torch.float16),
|
||
"fp16",
|
||
)
|
||
|
||
raise ValueError(
|
||
f"Unsupported inference_precision={precision}. Expected one of: fp32, bf16, fp16."
|
||
)
|
||
|
||
|
||
def make_profiler(cfg, results_path):
|
||
profile_cfg = get_profile_cfg(cfg)
|
||
if not profile_cfg["enabled"]:
|
||
return nullcontext(), None, profile_cfg
|
||
|
||
activities = [torch.profiler.ProfilerActivity.CPU]
|
||
if torch.cuda.is_available():
|
||
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||
|
||
profile_dir = results_path / profile_cfg["trace_dirname"]
|
||
profile_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
profiler = torch.profiler.profile(
|
||
activities=activities,
|
||
record_shapes=profile_cfg["record_shapes"],
|
||
profile_memory=profile_cfg["profile_memory"],
|
||
with_stack=profile_cfg["with_stack"],
|
||
with_flops=profile_cfg["with_flops"],
|
||
)
|
||
return profiler, profile_dir, profile_cfg
|
||
|
||
|
||
def dump_profiler_results(profiler, profile_dir, profile_cfg):
|
||
if profiler is None or profile_dir is None:
|
||
return None
|
||
|
||
has_cuda = torch.cuda.is_available()
|
||
table = profiler.key_averages().table(
|
||
sort_by="self_cuda_time_total" if has_cuda else "self_cpu_time_total",
|
||
row_limit=profile_cfg["row_limit"],
|
||
)
|
||
|
||
summary_path = profile_dir / "key_averages.txt"
|
||
summary_path.write_text(table)
|
||
|
||
if profile_cfg["export_tensorboard"]:
|
||
trace_handler = torch.profiler.tensorboard_trace_handler(
|
||
str(profile_dir), worker_name=profile_cfg["worker_name"]
|
||
)
|
||
trace_handler(profiler)
|
||
elif profile_cfg["export_chrome_trace"]:
|
||
profiler.export_chrome_trace(str(profile_dir / "trace.json"))
|
||
|
||
return summary_path
|
||
|
||
@hydra.main(version_base=None, config_path="./config/eval", config_name="pusht")
|
||
def run(cfg: DictConfig):
|
||
"""Run evaluation of dinowm vs random policy."""
|
||
assert (
|
||
cfg.plan_config.horizon * cfg.plan_config.action_block <= cfg.eval.eval_budget
|
||
), "Planning horizon must be smaller than or equal to eval_budget"
|
||
|
||
# create world environment
|
||
cfg.world.max_episode_steps = 2 * cfg.eval.eval_budget
|
||
world = swm.World(**cfg.world, image_shape=(224, 224))
|
||
|
||
# create the transform
|
||
transform = {
|
||
"pixels": img_transform(cfg),
|
||
"goal": img_transform(cfg),
|
||
}
|
||
|
||
dataset = get_dataset(cfg, cfg.eval.dataset_name)
|
||
stats_dataset = dataset # get_dataset(cfg, cfg.dataset.stats)
|
||
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
|
||
ep_indices, _ = np.unique(stats_dataset.get_col_data(col_name), return_index=True)
|
||
|
||
process = {}
|
||
for col in cfg.dataset.keys_to_cache:
|
||
if col in ["pixels"]:
|
||
continue
|
||
processor = preprocessing.StandardScaler()
|
||
col_data = stats_dataset.get_col_data(col)
|
||
col_data = col_data[~np.isnan(col_data).any(axis=1)]
|
||
processor.fit(col_data)
|
||
process[col] = processor
|
||
|
||
if col != "action":
|
||
process[f"goal_{col}"] = process[col]
|
||
|
||
# -- run evaluation
|
||
policy = cfg.get("policy", "random")
|
||
if policy != "random":
|
||
model = swm.policy.AutoCostModel(cfg.policy)
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
model = model.to(device)
|
||
model = model.eval()
|
||
model.requires_grad_(False)
|
||
print(f"model parameter dtype: {next(model.parameters()).dtype}")
|
||
inference_ctx, inference_precision = get_inference_context(cfg, device)
|
||
print(f"inference execution precision: {inference_precision}")
|
||
model.interpolate_pos_encoding = True
|
||
config = swm.PlanConfig(**cfg.plan_config)
|
||
solver = hydra.utils.instantiate(cfg.solver, model=model)
|
||
policy = swm.policy.WorldModelPolicy(
|
||
solver=solver, config=config, process=process, transform=transform
|
||
)
|
||
|
||
else:
|
||
policy = swm.policy.RandomPolicy()
|
||
inference_ctx = nullcontext()
|
||
inference_precision = "fp32"
|
||
|
||
# Hydra switches the working directory to the per-run outputs folder.
|
||
# Keep all generated artifacts with that run instead of scattering them
|
||
# next to the cache or source tree.
|
||
output_dir = Path.cwd().resolve()
|
||
profiler_ctx, profile_dir, profile_cfg = make_profiler(cfg, output_dir)
|
||
|
||
# sample the episodes and the starting indices
|
||
episode_len = get_episodes_length(dataset, ep_indices)
|
||
max_start_idx = episode_len - cfg.eval.goal_offset_steps - 1
|
||
max_start_idx_dict = {ep_id: max_start_idx[i] for i, ep_id in enumerate(ep_indices)}
|
||
# Map each dataset row’s episode_idx to its max_start_idx
|
||
col_name = "episode_idx" if "episode_idx" in dataset.column_names else "ep_idx"
|
||
max_start_per_row = np.array(
|
||
[max_start_idx_dict[ep_id] for ep_id in dataset.get_col_data(col_name)]
|
||
)
|
||
|
||
# remove all the lines of dataset for which dataset['step_idx'] > max_start_per_row
|
||
valid_mask = dataset.get_col_data("step_idx") <= max_start_per_row
|
||
valid_indices = np.nonzero(valid_mask)[0]
|
||
print(valid_mask.sum(), "valid starting points found for evaluation.")
|
||
|
||
g = np.random.default_rng(cfg.seed)
|
||
random_episode_indices = g.choice(
|
||
len(valid_indices) - 1, size=cfg.eval.num_eval, replace=False
|
||
)
|
||
|
||
# sort increasingly to avoid issues with HDF5Dataset indexing
|
||
random_episode_indices = np.sort(valid_indices[random_episode_indices])
|
||
|
||
print(random_episode_indices)
|
||
|
||
eval_episodes = dataset.get_row_data(random_episode_indices)[col_name]
|
||
eval_start_idx = dataset.get_row_data(random_episode_indices)["step_idx"]
|
||
|
||
if len(eval_episodes) < cfg.eval.num_eval:
|
||
raise ValueError("Not enough episodes with sufficient length for evaluation.")
|
||
|
||
world.set_policy(policy)
|
||
|
||
if torch.cuda.is_available():
|
||
torch.cuda.synchronize()
|
||
start_time = time.time()
|
||
with torch.inference_mode():
|
||
with profiler_ctx as profiler:
|
||
with inference_ctx:
|
||
with torch.profiler.record_function("eval.world_evaluate_from_dataset"):
|
||
metrics = world.evaluate_from_dataset(
|
||
dataset,
|
||
start_steps=eval_start_idx.tolist(),
|
||
goal_offset_steps=cfg.eval.goal_offset_steps,
|
||
eval_budget=cfg.eval.eval_budget,
|
||
episodes_idx=eval_episodes.tolist(),
|
||
callables=OmegaConf.to_container(cfg.eval.get("callables"), resolve=True),
|
||
video_path=output_dir,
|
||
)
|
||
if torch.cuda.is_available():
|
||
torch.cuda.synchronize()
|
||
end_time = time.time()
|
||
profile_summary_path = dump_profiler_results(profiler, profile_dir, profile_cfg)
|
||
|
||
print(metrics)
|
||
|
||
results_path = output_dir / cfg.output.filename
|
||
results_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
with results_path.open("a") as f:
|
||
f.write("\n") # separate from previous runs
|
||
|
||
f.write("==== CONFIG ====\n")
|
||
f.write(OmegaConf.to_yaml(cfg))
|
||
f.write("\n")
|
||
|
||
f.write("==== RESULTS ====\n")
|
||
f.write(f"metrics: {metrics}\n")
|
||
f.write(f"evaluation_time: {end_time - start_time} seconds\n")
|
||
f.write(f"inference_precision: {inference_precision}\n")
|
||
if profile_cfg["enabled"]:
|
||
f.write(f"profile_dir: {profile_dir}\n")
|
||
if profile_summary_path is not None:
|
||
f.write(f"profile_summary: {profile_summary_path}\n")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
run()
|