amd构建说明
This commit is contained in:
241
AMD_SETUP.md
Normal file
241
AMD_SETUP.md
Normal file
@@ -0,0 +1,241 @@
|
||||
# AMD ROCm 环境配置说明
|
||||
|
||||
这份文档记录了在 AMD ROCm 环境下运行 LeWM 的可复现配置,重点是保留
|
||||
`torch.compile` 时的 PyTorch 版本选择。
|
||||
|
||||
目标运行命令:
|
||||
|
||||
```bash
|
||||
python eval.py --config-name=pusht.yaml policy=pusht/lewm
|
||||
```
|
||||
|
||||
## 已验证环境
|
||||
|
||||
本次验证通过的环境:
|
||||
|
||||
- Ubuntu 24.04
|
||||
- AMD Radeon PRO W7900D (`gfx1100`)
|
||||
- 系统 ROCm 7.1.1
|
||||
- Python 3.10
|
||||
- `torch==2.10.0+rocm7.1`
|
||||
- `torchvision==0.25.0+rocm7.1`
|
||||
- `triton-rocm==3.6.0`
|
||||
|
||||
注意:`torch==2.12.0+rocm7.1` 可以正常导入,也能识别 GPU,但在本项目里开启
|
||||
`torch.compile` 后会崩溃,错误类似:
|
||||
|
||||
```text
|
||||
HSA_STATUS_ERROR_INVALID_PACKET_FORMAT
|
||||
CUDA error: unspecified launch failure
|
||||
```
|
||||
|
||||
降级到 `torch==2.10.0+rocm7.1` 后,`torch.compile` 路径可以正常跑通。
|
||||
|
||||
## 检查系统 ROCm
|
||||
|
||||
在新 AMD 机器上,先确认系统能识别 GPU:
|
||||
|
||||
```bash
|
||||
rocminfo
|
||||
amd-smi version
|
||||
hipcc --version
|
||||
```
|
||||
|
||||
`rocminfo` 里应该能看到 AMD GPU agent,例如 `gfx1100`。
|
||||
|
||||
## 创建 Python 环境
|
||||
|
||||
使用 `uv` 创建 Python 3.10 虚拟环境:
|
||||
|
||||
```bash
|
||||
cd /path/to/lewm
|
||||
uv venv --python 3.10 --allow-existing .venv
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
给 uv 创建的 venv 补上 pip。ROCm 版 PyTorch wheel 很大,如果 uv 解析或下载卡住,
|
||||
用 pip 安装大 wheel 更容易观察进度。
|
||||
|
||||
```bash
|
||||
uv pip install pip
|
||||
```
|
||||
|
||||
## 安装 ROCm 版 PyTorch
|
||||
|
||||
安装本项目已验证可用的 ROCm PyTorch 组合:
|
||||
|
||||
```bash
|
||||
python -m pip install --force-reinstall \
|
||||
--index-url https://download.pytorch.org/whl/rocm7.1 \
|
||||
--extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
"torch==2.10.0" \
|
||||
"torchvision==0.25.0"
|
||||
```
|
||||
|
||||
PyTorch wheel 有数 GB。如果网络慢,不要频繁中断重试,尽量等它下载完成。
|
||||
|
||||
## 安装项目依赖
|
||||
|
||||
普通 Python 包建议走国内 PyPI 镜像:
|
||||
|
||||
```bash
|
||||
python -m pip install \
|
||||
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
"gymnasium[all]==1.2.2" \
|
||||
"stable-baselines3==2.8.0" \
|
||||
"stable-worldmodel[train,env]"
|
||||
```
|
||||
|
||||
然后修正两个容易被 pip 带偏的依赖版本:
|
||||
|
||||
```bash
|
||||
python -m pip install \
|
||||
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
"fsspec==2025.3.0" \
|
||||
"pillow==11.3.0"
|
||||
```
|
||||
|
||||
检查环境:
|
||||
|
||||
```bash
|
||||
python -m pip check
|
||||
python - <<'PY'
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
print("torch:", torch.__version__)
|
||||
print("hip:", torch.version.hip)
|
||||
print("cuda available:", torch.cuda.is_available())
|
||||
print("device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)
|
||||
print("torchvision:", torchvision.__version__)
|
||||
PY
|
||||
```
|
||||
|
||||
期望看到类似输出:
|
||||
|
||||
```text
|
||||
torch: 2.10.0+rocm7.1
|
||||
cuda available: True
|
||||
torchvision: 0.25.0+rocm7.1
|
||||
```
|
||||
|
||||
## 恢复本仓库里的 stable-worldmodel 修改
|
||||
|
||||
这个仓库把一些本地修改后的 `stable_worldmodel` 文件纳入了 git 管控,路径在:
|
||||
|
||||
```text
|
||||
.venv/lib/python3.10/site-packages/stable_worldmodel/
|
||||
```
|
||||
|
||||
从 PyPI 安装 `stable-worldmodel` 时可能会覆盖这些文件。安装依赖后执行:
|
||||
|
||||
```bash
|
||||
git restore -- \
|
||||
.venv/lib/python3.10/site-packages/stable_worldmodel/policy.py \
|
||||
.venv/lib/python3.10/site-packages/stable_worldmodel/world.py \
|
||||
.venv/lib/python3.10/site-packages/stable_worldmodel/solver/cem.py \
|
||||
.venv/lib/python3.10/site-packages/stable_worldmodel/solver/gd.py
|
||||
```
|
||||
|
||||
然后确认没有意外修改:
|
||||
|
||||
```bash
|
||||
git status --short
|
||||
```
|
||||
|
||||
## 数据和 checkpoint 路径
|
||||
|
||||
`eval.py` 会从 `$STABLEWM_HOME` 里找数据和 checkpoint。
|
||||
|
||||
PushT 评估至少需要:
|
||||
|
||||
```text
|
||||
$STABLEWM_HOME/pusht_expert_train.h5
|
||||
$STABLEWM_HOME/pusht/lewm_object.ckpt
|
||||
```
|
||||
|
||||
例如本机使用:
|
||||
|
||||
```bash
|
||||
export STABLEWM_HOME=/mnt/ASC1637/stablewm
|
||||
```
|
||||
|
||||
如果没有正确设置,运行时会报找不到 `pusht_expert_train.h5`。
|
||||
|
||||
## 运行评估
|
||||
|
||||
默认 PushT 评估,保留 `torch.compile`:
|
||||
|
||||
```bash
|
||||
export STABLEWM_HOME=/path/to/stablewm
|
||||
python eval.py --config-name=pusht.yaml policy=pusht/lewm
|
||||
```
|
||||
|
||||
快速 smoke test:
|
||||
|
||||
```bash
|
||||
export STABLEWM_HOME=/path/to/stablewm
|
||||
python eval.py --config-name=pusht.yaml policy=pusht/lewm \
|
||||
eval.num_eval=1 \
|
||||
world.num_envs=1 \
|
||||
output.filename=/tmp/lewm_smoke_test.txt
|
||||
```
|
||||
|
||||
smoke test 应该能正常结束,并打印类似:
|
||||
|
||||
```text
|
||||
{'success_rate': 100.0, ...}
|
||||
```
|
||||
|
||||
## 常见问题
|
||||
|
||||
### `HSA_STATUS_ERROR_INVALID_PACKET_FORMAT`
|
||||
|
||||
如果开启 `torch.compile` 时出现这个错误,先检查 torch 版本:
|
||||
|
||||
```bash
|
||||
python -c "import torch; print(torch.__version__, torch.version.hip)"
|
||||
```
|
||||
|
||||
如果是 `2.12.0+rocm7.1`,建议降级到本项目验证通过的组合:
|
||||
|
||||
```bash
|
||||
python -m pip install --force-reinstall \
|
||||
--index-url https://download.pytorch.org/whl/rocm7.1 \
|
||||
--extra-index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
"torch==2.10.0" \
|
||||
"torchvision==0.25.0"
|
||||
```
|
||||
|
||||
### 找不到 `pusht_expert_train.h5`
|
||||
|
||||
设置 `STABLEWM_HOME` 到包含数据和 checkpoint 的目录:
|
||||
|
||||
```bash
|
||||
export STABLEWM_HOME=/path/to/stablewm
|
||||
```
|
||||
|
||||
### pip 尝试构建旧版 `gym==0.21`
|
||||
|
||||
这是依赖解析回退导致的。先显式安装兼容版本:
|
||||
|
||||
```bash
|
||||
python -m pip install \
|
||||
--index-url https://pypi.tuna.tsinghua.edu.cn/simple \
|
||||
"gymnasium[all]==1.2.2" \
|
||||
"stable-baselines3==2.8.0"
|
||||
```
|
||||
|
||||
### uv 或 pip 访问海外源很慢
|
||||
|
||||
普通 Python 包使用国内 PyPI 镜像:
|
||||
|
||||
```bash
|
||||
--index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
```
|
||||
|
||||
PyTorch ROCm wheel 继续使用 PyTorch 官方 ROCm 源:
|
||||
|
||||
```bash
|
||||
--index-url https://download.pytorch.org/whl/rocm7.1
|
||||
```
|
||||
Reference in New Issue
Block a user