From 02c3cea3f94018667ce00d307d0b7fc51e07bbf7 Mon Sep 17 00:00:00 2001 From: qihuanye Date: Thu, 14 May 2026 03:52:50 +0000 Subject: [PATCH] =?UTF-8?q?amd=E6=9E=84=E5=BB=BA=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AMD_SETUP.md | 241 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 241 insertions(+) create mode 100644 AMD_SETUP.md diff --git a/AMD_SETUP.md b/AMD_SETUP.md new file mode 100644 index 0000000..5b3f438 --- /dev/null +++ b/AMD_SETUP.md @@ -0,0 +1,241 @@ +# AMD ROCm 环境配置说明 + +这份文档记录了在 AMD ROCm 环境下运行 LeWM 的可复现配置,重点是保留 +`torch.compile` 时的 PyTorch 版本选择。 + +目标运行命令: + +```bash +python eval.py --config-name=pusht.yaml policy=pusht/lewm +``` + +## 已验证环境 + +本次验证通过的环境: + +- Ubuntu 24.04 +- AMD Radeon PRO W7900D (`gfx1100`) +- 系统 ROCm 7.1.1 +- Python 3.10 +- `torch==2.10.0+rocm7.1` +- `torchvision==0.25.0+rocm7.1` +- `triton-rocm==3.6.0` + +注意:`torch==2.12.0+rocm7.1` 可以正常导入,也能识别 GPU,但在本项目里开启 +`torch.compile` 后会崩溃,错误类似: + +```text +HSA_STATUS_ERROR_INVALID_PACKET_FORMAT +CUDA error: unspecified launch failure +``` + +降级到 `torch==2.10.0+rocm7.1` 后,`torch.compile` 路径可以正常跑通。 + +## 检查系统 ROCm + +在新 AMD 机器上,先确认系统能识别 GPU: + +```bash +rocminfo +amd-smi version +hipcc --version +``` + +`rocminfo` 里应该能看到 AMD GPU agent,例如 `gfx1100`。 + +## 创建 Python 环境 + +使用 `uv` 创建 Python 3.10 虚拟环境: + +```bash +cd /path/to/lewm +uv venv --python 3.10 --allow-existing .venv +source .venv/bin/activate +``` + +给 uv 创建的 venv 补上 pip。ROCm 版 PyTorch wheel 很大,如果 uv 解析或下载卡住, +用 pip 安装大 wheel 更容易观察进度。 + +```bash +uv pip install pip +``` + +## 安装 ROCm 版 PyTorch + +安装本项目已验证可用的 ROCm PyTorch 组合: + +```bash +python -m pip install --force-reinstall \ + --index-url https://download.pytorch.org/whl/rocm7.1 \ + --extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \ + "torch==2.10.0" \ + "torchvision==0.25.0" +``` + +PyTorch wheel 有数 GB。如果网络慢,不要频繁中断重试,尽量等它下载完成。 + +## 安装项目依赖 + +普通 Python 包建议走国内 PyPI 镜像: + +```bash +python -m pip install \ + --index-url https://pypi.tuna.tsinghua.edu.cn/simple \ + "gymnasium[all]==1.2.2" \ + "stable-baselines3==2.8.0" \ + "stable-worldmodel[train,env]" +``` + +然后修正两个容易被 pip 带偏的依赖版本: + +```bash +python -m pip install \ + --index-url https://pypi.tuna.tsinghua.edu.cn/simple \ + "fsspec==2025.3.0" \ + "pillow==11.3.0" +``` + +检查环境: + +```bash +python -m pip check +python - <<'PY' +import torch +import torchvision + +print("torch:", torch.__version__) +print("hip:", torch.version.hip) +print("cuda available:", torch.cuda.is_available()) +print("device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None) +print("torchvision:", torchvision.__version__) +PY +``` + +期望看到类似输出: + +```text +torch: 2.10.0+rocm7.1 +cuda available: True +torchvision: 0.25.0+rocm7.1 +``` + +## 恢复本仓库里的 stable-worldmodel 修改 + +这个仓库把一些本地修改后的 `stable_worldmodel` 文件纳入了 git 管控,路径在: + +```text +.venv/lib/python3.10/site-packages/stable_worldmodel/ +``` + +从 PyPI 安装 `stable-worldmodel` 时可能会覆盖这些文件。安装依赖后执行: + +```bash +git restore -- \ + .venv/lib/python3.10/site-packages/stable_worldmodel/policy.py \ + .venv/lib/python3.10/site-packages/stable_worldmodel/world.py \ + .venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py \ + .venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py +``` + +然后确认没有意外修改: + +```bash +git status --short +``` + +## 数据和 checkpoint 路径 + +`eval.py` 会从 `$STABLEWM_HOME` 里找数据和 checkpoint。 + +PushT 评估至少需要: + +```text +$STABLEWM_HOME/pusht_expert_train.h5 +$STABLEWM_HOME/pusht/lewm_object.ckpt +``` + +例如本机使用: + +```bash +export STABLEWM_HOME=/mnt/ASC1637/stablewm +``` + +如果没有正确设置,运行时会报找不到 `pusht_expert_train.h5`。 + +## 运行评估 + +默认 PushT 评估,保留 `torch.compile`: + +```bash +export STABLEWM_HOME=/path/to/stablewm +python eval.py --config-name=pusht.yaml policy=pusht/lewm +``` + +快速 smoke test: + +```bash +export STABLEWM_HOME=/path/to/stablewm +python eval.py --config-name=pusht.yaml policy=pusht/lewm \ + eval.num_eval=1 \ + world.num_envs=1 \ + output.filename=/tmp/lewm_smoke_test.txt +``` + +smoke test 应该能正常结束,并打印类似: + +```text +{'success_rate': 100.0, ...} +``` + +## 常见问题 + +### `HSA_STATUS_ERROR_INVALID_PACKET_FORMAT` + +如果开启 `torch.compile` 时出现这个错误,先检查 torch 版本: + +```bash +python -c "import torch; print(torch.__version__, torch.version.hip)" +``` + +如果是 `2.12.0+rocm7.1`,建议降级到本项目验证通过的组合: + +```bash +python -m pip install --force-reinstall \ + --index-url https://download.pytorch.org/whl/rocm7.1 \ + --extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \ + "torch==2.10.0" \ + "torchvision==0.25.0" +``` + +### 找不到 `pusht_expert_train.h5` + +设置 `STABLEWM_HOME` 到包含数据和 checkpoint 的目录: + +```bash +export STABLEWM_HOME=/path/to/stablewm +``` + +### pip 尝试构建旧版 `gym==0.21` + +这是依赖解析回退导致的。先显式安装兼容版本: + +```bash +python -m pip install \ + --index-url https://pypi.tuna.tsinghua.edu.cn/simple \ + "gymnasium[all]==1.2.2" \ + "stable-baselines3==2.8.0" +``` + +### uv 或 pip 访问海外源很慢 + +普通 Python 包使用国内 PyPI 镜像: + +```bash +--index-url https://pypi.tuna.tsinghua.edu.cn/simple +``` + +PyTorch ROCm wheel 继续使用 PyTorch 官方 ROCm 源: + +```bash +--index-url https://download.pytorch.org/whl/rocm7.1 +```