10 Commits

Author SHA1 Message Date
qhy
bb274870c2 整理代码 2026-02-10 12:46:12 +08:00
qhy
f1f92072e6 remove profile 2026-02-10 11:28:26 +08:00
qhy
ff920b85a2 理论性能分析 2026-02-10 10:10:09 +08:00
qhy
6630952d2b 异步保存结果 2026-02-09 21:23:00 +08:00
qhy
bc78815acf 脚本参数暂时修改 2026-02-07 21:28:54 +08:00
qhy
d5f6577fa8 复制模型对象,跳过加载模型 2026-02-07 19:18:49 +08:00
qhy
7dcf9e8b89 VAE优化,模型直接加载至GPU 2026-02-07 17:36:00 +08:00
qhy
aba2a90045 算子融合 2026-02-07 16:40:33 +08:00
25de36b9bc 添加当前优化说明
相关参数改动和效果
2026-01-19 16:58:37 +08:00
2fdcec6da0 Delete README.md 2026-01-19 16:39:49 +08:00
9 changed files with 868 additions and 1412 deletions

245
README.md
View File

@@ -1,228 +1,29 @@
# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family
<p style="font-size: 1.2em;">
<a href="https://unigen-x.github.io/unifolm-world-model-action.github.io"><strong>Project Page</strong></a> |
<a href="https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c"><strong>Models</strong></a> |
<a href="https://huggingface.co/unitreerobotics/datasets"><strong>Dataset</strong></a>
</p>
<div align="center">
<p align="right">
<span> 🌎English </span> | <a href="README_cn.md"> 🇨🇳中文 </a>
</p>
</div>
<div align="justify">
<b>UnifoLM-WMA-0</b> is Unitrees open-source world-modelaction architecture spanning multiple types of robotic embodiments, designed specifically for general-purpose robot learning. Its core component is a world-model capable of understanding the physical interactions between robots and the environments. This world-model provides two key functions: (a) <b>Simulation Engine</b> operates as an interactive simulator to generate synthetic data for robot learning; (b) <b>Policy Enhancement</b> connects with an action head and, by predicting future interaction processes with the world-model, further optimizes decision-making performance.
</div>
# World Model Interaction 混合精度加速记录case1
## 🦾 Real-Robot Demonstrations
| <img src="assets/gifs/real_z1_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_dual_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
|:---:|:---:|
| <img src="assets/gifs/real_cleanup_pencils.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_g1_pack_camera.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
## 变更位置
- 脚本路径:`/home/dyz/unifolm-world-model-action/unitree_g1_pack_camera/case1/run_world_model_interaction.sh`
- 当前状态:已修改了部分原本不建议修改/需要谨慎修改的参数(后续会在确认最优后固化为默认)。
**Note: the top-right window shows the world models pretion of future action videos.**
## 新增参数(确认最优后可变为默认)
- `--diffusion_dtype {fp32,bf16}`Diffusion 权重与前向 dtype默认 `fp32`
- `--projector_mode {fp32,autocast,bf16_full}`Projector 精度策略,默认 `fp32`
- `--encoder_mode {fp32,autocast,bf16_full}`Encoder 精度策略,默认 `fp32`
- `--vae_dtype {fp32,bf16}`VAE 权重与前向 dtype默认 `fp32`
- `--export_casted_ckpt <path>`:按当前精度设置导出 ckpt用于离线导出混合精度权重
- `--export_only`:只导出 ckpt 后退出,默认关闭
## 🔥 News
### 参数语义约定
- `fp32`:权重 + 前向均使用 fp32
- `autocast`:权重保持 fp32forward 在 `torch.autocast` 下运行(算子级混精)
- `bf16_full`:权重显式转换为 bf16forward 也以 bf16 为主
* Sep 22, 2025: 🚀 We released the deployment code for assisting experiments with [Unitree](https://www.unitree.com/) robots.
* Sep 15, 2025: 🚀 We released the training and inference code along with the model weights of [**UnifoLM-WMA-0**](https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c).
## 当前最优配置与结果
### 配置
- 除 VAE 模块外,其它模块全部 bf16
- 模型离线导出混合精度 ckpt使用 `--export_casted_ckpt`
## 📑 Opensource Plan
- [x] Training
- [x] Inference
- [x] Checkpoints
- [x] Deployment
### 结果
- 耗时:从 `15m6s` 降到 `7m5s`
- PSNR下降不到 `4``35 -> 31`
- 显存:占用降到原本约 `50%`
## ⚙️ Installation
```
conda create -n unifolm-wma python==3.10.18
conda activate unifolm-wma
conda install pinocchio=3.2.0 -c conda-forge -y
conda install ffmpeg=7.1.1 -c conda-forge
git clone --recurse-submodules https://github.com/unitreerobotics/unifolm-world-model-action.git
# If you already downloaded the repo:
cd unifolm-world-model-action
git submodule update --init --recursive
pip install -e .
cd external/dlimp
pip install -e .
```
## 🧰 Model Checkpoints
| Model | Description | Link|
|---------|-------|------|
|$\text{UnifoLM-WMA-0}_{Base}$| Fine-tuned on [Open-X](https://robotics-transformer-x.github.io/) dataset. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Base)|
|$\text{UnifoLM-WMA-0}_{Dual}$| Fine-tuned on five [Unitree opensource dataset](https://huggingface.co/collections/unitreerobotics/g1-dex1-datasets-68bae98bf0a26d617f9983ab) in both decision-making and simulation modes. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Dual)|
## 🛢️ Dataset
In our experiments, we consider the following three opensource dataset:
| Dataset | Robot | Link |
|---------|-------|------|
|Z1_StackBox| [Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_StackBox_Dataset/tree/v2.1)|
|Z1_DualArm_StackBox|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset/tree/v2.1)|
|Z1_DualArm_StackBox_V2|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset_V2/tree/v2.1)|
|Z1_DualArm_Cleanup_Pencils|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_CleanupPencils_Dataset/tree/v2.1)|
|G1_Pack_Camera|[Unitree G1](https://www.unitree.com/g1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/G1_Dex1_MountCameraRedGripper_Dataset/tree/v2.1)|
To train on your own dataset, first to have the data following the [Huggingface LeRobot V2.1](https://github.com/huggingface/lerobot) dataset format. Assume the datasets source directory structure is as follows:
```
source_dir/
├── dataset1_name
├── dataset2_name
├── dataset3_name
└── ...
```
Then, convert a dataset to the required format using the command below:
```python
cd prepare_data
python prepare_training_data.py \
--source_dir /path/to/your/source_dir \
--target_dir /path/to/save/the/converted/data \
--dataset_name "dataset1_name" \
--robot_name "a tag of the robot in the dataset" # e.g, Unitree Z1 Robot Arm or Unitree G1 Robot with Gripper.
```
The resulting data structure (Note: model training only supports input from the main-view camera. If the dataset includes multiple views, remove the corresponding values from the ```data_dir``` column in the CSV file.
```
target_dir/
├── videos
│ ├──dataset1_name
│ │ ├──camera_view_dir
│ │ ├── 0.mp4
│ │ ├── 1.mp4
│ │ └── ...
│ └── ...
├── transitions
│ ├── dataset1_name
│ ├── meta_data
│ ├── 0.h5
│ ├── 1.h5
│ └── ...
└── dataset1_name.csv
```
## 🚴‍♂️ Training
A. Our training strategy is outlined as follows:
- **Step 1**: Fine-tune a video generation model as the world model using the [Open-X](https://robotics-transformer-x.github.io/) dataset;
- **Step 2**: Post-train $\text{UnifoLM-WMA}$ in decision-making mode on the downstream task dataset;
<div align="left">
<img src="assets/pngs/dm_mode.png" width="600">
</div>
- **Step 3**: Post-train $\text{UnifoLM-WMA}$ in simulation mode on the downstream task dataset.
<div align="left">
<img src="assets/pngs/sim_mode.png" width="600">
</div>
**Note**: If you only require $\text{UnifoLM-WMA}$ to operate in a single mode, you may skip the corresponding step.
B. To conduct training on a single or multiple datasets, please follow the steps below:
- **Step 1**: The maximum DoF is assumed to be 16, if you have more than 16 DoF, update ```agent_state_dim``` and ```agent_action_dim``` in [configs/train/config.yaml](https://github.com/unitreerobotics/unifolm-wma/blob/working/configs/train/config.yaml) ;
- **Step 2**: Set up the input shapes for each modality in [configs/train/meta.json](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/meta.json);
- **Step 3**: Configure the training parameters in [configs/train/config.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/config.yaml). For the ```pretrained_checkpoint```, we recommend using the checkpoint " $\text{UnifoLM-WMA-0}_{Base}$ " fine-tuned on the [Open-X](https://robotics-transformer-x.github.io/) dataset;
```yaml
model:
pretrained_checkpoint: /path/to/pretrained/checkpoint;
...
decision_making_only: True # Train the world model only in decision-making mode. If False, jointly train it in both decision-making and simulation modes.
...
data:
...
train:
...
data_dir: /path/to/training/dataset/directory
dataset_and_weights: # list the name of each dataset below and make sure the summation of weights is 1.0
dataset1_name: 0.2
dataset2_name: 0.2
dataset3_name: 0.2
dataset4_name: 0.2
dataset5_name: 0.2
```
- **Step 4**: Setup ```experiment_name```, ```save_root``` variables in [scripts/train.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/train.sh);
- **Step 5**: Launch the training with the command:
```
bash scripts/train.sh
```
## 🌏 Inference under Interactive Simulation Mode
To run the world model in an interactive simulation mode, follow these steps:
- **Step 1**: (Skip this step if you just would like to test using the examples we provided) Prepare your own prompt following the format used in the [examples/world_model_interaction_prompts](https://github.com/unitreerobotics/unitree-world-model/tree/main/examples/world_model_interaction_prompts):
```
world_model_interaction_prompts/
├── images
│ ├── dataset1_name
│ │ ├── 0.png # Image prompt
│ │ └── ...
│ └── ...
├── transitions
│ ├── dataset1_name
│ │ ├── meta_data # Used for normalization
│ │ ├── 0.h # Robot state and action data; in interaction mode,
│ │ │ # only used to retrieve the robot state corresponding
│ │ │ # to the image prompt
│ │ └── ...
│ └── ...
├── dataset1_name.csv # File for loading image prompts, text instruction and corresponding robot states
└── ...
```
- **Step 2**: Specify the correct paths for ```pretrained_checkpoint```(e.g, $\text{UnifoLM-WMA-0}_{Dual}$) and ```data_dir``` in [configs/inference/world_model_interaction.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/inference/world_model_interaction.yaml)
- **Step 3**: Set the paths for ```checkpoint```, ```res_dir``` and ```prompt_dir``` in [scripts/run_world_model_interaction.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/run_world_model_interaction.sh), and specify all the dataset's name in ```datasets=(...)```. Then, launch the inference with the command:
```
bash scripts/run_world_model_interaction.sh
```
## 🧠 Inference and Deployment under Decision-Making Mode
In this setup, inference is performed on a server, while a robot client gathers observations from the real-robot and sends them to the server to query actions. The process unfolds through the following steps:
### Server Setup:
- **Step-1**: Specify ```ckpt```, ```res_dir```, ```datasets``` in [scripts/run_real_eval_server.sh](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/scripts/run_real_eval_server.sh);
- **Step-2**: Configure ```data_dir``` and ```dataset_and_weights``` in [config/inference/world_model_decision_making.yaml](https://github.com/unitreerobotics/unifolm-world-model-action/blob/f12b4782652ca00452941d851b17446e4ee7124a/configs/inference/world_model_decision_making.yaml#L225);
- **Step-3**: Launch the server:
```
conda activate unifolm-wma
cd unifolm-world-model-action
bash scripts/run_real_eval_server.sh
```
### Client Setup
- **Step-1**: Follow the instructions in [unitree_deploy/README.md](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/unitree_deploy/README.md) to create the ```unitree_deploy``` conda environment, install the required packages, launch the controllers or services on the real-robot.
- **Step-2**: Open a new terminal and establish a tunnel connection from the client to the server:
```
ssh user_name@remote_server_IP -CNg -L 8000:127.0.0.1:8000
```
- **Step-3**: Run the ```unitree_deploy/robot_client.py``` script to start inference:
```
cd unitree_deploy
python scripts/robot_client.py --robot_type "g1_dex1" --action_horizon 16 --exe_steps 16 --observation_horizon 2 --language_instruction "pack black camera into box" --output_dir ./results --control_freq 15
```
## 📝 Codebase Architecture
Here's a high-level overview of the project's code structure and core components:
```
unitree-world-model/
├── assets # Media assets such as GIFs, images, and demo videos
├── configs # Configuration files for training and inference
│ ├── inference
│ └── train
├── examples # Example inputs and prompts for running inference
├── external # External packages
├── prepare_data # Scripts for dataset preprocessing and format conversion
├── scripts # Main scripts for training, evaluation, and deployment
├── src
│ ├──unitree_worldmodel # Core Python package for the Unitree world model
│ │ ├── data # Dataset loading, transformations, and dataloaders
│ │ ├── models # Model architectures and backbone definitions
│ │ ├── modules # Custom model modules and components
│ │ └── utils # Utility functions and common helpers
└── unitree_deploy # Deployment code
```
## 🙏 Acknowledgement
Lots of code are inherited from [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), [Diffusion Policy](https://github.com/real-stanford/diffusion_policy), [ACT](https://github.com/MarkFzp/act-plus-plus) and [HPT](https://github.com/liruiw/HPT).
## 📝 Citation
```
@misc{unifolm-wma-0,
author = {Unitree},
title = {UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family},
year = {2025},
}
```

View File

@@ -222,7 +222,7 @@ data:
test:
target: unifolm_wma.data.wma_data.WMAData
params:
data_dir: '/home/dyz/unifolm-world-model-action/examples/world_model_interaction_prompts'
data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
video_length: ${model.params.wma_config.params.temporal_length}
frame_stride: 2
load_raw_resolution: True

File diff suppressed because it is too large Load Diff

View File

@@ -99,7 +99,6 @@ class AutoencoderKL(pl.LightningModule):
print(f"Restored from {path}")
def encode(self, x, **kwargs):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)

View File

@@ -1073,15 +1073,19 @@ class LatentDiffusion(DDPM):
if not self.perframe_ae:
encoder_posterior = self.first_stage_model.encode(x)
results = self.get_first_stage_encoding(encoder_posterior).detach()
else: ## Consume less GPU memory but slower
results = []
for index in range(x.shape[0]):
frame_batch = self.first_stage_model.encode(x[index:index +
1, :, :, :])
frame_result = self.get_first_stage_encoding(
frame_batch).detach()
results.append(frame_result)
results = torch.cat(results, dim=0)
else: ## Batch encode with configurable batch size
bs = getattr(self, 'vae_encode_bs', 1)
if bs >= x.shape[0]:
encoder_posterior = self.first_stage_model.encode(x)
results = self.get_first_stage_encoding(encoder_posterior).detach()
else:
results = []
for i in range(0, x.shape[0], bs):
frame_batch = self.first_stage_model.encode(x[i:i + bs])
frame_result = self.get_first_stage_encoding(
frame_batch).detach()
results.append(frame_result)
results = torch.cat(results, dim=0)
if reshape_back:
results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t)
@@ -1105,16 +1109,21 @@ class LatentDiffusion(DDPM):
else:
reshape_back = False
z = 1. / self.scale_factor * z
if not self.perframe_ae:
z = 1. / self.scale_factor * z
results = self.first_stage_model.decode(z, **kwargs)
else:
results = []
for index in range(z.shape[0]):
frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :]
frame_result = self.first_stage_model.decode(frame_z, **kwargs)
results.append(frame_result)
results = torch.cat(results, dim=0)
bs = getattr(self, 'vae_decode_bs', 1)
if bs >= z.shape[0]:
# all frames in one batch
results = self.first_stage_model.decode(z, **kwargs)
else:
results = []
for i in range(0, z.shape[0], bs):
results.append(
self.first_stage_model.decode(z[i:i + bs], **kwargs))
results = torch.cat(results, dim=0)
if reshape_back:
results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t)

View File

@@ -55,16 +55,13 @@ class DDIMSampler(object):
to_torch(self.model.alphas_cumprod_prev))
# Calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod',
to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod',
to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# Computed directly on GPU to avoid CPU↔GPU transfers
ac = to_torch(alphas_cumprod)
self.register_buffer('sqrt_alphas_cumprod', ac.sqrt())
self.register_buffer('sqrt_one_minus_alphas_cumprod', (1. - ac).sqrt())
self.register_buffer('log_one_minus_alphas_cumprod', (1. - ac).log())
self.register_buffer('sqrt_recip_alphas_cumprod', ac.rsqrt())
self.register_buffer('sqrt_recipm1_alphas_cumprod', (1. / ac - 1).sqrt())
# DDIM sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
@@ -86,6 +83,11 @@ class DDIMSampler(object):
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas',
torch.sqrt(1. - ddim_alphas))
# Precomputed coefficients for DDIM update formula
self.register_buffer('ddim_sqrt_alphas', ddim_alphas.sqrt())
self.register_buffer('ddim_sqrt_alphas_prev', ddim_alphas_prev.sqrt())
self.register_buffer('ddim_dir_coeff',
(1. - ddim_alphas_prev - ddim_sigmas**2).sqrt())
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
@@ -208,18 +210,11 @@ class DDIMSampler(object):
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
action = torch.randn((b, 16, self.model.agent_action_dim),
device=device)
state = torch.randn((b, 16, self.model.agent_state_dim),
device=device)
else:
img = x_T
action = torch.randn((b, 16, self.model.agent_action_dim),
device=device)
state = torch.randn((b, 16, self.model.agent_state_dim),
device=device)
action = torch.randn((b, 16, self.model.agent_action_dim),
device=device)
state = torch.randn((b, 16, self.model.agent_state_dim),
device=device)
img = torch.randn(shape, device=device) if x_T is None else x_T
if precision is not None:
if precision == 16:
@@ -362,12 +357,13 @@ class DDIMSampler(object):
**kwargs)
else:
raise NotImplementedError
model_output = e_t_uncond + unconditional_guidance_scale * (
e_t_cond - e_t_uncond)
model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
e_t_cond_action - e_t_uncond_action)
model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
e_t_cond_state - e_t_uncond_state)
model_output = torch.lerp(e_t_uncond, e_t_cond,
unconditional_guidance_scale)
model_output_action = torch.lerp(e_t_uncond_action,
e_t_cond_action,
unconditional_guidance_scale)
model_output_state = torch.lerp(e_t_uncond_state, e_t_cond_state,
unconditional_guidance_scale)
if guidance_rescale > 0.0:
model_output = rescale_noise_cfg(
@@ -396,18 +392,28 @@ class DDIMSampler(object):
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
if use_original_steps:
sqrt_alphas = alphas.sqrt()
sqrt_alphas_prev = alphas_prev.sqrt()
dir_coeffs = (1. - alphas_prev - sigmas**2).sqrt()
else:
sqrt_alphas = self.ddim_sqrt_alphas
sqrt_alphas_prev = self.ddim_sqrt_alphas_prev
dir_coeffs = self.ddim_dir_coeff
if is_video:
size = (1, 1, 1, 1, 1)
else:
size = (1, 1, 1, 1)
a_t = alphas[index].view(size)
a_prev = alphas_prev[index].view(size)
sqrt_at = sqrt_alphas[index].view(size)
sqrt_a_prev = sqrt_alphas_prev[index].view(size)
sigma_t = sigmas[index].view(size)
sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size)
dir_coeff = dir_coeffs[index].view(size)
if self.model.parameterization != "v":
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
pred_x0 = (x - sqrt_one_minus_at * e_t) / sqrt_at
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
@@ -420,14 +426,11 @@ class DDIMSampler(object):
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device,
repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
x_prev = sqrt_a_prev * pred_x0 + dir_coeff * e_t + noise
return x_prev, pred_x0, model_output_action, model_output_state
@@ -475,7 +478,7 @@ class DDIMSampler(object):
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_alphas_cumprod = self.ddim_sqrt_alphas
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:

