diff --git a/scripts/evaluation/world_model_interaction.py b/scripts/evaluation/world_model_interaction.py index ac5ebde..92025b8 100644 --- a/scripts/evaluation/world_model_interaction.py +++ b/scripts/evaluation/world_model_interaction.py @@ -13,7 +13,7 @@ import time import json from contextlib import contextmanager, nullcontext from dataclasses import dataclass, field, asdict -from typing import Optional, Dict, List, Any +from typing import Optional, Dict, List, Any, Mapping from pytorch_lightning import seed_everything from omegaconf import OmegaConf @@ -673,8 +673,8 @@ def get_latent_z(model, videos: Tensor) -> Tensor: return z -def preprocess_observation( - model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]: +def preprocess_observation( + model, observations: dict[str, np.ndarray]) -> dict[str, Tensor]: """Convert environment observation to LeRobot format observation. Args: observation: Dictionary of observation batches from a Gym vector environment. @@ -715,7 +715,18 @@ def preprocess_observation( return_observations['observation.state'].to(model.device) })['observation.state'] - return return_observations + return return_observations + + +def _move_to_device(batch: Mapping[str, Any], + device: torch.device) -> dict[str, Any]: + moved = {} + for key, value in batch.items(): + if isinstance(value, torch.Tensor) and value.device != device: + moved[key] = value.to(device, non_blocking=True) + else: + moved[key] = value + return moved def image_guided_synthesis_sim_mode( @@ -768,8 +779,11 @@ def image_guided_synthesis_sim_mode( profiler = get_profiler() b, _, t, _, _ = noise_shape - ddim_sampler = DDIMSampler(model) - batch_size = noise_shape[0] + ddim_sampler = getattr(model, "_ddim_sampler", None) + if ddim_sampler is None: + ddim_sampler = DDIMSampler(model) + model._ddim_sampler = ddim_sampler + batch_size = noise_shape[0] fs = torch.tensor([fs] * batch_size, dtype=torch.long, device=model.device) @@ -900,7 +914,7 @@ def image_guided_synthesis_sim_mode( return batch_variants, actions, states -def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: +def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: """ Run inference pipeline on prompts and image inputs. @@ -912,7 +926,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: Returns: None """ - profiler = get_profiler() + profiler = get_profiler() # Create inference and tensorboard dirs os.makedirs(args.savedir + '/inference', exist_ok=True) @@ -1077,10 +1091,7 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: 'action': torch.zeros_like(batch['action'][-1]).unsqueeze(0) } - observation = { - key: observation[key].to(device, non_blocking=True) - for key in observation - } + observation = _move_to_device(observation, device) # Update observation queues cond_obs_queues = populate_queues(cond_obs_queues, observation) @@ -1093,7 +1104,9 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: # Multi-round interaction with the world-model with pytorch_prof_ctx: - for itr in tqdm(range(args.n_iter)): + for itr in tqdm(range(args.n_iter)): + log_every = max(1, args.step_log_every) + log_step = (itr % log_every == 0) profiler.current_iteration = itr profiler.record_memory(f"iter_{itr}_start") @@ -1111,13 +1124,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: 'action': torch.stack(list(cond_obs_queues['action']), dim=1), } - observation = { - key: observation[key].to(device, non_blocking=True) - for key in observation - } + observation = _move_to_device(observation, device) # Use world-model in policy to generate action - print(f'>>> Step {itr}: generating actions ...') + if log_step: + print(f'>>> Step {itr}: generating actions ...') with profiler.profile_section("action_generation"): pred_videos_0, pred_actions, _ = image_guided_synthesis_sim_mode( model, @@ -1156,13 +1167,11 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None: 'action': torch.stack(list(cond_obs_queues['action']), dim=1), } - observation = { - key: observation[key].to(device, non_blocking=True) - for key in observation - } + observation = _move_to_device(observation, device) # Interaction with the world-model - print(f'>>> Step {itr}: interacting with world model ...') + if log_step: + print(f'>>> Step {itr}: interacting with world model ...') with profiler.profile_section("world_model_interaction"): pred_videos_1, _, pred_states = image_guided_synthesis_sim_mode( model, @@ -1364,6 +1373,12 @@ def get_parser(): default="fp32", help="Dtype for VAE/first_stage_model weights and forward autocast." ) + parser.add_argument( + "--step_log_every", + type=int, + default=1, + help="Print per-iteration step logs every N iterations." + ) parser.add_argument( "--n_action_steps", type=int, diff --git a/src/unifolm_wma/models/samplers/ddim.py b/src/unifolm_wma/models/samplers/ddim.py index 77602a1..cd761c4 100644 --- a/src/unifolm_wma/models/samplers/ddim.py +++ b/src/unifolm_wma/models/samplers/ddim.py @@ -28,6 +28,11 @@ class DDIMSampler(object): ddim_discretize="uniform", ddim_eta=0., verbose=True): + device = self.model.betas.device + cache_key = (ddim_num_steps, ddim_discretize, float(ddim_eta), + str(device)) + if getattr(self, "_schedule_cache", None) == cache_key: + return self.ddim_timesteps = make_ddim_timesteps( ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, @@ -67,16 +72,26 @@ class DDIMSampler(object): ddim_timesteps=self.ddim_timesteps, eta=ddim_eta, verbose=verbose) + ddim_sigmas = torch.as_tensor(ddim_sigmas, + device=self.model.device, + dtype=torch.float32) + ddim_alphas = torch.as_tensor(ddim_alphas, + device=self.model.device, + dtype=torch.float32) + ddim_alphas_prev = torch.as_tensor(ddim_alphas_prev, + device=self.model.device, + dtype=torch.float32) self.register_buffer('ddim_sigmas', ddim_sigmas) self.register_buffer('ddim_alphas', ddim_alphas) self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) self.register_buffer('ddim_sqrt_one_minus_alphas', - np.sqrt(1. - ddim_alphas)) + torch.sqrt(1. - ddim_alphas)) sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod / self.alphas_cumprod_prev)) self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + self._schedule_cache = cache_key @torch.no_grad() def sample( @@ -228,10 +243,14 @@ class DDIMSampler(object): 'x_inter_state': [state], 'pred_x0_state': [state], } - time_range = reversed(range( - 0, timesteps)) if ddim_use_original_steps else np.flip(timesteps) - total_steps = timesteps if ddim_use_original_steps else timesteps.shape[ - 0] + if ddim_use_original_steps: + time_range = np.arange(timesteps - 1, -1, -1) + else: + time_range = np.flip(timesteps) + time_range = np.ascontiguousarray(time_range) + total_steps = int(time_range.shape[0]) + t_seq = torch.as_tensor(time_range, device=device, dtype=torch.long) + ts_batch = t_seq.unsqueeze(1).expand(total_steps, b).contiguous() if verbose: iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) else: @@ -243,7 +262,7 @@ class DDIMSampler(object): dp_ddim_scheduler_state.set_timesteps(len(timesteps)) for i, step in enumerate(iterator): index = total_steps - i - 1 - ts = torch.full((b, ), step, device=device, dtype=torch.long) + ts = ts_batch[i] # Use mask to blend noised original latent (img_orig) & new sampled latent (img) if mask is not None: @@ -378,16 +397,14 @@ class DDIMSampler(object): sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas if is_video: - size = (b, 1, 1, 1, 1) + size = (1, 1, 1, 1, 1) else: - size = (b, 1, 1, 1) + size = (1, 1, 1, 1) - a_t = torch.full(size, alphas[index], device=device) - a_prev = torch.full(size, alphas_prev[index], device=device) - sigma_t = torch.full(size, sigmas[index], device=device) - sqrt_one_minus_at = torch.full(size, - sqrt_one_minus_alphas[index], - device=device) + a_t = alphas[index].view(size) + a_prev = alphas_prev[index].view(size) + sigma_t = sigmas[index].view(size) + sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size) if self.model.parameterization != "v": pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() @@ -395,12 +412,8 @@ class DDIMSampler(object): pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output) if self.model.use_dynamic_rescale: - scale_t = torch.full(size, - self.ddim_scale_arr[index], - device=device) - prev_scale_t = torch.full(size, - self.ddim_scale_arr_prev[index], - device=device) + scale_t = self.ddim_scale_arr[index].view(size) + prev_scale_t = self.ddim_scale_arr_prev[index].view(size) rescale = (prev_scale_t / scale_t) pred_x0 *= rescale diff --git a/src/unifolm_wma/modules/attention.py b/src/unifolm_wma/modules/attention.py index 7b21317..cf087c9 100644 --- a/src/unifolm_wma/modules/attention.py +++ b/src/unifolm_wma/modules/attention.py @@ -99,6 +99,7 @@ class CrossAttention(nn.Module): self.agent_state_context_len = agent_state_context_len self.agent_action_context_len = agent_action_context_len self.cross_attention_scale_learnable = cross_attention_scale_learnable + self._attn_mask_cache = {} if self.image_cross_attention: self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False) self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False) @@ -275,7 +276,8 @@ class CrossAttention(nn.Module): attn_mask_aa = self._get_attn_mask_aa(x.shape[0], q.shape[1], k_aa.shape[1], - block_size=16).to(k_aa.device) + block_size=16, + device=k_aa.device) else: if not spatial_self_attn: assert 1 > 2, ">>> ERROR: you should never go into here ..." @@ -386,14 +388,26 @@ class CrossAttention(nn.Module): return self.to_out(out) - def _get_attn_mask_aa(self, b, l1, l2, block_size=16): + def _get_attn_mask_aa(self, + b, + l1, + l2, + block_size=16, + device=None): + if device is None: + device = self.to_q.weight.device + cache_key = (b, l1, l2, block_size, str(device)) + if cache_key in self._attn_mask_cache: + return self._attn_mask_cache[cache_key] num_token = l2 // block_size - start_positions = ((torch.arange(b) % block_size) + 1) * num_token - col_indices = torch.arange(l2) + start_positions = ((torch.arange(b, device=device) % block_size) + + 1) * num_token + col_indices = torch.arange(l2, device=device) mask_2d = col_indices.unsqueeze(0) >= start_positions.unsqueeze(1) mask = mask_2d.unsqueeze(1).expand(b, l1, l2) - attn_mask = torch.zeros_like(mask, dtype=torch.float) + attn_mask = torch.zeros_like(mask, dtype=torch.float32) attn_mask[mask] = float('-inf') + self._attn_mask_cache[cache_key] = attn_mask return attn_mask diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768742644.node-0.490488.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768742644.node-0.490488.0 new file mode 100644 index 0000000..04fb1ce Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768742644.node-0.490488.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768744224.node-0.506787.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768744224.node-0.506787.0 new file mode 100644 index 0000000..e3c8f22 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768744224.node-0.506787.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768746193.node-0.512651.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768746193.node-0.512651.0 new file mode 100644 index 0000000..ffad828 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768746193.node-0.512651.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768746333.node-0.514047.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768746333.node-0.514047.0 new file mode 100644 index 0000000..b9fde46 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768746333.node-0.514047.0 differ diff --git a/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768746500.node-0.514744.0 b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768746500.node-0.514744.0 new file mode 100644 index 0000000..5413ff8 Binary files /dev/null and b/unitree_g1_pack_camera/case1/output/tensorboard/events.out.tfevents.1768746500.node-0.514744.0 differ diff --git a/useful.sh b/useful.sh index cc5e9de..86a04b5 100644 --- a/useful.sh +++ b/useful.sh @@ -106,4 +106,16 @@ embedder: 1. 新增 --encoder_mode {fp32, autocast, bf16_full} 2. bf16_full = 权重 BF16 + 前向 BF16 - 3. autocast = 权重 FP32 + 仅主干 autocast(现在的实现) \ No newline at end of file + 3. autocast = 权重 FP32 + 仅主干 autocast(现在的实现) + + + + 1. DDIM loop 内小张量分配优化(已完成) + + - 每步 torch.full(...) 改成预先构造/广播,减少 loop 内分配 + - 位置:src/unifolm_wma/models/samplers/ddim.py + + 2. attention mask 缓存到 GPU(已完成) + + - _get_attn_mask_aa 现在直接在目标 device 构造并缓存,避免每步 CPU→GPU 拷贝 + - 位置:src/unifolm_wma/modules/attention.py \ No newline at end of file