Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
183 lines
5.4 KiB
Python
183 lines
5.4 KiB
Python
#!/usr/bin/env python
|
|
"""Probe MPI rank placement and whether torch CPU ops use multiple threads.
|
|
|
|
Run this under mpirun/mpiexec to check:
|
|
|
|
* which CPUs each rank is allowed to run on,
|
|
* whether torch sees the requested intra-op thread count, and
|
|
* whether a large CPU tensor op actually consumes more CPU time than wall time.
|
|
|
|
The script is intentionally small and self-contained so it can be used to debug
|
|
MPI launcher affinity and torch OpenMP behavior independently from the TN code
|
|
path.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import os
|
|
import socket
|
|
import time
|
|
from pathlib import Path
|
|
|
|
from mpi4py import MPI
|
|
|
|
|
|
def _dtype_from_name(name):
|
|
import torch
|
|
|
|
mapping = {
|
|
"float32": torch.float32,
|
|
"float64": torch.float64,
|
|
"complex64": torch.complex64,
|
|
"complex128": torch.complex128,
|
|
}
|
|
return mapping[name]
|
|
|
|
|
|
def _make_tensor(shape, dtype):
|
|
import torch
|
|
|
|
if dtype in (torch.complex64, torch.complex128):
|
|
base = torch.float32 if dtype == torch.complex64 else torch.float64
|
|
return torch.complex(
|
|
torch.randn(shape, dtype=base),
|
|
torch.randn(shape, dtype=base),
|
|
)
|
|
return torch.randn(shape, dtype=dtype)
|
|
|
|
|
|
def _bench(label, fn, iters, warmup=2):
|
|
for _ in range(warmup):
|
|
fn()
|
|
|
|
start_wall = time.perf_counter()
|
|
start_cpu = time.process_time()
|
|
checksum = 0.0
|
|
for _ in range(iters):
|
|
value = fn()
|
|
checksum += float(value)
|
|
wall = time.perf_counter() - start_wall
|
|
cpu = time.process_time() - start_cpu
|
|
ratio = cpu / wall if wall > 0 else float("inf")
|
|
print(
|
|
f"{label} wall={wall:.3f}s cpu={cpu:.3f}s cpu_over_wall={ratio:.2f} "
|
|
f"checksum={checksum:.6e}",
|
|
flush=True,
|
|
)
|
|
|
|
|
|
def _visible_numa_nodes():
|
|
nodes = []
|
|
for path in sorted(Path("/sys/devices/system/node").glob("node[0-9]*")):
|
|
cpulist = path / "cpulist"
|
|
if cpulist.exists():
|
|
nodes.append(f"{path.name}:{cpulist.read_text(encoding='utf-8').strip()}")
|
|
return ",".join(nodes) if nodes else "unknown"
|
|
|
|
|
|
def _dtype_nbytes(name):
|
|
return {
|
|
"float32": 4,
|
|
"float64": 8,
|
|
"complex64": 8,
|
|
"complex128": 16,
|
|
}[name]
|
|
|
|
|
|
def _format_gib(nbytes):
|
|
return f"{nbytes / (1024 ** 3):.2f}GiB"
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--threads", type=int, default=48)
|
|
parser.add_argument("--n", type=int, default=4096)
|
|
parser.add_argument("--iters", type=int, default=4)
|
|
parser.add_argument("--dtype", choices=("float32", "float64", "complex64", "complex128"), default="float32")
|
|
parser.add_argument("--op", choices=("matmul", "tensordot", "both"), default="both")
|
|
parser.add_argument(
|
|
"--affinity-only",
|
|
action="store_true",
|
|
help="Print MPI/torch placement diagnostics without allocating tensors.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
os.environ.setdefault("OMP_NUM_THREADS", str(args.threads))
|
|
os.environ.setdefault("MKL_NUM_THREADS", str(args.threads))
|
|
os.environ.setdefault("OMP_PROC_BIND", "close")
|
|
os.environ.setdefault("OMP_PLACES", "cores")
|
|
|
|
import torch
|
|
|
|
comm = MPI.COMM_WORLD
|
|
rank = comm.Get_rank()
|
|
size = comm.Get_size()
|
|
|
|
torch.set_num_threads(args.threads)
|
|
try:
|
|
torch.set_num_interop_threads(1)
|
|
except Exception:
|
|
pass
|
|
|
|
dtype = _dtype_from_name(args.dtype)
|
|
affinity = sorted(os.sched_getaffinity(0))
|
|
allowed_list = ""
|
|
try:
|
|
with open("/proc/self/status", encoding="utf-8") as f:
|
|
for line in f:
|
|
if line.startswith("Cpus_allowed_list:"):
|
|
allowed_list = line.split(":", 1)[1].strip()
|
|
break
|
|
except OSError:
|
|
pass
|
|
|
|
print(
|
|
f"rank={rank}/{size} host={socket.gethostname()} pid={os.getpid()} "
|
|
f"affinity_len={len(affinity)} allowed={allowed_list} "
|
|
f"torch_threads={torch.get_num_threads()} "
|
|
f"torch_interop={torch.get_num_interop_threads()} "
|
|
f"OMP_NUM_THREADS={os.environ.get('OMP_NUM_THREADS')} "
|
|
f"MKL_NUM_THREADS={os.environ.get('MKL_NUM_THREADS')} "
|
|
f"OMP_PROC_BIND={os.environ.get('OMP_PROC_BIND')} "
|
|
f"OMP_PLACES={os.environ.get('OMP_PLACES')} "
|
|
f"visible_numa={_visible_numa_nodes()}",
|
|
flush=True,
|
|
)
|
|
|
|
if rank == 0:
|
|
print(torch.__config__.parallel_info(), flush=True)
|
|
input_bytes = args.n * args.n * _dtype_nbytes(args.dtype)
|
|
min_live_bytes = 3 * input_bytes
|
|
print(
|
|
f"matrix_n={args.n} dtype={args.dtype} "
|
|
f"one_matrix={_format_gib(input_bytes)} "
|
|
f"approx_min_live_per_rank={_format_gib(min_live_bytes)} "
|
|
f"approx_min_live_all_ranks={_format_gib(min_live_bytes * size)}",
|
|
flush=True,
|
|
)
|
|
comm.Barrier()
|
|
if args.affinity_only:
|
|
return
|
|
|
|
a = _make_tensor((args.n, args.n), dtype)
|
|
b = _make_tensor((args.n, args.n), dtype)
|
|
|
|
def run_matmul():
|
|
value = (a @ b).sum()
|
|
return value.real.item() if value.is_complex() else value.item()
|
|
|
|
def run_tensordot():
|
|
value = torch.tensordot(a, b, dims=1)
|
|
value = value.sum()
|
|
return value.real.item() if value.is_complex() else value.item()
|
|
|
|
if args.op in ("matmul", "both"):
|
|
_bench("matmul", run_matmul, args.iters)
|
|
if args.op in ("tensordot", "both"):
|
|
_bench("tensordot", run_tensordot, args.iters)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|