Compare commits
10 Commits
7e501b17fd
...
bb274870c2
| Author | SHA1 | Date | |
|---|---|---|---|
| bb274870c2 | |||
| f1f92072e6 | |||
| ff920b85a2 | |||
| 6630952d2b | |||
| bc78815acf | |||
| d5f6577fa8 | |||
| 7dcf9e8b89 | |||
| aba2a90045 | |||
| 25de36b9bc | |||
| 2fdcec6da0 |
245
README.md
245
README.md
@@ -1,228 +1,29 @@
|
||||
# UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family
|
||||
<p style="font-size: 1.2em;">
|
||||
<a href="https://unigen-x.github.io/unifolm-world-model-action.github.io"><strong>Project Page</strong></a> |
|
||||
<a href="https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c"><strong>Models</strong></a> |
|
||||
<a href="https://huggingface.co/unitreerobotics/datasets"><strong>Dataset</strong></a>
|
||||
</p>
|
||||
<div align="center">
|
||||
<p align="right">
|
||||
<span> 🌎English </span> | <a href="README_cn.md"> 🇨🇳中文 </a>
|
||||
</p>
|
||||
</div>
|
||||
<div align="justify">
|
||||
<b>UnifoLM-WMA-0</b> is Unitree‘s open-source world-model–action architecture spanning multiple types of robotic embodiments, designed specifically for general-purpose robot learning. Its core component is a world-model capable of understanding the physical interactions between robots and the environments. This world-model provides two key functions: (a) <b>Simulation Engine</b> – operates as an interactive simulator to generate synthetic data for robot learning; (b) <b>Policy Enhancement</b> – connects with an action head and, by predicting future interaction processes with the world-model, further optimizes decision-making performance.
|
||||
</div>
|
||||
# World Model Interaction 混合精度加速记录(case1)
|
||||
|
||||
## 🦾 Real-Robot Demonstrations
|
||||
| <img src="assets/gifs/real_z1_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_dual_stackbox.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
|
||||
|:---:|:---:|
|
||||
| <img src="assets/gifs/real_cleanup_pencils.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> | <img src="assets/gifs/real_g1_pack_camera.gif" style="border:none;box-shadow:none;margin:0;padding:0;" /> |
|
||||
## 变更位置
|
||||
- 脚本路径:`/home/dyz/unifolm-world-model-action/unitree_g1_pack_camera/case1/run_world_model_interaction.sh`
|
||||
- 当前状态:已修改了部分原本不建议修改/需要谨慎修改的参数(后续会在确认最优后固化为默认)。
|
||||
|
||||
**Note: the top-right window shows the world model’s pretion of future action videos.**
|
||||
## 新增参数(确认最优后可变为默认)
|
||||
- `--diffusion_dtype {fp32,bf16}`:Diffusion 权重与前向 dtype,默认 `fp32`
|
||||
- `--projector_mode {fp32,autocast,bf16_full}`:Projector 精度策略,默认 `fp32`
|
||||
- `--encoder_mode {fp32,autocast,bf16_full}`:Encoder 精度策略,默认 `fp32`
|
||||
- `--vae_dtype {fp32,bf16}`:VAE 权重与前向 dtype,默认 `fp32`
|
||||
- `--export_casted_ckpt <path>`:按当前精度设置导出 ckpt(用于离线导出混合精度权重)
|
||||
- `--export_only`:只导出 ckpt 后退出,默认关闭
|
||||
|
||||
## 🔥 News
|
||||
### 参数语义约定
|
||||
- `fp32`:权重 + 前向均使用 fp32
|
||||
- `autocast`:权重保持 fp32,forward 在 `torch.autocast` 下运行(算子级混精)
|
||||
- `bf16_full`:权重显式转换为 bf16,forward 也以 bf16 为主
|
||||
|
||||
* Sep 22, 2025: 🚀 We released the deployment code for assisting experiments with [Unitree](https://www.unitree.com/) robots.
|
||||
* Sep 15, 2025: 🚀 We released the training and inference code along with the model weights of [**UnifoLM-WMA-0**](https://huggingface.co/collections/unitreerobotics/unifolm-wma-0-68ca23027310c0ca0f34959c).
|
||||
## 当前最优配置与结果
|
||||
### 配置
|
||||
- 除 VAE 模块外,其它模块全部 bf16
|
||||
- 模型离线导出混合精度 ckpt(使用 `--export_casted_ckpt`)
|
||||
|
||||
## 📑 Opensource Plan
|
||||
- [x] Training
|
||||
- [x] Inference
|
||||
- [x] Checkpoints
|
||||
- [x] Deployment
|
||||
### 结果
|
||||
- 耗时:从 `15m6s` 降到 `7m5s`
|
||||
- PSNR:下降不到 `4`(`35 -> 31`)
|
||||
- 显存:占用降到原本约 `50%`
|
||||
|
||||
## ⚙️ Installation
|
||||
```
|
||||
conda create -n unifolm-wma python==3.10.18
|
||||
conda activate unifolm-wma
|
||||
|
||||
conda install pinocchio=3.2.0 -c conda-forge -y
|
||||
conda install ffmpeg=7.1.1 -c conda-forge
|
||||
|
||||
git clone --recurse-submodules https://github.com/unitreerobotics/unifolm-world-model-action.git
|
||||
|
||||
# If you already downloaded the repo:
|
||||
cd unifolm-world-model-action
|
||||
git submodule update --init --recursive
|
||||
|
||||
pip install -e .
|
||||
|
||||
cd external/dlimp
|
||||
pip install -e .
|
||||
```
|
||||
## 🧰 Model Checkpoints
|
||||
| Model | Description | Link|
|
||||
|---------|-------|------|
|
||||
|$\text{UnifoLM-WMA-0}_{Base}$| Fine-tuned on [Open-X](https://robotics-transformer-x.github.io/) dataset. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Base)|
|
||||
|$\text{UnifoLM-WMA-0}_{Dual}$| Fine-tuned on five [Unitree opensource dataset](https://huggingface.co/collections/unitreerobotics/g1-dex1-datasets-68bae98bf0a26d617f9983ab) in both decision-making and simulation modes. | [HuggingFace](https://huggingface.co/unitreerobotics/UnifoLM-WMA-0-Dual)|
|
||||
|
||||
## 🛢️ Dataset
|
||||
In our experiments, we consider the following three opensource dataset:
|
||||
| Dataset | Robot | Link |
|
||||
|---------|-------|------|
|
||||
|Z1_StackBox| [Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_StackBox_Dataset/tree/v2.1)|
|
||||
|Z1_DualArm_StackBox|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset/tree/v2.1)|
|
||||
|Z1_DualArm_StackBox_V2|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_StackBox_Dataset_V2/tree/v2.1)|
|
||||
|Z1_DualArm_Cleanup_Pencils|[Unitree Z1](https://www.unitree.com/z1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/Z1_Dual_Dex1_CleanupPencils_Dataset/tree/v2.1)|
|
||||
|G1_Pack_Camera|[Unitree G1](https://www.unitree.com/g1)|[Huggingface](https://huggingface.co/datasets/unitreerobotics/G1_Dex1_MountCameraRedGripper_Dataset/tree/v2.1)|
|
||||
|
||||
To train on your own dataset, first to have the data following the [Huggingface LeRobot V2.1](https://github.com/huggingface/lerobot) dataset format. Assume the dataset’s source directory structure is as follows:
|
||||
```
|
||||
source_dir/
|
||||
├── dataset1_name
|
||||
├── dataset2_name
|
||||
├── dataset3_name
|
||||
└── ...
|
||||
```
|
||||
Then, convert a dataset to the required format using the command below:
|
||||
```python
|
||||
cd prepare_data
|
||||
python prepare_training_data.py \
|
||||
--source_dir /path/to/your/source_dir \
|
||||
--target_dir /path/to/save/the/converted/data \
|
||||
--dataset_name "dataset1_name" \
|
||||
--robot_name "a tag of the robot in the dataset" # e.g, Unitree Z1 Robot Arm or Unitree G1 Robot with Gripper.
|
||||
```
|
||||
The resulting data structure (Note: model training only supports input from the main-view camera. If the dataset includes multiple views, remove the corresponding values from the ```data_dir``` column in the CSV file.
|
||||
```
|
||||
target_dir/
|
||||
├── videos
|
||||
│ ├──dataset1_name
|
||||
│ │ ├──camera_view_dir
|
||||
│ │ ├── 0.mp4
|
||||
│ │ ├── 1.mp4
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── transitions
|
||||
│ ├── dataset1_name
|
||||
│ ├── meta_data
|
||||
│ ├── 0.h5
|
||||
│ ├── 1.h5
|
||||
│ └── ...
|
||||
└── dataset1_name.csv
|
||||
```
|
||||
## 🚴♂️ Training
|
||||
A. Our training strategy is outlined as follows:
|
||||
- **Step 1**: Fine-tune a video generation model as the world model using the [Open-X](https://robotics-transformer-x.github.io/) dataset;
|
||||
- **Step 2**: Post-train $\text{UnifoLM-WMA}$ in decision-making mode on the downstream task dataset;
|
||||
<div align="left">
|
||||
<img src="assets/pngs/dm_mode.png" width="600">
|
||||
</div>
|
||||
- **Step 3**: Post-train $\text{UnifoLM-WMA}$ in simulation mode on the downstream task dataset.
|
||||
<div align="left">
|
||||
<img src="assets/pngs/sim_mode.png" width="600">
|
||||
</div>
|
||||
**Note**: If you only require $\text{UnifoLM-WMA}$ to operate in a single mode, you may skip the corresponding step.
|
||||
|
||||
B. To conduct training on a single or multiple datasets, please follow the steps below:
|
||||
- **Step 1**: The maximum DoF is assumed to be 16, if you have more than 16 DoF, update ```agent_state_dim``` and ```agent_action_dim``` in [configs/train/config.yaml](https://github.com/unitreerobotics/unifolm-wma/blob/working/configs/train/config.yaml) ;
|
||||
- **Step 2**: Set up the input shapes for each modality in [configs/train/meta.json](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/meta.json);
|
||||
- **Step 3**: Configure the training parameters in [configs/train/config.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/train/config.yaml). For the ```pretrained_checkpoint```, we recommend using the checkpoint " $\text{UnifoLM-WMA-0}_{Base}$ " fine-tuned on the [Open-X](https://robotics-transformer-x.github.io/) dataset;
|
||||
```yaml
|
||||
model:
|
||||
pretrained_checkpoint: /path/to/pretrained/checkpoint;
|
||||
...
|
||||
decision_making_only: True # Train the world model only in decision-making mode. If False, jointly train it in both decision-making and simulation modes.
|
||||
...
|
||||
data:
|
||||
...
|
||||
train:
|
||||
...
|
||||
data_dir: /path/to/training/dataset/directory
|
||||
dataset_and_weights: # list the name of each dataset below and make sure the summation of weights is 1.0
|
||||
dataset1_name: 0.2
|
||||
dataset2_name: 0.2
|
||||
dataset3_name: 0.2
|
||||
dataset4_name: 0.2
|
||||
dataset5_name: 0.2
|
||||
```
|
||||
- **Step 4**: Setup ```experiment_name```, ```save_root``` variables in [scripts/train.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/train.sh);
|
||||
- **Step 5**: Launch the training with the command:
|
||||
```
|
||||
bash scripts/train.sh
|
||||
```
|
||||
## 🌏 Inference under Interactive Simulation Mode
|
||||
To run the world model in an interactive simulation mode, follow these steps:
|
||||
- **Step 1**: (Skip this step if you just would like to test using the examples we provided) Prepare your own prompt following the format used in the [examples/world_model_interaction_prompts](https://github.com/unitreerobotics/unitree-world-model/tree/main/examples/world_model_interaction_prompts):
|
||||
```
|
||||
world_model_interaction_prompts/
|
||||
├── images
|
||||
│ ├── dataset1_name
|
||||
│ │ ├── 0.png # Image prompt
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── transitions
|
||||
│ ├── dataset1_name
|
||||
│ │ ├── meta_data # Used for normalization
|
||||
│ │ ├── 0.h # Robot state and action data; in interaction mode,
|
||||
│ │ │ # only used to retrieve the robot state corresponding
|
||||
│ │ │ # to the image prompt
|
||||
│ │ └── ...
|
||||
│ └── ...
|
||||
├── dataset1_name.csv # File for loading image prompts, text instruction and corresponding robot states
|
||||
└── ...
|
||||
```
|
||||
- **Step 2**: Specify the correct paths for ```pretrained_checkpoint```(e.g, $\text{UnifoLM-WMA-0}_{Dual}$) and ```data_dir``` in [configs/inference/world_model_interaction.yaml](https://github.com/unitreerobotics/unitree-world-model/blob/main/configs/inference/world_model_interaction.yaml)
|
||||
- **Step 3**: Set the paths for ```checkpoint```, ```res_dir``` and ```prompt_dir``` in [scripts/run_world_model_interaction.sh](https://github.com/unitreerobotics/unitree-world-model/blob/main/scripts/run_world_model_interaction.sh), and specify all the dataset's name in ```datasets=(...)```. Then, launch the inference with the command:
|
||||
```
|
||||
bash scripts/run_world_model_interaction.sh
|
||||
```
|
||||
|
||||
## 🧠 Inference and Deployment under Decision-Making Mode
|
||||
|
||||
In this setup, inference is performed on a server, while a robot client gathers observations from the real-robot and sends them to the server to query actions. The process unfolds through the following steps:
|
||||
|
||||
### Server Setup:
|
||||
- **Step-1**: Specify ```ckpt```, ```res_dir```, ```datasets``` in [scripts/run_real_eval_server.sh](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/scripts/run_real_eval_server.sh);
|
||||
- **Step-2**: Configure ```data_dir``` and ```dataset_and_weights``` in [config/inference/world_model_decision_making.yaml](https://github.com/unitreerobotics/unifolm-world-model-action/blob/f12b4782652ca00452941d851b17446e4ee7124a/configs/inference/world_model_decision_making.yaml#L225);
|
||||
- **Step-3**: Launch the server:
|
||||
```
|
||||
conda activate unifolm-wma
|
||||
cd unifolm-world-model-action
|
||||
bash scripts/run_real_eval_server.sh
|
||||
```
|
||||
|
||||
### Client Setup
|
||||
- **Step-1**: Follow the instructions in [unitree_deploy/README.md](https://github.com/unitreerobotics/unifolm-world-model-action/blob/main/unitree_deploy/README.md) to create the ```unitree_deploy``` conda environment, install the required packages, launch the controllers or services on the real-robot.
|
||||
- **Step-2**: Open a new terminal and establish a tunnel connection from the client to the server:
|
||||
```
|
||||
ssh user_name@remote_server_IP -CNg -L 8000:127.0.0.1:8000
|
||||
```
|
||||
- **Step-3**: Run the ```unitree_deploy/robot_client.py``` script to start inference:
|
||||
```
|
||||
cd unitree_deploy
|
||||
python scripts/robot_client.py --robot_type "g1_dex1" --action_horizon 16 --exe_steps 16 --observation_horizon 2 --language_instruction "pack black camera into box" --output_dir ./results --control_freq 15
|
||||
```
|
||||
|
||||
## 📝 Codebase Architecture
|
||||
Here's a high-level overview of the project's code structure and core components:
|
||||
```
|
||||
unitree-world-model/
|
||||
├── assets # Media assets such as GIFs, images, and demo videos
|
||||
├── configs # Configuration files for training and inference
|
||||
│ ├── inference
|
||||
│ └── train
|
||||
├── examples # Example inputs and prompts for running inference
|
||||
├── external # External packages
|
||||
├── prepare_data # Scripts for dataset preprocessing and format conversion
|
||||
├── scripts # Main scripts for training, evaluation, and deployment
|
||||
├── src
|
||||
│ ├──unitree_worldmodel # Core Python package for the Unitree world model
|
||||
│ │ ├── data # Dataset loading, transformations, and dataloaders
|
||||
│ │ ├── models # Model architectures and backbone definitions
|
||||
│ │ ├── modules # Custom model modules and components
|
||||
│ │ └── utils # Utility functions and common helpers
|
||||
└── unitree_deploy # Deployment code
|
||||
```
|
||||
|
||||
## 🙏 Acknowledgement
|
||||
Lots of code are inherited from [DynamiCrafter](https://github.com/Doubiiu/DynamiCrafter), [Diffusion Policy](https://github.com/real-stanford/diffusion_policy), [ACT](https://github.com/MarkFzp/act-plus-plus) and [HPT](https://github.com/liruiw/HPT).
|
||||
|
||||
## 📝 Citation
|
||||
```
|
||||
@misc{unifolm-wma-0,
|
||||
author = {Unitree},
|
||||
title = {UnifoLM-WMA-0: A World-Model-Action (WMA) Framework under UnifoLM Family},
|
||||
year = {2025},
|
||||
}
|
||||
```
|
||||
|
||||
@@ -222,7 +222,7 @@ data:
|
||||
test:
|
||||
target: unifolm_wma.data.wma_data.WMAData
|
||||
params:
|
||||
data_dir: '/home/dyz/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||
data_dir: '/home/qhy/unifolm-world-model-action/examples/world_model_interaction_prompts'
|
||||
video_length: ${model.params.wma_config.params.temporal_length}
|
||||
frame_stride: 2
|
||||
load_raw_resolution: True
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -99,7 +99,6 @@ class AutoencoderKL(pl.LightningModule):
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
|
||||
@@ -1073,15 +1073,19 @@ class LatentDiffusion(DDPM):
|
||||
if not self.perframe_ae:
|
||||
encoder_posterior = self.first_stage_model.encode(x)
|
||||
results = self.get_first_stage_encoding(encoder_posterior).detach()
|
||||
else: ## Consume less GPU memory but slower
|
||||
results = []
|
||||
for index in range(x.shape[0]):
|
||||
frame_batch = self.first_stage_model.encode(x[index:index +
|
||||
1, :, :, :])
|
||||
frame_result = self.get_first_stage_encoding(
|
||||
frame_batch).detach()
|
||||
results.append(frame_result)
|
||||
results = torch.cat(results, dim=0)
|
||||
else: ## Batch encode with configurable batch size
|
||||
bs = getattr(self, 'vae_encode_bs', 1)
|
||||
if bs >= x.shape[0]:
|
||||
encoder_posterior = self.first_stage_model.encode(x)
|
||||
results = self.get_first_stage_encoding(encoder_posterior).detach()
|
||||
else:
|
||||
results = []
|
||||
for i in range(0, x.shape[0], bs):
|
||||
frame_batch = self.first_stage_model.encode(x[i:i + bs])
|
||||
frame_result = self.get_first_stage_encoding(
|
||||
frame_batch).detach()
|
||||
results.append(frame_result)
|
||||
results = torch.cat(results, dim=0)
|
||||
|
||||
if reshape_back:
|
||||
results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||
@@ -1105,16 +1109,21 @@ class LatentDiffusion(DDPM):
|
||||
else:
|
||||
reshape_back = False
|
||||
|
||||
z = 1. / self.scale_factor * z
|
||||
|
||||
if not self.perframe_ae:
|
||||
z = 1. / self.scale_factor * z
|
||||
results = self.first_stage_model.decode(z, **kwargs)
|
||||
else:
|
||||
results = []
|
||||
for index in range(z.shape[0]):
|
||||
frame_z = 1. / self.scale_factor * z[index:index + 1, :, :, :]
|
||||
frame_result = self.first_stage_model.decode(frame_z, **kwargs)
|
||||
results.append(frame_result)
|
||||
results = torch.cat(results, dim=0)
|
||||
bs = getattr(self, 'vae_decode_bs', 1)
|
||||
if bs >= z.shape[0]:
|
||||
# all frames in one batch
|
||||
results = self.first_stage_model.decode(z, **kwargs)
|
||||
else:
|
||||
results = []
|
||||
for i in range(0, z.shape[0], bs):
|
||||
results.append(
|
||||
self.first_stage_model.decode(z[i:i + bs], **kwargs))
|
||||
results = torch.cat(results, dim=0)
|
||||
|
||||
if reshape_back:
|
||||
results = rearrange(results, '(b t) c h w -> b c t h w', b=b, t=t)
|
||||
|
||||
@@ -55,16 +55,13 @@ class DDIMSampler(object):
|
||||
to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# Calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod',
|
||||
to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
# Computed directly on GPU to avoid CPU↔GPU transfers
|
||||
ac = to_torch(alphas_cumprod)
|
||||
self.register_buffer('sqrt_alphas_cumprod', ac.sqrt())
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', (1. - ac).sqrt())
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', (1. - ac).log())
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', ac.rsqrt())
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', (1. / ac - 1).sqrt())
|
||||
|
||||
# DDIM sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
@@ -86,6 +83,11 @@ class DDIMSampler(object):
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
torch.sqrt(1. - ddim_alphas))
|
||||
# Precomputed coefficients for DDIM update formula
|
||||
self.register_buffer('ddim_sqrt_alphas', ddim_alphas.sqrt())
|
||||
self.register_buffer('ddim_sqrt_alphas_prev', ddim_alphas_prev.sqrt())
|
||||
self.register_buffer('ddim_dir_coeff',
|
||||
(1. - ddim_alphas_prev - ddim_sigmas**2).sqrt())
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
||||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) *
|
||||
(1 - self.alphas_cumprod / self.alphas_cumprod_prev))
|
||||
@@ -208,18 +210,11 @@ class DDIMSampler(object):
|
||||
dp_ddim_scheduler_state = self.model.dp_noise_scheduler_state
|
||||
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
device=device)
|
||||
else:
|
||||
img = x_T
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
device=device)
|
||||
action = torch.randn((b, 16, self.model.agent_action_dim),
|
||||
device=device)
|
||||
state = torch.randn((b, 16, self.model.agent_state_dim),
|
||||
device=device)
|
||||
img = torch.randn(shape, device=device) if x_T is None else x_T
|
||||
|
||||
if precision is not None:
|
||||
if precision == 16:
|
||||
@@ -362,12 +357,13 @@ class DDIMSampler(object):
|
||||
**kwargs)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
model_output = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t_cond - e_t_uncond)
|
||||
model_output_action = e_t_uncond_action + unconditional_guidance_scale * (
|
||||
e_t_cond_action - e_t_uncond_action)
|
||||
model_output_state = e_t_uncond_state + unconditional_guidance_scale * (
|
||||
e_t_cond_state - e_t_uncond_state)
|
||||
model_output = torch.lerp(e_t_uncond, e_t_cond,
|
||||
unconditional_guidance_scale)
|
||||
model_output_action = torch.lerp(e_t_uncond_action,
|
||||
e_t_cond_action,
|
||||
unconditional_guidance_scale)
|
||||
model_output_state = torch.lerp(e_t_uncond_state, e_t_cond_state,
|
||||
unconditional_guidance_scale)
|
||||
|
||||
if guidance_rescale > 0.0:
|
||||
model_output = rescale_noise_cfg(
|
||||
@@ -396,18 +392,28 @@ class DDIMSampler(object):
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
if use_original_steps:
|
||||
sqrt_alphas = alphas.sqrt()
|
||||
sqrt_alphas_prev = alphas_prev.sqrt()
|
||||
dir_coeffs = (1. - alphas_prev - sigmas**2).sqrt()
|
||||
else:
|
||||
sqrt_alphas = self.ddim_sqrt_alphas
|
||||
sqrt_alphas_prev = self.ddim_sqrt_alphas_prev
|
||||
dir_coeffs = self.ddim_dir_coeff
|
||||
|
||||
if is_video:
|
||||
size = (1, 1, 1, 1, 1)
|
||||
else:
|
||||
size = (1, 1, 1, 1)
|
||||
|
||||
a_t = alphas[index].view(size)
|
||||
a_prev = alphas_prev[index].view(size)
|
||||
sqrt_at = sqrt_alphas[index].view(size)
|
||||
sqrt_a_prev = sqrt_alphas_prev[index].view(size)
|
||||
sigma_t = sigmas[index].view(size)
|
||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].view(size)
|
||||
dir_coeff = dir_coeffs[index].view(size)
|
||||
|
||||
if self.model.parameterization != "v":
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / sqrt_at
|
||||
else:
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
@@ -420,14 +426,11 @@ class DDIMSampler(object):
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
|
||||
noise = sigma_t * noise_like(x.shape, device,
|
||||
repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
x_prev = sqrt_a_prev * pred_x0 + dir_coeff * e_t + noise
|
||||
|
||||
return x_prev, pred_x0, model_output_action, model_output_state
|
||||
|
||||
@@ -475,7 +478,7 @@ class DDIMSampler(object):
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_alphas_cumprod = self.ddim_sqrt_alphas
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
|
||||
@@ -10,8 +10,8 @@ from unifolm_wma.utils.utils import instantiate_from_config
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
# swish / SiLU — single fused CUDA kernel instead of x * sigmoid(x)
|
||||
return torch.nn.functional.silu(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
|
||||
@@ -7,79 +7,97 @@ MACRO-LEVEL TIMING SUMMARY
|
||||
----------------------------------------
|
||||
Section Count Total(ms) Avg(ms) CUDA Avg(ms)
|
||||
--------------------------------------------------------------------------------------
|
||||
action_generation 11 399707.47 36337.04 36336.85
|
||||
data_loading 1 52.85 52.85 52.88
|
||||
get_latent_z/encode 22 901.39 40.97 41.01
|
||||
iteration_total 11 836793.23 76072.11 76071.63
|
||||
load_transitions 1 2.24 2.24 2.28
|
||||
model_loading/checkpoint 1 11833.31 11833.31 11833.43
|
||||
model_loading/config 1 49774.19 49774.19 49774.16
|
||||
model_to_cuda 1 8909.30 8909.30 8909.33
|
||||
prepare_init_input 1 10.52 10.52 10.55
|
||||
prepare_observation 11 5.41 0.49 0.53
|
||||
prepare_wm_observation 11 2.12 0.19 0.22
|
||||
save_results 11 38668.06 3515.28 3515.32
|
||||
synthesis/conditioning_prep 22 2916.63 132.57 132.61
|
||||
synthesis/ddim_sampling 22 782695.01 35577.05 35576.86
|
||||
synthesis/decode_first_stage 22 12444.31 565.65 565.70
|
||||
update_action_queues 11 6.85 0.62 0.65
|
||||
update_state_queues 11 17.67 1.61 1.64
|
||||
world_model_interaction 11 398375.58 36215.96 36215.75
|
||||
action_generation 11 173133.54 15739.41 15739.36
|
||||
data_loading 1 54.31 54.31 54.34
|
||||
get_latent_z/encode 22 785.25 35.69 35.72
|
||||
iteration_total 11 386482.08 35134.73 35134.55
|
||||
load_transitions 1 2.07 2.07 2.10
|
||||
model_loading/prepared 1 4749.22 4749.22 4749.83
|
||||
prepare_init_input 1 29.19 29.19 29.22
|
||||
prepare_observation 11 5.49 0.50 0.53
|
||||
prepare_wm_observation 11 1.93 0.18 0.20
|
||||
save_results 11 38791.18 3526.47 3526.51
|
||||
synthesis/conditioning_prep 22 2528.23 114.92 114.95
|
||||
synthesis/ddim_sampling 22 336003.29 15272.88 15272.83
|
||||
synthesis/decode_first_stage 22 9095.14 413.42 413.46
|
||||
update_action_queues 11 7.28 0.66 0.69
|
||||
update_state_queues 11 17.38 1.58 1.61
|
||||
world_model_interaction 11 174516.52 15865.14 15865.07
|
||||
--------------------------------------------------------------------------------------
|
||||
TOTAL 2543116.13
|
||||
TOTAL 1126202.08
|
||||
|
||||
----------------------------------------
|
||||
GPU MEMORY SUMMARY
|
||||
----------------------------------------
|
||||
Peak allocated: 17890.50 MB
|
||||
Average allocated: 16129.98 MB
|
||||
Peak allocated: 18188.29 MB
|
||||
Average allocated: 9117.49 MB
|
||||
|
||||
----------------------------------------
|
||||
TOP 30 OPERATORS BY CUDA TIME
|
||||
----------------------------------------
|
||||
Operator Count CUDA(ms) CPU(ms) Self CUDA(ms)
|
||||
------------------------------------------------------------------------------------------------
|
||||
ProfilerStep* 6 443804.16 237696.98 237689.25
|
||||
aten::linear 171276 112286.23 13179.82 0.00
|
||||
aten::addmm 81456 79537.36 3799.84 79296.37
|
||||
ampere_sgemm_128x64_tn 26400 52052.10 0.00 52052.10
|
||||
aten::matmul 90468 34234.05 6281.32 0.00
|
||||
aten::_convolution 100242 33623.79 13105.89 0.00
|
||||
aten::mm 89820 33580.74 3202.22 33253.18
|
||||
aten::convolution 100242 33575.23 13714.47 0.00
|
||||
aten::cudnn_convolution 98430 30932.19 8640.50 29248.12
|
||||
ampere_sgemm_32x128_tn 42348 20394.52 0.00 20394.52
|
||||
aten::conv2d 42042 18115.35 5932.30 0.00
|
||||
ampere_sgemm_128x32_tn 40938 16429.81 0.00 16429.81
|
||||
xformers::efficient_attention_forward_cutlass 24000 15222.23 2532.93 15120.44
|
||||
fmha_cutlassF_f32_aligned_64x64_rf_sm80(Attenti... 24000 15121.31 0.00 15121.31
|
||||
ampere_sgemm_64x64_tn 21000 14627.12 0.00 14627.12
|
||||
aten::copy_ 231819 14504.87 127056.51 14038.39
|
||||
aten::group_norm 87144 12033.73 10659.57 0.00
|
||||
aten::native_group_norm 87144 11473.40 9449.36 11002.02
|
||||
aten::conv3d 26400 8852.13 3365.43 0.00
|
||||
void at::native::(anonymous namespace)::Rowwise... 87144 8714.68 0.00 8714.68
|
||||
void cudnn::ops::nchwToNhwcKernel<float, float,... 169824 8525.44 0.00 8525.44
|
||||
aten::clone 214314 8200.26 8568.82 0.00
|
||||
void at::native::elementwise_kernel<128, 2, at:... 220440 8109.62 0.00 8109.62
|
||||
void cutlass::Kernel<cutlass_80_simt_sgemm_128x... 15000 7919.30 0.00 7919.30
|
||||
aten::_to_copy 12219 5963.43 122411.53 0.00
|
||||
aten::to 58101 5952.65 122443.72 0.00
|
||||
aten::conv1d 30000 5878.95 4556.48 0.00
|
||||
Memcpy HtoD (Pageable -> Device) 6696 5856.39 0.00 5856.39
|
||||
aten::reshape 671772 5124.03 9636.01 0.00
|
||||
sm80_xmma_fprop_implicit_gemm_indexed_tf32f32_t... 16272 5097.70 0.00 5097.70
|
||||
ProfilerStep* 18 690146.23 133688.74 616385.44
|
||||
aten::group_norm 168624 24697.84 29217.27 0.00
|
||||
aten::_convolution 96450 21420.26 12845.86 0.00
|
||||
aten::convolution 96450 21408.68 13480.97 0.00
|
||||
aten::linear 297398 20780.15 26257.38 0.00
|
||||
aten::cudnn_convolution 94638 18660.24 8239.04 18329.28
|
||||
aten::copy_ 772677 18135.46 17387.09 17864.87
|
||||
aten::conv3d 52800 12922.42 8572.58 0.00
|
||||
aten::conv2d 52469 12747.13 7725.70 0.00
|
||||
aten::native_group_norm 84312 10285.37 8974.31 10197.66
|
||||
aten::_to_copy 590277 10270.09 22570.90 0.00
|
||||
aten::to 602979 9655.26 23666.06 0.00
|
||||
aten::conv1d 56245 8174.37 10015.24 0.00
|
||||
void at::native::(anonymous namespace)::Rowwise... 84312 7979.71 0.00 7979.71
|
||||
aten::clone 177132 7502.90 7007.48 0.00
|
||||
void cudnn::ops::nchwToNhwcKernel<__nv_bfloat16... 164700 7384.52 0.00 7384.52
|
||||
aten::addmm 81456 6958.44 3903.01 6908.44
|
||||
aten::layer_norm 65700 5698.92 7816.08 0.00
|
||||
void at::native::elementwise_kernel<128, 4, at:... 149688 5372.46 0.00 5372.46
|
||||
void at::native::unrolled_elementwise_kernel<at... 180120 5165.28 0.00 5165.28
|
||||
ampere_bf16_s16816gemm_bf16_128x128_ldg8_relu_f... 24900 4449.05 0.00 4449.05
|
||||
void at::native::unrolled_elementwise_kernel<at... 368664 4405.30 0.00 4405.30
|
||||
aten::reshape 686778 3771.84 8309.51 0.00
|
||||
aten::contiguous 46008 3400.88 1881.73 0.00
|
||||
sm80_xmma_fprop_implicit_gemm_bf16bf16_bf16f32_... 15516 3398.03 0.00 3398.03
|
||||
aten::matmul 90489 3366.62 4946.69 0.00
|
||||
aten::mm 89820 3284.53 3308.76 3228.56
|
||||
void at::native::elementwise_kernel<128, 2, at:... 46518 2441.55 0.00 2441.55
|
||||
aten::add 113118 2426.66 2776.23 2385.52
|
||||
void at::native::elementwise_kernel<128, 4, at:... 104550 2426.41 0.00 2426.41
|
||||
|
||||
----------------------------------------
|
||||
OPERATOR CATEGORY BREAKDOWN
|
||||
----------------------------------------
|
||||
Category CUDA Time(ms) Percentage
|
||||
---------------------------------------------------------
|
||||
Other 481950.47 41.9%
|
||||
Linear/GEMM 342333.09 29.8%
|
||||
Convolution 159920.77 13.9%
|
||||
Elementwise 54682.93 4.8%
|
||||
Memory 36883.36 3.2%
|
||||
Attention 34736.13 3.0%
|
||||
Normalization 32081.19 2.8%
|
||||
Activation 6449.19 0.6%
|
||||
Other 723472.91 71.9%
|
||||
Convolution 114469.81 11.4%
|
||||
Memory 53845.46 5.4%
|
||||
Normalization 46852.57 4.7%
|
||||
Linear/GEMM 35354.58 3.5%
|
||||
Elementwise 17078.44 1.7%
|
||||
Activation 12296.29 1.2%
|
||||
Attention 2956.61 0.3%
|
||||
|
||||
------------------------------------------------------------------------------------------
|
||||
aten::addmm (Linear/GEMM) UTILIZATION ANALYSIS ON A100
|
||||
------------------------------------------------------------------------------------------
|
||||
Effective compute precision: BF16 Tensor Core (312 TFLOPS)
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
Metric Value
|
||||
-------------------------------------------------------------------
|
||||
Total aten::addmm calls 81,456
|
||||
Total Self CUDA time 6908.44 ms
|
||||
Total FLOPs (profiler) 1.33 PFLOPS
|
||||
Achieved throughput 191.88 TFLOPS/s
|
||||
A100 peak throughput 312.00 TFLOPS/s
|
||||
MFU (Model FLOPs Utilization) 61.50%
|
||||
|
||||
INTERPRETATION:
|
||||
-------------------------------------------------------------------
|
||||
Good utilization (>60%). GEMM kernels are compute-bound
|
||||
and running efficiently on Tensor Cores.
|
||||
|
||||
@@ -2,7 +2,7 @@ res_dir="unitree_g1_pack_camera/case1"
|
||||
dataset="unitree_g1_pack_camera"
|
||||
|
||||
{
|
||||
time CUDA_VISIBLE_DEVICES=1 python3 scripts/evaluation/world_model_interaction.py \
|
||||
time CUDA_VISIBLE_DEVICES=0 python3 scripts/evaluation/world_model_interaction.py \
|
||||
--seed 123 \
|
||||
--ckpt_path ckpts/unifolm_wma_dual_mixbf16.ckpt \
|
||||
--config configs/inference/world_model_interaction.yaml \
|
||||
@@ -23,5 +23,6 @@ dataset="unitree_g1_pack_camera"
|
||||
--perframe_ae \
|
||||
--diffusion_dtype bf16 \
|
||||
--projector_mode bf16_full \
|
||||
--encoder_mode bf16_full
|
||||
--encoder_mode bf16_full \
|
||||
--vae_dtype bf16
|
||||
} 2>&1 | tee "${res_dir}/output.log"
|
||||
|
||||
Reference in New Issue
Block a user