Files
qibotn/tools/benchmark_qredtea_svd_controls.py
jaunatisblue ef3d7e9ee6
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
决赛现场脚本
2026-05-18 01:37:19 +08:00

158 lines
4.8 KiB
Python

#!/usr/bin/env python
"""Benchmark qredtea/qtealeaves SVD control modes.
This isolates the tensor split used by MPS updates: a rank-2 tensor is split
with singular values contracted either left or right, then reconstructed to
measure numerical error and timing.
"""
from __future__ import annotations
import argparse
import gc
import statistics
import time
import torch
import qmatchatea
from qredtea.torchapi import QteaTorchTensor
def _dtype(name: str):
return {
"complex64": torch.complex64,
"complex128": torch.complex128,
"float64": torch.float64,
"float32": torch.float32,
}[name]
def _random_matrix(shape, dtype, seed):
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
if dtype.is_complex:
real_dtype = torch.float32 if dtype == torch.complex64 else torch.float64
real = torch.randn(shape, dtype=real_dtype, generator=gen)
imag = torch.randn(shape, dtype=real_dtype, generator=gen)
return torch.complex(real, imag).to(dtype)
return torch.randn(shape, dtype=dtype, generator=gen)
def _sync():
if torch.cuda.is_available():
torch.cuda.synchronize()
def run_one(matrix, ctrl, max_bond, contract_singvals, repeats):
conv = qmatchatea.QCConvergenceParameters(
max_bond_dimension=max_bond,
cut_ratio=0.0,
svd_ctrl=ctrl,
)
qtensor = QteaTorchTensor.from_elem_array(matrix, dtype=matrix.dtype, device="cpu")
times = []
rel_error = None
kept = None
status = "ok"
error = ""
for i in range(repeats):
gc.collect()
_sync()
t0 = time.perf_counter()
try:
left, right, singvals, _ = qtensor.split_svd(
[0],
[1],
contract_singvals=contract_singvals,
conv_params=conv,
)
except Exception as exc: # noqa: BLE001 - benchmark should keep going
status = "error"
error = repr(exc)
break
_sync()
times.append(time.perf_counter() - t0)
if i == repeats - 1:
left_matrix = left.elem.reshape(matrix.shape[0], -1)
right_matrix = right.elem.reshape(-1, matrix.shape[1])
recon = left_matrix @ right_matrix
rel_error = (
torch.linalg.vector_norm(matrix - recon)
/ torch.linalg.vector_norm(matrix)
).item()
kept = int(singvals.numel())
return {
"ctrl": ctrl,
"contract_singvals": contract_singvals,
"status": status,
"median_ms": float("nan") if not times else statistics.median(times) * 1000,
"min_ms": float("nan") if not times else min(times) * 1000,
"rel_error": rel_error,
"kept": kept,
"error": error,
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--shapes", nargs="+", default=("256x1024", "1024x256", "512x512"))
parser.add_argument("--max-bond", type=int, default=128)
parser.add_argument("--dtype", choices=("complex64", "complex128", "float32", "float64"), default="complex128")
parser.add_argument("--threads", type=int, default=8)
parser.add_argument("--repeats", type=int, default=3)
parser.add_argument(
"--controls",
nargs="+",
default=("A", "D", "V", "R", "E", "E!", "X", "X!"),
)
args = parser.parse_args()
torch.set_num_threads(args.threads)
dtype = _dtype(args.dtype)
print(
"svd_benchmark "
f"dtype={args.dtype} threads={torch.get_num_threads()} "
f"max_bond={args.max_bond} repeats={args.repeats}",
flush=True,
)
print(
"columns shape contract ctrl status median_ms min_ms kept rel_error error",
flush=True,
)
for shape_text in args.shapes:
m_text, n_text = shape_text.lower().split("x", 1)
shape = (int(m_text), int(n_text))
matrix = _random_matrix(shape, dtype, seed=sum(shape))
for contract_singvals in ("L", "R"):
for ctrl in args.controls:
result = run_one(
matrix,
ctrl=ctrl,
max_bond=args.max_bond,
contract_singvals=contract_singvals,
repeats=args.repeats,
)
print(
f"row shape={shape_text} "
f"contract={contract_singvals} "
f"ctrl={ctrl} "
f"status={result['status']} "
f"median_ms={result['median_ms']:.3f} "
f"min_ms={result['min_ms']:.3f} "
f"kept={result['kept']} "
f"rel_error={result['rel_error']} "
f"error={result['error']}",
flush=True,
)
if __name__ == "__main__":
main()