Files
unifolm-world-model-action/profile_unet_flops.md
olivame a2cd34dd51 1. einsum('b i d, b j d -> b i j') → torch.bmm(q, k.transpose(-1,-2)) — 直接映射 rocBLAS batched GEMM
2. baddbmm 把 scale 融合进 GEMM,少一次 kernel launch
3. 第二个 einsum 同理换torch.bm
每一轮加速1到两秒
2026-02-08 18:54:48 +00:00

15 KiB

TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 CUDA_VISIBLE_DEVICES=0 python scripts/evaluation/profile_unet.py --ckpt_path ckpts/unifolm_wma_dual_mix_bf16.ckpt --config configs/inference/world_model_interaction.yaml

================================================================================================================================== FLOPS BY ATen OPERATOR (FlopCounterMode)

              ATen Op |       GFLOPS | % of Total

          convolution |      6185.17 |      46.4%
                addmm |      4411.17 |      33.1%
                   mm |      1798.34 |      13.5%
                  bmm |       949.54 |       7.1%

================================================================================================================================== FLOPS BY MODULE (FlopCounterMode)

                                                  Module |       GFLOPS | % of Total

                                                  Global |     13344.23 |     100.0%
                                        DiffusionWrapper |     13344.23 |     100.0%
                        DiffusionWrapper.diffusion_model |     13344.23 |     100.0%
        DiffusionWrapper.diffusion_model.output_blocks.8 |       997.87 |       7.5%
        DiffusionWrapper.diffusion_model.output_blocks.5 |       992.91 |       7.4%
        DiffusionWrapper.diffusion_model.output_blocks.9 |       941.81 |       7.1%
       DiffusionWrapper.diffusion_model.output_blocks.10 |       857.93 |       6.4%
       DiffusionWrapper.diffusion_model.output_blocks.11 |       857.93 |       6.4%
        DiffusionWrapper.diffusion_model.output_blocks.6 |       821.71 |       6.2%
         DiffusionWrapper.diffusion_model.input_blocks.1 |       765.65 |       5.7%
         DiffusionWrapper.diffusion_model.input_blocks.2 |       765.65 |       5.7%
        DiffusionWrapper.diffusion_model.output_blocks.7 |       737.82 |       5.5%
        DiffusionWrapper.diffusion_model.output_blocks.3 |       732.87 |       5.5%
        DiffusionWrapper.diffusion_model.output_blocks.4 |       732.87 |       5.5%
         DiffusionWrapper.diffusion_model.input_blocks.5 |       645.55 |       4.8%
         DiffusionWrapper.diffusion_model.input_blocks.8 |       640.59 |       4.8%
         DiffusionWrapper.diffusion_model.input_blocks.4 |       611.99 |       4.6%
         DiffusionWrapper.diffusion_model.input_blocks.7 |       607.04 |       4.5%
            DiffusionWrapper.diffusion_model.init_attn.0 |       459.02 |       3.4%
              DiffusionWrapper.diffusion_model.init_attn |       459.02 |       3.4%

nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2% DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2% DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2% DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6% DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5% DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%

================================================================================================================================== SUMMARY

Total CUDA time: 761.4 ms Matmul CUDA time: 404.2 ms (53.1%) Non-matmul CUDA time: 357.1 ms (46.9%) Total FLOPS (FlopCounter): 13344.23 GFLOPS Matmul throughput: 33.01 TFLOPS/s (54.1% of BF16 peak) Overall throughput: 17.53 TFLOPS/s (28.7% of BF16 peak) GPU peak (BF16): 61.0 TFLOPS

================================================================================================================================== FLOPS BY ATen OPERATOR (FlopCounterMode)

              ATen Op |       GFLOPS | % of Total

          convolution |      6185.17 |      46.4%
                addmm |      4411.17 |      33.1%
                   mm |      1798.34 |      13.5%
                  bmm |       949.54 |       7.1%

