Files
qibotn/tools/mpi_torch_thread_probe.py
jaunatisblue ef3d7e9ee6
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
决赛现场脚本
2026-05-18 01:37:19 +08:00

183 lines
5.4 KiB
Python

#!/usr/bin/env python
"""Probe MPI rank placement and whether torch CPU ops use multiple threads.
Run this under mpirun/mpiexec to check:
* which CPUs each rank is allowed to run on,
* whether torch sees the requested intra-op thread count, and
* whether a large CPU tensor op actually consumes more CPU time than wall time.
The script is intentionally small and self-contained so it can be used to debug
MPI launcher affinity and torch OpenMP behavior independently from the TN code
path.
"""
from __future__ import annotations
import argparse
import os
import socket
import time
from pathlib import Path
from mpi4py import MPI
def _dtype_from_name(name):
import torch
mapping = {
"float32": torch.float32,
"float64": torch.float64,
"complex64": torch.complex64,
"complex128": torch.complex128,
}
return mapping[name]
def _make_tensor(shape, dtype):
import torch
if dtype in (torch.complex64, torch.complex128):
base = torch.float32 if dtype == torch.complex64 else torch.float64
return torch.complex(
torch.randn(shape, dtype=base),
torch.randn(shape, dtype=base),
)
return torch.randn(shape, dtype=dtype)
def _bench(label, fn, iters, warmup=2):
for _ in range(warmup):
fn()
start_wall = time.perf_counter()
start_cpu = time.process_time()
checksum = 0.0
for _ in range(iters):
value = fn()
checksum += float(value)
wall = time.perf_counter() - start_wall
cpu = time.process_time() - start_cpu
ratio = cpu / wall if wall > 0 else float("inf")
print(
f"{label} wall={wall:.3f}s cpu={cpu:.3f}s cpu_over_wall={ratio:.2f} "
f"checksum={checksum:.6e}",
flush=True,
)
def _visible_numa_nodes():
nodes = []
for path in sorted(Path("/sys/devices/system/node").glob("node[0-9]*")):
cpulist = path / "cpulist"
if cpulist.exists():
nodes.append(f"{path.name}:{cpulist.read_text(encoding='utf-8').strip()}")
return ",".join(nodes) if nodes else "unknown"
def _dtype_nbytes(name):
return {
"float32": 4,
"float64": 8,
"complex64": 8,
"complex128": 16,
}[name]
def _format_gib(nbytes):
return f"{nbytes / (1024 ** 3):.2f}GiB"
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--threads", type=int, default=48)
parser.add_argument("--n", type=int, default=4096)
parser.add_argument("--iters", type=int, default=4)
parser.add_argument("--dtype", choices=("float32", "float64", "complex64", "complex128"), default="float32")
parser.add_argument("--op", choices=("matmul", "tensordot", "both"), default="both")
parser.add_argument(
"--affinity-only",
action="store_true",
help="Print MPI/torch placement diagnostics without allocating tensors.",
)
args = parser.parse_args()
os.environ.setdefault("OMP_NUM_THREADS", str(args.threads))
os.environ.setdefault("MKL_NUM_THREADS", str(args.threads))
os.environ.setdefault("OMP_PROC_BIND", "close")
os.environ.setdefault("OMP_PLACES", "cores")
import torch
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
torch.set_num_threads(args.threads)
try:
torch.set_num_interop_threads(1)
except Exception:
pass
dtype = _dtype_from_name(args.dtype)
affinity = sorted(os.sched_getaffinity(0))
allowed_list = ""
try:
with open("/proc/self/status", encoding="utf-8") as f:
for line in f:
if line.startswith("Cpus_allowed_list:"):
allowed_list = line.split(":", 1)[1].strip()
break
except OSError:
pass
print(
f"rank={rank}/{size} host={socket.gethostname()} pid={os.getpid()} "
f"affinity_len={len(affinity)} allowed={allowed_list} "
f"torch_threads={torch.get_num_threads()} "
f"torch_interop={torch.get_num_interop_threads()} "
f"OMP_NUM_THREADS={os.environ.get('OMP_NUM_THREADS')} "
f"MKL_NUM_THREADS={os.environ.get('MKL_NUM_THREADS')} "
f"OMP_PROC_BIND={os.environ.get('OMP_PROC_BIND')} "
f"OMP_PLACES={os.environ.get('OMP_PLACES')} "
f"visible_numa={_visible_numa_nodes()}",
flush=True,
)
if rank == 0:
print(torch.__config__.parallel_info(), flush=True)
input_bytes = args.n * args.n * _dtype_nbytes(args.dtype)
min_live_bytes = 3 * input_bytes
print(
f"matrix_n={args.n} dtype={args.dtype} "
f"one_matrix={_format_gib(input_bytes)} "
f"approx_min_live_per_rank={_format_gib(min_live_bytes)} "
f"approx_min_live_all_ranks={_format_gib(min_live_bytes * size)}",
flush=True,
)
comm.Barrier()
if args.affinity_only:
return
a = _make_tensor((args.n, args.n), dtype)
b = _make_tensor((args.n, args.n), dtype)
def run_matmul():
value = (a @ b).sum()
return value.real.item() if value.is_complex() else value.item()
def run_tensordot():
value = torch.tensordot(a, b, dims=1)
value = value.sum()
return value.real.item() if value.is_complex() else value.item()
if args.op in ("matmul", "both"):
_bench("matmul", run_matmul, args.iters)
if args.op in ("tensordot", "both"):
_bench("tensordot", run_tensordot, args.iters)
if __name__ == "__main__":
main()