embedder权重改成bf16
似乎因为权重的处理更慢了,整体速度反而变慢了一点点
This commit is contained in:
@@ -772,7 +772,11 @@ def image_guided_synthesis_sim_mode(
|
|||||||
with profiler.profile_section("synthesis/conditioning_prep"):
|
with profiler.profile_section("synthesis/conditioning_prep"):
|
||||||
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
img = observation['observation.images.top'].permute(0, 2, 1, 3, 4)
|
||||||
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
cond_img = rearrange(img, 'b o c h w -> (b o) c h w')[-1:]
|
||||||
cond_img_emb = model.embedder(cond_img)
|
embedder_ctx = nullcontext()
|
||||||
|
if getattr(model, "encoder_bf16", False) and model.device.type == "cuda":
|
||||||
|
embedder_ctx = torch.autocast("cuda", dtype=torch.bfloat16)
|
||||||
|
with embedder_ctx:
|
||||||
|
cond_img_emb = model.embedder(cond_img)
|
||||||
|
|
||||||
if model.model.conditioning_key == 'hybrid':
|
if model.model.conditioning_key == 'hybrid':
|
||||||
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
z = get_latent_z(model, img.permute(0, 2, 1, 3, 4))
|
||||||
@@ -912,6 +916,16 @@ def run_inference(args: argparse.Namespace, gpu_num: int, gpu_no: int) -> None:
|
|||||||
diffusion_autocast_dtype = torch.bfloat16
|
diffusion_autocast_dtype = torch.bfloat16
|
||||||
print(">>> diffusion backbone set to bfloat16")
|
print(">>> diffusion backbone set to bfloat16")
|
||||||
|
|
||||||
|
encoder_dtype = torch.float32
|
||||||
|
if args.encoder_dtype == "bf16":
|
||||||
|
encoder_dtype = torch.bfloat16
|
||||||
|
if hasattr(model, "cond_stage_model") and model.cond_stage_model is not None:
|
||||||
|
model.cond_stage_model.to(dtype=encoder_dtype)
|
||||||
|
if hasattr(model, "embedder") and model.embedder is not None:
|
||||||
|
model.embedder.to(dtype=encoder_dtype)
|
||||||
|
model.encoder_bf16 = args.encoder_dtype == "bf16"
|
||||||
|
print(f">>> encoder dtype set to {args.encoder_dtype}")
|
||||||
|
|
||||||
if hasattr(model, "projector_bf16"):
|
if hasattr(model, "projector_bf16"):
|
||||||
model.projector_bf16 = args.projector_dtype == "bf16"
|
model.projector_bf16 = args.projector_dtype == "bf16"
|
||||||
print(f">>> projector dtype set to {args.projector_dtype}")
|
print(f">>> projector dtype set to {args.projector_dtype}")
|
||||||
@@ -1266,6 +1280,13 @@ def get_parser():
|
|||||||
default="fp32",
|
default="fp32",
|
||||||
help="Dtype for image/state/action projectors (autocast in forward)."
|
help="Dtype for image/state/action projectors (autocast in forward)."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder_dtype",
|
||||||
|
type=str,
|
||||||
|
choices=["fp32", "bf16"],
|
||||||
|
default="fp32",
|
||||||
|
help="Dtype for text/image encoders (cond_stage_model/embedder)."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--n_action_steps",
|
"--n_action_steps",
|
||||||
type=int,
|
type=int,
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
@@ -22,5 +22,6 @@ dataset="unitree_g1_pack_camera"
|
|||||||
--guidance_rescale 0.7 \
|
--guidance_rescale 0.7 \
|
||||||
--perframe_ae \
|
--perframe_ae \
|
||||||
--diffusion_dtype bf16 \
|
--diffusion_dtype bf16 \
|
||||||
--projector_dtype bf16
|
--projector_dtype bf16 \
|
||||||
|
--encoder_dtype bf16
|
||||||
} 2>&1 | tee "${res_dir}/output.log"
|
} 2>&1 | tee "${res_dir}/output.log"
|
||||||
|
|||||||
@@ -79,3 +79,6 @@ BF16 projector比FP32 projector更准的可能原因:
|
|||||||
- LayerNorm/Softmax 敏感:Resampler/MLP 里 LN/Softmax 对精度很敏感,FP32 计算后再降精度,数值边界更容易“硬截断”;BF16 全程计算可能更平滑。
|
- LayerNorm/Softmax 敏感:Resampler/MLP 里 LN/Softmax 对精度很敏感,FP32 计算后再降精度,数值边界更容易“硬截断”;BF16 全程计算可能更平滑。
|
||||||
|
|
||||||
这也解释了为什么你看到 BF16 projector 反而更准。
|
这也解释了为什么你看到 BF16 projector 反而更准。
|
||||||
|
|
||||||
|
embedder:
|
||||||
|
改成 autocast only(权重 FP32,预处理 FP32,仅主干 BF16)
|
||||||
Reference in New Issue
Block a user