================================================================================================================================== FLOPS BY MODULE (FlopCounterMode)

                                                  Module |       GFLOPS | % of Total

                                        DiffusionWrapper |     13344.23 |     100.0%
                                                  Global |     13344.23 |     100.0%
                        DiffusionWrapper.diffusion_model |     13344.23 |     100.0%
        DiffusionWrapper.diffusion_model.output_blocks.8 |       997.87 |       7.5%
        DiffusionWrapper.diffusion_model.output_blocks.5 |       992.91 |       7.4%
        DiffusionWrapper.diffusion_model.output_blocks.9 |       941.81 |       7.1%
       DiffusionWrapper.diffusion_model.output_blocks.10 |       857.93 |       6.4%
       DiffusionWrapper.diffusion_model.output_blocks.11 |       857.93 |       6.4%
        DiffusionWrapper.diffusion_model.output_blocks.6 |       821.71 |       6.2%
         DiffusionWrapper.diffusion_model.input_blocks.1 |       765.65 |       5.7%
         DiffusionWrapper.diffusion_model.input_blocks.2 |       765.65 |       5.7%
        DiffusionWrapper.diffusion_model.output_blocks.7 |       737.82 |       5.5%
        DiffusionWrapper.diffusion_model.output_blocks.3 |       732.87 |       5.5%
        DiffusionWrapper.diffusion_model.output_blocks.4 |       732.87 |       5.5%
         DiffusionWrapper.diffusion_model.input_blocks.5 |       645.55 |       4.8%
         DiffusionWrapper.diffusion_model.input_blocks.8 |       640.59 |       4.8%
         DiffusionWrapper.diffusion_model.input_blocks.4 |       611.99 |       4.6%
         DiffusionWrapper.diffusion_model.input_blocks.7 |       607.04 |       4.5%
              DiffusionWrapper.diffusion_model.init_attn |       459.02 |       3.4%
            DiffusionWrapper.diffusion_model.init_attn.0 |       459.02 |       3.4%

nWrapper.diffusion_model.init_attn.0.transformer_blocks.0 | 432.18 | 3.2% DiffusionWrapper.diffusion_model.output_blocks.6.0 | 427.85 | 3.2% DiffusionWrapper.diffusion_model.output_blocks.9.0 | 427.83 | 3.2% DiffusionWrapper.diffusion_model.output_blocks.3.0 | 343.99 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.4.0 | 343.99 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.7.0 | 343.96 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.10.0 | 343.95 | 2.6% DiffusionWrapper.diffusion_model.output_blocks.11.0 | 343.95 | 2.6% DiffusionWrapper.diffusion_model.input_blocks.1.1 | 327.75 | 2.5% DiffusionWrapper.diffusion_model.input_blocks.2.1 | 327.75 | 2.5%

================================================================================================================================== SUMMARY

Total CUDA time: 707.1 ms Matmul CUDA time: 403.1 ms (57.0%) Non-matmul CUDA time: 304.0 ms (43.0%) Total FLOPS (FlopCounter): 13344.23 GFLOPS Matmul throughput: 33.11 TFLOPS/s (54.3% of BF16 peak) Overall throughput: 18.87 TFLOPS/s (30.9% of BF16 peak) GPU peak (BF16): 61.0 TFLOPS (unifolm-wma) ASC1637@wx-ms-w7900d-0033:/mnt/ASC1637/unifolm-world-model-action$

======================================================================== TABLE 1: STAGE TIMING

Stage Mean(ms) Std %

1_Image_Embedding 29.5 0.16 0.1% 2_VAE_Encode 51.3 0.06 0.1% 3_Text_Conditioning 14.7 0.18 0.0% 4_Projectors 0.2 0.03 0.0% 5_DDIM_Loop 33392.5 3.21 97.3% 6_VAE_Decode 808.4 1.00 2.4% 7_Post_Process 15.8 0.56 0.0%

TOTAL 34312.4

================================================================================ TABLE 2: UNET SUB-MODULE BREAKDOWN

Module Type Total(ms) Count Per-call %

ResBlock 10256.3 1100 9.32 23.2% SpatialTransformer 9228.2 800 11.54 20.9% CrossAttention 8105.8 3300 2.46 18.3% ConditionalUnet1D 6409.5 100 64.10 14.5% TemporalTransformer 5847.0 850 6.88 13.2% FeedForward 4338.1 1650 2.63 9.8% UNet.out 73.8 50 1.48 0.2%

