#!/usr/bin/env bash set -euo pipefail # Single-node CPU scale probes for expectation benchmarks. # # Intended for one 96-core / ~500 GiB RAM node. The default "probe" mode runs # moderate MPS and TN cases first. Larger modes are available after checking # runtime and memory from the probe output. ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" cd "$ROOT_DIR" PYTHON_BIN="${PYTHON_BIN:-.venv/bin/python}" PYTHON_FLAGS="${PYTHON_FLAGS:--u}" MPIEXEC="${MPIEXEC:-mpiexec}" TIME_BIN="${TIME_BIN:-/usr/bin/time}" MPS_RANKS="${MPS_RANKS:-8}" MPS_THREADS="${MPS_THREADS:-12}" TN_RANKS="${TN_RANKS:-8}" TN_THREADS="${TN_THREADS:-12}" export OMP_NUM_THREADS="${OMP_NUM_THREADS:-1}" export MKL_NUM_THREADS="${MKL_NUM_THREADS:-1}" estimate_mps_memory() { local nqubits="$1" local bond="$2" "$PYTHON_BIN" - "$nqubits" "$bond" "$MPS_RANKS" <<'PY' import sys n = int(sys.argv[1]) chi = int(sys.argv[2]) ranks = int(sys.argv[3]) resident = n * 2 * chi * chi * 16 per_rank = resident / ranks print( "MPS rough resident memory: " f"total={resident / 1024**3:.1f} GiB " f"per_rank={per_rank / 1024**3:.1f} GiB " "(temporary eig/SVD workspaces are additional)" ) PY } run_timed() { echo echo "--------------------------------------------------------------------------------" echo "$*" echo "--------------------------------------------------------------------------------" "$TIME_BIN" -v "$@" } run_mps_case() { local label="$1" local nqubits="$2" local nlayers="$3" local bond="$4" shift 4 echo echo "================================================================================" echo "$label" echo "================================================================================" echo "PYTHON_BIN=$PYTHON_BIN MPIEXEC=$MPIEXEC" echo "MPS_RANKS=$MPS_RANKS MPS_THREADS=$MPS_THREADS" echo "OMP_NUM_THREADS=$OMP_NUM_THREADS MKL_NUM_THREADS=$MKL_NUM_THREADS" estimate_mps_memory "$nqubits" "$bond" run_timed "$MPIEXEC" -n "$MPS_RANKS" "$PYTHON_BIN" $PYTHON_FLAGS benchmark_cpu_expectation.py \ --mpi --mps \ --nqubits "$nqubits" \ --nlayers "$nlayers" \ --bond "$bond" \ --torch-threads "$MPS_THREADS" \ "$@" } run_tn_case() { local label="$1" local nqubits="$2" local nlayers="$3" shift 3 echo echo "================================================================================" echo "$label" echo "================================================================================" echo "PYTHON_BIN=$PYTHON_BIN MPIEXEC=$MPIEXEC" echo "TN_RANKS=$TN_RANKS TN_THREADS=$TN_THREADS" echo "OMP_NUM_THREADS=$OMP_NUM_THREADS MKL_NUM_THREADS=$MKL_NUM_THREADS" echo "TN memory is contraction-tree dependent; increase --tn-target-slices if RSS is high." run_timed "$MPIEXEC" -n "$TN_RANKS" "$PYTHON_BIN" $PYTHON_FLAGS benchmark_cpu_expectation.py \ --mpi \ --nqubits "$nqubits" \ --nlayers "$nlayers" \ --torch-threads "$TN_THREADS" \ "$@" } case "${1:-help}" in probe) run_mps_case "MPS probe: n=40 layers=30 bond=2048" 40 30 2048 \ --circuits brickwall_cnot \ --observables ring_xz run_tn_case "TN probe: n=28 layers=12 target_slices=8" 28 12 \ --circuits brickwall_cnot \ --observables ring_xz \ --tn-target-slices 8 ;; mps-medium) run_mps_case "MPS medium: n=56 layers=40 bond=3072" 56 40 3072 \ --circuits brickwall_cnot reversed_cnot shifted_cz rxx_rzz \ --observables ring_xz open_zz mixed_local range2_xx ;; mps-long) run_mps_case "MPS long: n=64 layers=48 bond=4096" 64 48 4096 \ --circuits brickwall_cnot reversed_cnot shifted_cz rxx_rzz \ --observables ring_xz open_zz mixed_local range2_xx ;; tn-medium) run_tn_case "TN medium: n=32 layers=16 target_slices=16" 32 16 \ --circuits brickwall_cnot shifted_cz rxx_rzz \ --observables ring_xz open_zz range2_xx \ --tn-target-slices 16 ;; tn-long) run_tn_case "TN long: n=36 layers=20 target_slices=32" 36 20 \ --circuits brickwall_cnot shifted_cz rxx_rzz \ --observables ring_xz open_zz range2_xx \ --tn-target-slices 32 ;; help|*) cat >&2 <<'EOF' Usage: tools/run_cpu_single_cases.sh [probe|mps-medium|mps-long|tn-medium|tn-long] Common overrides: PYTHON_BIN=.venv/bin/python MPIEXEC=mpiexec MPS_RANKS=8 MPS_THREADS=12 TN_RANKS=8 TN_THREADS=12 OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 EOF exit 2 ;; esac