Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
158 lines
4.8 KiB
Python
158 lines
4.8 KiB
Python
#!/usr/bin/env python
|
|
"""Benchmark qredtea/qtealeaves SVD control modes.
|
|
|
|
This isolates the tensor split used by MPS updates: a rank-2 tensor is split
|
|
with singular values contracted either left or right, then reconstructed to
|
|
measure numerical error and timing.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import gc
|
|
import statistics
|
|
import time
|
|
|
|
import torch
|
|
|
|
import qmatchatea
|
|
from qredtea.torchapi import QteaTorchTensor
|
|
|
|
|
|
def _dtype(name: str):
|
|
return {
|
|
"complex64": torch.complex64,
|
|
"complex128": torch.complex128,
|
|
"float64": torch.float64,
|
|
"float32": torch.float32,
|
|
}[name]
|
|
|
|
|
|
def _random_matrix(shape, dtype, seed):
|
|
gen = torch.Generator(device="cpu")
|
|
gen.manual_seed(seed)
|
|
if dtype.is_complex:
|
|
real_dtype = torch.float32 if dtype == torch.complex64 else torch.float64
|
|
real = torch.randn(shape, dtype=real_dtype, generator=gen)
|
|
imag = torch.randn(shape, dtype=real_dtype, generator=gen)
|
|
return torch.complex(real, imag).to(dtype)
|
|
return torch.randn(shape, dtype=dtype, generator=gen)
|
|
|
|
|
|
def _sync():
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
def run_one(matrix, ctrl, max_bond, contract_singvals, repeats):
|
|
conv = qmatchatea.QCConvergenceParameters(
|
|
max_bond_dimension=max_bond,
|
|
cut_ratio=0.0,
|
|
svd_ctrl=ctrl,
|
|
)
|
|
qtensor = QteaTorchTensor.from_elem_array(matrix, dtype=matrix.dtype, device="cpu")
|
|
|
|
times = []
|
|
rel_error = None
|
|
kept = None
|
|
status = "ok"
|
|
error = ""
|
|
|
|
for i in range(repeats):
|
|
gc.collect()
|
|
_sync()
|
|
t0 = time.perf_counter()
|
|
try:
|
|
left, right, singvals, _ = qtensor.split_svd(
|
|
[0],
|
|
[1],
|
|
contract_singvals=contract_singvals,
|
|
conv_params=conv,
|
|
)
|
|
except Exception as exc: # noqa: BLE001 - benchmark should keep going
|
|
status = "error"
|
|
error = repr(exc)
|
|
break
|
|
_sync()
|
|
times.append(time.perf_counter() - t0)
|
|
|
|
if i == repeats - 1:
|
|
left_matrix = left.elem.reshape(matrix.shape[0], -1)
|
|
right_matrix = right.elem.reshape(-1, matrix.shape[1])
|
|
recon = left_matrix @ right_matrix
|
|
rel_error = (
|
|
torch.linalg.vector_norm(matrix - recon)
|
|
/ torch.linalg.vector_norm(matrix)
|
|
).item()
|
|
kept = int(singvals.numel())
|
|
|
|
return {
|
|
"ctrl": ctrl,
|
|
"contract_singvals": contract_singvals,
|
|
"status": status,
|
|
"median_ms": float("nan") if not times else statistics.median(times) * 1000,
|
|
"min_ms": float("nan") if not times else min(times) * 1000,
|
|
"rel_error": rel_error,
|
|
"kept": kept,
|
|
"error": error,
|
|
}
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--shapes", nargs="+", default=("256x1024", "1024x256", "512x512"))
|
|
parser.add_argument("--max-bond", type=int, default=128)
|
|
parser.add_argument("--dtype", choices=("complex64", "complex128", "float32", "float64"), default="complex128")
|
|
parser.add_argument("--threads", type=int, default=8)
|
|
parser.add_argument("--repeats", type=int, default=3)
|
|
parser.add_argument(
|
|
"--controls",
|
|
nargs="+",
|
|
default=("A", "D", "V", "R", "E", "E!", "X", "X!"),
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
torch.set_num_threads(args.threads)
|
|
dtype = _dtype(args.dtype)
|
|
|
|
print(
|
|
"svd_benchmark "
|
|
f"dtype={args.dtype} threads={torch.get_num_threads()} "
|
|
f"max_bond={args.max_bond} repeats={args.repeats}",
|
|
flush=True,
|
|
)
|
|
print(
|
|
"columns shape contract ctrl status median_ms min_ms kept rel_error error",
|
|
flush=True,
|
|
)
|
|
|
|
for shape_text in args.shapes:
|
|
m_text, n_text = shape_text.lower().split("x", 1)
|
|
shape = (int(m_text), int(n_text))
|
|
matrix = _random_matrix(shape, dtype, seed=sum(shape))
|
|
for contract_singvals in ("L", "R"):
|
|
for ctrl in args.controls:
|
|
result = run_one(
|
|
matrix,
|
|
ctrl=ctrl,
|
|
max_bond=args.max_bond,
|
|
contract_singvals=contract_singvals,
|
|
repeats=args.repeats,
|
|
)
|
|
print(
|
|
f"row shape={shape_text} "
|
|
f"contract={contract_singvals} "
|
|
f"ctrl={ctrl} "
|
|
f"status={result['status']} "
|
|
f"median_ms={result['median_ms']:.3f} "
|
|
f"min_ms={result['min_ms']:.3f} "
|
|
f"kept={result['kept']} "
|
|
f"rel_error={result['rel_error']} "
|
|
f"error={result['error']}",
|
|
flush=True,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|