107 lines
3.4 KiB
Bash
107 lines
3.4 KiB
Bash
#!/usr/bin/env bash
|
|
|
|
set -euo pipefail
|
|
|
|
if [[ $# -ne 4 ]]; then
|
|
echo "Usage: $0 <res_dir> <dataset> <frame_stride> <n_iter>" >&2
|
|
exit 1
|
|
fi
|
|
|
|
res_dir="$1"
|
|
dataset="$2"
|
|
frame_stride="$3"
|
|
n_iter="$4"
|
|
|
|
repo_root="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
|
cd "${repo_root}"
|
|
|
|
run_analysis="${RUN_ANALYSIS:-1}"
|
|
scheme_filter="${SCHEME_FILTER:-}"
|
|
gpu_devices="${CUDA_VISIBLE_DEVICES:-0}"
|
|
python_bin="${PYTHON_BIN:-}"
|
|
if [[ -z "${python_bin}" ]]; then
|
|
if [[ -n "${CONDA_PREFIX:-}" ]] && [[ -x "${CONDA_PREFIX}/bin/python" ]]; then
|
|
python_bin="${CONDA_PREFIX}/bin/python"
|
|
elif [[ -n "${VIRTUAL_ENV:-}" ]] && [[ -x "${VIRTUAL_ENV}/bin/python" ]]; then
|
|
python_bin="${VIRTUAL_ENV}/bin/python"
|
|
else
|
|
python_bin="$(command -v python || true)"
|
|
fi
|
|
fi
|
|
if [[ -z "${python_bin}" ]] || [[ ! -x "${python_bin}" ]]; then
|
|
echo "Unable to resolve a usable Python interpreter. Set PYTHON_BIN explicitly." >&2
|
|
exit 1
|
|
fi
|
|
head_log_steps=(40 43 46 47 48 49)
|
|
scheme_names=(sparse_8)
|
|
|
|
get_head_schedule() {
|
|
case "$1" in
|
|
sparse_8)
|
|
echo "0 7 14 21 28 35 42 49"
|
|
;;
|
|
sparse_4)
|
|
echo "0 16 32 49"
|
|
;;
|
|
tail_only)
|
|
echo "40 43 46 49"
|
|
;;
|
|
tail_heavy)
|
|
echo "0 32 40 44 47 49"
|
|
;;
|
|
*)
|
|
echo "Unknown scheme: $1" >&2
|
|
exit 1
|
|
;;
|
|
esac
|
|
}
|
|
|
|
mkdir -p "${res_dir}/output_logs"
|
|
|
|
for scheme_name in "${scheme_names[@]}"; do
|
|
if [[ -n "${scheme_filter}" ]] && [[ "${scheme_name}" != *"${scheme_filter}"* ]]; then
|
|
continue
|
|
fi
|
|
|
|
read -r -a head_schedule <<< "$(get_head_schedule "${scheme_name}")"
|
|
scheme_savedir="${res_dir}/output/${scheme_name}"
|
|
scheme_log_path="${res_dir}/output_logs/${scheme_name}.log"
|
|
|
|
{
|
|
echo "============================================================"
|
|
echo "Running ${res_dir} with ${scheme_name}"
|
|
echo "============================================================"
|
|
echo "Using Python: ${python_bin}"
|
|
time CUDA_VISIBLE_DEVICES="${gpu_devices}" "${python_bin}" scripts/evaluation/world_model_interaction.py \
|
|
--seed 123 \
|
|
--ckpt_path ckpts/unifolm_wma_dual.ckpt \
|
|
--config configs/inference/world_model_interaction.yaml \
|
|
--savedir "${scheme_savedir}" \
|
|
--bs 1 --height 320 --width 512 \
|
|
--unconditional_guidance_scale 1.0 \
|
|
--ddim_steps 50 \
|
|
--ddim_eta 1.0 \
|
|
--prompt_dir "${res_dir}/world_model_interaction_prompts" \
|
|
--dataset "${dataset}" \
|
|
--video_length 16 \
|
|
--frame_stride "${frame_stride}" \
|
|
--n_action_steps 16 \
|
|
--exe_steps 16 \
|
|
--n_iter "${n_iter}" \
|
|
--timestep_spacing uniform_trailing \
|
|
--guidance_rescale 0.7 \
|
|
--perframe_ae \
|
|
--analysis_log_metrics \
|
|
--analysis_reference_steps 50 \
|
|
--head_schedule_steps "${head_schedule[@]}" \
|
|
--head_skip_mode reuse_prediction \
|
|
--head_log_steps "${head_log_steps[@]}"
|
|
|
|
if [[ "${run_analysis}" == "1" ]]; then
|
|
"${python_bin}" scripts/evaluation/analyze_metrics.py \
|
|
--input_dir "${scheme_savedir}/inference" \
|
|
--output_dir "${scheme_savedir}/inference/analysis"
|
|
fi
|
|
} 2>&1 | tee "${scheme_log_path}"
|
|
done
|