VAE 也做 BF16
这个权重不做修改更好精度
This commit is contained in:
@@ -649,7 +649,7 @@ def prepare_init_input(start_idx: int,
|
||||
return data, ori_state_dim, ori_action_dim
|
||||
|
||||
|
||||
def get_latent_z(model, videos: Tensor) -> Tensor:
|
||||
def get_latent_z(model, videos: Tensor) -> Tensor:
|
||||
"""
|
||||
Extracts latent features from a video batch using the model's first-stage encoder.
|
||||
|
||||
@@ -661,11 +661,15 @@ def get_latent_z(model, videos: Tensor) -> Tensor:
|
||||
Tensor: Latent video tensor of shape [B, C, T, H, W].
|
||||
"""
|
||||
profiler = get_profiler()
|
||||
with profiler.profile_section("get_latent_z/encode"):
|
||||
b, c, t, h, w = videos.shape
|
||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||
z = model.encode_first_stage(x)
|
||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||
with profiler.profile_section("get_latent_z/encode"):
|
||||
b, c, t, h, w = videos.shape
|
||||
x = rearrange(videos, 'b c t h w -> (b t) c h w')
|
||||
vae_ctx = nullcontext()
|
||||
if getattr(model, "vae_bf16", False) and model.device.type == "cuda":
|
||||
vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
|
||||
with vae_ctx:
|
||||
z = model.encode_first_stage(x)
|
||||
z = rearrange(z, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||
return z
|
||||
|
||||
|
||||
@@ -879,9 +883,18 @@ def image_guided_synthesis_sim_mode(
|
||||
|
||||
# Reconstruct from latent to pixel space
|
||||
with profiler.profile_section("synthesis/decode_first_stage"):
|
||||
if samples.dtype != torch.float32:
|
||||
samples = samples.float()
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
if getattr(model, "vae_bf16", False):
|
||||
if samples.dtype != torch.bfloat16:
|
||||
samples = samples.to(dtype=torch.bfloat16)
|
||||
vae_ctx = nullcontext()
|
||||
if model.device.type == "cuda":
|
||||
vae_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
|
||||
with vae_ctx:
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
else:
|
||||
if samples.dtype != torch.float32:
|
||||
samples = samples.float()
|
||||
batch_images = model.decode_first_stage(samples)
|
||||
batch_variants = batch_images
|
||||
|
||||
return batch_variants, actions, states
|
||||
@@ -944,6 +957,14 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
diffusion_autocast_dtype = torch.bfloat16
|
||||
print(">>> diffusion backbone set to bfloat16")
|
||||
|
||||
if hasattr(model, "first_stage_model") and model.first_stage_model is not None:
|
||||
if args.vae_dtype == "bf16":
|
||||
model.first_stage_model.to(dtype=torch.bfloat16)
|
||||
else:
|
||||
model.first_stage_model.to(dtype=torch.float32)
|
||||
model.vae_bf16 = args.vae_dtype == "bf16"
|
||||
print(f">>> VAE dtype set to {args.vae_dtype}")
|
||||
|
||||
encoder_mode = args.encoder_mode
|
||||
encoder_bf16 = encoder_mode in ("autocast", "bf16_full")
|
||||
encoder_weight_dtype = torch.bfloat16 if encoder_mode == "bf16_full" else torch.float32
|
||||
@@ -957,9 +978,21 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
||||
f">>> encoder mode set to {encoder_mode} (weights={encoder_weight_dtype})"
|
||||
)
|
||||
|
||||
projector_mode = args.projector_mode
|
||||
projector_bf16 = projector_mode in ("autocast", "bf16_full")
|
||||
projector_weight_dtype = torch.bfloat16 if projector_mode == "bf16_full" else torch.float32
|
||||
if hasattr(model, "image_proj_model") and model.image_proj_model is not None:
|
||||
model.image_proj_model.to(dtype=projector_weight_dtype)
|
||||
if hasattr(model, "state_projector") and model.state_projector is not None:
|
||||
model.state_projector.to(dtype=projector_weight_dtype)
|
||||
if hasattr(model, "action_projector") and model.action_projector is not None:
|
||||
model.action_projector.to(dtype=projector_weight_dtype)
|
||||
if hasattr(model, "projector_bf16"):
|
||||
model.projector_bf16 = args.projector_dtype == "bf16"
|
||||
print(f">>> projector dtype set to {args.projector_dtype}")
|
||||
model.projector_bf16 = projector_bf16
|
||||
model.projector_mode = projector_mode
|
||||
print(
|
||||
f">>> projector mode set to {projector_mode} (weights={projector_weight_dtype})"
|
||||
)
|
||||
|
||||
log_inference_precision(model)
|
||||
|
||||
@@ -1305,11 +1338,14 @@ def get_parser():
|
||||
help="Dtype for diffusion backbone weights and sampling autocast."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--projector_dtype",
|
||||
"--projector_mode",
|
||||
type=str,
|
||||
choices=["fp32", "bf16"],
|
||||
choices=["fp32", "autocast", "bf16_full"],
|
||||
default="fp32",
|
||||
help="Dtype for image/state/action projectors (autocast in forward)."
|
||||
help=
|
||||
"Projector precision mode for image/state/action projectors: "
|
||||
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
|
||||
"bf16_full=bf16 weights + bf16 forward."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_mode",
|
||||
@@ -1321,6 +1357,13 @@ def get_parser():
|
||||
"fp32=full fp32, autocast=fp32 weights + bf16 autocast in forward, "
|
||||
"bf16_full=bf16 weights + bf16 forward."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_dtype",
|
||||
type=str,
|
||||
choices=["fp32", "bf16"],
|
||||
default="fp32",
|
||||
help="Dtype for VAE/first_stage_model weights and forward autocast."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_action_steps",
|
||||
type=int,
|
||||
|
||||
@@ -2032,6 +2032,13 @@ class LatentVisualDiffusion(LatentDiffusion):
|
||||
target_dtype: torch.dtype | None) -> Tensor:
|
||||
use_bf16 = (self.projector_bf16 and x.device.type == "cuda"
|
||||
and torch.cuda.is_bf16_supported())
|
||||
if not use_bf16:
|
||||
weight_dtype = None
|
||||
for param in projector.parameters():
|
||||
weight_dtype = param.dtype
|
||||
break
|
||||
if weight_dtype is not None and x.dtype != weight_dtype:
|
||||
x = x.to(dtype=weight_dtype)
|
||||
if use_bf16:
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
out = projector(x)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -22,6 +22,7 @@ dataset="unitree_g1_pack_camera"
|
||||
--guidance_rescale 0.7 \
|
||||
--perframe_ae \
|
||||
--diffusion_dtype bf16 \
|
||||
--projector_dtype bf16 \
|
||||
--encoder_mode autocast #fp32/autocast/bf16_full
|
||||
--projector_mode autocast \
|
||||
--encoder_mode bf16_full \
|
||||
--vae_dtype bf16
|
||||
} 2>&1 | tee "${res_dir}/output.log"
|
||||
|
||||
Reference in New Issue
Block a user