amd构建说明

This commit is contained in:
qihuanye
2026-05-14 03:52:50 +00:00
parent f08f2b82f4
commit 02c3cea3f9

241
AMD_SETUP.md Normal file
View File

@@ -0,0 +1,241 @@
# AMD ROCm 环境配置说明
这份文档记录了在 AMD ROCm 环境下运行 LeWM 的可复现配置,重点是保留
`torch.compile` 时的 PyTorch 版本选择。
目标运行命令:
```bash
python eval.py --config-name=pusht.yaml policy=pusht/lewm
```
## 已验证环境
本次验证通过的环境:
- Ubuntu 24.04
- AMD Radeon PRO W7900D (`gfx1100`)
- 系统 ROCm 7.1.1
- Python 3.10
- `torch==2.10.0+rocm7.1`
- `torchvision==0.25.0+rocm7.1`
- `triton-rocm==3.6.0`
注意:`torch==2.12.0+rocm7.1` 可以正常导入,也能识别 GPU但在本项目里开启
`torch.compile` 后会崩溃,错误类似:
```text
HSA_STATUS_ERROR_INVALID_PACKET_FORMAT
CUDA error: unspecified launch failure
```
降级到 `torch==2.10.0+rocm7.1` 后,`torch.compile` 路径可以正常跑通。
## 检查系统 ROCm
在新 AMD 机器上,先确认系统能识别 GPU
```bash
rocminfo
amd-smi version
hipcc --version
```
`rocminfo` 里应该能看到 AMD GPU agent例如 `gfx1100`
## 创建 Python 环境
使用 `uv` 创建 Python 3.10 虚拟环境:
```bash
cd /path/to/lewm
uv venv --python 3.10 --allow-existing .venv
source .venv/bin/activate
```
给 uv 创建的 venv 补上 pip。ROCm 版 PyTorch wheel 很大,如果 uv 解析或下载卡住,
用 pip 安装大 wheel 更容易观察进度。
```bash
uv pip install pip
```
## 安装 ROCm 版 PyTorch
安装本项目已验证可用的 ROCm PyTorch 组合:
```bash
python -m pip install --force-reinstall \
--index-url https://download.pytorch.org/whl/rocm7.1 \
--extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
"torch==2.10.0" \
"torchvision==0.25.0"
```
PyTorch wheel 有数 GB。如果网络慢不要频繁中断重试尽量等它下载完成。
## 安装项目依赖
普通 Python 包建议走国内 PyPI 镜像:
```bash
python -m pip install \
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
"gymnasium[all]==1.2.2" \
"stable-baselines3==2.8.0" \
"stable-worldmodel[train,env]"
```
然后修正两个容易被 pip 带偏的依赖版本:
```bash
python -m pip install \
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
"fsspec==2025.3.0" \
"pillow==11.3.0"
```
检查环境:
```bash
python -m pip check
python - <<'PY'
import torch
import torchvision
print("torch:", torch.__version__)
print("hip:", torch.version.hip)
print("cuda available:", torch.cuda.is_available())
print("device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)
print("torchvision:", torchvision.__version__)
PY
```
期望看到类似输出:
```text
torch: 2.10.0+rocm7.1
cuda available: True
torchvision: 0.25.0+rocm7.1
```
## 恢复本仓库里的 stable-worldmodel 修改
这个仓库把一些本地修改后的 `stable_worldmodel` 文件纳入了 git 管控,路径在:
```text
.venv/lib/python3.10/site-packages/stable_worldmodel/
```
从 PyPI 安装 `stable-worldmodel` 时可能会覆盖这些文件。安装依赖后执行:
```bash
git restore -- \
.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py \
.venv/lib/python3.10/site-packages/stable_worldmodel/world.py \
.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py \
.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py
```
然后确认没有意外修改:
```bash
git status --short
```
## 数据和 checkpoint 路径
`eval.py` 会从 `$STABLEWM_HOME` 里找数据和 checkpoint。
PushT 评估至少需要:
```text
$STABLEWM_HOME/pusht_expert_train.h5
$STABLEWM_HOME/pusht/lewm_object.ckpt
```
例如本机使用:
```bash
export STABLEWM_HOME=/mnt/ASC1637/stablewm
```
如果没有正确设置,运行时会报找不到 `pusht_expert_train.h5`
## 运行评估
默认 PushT 评估,保留 `torch.compile`
```bash
export STABLEWM_HOME=/path/to/stablewm
python eval.py --config-name=pusht.yaml policy=pusht/lewm
```
快速 smoke test
```bash
export STABLEWM_HOME=/path/to/stablewm
python eval.py --config-name=pusht.yaml policy=pusht/lewm \
eval.num_eval=1 \
world.num_envs=1 \
output.filename=/tmp/lewm_smoke_test.txt
```
smoke test 应该能正常结束,并打印类似:
```text
{'success_rate': 100.0, ...}
```
## 常见问题
### `HSA_STATUS_ERROR_INVALID_PACKET_FORMAT`
如果开启 `torch.compile` 时出现这个错误,先检查 torch 版本:
```bash
python -c "import torch; print(torch.__version__, torch.version.hip)"
```
如果是 `2.12.0+rocm7.1`,建议降级到本项目验证通过的组合:
```bash
python -m pip install --force-reinstall \
--index-url https://download.pytorch.org/whl/rocm7.1 \
--extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
"torch==2.10.0" \
"torchvision==0.25.0"
```
### 找不到 `pusht_expert_train.h5`
设置 `STABLEWM_HOME` 到包含数据和 checkpoint 的目录:
```bash
export STABLEWM_HOME=/path/to/stablewm
```
### pip 尝试构建旧版 `gym==0.21`
这是依赖解析回退导致的。先显式安装兼容版本:
```bash
python -m pip install \
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
"gymnasium[all]==1.2.2" \
"stable-baselines3==2.8.0"
```
### uv 或 pip 访问海外源很慢
普通 Python 包使用国内 PyPI 镜像:
```bash
--index-url https://pypi.tuna.tsinghua.edu.cn/simple
```
PyTorch ROCm wheel 继续使用 PyTorch 官方 ROCm 源:
```bash
--index-url https://download.pytorch.org/whl/rocm7.1
```