View File

@@ -10,8 +10,8 @@ from unifolm_wma.utils.utils import instantiate_from_config
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
# swish / SiLU — single fused CUDA kernel instead of x * sigmoid(x)
return torch.nn.functional.silu(x)
def Normalize(in_channels, num_groups=32):

View File

@@ -7,79 +7,97 @@ MACRO-LEVEL TIMING SUMMARY
----------------------------------------
Section Count Total(ms) Avg(ms) CUDA Avg(ms)
--------------------------------------------------------------------------------------
action_generation 11 399707.47 36337.04 36336.85
data_loading 1 52.85 52.85 52.88
get_latent_z/encode 22 901.39 40.97 41.01
iteration_total 11 836793.23 76072.11 76071.63
load_transitions 1 2.24 2.24 2.28
model_loading/checkpoint 1 11833.31 11833.31 11833.43
model_loading/config 1 49774.19 49774.19 49774.16
model_to_cuda 1 8909.30 8909.30 8909.33
prepare_init_input 1 10.52 10.52 10.55
prepare_observation 11 5.41 0.49 0.53
prepare_wm_observation 11 2.12 0.19 0.22
save_results 11 38668.06 3515.28 3515.32
synthesis/conditioning_prep 22 2916.63 132.57 132.61
synthesis/ddim_sampling 22 782695.01 35577.05 35576.86
synthesis/decode_first_stage 22 12444.31 565.65 565.70
update_action_queues 11 6.85 0.62 0.65
update_state_queues 11 17.67 1.61 1.64
world_model_interaction 11 398375.58 36215.96 36215.75
action_generation 11 173133.54 15739.41 15739.36
data_loading 1 54.31 54.31 54.34
get_latent_z/encode 22 785.25 35.69 35.72
iteration_total 11 386482.08 35134.73 35134.55
load_transitions 1 2.07 2.07 2.10
model_loading/prepared 1 4749.22 4749.22 4749.83
prepare_init_input 1 29.19 29.19 29.22
prepare_observation 11 5.49 0.50 0.53
prepare_wm_observation 11 1.93 0.18 0.20
save_results 11 38791.18 3526.47 3526.51
synthesis/conditioning_prep 22 2528.23 114.92 114.95
synthesis/ddim_sampling 22 336003.29 15272.88 15272.83
synthesis/decode_first_stage 22 9095.14 413.42 413.46
update_action_queues 11 7.28 0.66 0.69
update_state_queues 11 17.38 1.58 1.61
world_model_interaction 11 174516.52 15865.14 15865.07
--------------------------------------------------------------------------------------
TOTAL 2543116.13
TOTAL 1126202.08
----------------------------------------
GPU MEMORY SUMMARY
----------------------------------------
Peak allocated: 17890.50 MB
Average allocated: 16129.98 MB
Peak allocated: 18188.29 MB
Average allocated: 9117.49 MB
----------------------------------------
TOP 30 OPERATORS BY CUDA TIME
----------------------------------------
Operator Count CUDA(ms) CPU(ms) Self CUDA(ms)
------------------------------------------------------------------------------------------------
ProfilerStep* 6 443804.16 237696.98 237689.25
aten::linear 171276 112286.23 13179.82 0.00
aten::addmm 81456 79537.36 3799.84 79296.37
ampere_sgemm_128x64_tn 26400 52052.10 0.00 52052.10
aten::matmul 90468 34234.05 6281.32 0.00
aten::_convolution 100242 33623.79 13105.89 0.00
aten::mm 89820 33580.74 3202.22 33253.18
aten::convolution 100242 33575.23 13714.47 0.00
aten::cudnn_convolution 98430 30932.19 8640.50 29248.12
ampere_sgemm_32x128_tn 42348 20394.52 0.00 20394.52
aten::conv2d 42042 18115.35 5932.30 0.00
ampere_sgemm_128x32_tn 40938 16429.81 0.00 16429.81
xformers::efficient_attention_forward_cutlass 24000 15222.23 2532.93 15120.44
fmha_cutlassF_f32_aligned_64x64_rf_sm80(Attenti... 24000 15121.31 0.00 15121.31
ampere_sgemm_64x64_tn 21000 14627.12 0.00 14627.12
aten::copy_ 231819 14504.87 127056.51 14038.39
aten::group_norm 87144 12033.73 10659.57 0.00
aten::native_group_norm 87144 11473.40 9449.36 11002.02
aten::conv3d 26400 8852.13 3365.43 0.00
void at::native::(anonymous namespace)::Rowwise... 87144 8714.68 0.00 8714.68
void cudnn::ops::nchwToNhwcKernel<float, float,... 169824 8525.44 0.00 8525.44
aten::clone 214314 8200.26 8568.82 0.00
void at::native::elementwise_kernel<128, 2, at:... 220440 8109.62 0.00 8109.62
void cutlass::Kernel<cutlass_80_simt_sgemm_128x... 15000 7919.30 0.00 7919.30
aten::_to_copy 12219 5963.43 122411.53 0.00
aten::to 58101 5952.65 122443.72 0.00
aten::conv1d 30000 5878.95 4556.48 0.00
Memcpy HtoD (Pageable -> Device) 6696 5856.39 0.00 5856.39
aten::reshape 671772 5124.03 9636.01 0.00
sm80_xmma_fprop_implicit_gemm_indexed_tf32f32_t... 16272 5097.70 0.00 5097.70
ProfilerStep* 18 690146.23 133688.74 616385.44
aten::group_norm 168624 24697.84 29217.27 0.00
aten::_convolution 96450 21420.26 12845.86 0.00
aten::convolution 96450 21408.68 13480.97 0.00
aten::linear 297398 20780.15 26257.38 0.00
aten::cudnn_convolution 94638 18660.24 8239.04 18329.28
aten::copy_ 772677 18135.46 17387.09 17864.87
aten::conv3d 52800 12922.42 8572.58 0.00
aten::conv2d 52469 12747.13 7725.70 0.00
aten::native_group_norm 84312 10285.37 8974.31 10197.66
aten::_to_copy 590277 10270.09 22570.90 0.00
aten::to 602979 9655.26 23666.06 0.00
aten::conv1d 56245 8174.37 10015.24 0.00
void at::native::(anonymous namespace)::Rowwise... 84312 7979.71 0.00 7979.71
aten::clone 177132 7502.90 7007.48 0.00
void cudnn::ops::nchwToNhwcKernel<__nv_bfloat16... 164700 7384.52 0.00 7384.52
aten::addmm 81456 6958.44 3903.01 6908.44
aten::layer_norm 65700 5698.92 7816.08 0.00
void at::native::elementwise_kernel<128, 4, at:... 149688 5372.46 0.00 5372.46
void at::native::unrolled_elementwise_kernel<at... 180120 5165.28 0.00 5165.28
ampere_bf16_s16816gemm_bf16_128x128_ldg8_relu_f... 24900 4449.05 0.00 4449.05
void at::native::unrolled_elementwise_kernel<at... 368664 4405.30 0.00 4405.30
aten::reshape 686778 3771.84 8309.51 0.00
aten::contiguous 46008 3400.88 1881.73 0.00
sm80_xmma_fprop_implicit_gemm_bf16bf16_bf16f32_... 15516 3398.03 0.00 3398.03
aten::matmul 90489 3366.62 4946.69 0.00
aten::mm 89820 3284.53 3308.76 3228.56
void at::native::elementwise_kernel<128, 2, at:... 46518 2441.55 0.00 2441.55
aten::add 113118 2426.66 2776.23 2385.52
void at::native::elementwise_kernel<128, 4, at:... 104550 2426.41 0.00 2426.41
----------------------------------------
OPERATOR CATEGORY BREAKDOWN
----------------------------------------
Category CUDA Time(ms) Percentage
---------------------------------------------------------
Other 481950.47 41.9%
Linear/GEMM 342333.09 29.8%
Convolution 159920.77 13.9%
Elementwise 54682.93 4.8%
Memory 36883.36 3.2%
Attention 34736.13 3.0%
Normalization 32081.19 2.8%
Activation 6449.19 0.6%
Other 723472.91 71.9%
Convolution 114469.81 11.4%
Memory 53845.46 5.4%
Normalization 46852.57 4.7%
Linear/GEMM 35354.58 3.5%
Elementwise 17078.44 1.7%
Activation 12296.29 1.2%
Attention 2956.61 0.3%
------------------------------------------------------------------------------------------
aten::addmm (Linear/GEMM) UTILIZATION ANALYSIS ON A100
------------------------------------------------------------------------------------------
Effective compute precision: BF16 Tensor Core (312 TFLOPS)
torch.backends.cuda.matmul.allow_tf32 = False
Metric Value
-------------------------------------------------------------------
Total aten::addmm calls 81,456
Total Self CUDA time 6908.44 ms
Total FLOPs (profiler) 1.33 PFLOPS
Achieved throughput 191.88 TFLOPS/s
A100 peak throughput 312.00 TFLOPS/s
MFU (Model FLOPs Utilization) 61.50%
INTERPRETATION:
-------------------------------------------------------------------
Good utilization (>60%). GEMM kernels are compute-bound
and running efficiently on Tensor Cores.

View File

@@ -2,7 +2,7 @@ res_dir="unitree_g1_pack_camera/case1"
dataset="unitree_g1_pack_camera"
{
time CUDA_VISIBLE_DEVICES=1 python3 scripts/evaluation/world_model_interaction.py \
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
--seed 123 \
--ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \
--config configs/inference/world_model_interaction.yaml \
@@ -23,5 +23,6 @@ dataset="unitree_g1_pack_camera"
--perframe_ae \
--diffusion_dtype bf16 \
--projector_mode bf16_full \
--encoder_mode bf16_full
--encoder_mode bf16_full \
--vae_dtype bf16
} 2>&1 | tee "${res_dir}/output.log"