TOTAL (hooked) 44258.7

========================================================================================== TABLE 2b: PER-UNET-BLOCK TIMING (coarse modules, per DDIM loop)

Block Total(ms) % Breakdown

input_blocks.1 3376.2 7.6% SpatialTransformer=1101, CrossAttention=990, ResBlock=543, TemporalTransformer=454, FeedForward=288 input_blocks.2 3374.0 7.6% SpatialTransformer=1100, CrossAttention=991, ResBlock=540, TemporalTransformer=455, FeedForward=288 input_blocks.4 1592.4 3.6% SpatialTransformer=394, ResBlock=374, CrossAttention=303, TemporalTransformer=272, FeedForward=249 input_blocks.5 1642.5 3.7% ResBlock=425, SpatialTransformer=397, CrossAttention=303, TemporalTransformer=271, FeedForward=247 input_blocks.7 1469.0 3.3% ResBlock=416, SpatialTransformer=324, FeedForward=251, CrossAttention=240, TemporalTransformer=237 input_blocks.8 1543.7 3.5% ResBlock=491, SpatialTransformer=325, FeedForward=250, CrossAttention=240, TemporalTransformer=238 input_blocks.10 217.5 0.5% ResBlock=218 input_blocks.11 216.8 0.5% ResBlock=217 middle_block 848.9 1.9% ResBlock=434, SpatialTransformer=151, CrossAttention=134, TemporalTransformer=69, FeedForward=61 output_blocks.0 303.2 0.7% ResBlock=303 output_blocks.1 303.1 0.7% ResBlock=303 output_blocks.2 302.8 0.7% ResBlock=303 output_blocks.3 1734.8 3.9% ResBlock=687, SpatialTransformer=322, FeedForward=249, CrossAttention=239, TemporalTransformer=237 output_blocks.4 1739.8 3.9% ResBlock=688, SpatialTransformer=323, FeedForward=251, CrossAttention=239, TemporalTransformer=238 output_blocks.5 1622.3 3.7% ResBlock=570, SpatialTransformer=324, FeedForward=251, CrossAttention=239, TemporalTransformer=238 output_blocks.6 1881.0 4.3% ResBlock=664, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=250 output_blocks.7 1768.0 4.0% ResBlock=554, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249 output_blocks.8 1688.7 3.8% ResBlock=474, SpatialTransformer=393, CrossAttention=301, TemporalTransformer=272, FeedForward=249 output_blocks.9 3558.6 8.0% SpatialTransformer=1096, CrossAttention=992, ResBlock=727, TemporalTransformer=454, FeedForward=290 output_blocks.10 3492.8 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289 output_blocks.11 3493.3 7.9% SpatialTransformer=1096, CrossAttention=992, ResBlock=662, TemporalTransformer=454, FeedForward=289 out 73.8 0.2% UNet.out=74 action_unet 3212.0 7.3% ConditionalUnet1D=3212 state_unet 3197.6 7.2% ConditionalUnet1D=3198 other 1606.2 3.6% TemporalTransformer=960, FeedForward=337, CrossAttention=309

TOTAL 44258.7

====================================================================== TABLE 2c: ATTENTION vs FEEDFORWARD (deep hooks)

Component Total(ms) %

CrossAttention 8105.8 65.1% FeedForward 4338.1 34.9%

TOTAL (attn+ff) 12443.9

================================================== TABLE 3: MEMORY SUMMARY

Initial allocated: 11.82 GB Peak allocated: 14.43 GB Delta (pipeline): 2.61 GB

============================================================ TABLE 4: THROUGHPUT

Total pipeline latency: 34312.4 ms DDIM loop latency: 33392.5 ms DDIM steps: 50 CFG scale: 1.0 (1x UNet/step) UNet forward calls: 50 Per DDIM step: 667.9 ms Per UNet forward: 667.9 ms VAE encode bandwidth: 0.1 GB/s (peak HBM: 864.0 GB/s) VAE decode bandwidth: 0.0 GB/s (peak HBM: 864.0 GB/s) GPU BF16 peak: 61.0 TFLOPS

Done.