Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
444 lines
15 KiB
Python
444 lines
15 KiB
Python
#!/usr/bin/env python
|
|
"""Contest-style CPU TN path search and contraction runner.
|
|
|
|
This file is intentionally self-contained: define contest circuits and
|
|
observables here, run path search once, then load the saved trees for repeated
|
|
MPI contractions.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import math
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from urllib.parse import urlparse
|
|
|
|
import numpy as np
|
|
from qibo import Circuit, gates, hamiltonians
|
|
from qibo.symbols import X, Y, Z
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SRC = ROOT / "src"
|
|
if str(SRC) not in sys.path:
|
|
sys.path.insert(0, str(SRC))
|
|
|
|
from qibotn.expectation_runner import ( # noqa: E402
|
|
ExpectationConfig,
|
|
exact_for_observable,
|
|
run_cpu_expectation,
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CaseSpec:
|
|
circuit_kind: str
|
|
observables: tuple[str, ...]
|
|
nqubits: int
|
|
nlayers: int
|
|
seed: int
|
|
target_slices: int | None = None
|
|
|
|
|
|
CASES = {
|
|
"main1": CaseSpec(
|
|
circuit_kind="rxx_rzz_chain",
|
|
observables=("ring_xz",),
|
|
nqubits=37,
|
|
nlayers=20,
|
|
seed=31001,
|
|
target_slices=None,
|
|
),
|
|
"main2": CaseSpec(
|
|
circuit_kind="scramble_chain",
|
|
observables=("open_zz", "range2_xx"),
|
|
nqubits=36,
|
|
nlayers=18,
|
|
seed=31002,
|
|
target_slices=None,
|
|
),
|
|
"strong": CaseSpec(
|
|
circuit_kind="reversed_cnot",
|
|
observables=("ring_xz", "long_z_string"),
|
|
nqubits=40,
|
|
nlayers=24,
|
|
seed=41001,
|
|
target_slices=None,
|
|
),
|
|
}
|
|
|
|
|
|
def optional_int(text):
|
|
if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}:
|
|
return None
|
|
return int(text)
|
|
|
|
|
|
def optional_float(text):
|
|
if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}:
|
|
return None
|
|
return float(text)
|
|
|
|
|
|
def set_torch_threads(nthreads):
|
|
try:
|
|
import torch
|
|
|
|
torch.set_num_threads(nthreads)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def add_single_qubit_layer(circuit, nqubits, rng, include_rx=False):
|
|
for qubit in range(nqubits):
|
|
circuit.add(gates.RY(qubit, theta=rng.uniform(-math.pi, math.pi)))
|
|
circuit.add(gates.RZ(qubit, theta=rng.uniform(-math.pi, math.pi)))
|
|
if include_rx:
|
|
circuit.add(gates.RX(qubit, theta=rng.uniform(-math.pi, math.pi)))
|
|
|
|
|
|
def build_circuit(kind, nqubits, nlayers, seed):
|
|
"""Define contest circuits here."""
|
|
rng = np.random.default_rng(seed)
|
|
circuit = Circuit(nqubits)
|
|
|
|
for layer in range(nlayers):
|
|
if kind == "rxx_rzz_chain":
|
|
add_single_qubit_layer(circuit, nqubits, rng, include_rx=True)
|
|
for qubit in range(layer % 2, nqubits - 1, 2):
|
|
circuit.add(gates.RXX(qubit, qubit + 1, theta=rng.uniform(-0.9, 0.9)))
|
|
circuit.add(gates.RZZ(qubit, qubit + 1, theta=rng.uniform(-0.9, 0.9)))
|
|
|
|
elif kind == "scramble_chain":
|
|
add_single_qubit_layer(circuit, nqubits, rng, include_rx=True)
|
|
for qubit in range(layer % 2, nqubits - 1, 2):
|
|
circuit.add(gates.RXX(qubit, qubit + 1, theta=rng.uniform(-0.8, 0.8)))
|
|
circuit.add(gates.RZZ(qubit, qubit + 1, theta=rng.uniform(-0.8, 0.8)))
|
|
if layer % 5 == 4:
|
|
circuit.add(gates.SWAP(qubit, qubit + 1))
|
|
|
|
elif kind == "reversed_cnot":
|
|
add_single_qubit_layer(circuit, nqubits, rng)
|
|
for qubit in range(0, nqubits - 1, 2):
|
|
gate = gates.CNOT(qubit + 1, qubit) if layer % 2 else gates.CNOT(qubit, qubit + 1)
|
|
circuit.add(gate)
|
|
for qubit in range(1, nqubits - 1, 2):
|
|
gate = gates.CNOT(qubit + 1, qubit) if layer % 2 == 0 else gates.CNOT(qubit, qubit + 1)
|
|
circuit.add(gate)
|
|
|
|
else:
|
|
raise ValueError(f"Unknown circuit kind {kind!r}.")
|
|
|
|
return circuit
|
|
|
|
|
|
def pauli_sum_observable(kind, nqubits, seed):
|
|
"""Define contest observables here.
|
|
|
|
TN path currently expects Pauli products / SymbolicHamiltonian terms.
|
|
Keep production contest observables Hermitian unless complex output is
|
|
explicitly required by the scoring rule.
|
|
"""
|
|
del seed
|
|
if kind == "ring_xz":
|
|
form = 0
|
|
for qubit in range(nqubits):
|
|
form += 0.5 * X(qubit) * Z((qubit + 1) % nqubits)
|
|
return hamiltonians.SymbolicHamiltonian(form=form)
|
|
|
|
if kind == "open_zz":
|
|
form = 0
|
|
for qubit in range(nqubits - 1):
|
|
form += (1.0 / max(1, nqubits - 1)) * Z(qubit) * Z(qubit + 1)
|
|
return hamiltonians.SymbolicHamiltonian(form=form)
|
|
|
|
if kind == "range2_xx":
|
|
form = 0
|
|
for qubit in range(nqubits - 2):
|
|
form += (1.0 / max(1, nqubits - 2)) * X(qubit) * X(qubit + 2)
|
|
return hamiltonians.SymbolicHamiltonian(form=form)
|
|
|
|
if kind == "long_z_string":
|
|
stride = max(1, nqubits // 16)
|
|
form = None
|
|
for qubit in range(0, nqubits, stride):
|
|
form = Z(qubit) if form is None else form * Z(qubit)
|
|
return hamiltonians.SymbolicHamiltonian(form=form)
|
|
|
|
if kind == "mixed_local":
|
|
q1 = nqubits // 4
|
|
q2 = nqubits // 2
|
|
q3 = (3 * nqubits) // 4
|
|
form = 0.25 * X(0) - 0.5 * Z(nqubits - 1)
|
|
form += 0.125 * X(q1) * Z(q2) * Y(q3)
|
|
return hamiltonians.SymbolicHamiltonian(form=form)
|
|
|
|
raise ValueError(f"Unknown observable kind {kind!r}.")
|
|
|
|
|
|
def tree_path(tree_dir, case_name, obs_name, nqubits, nlayers, target_slices):
|
|
slice_label = "auto" if target_slices is None else f"s{target_slices}"
|
|
return (
|
|
Path(tree_dir)
|
|
/ f"{case_name}_{obs_name}_{nqubits}q{nlayers}l_{slice_label}.pkl"
|
|
)
|
|
|
|
|
|
def build_parallel_opts(args, tree_file=None, search_only=False):
|
|
slicing_opts = {}
|
|
if args.tn_target_slices is not None:
|
|
slicing_opts["target_slices"] = args.tn_target_slices
|
|
if args.tn_target_size is not None:
|
|
slicing_opts["target_size"] = args.tn_target_size
|
|
|
|
opts = {
|
|
"slicing_opts": slicing_opts or None,
|
|
"search_workers": args.tn_search_workers or args.torch_threads,
|
|
"max_repeats": args.tn_search_repeats,
|
|
"max_time": args.tn_search_time,
|
|
"print_stats": False,
|
|
}
|
|
if args.tn_search_backend is not None:
|
|
opts["search_backend"] = args.tn_search_backend
|
|
if args.dask_address is not None:
|
|
opts["dask_address"] = args.dask_address
|
|
if args.dask_expected_workers is not None:
|
|
opts["dask_expected_workers"] = args.dask_expected_workers
|
|
if args.dask_close_workers:
|
|
opts["dask_close_workers"] = True
|
|
if args.tn_debug_trials:
|
|
opts["debug_trials"] = True
|
|
if search_only:
|
|
opts["search_only"] = True
|
|
opts["save_tree_path"] = str(tree_file)
|
|
elif tree_file is not None:
|
|
opts["load_tree_path"] = str(tree_file)
|
|
return opts
|
|
|
|
|
|
def run_one(args, case_name, obs_name, mode):
|
|
case = CASES[case_name]
|
|
circuit = build_circuit(case.circuit_kind, args.nqubits, args.nlayers, args.seed)
|
|
observable = pauli_sum_observable(obs_name, args.nqubits, args.seed)
|
|
path = tree_path(
|
|
args.tree_dir,
|
|
case_name,
|
|
obs_name,
|
|
args.nqubits,
|
|
args.nlayers,
|
|
args.tn_target_slices,
|
|
)
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
rank = 0
|
|
if args.mpi:
|
|
from mpi4py import MPI
|
|
|
|
rank = MPI.COMM_WORLD.Get_rank()
|
|
|
|
if rank == 0:
|
|
print("=" * 88, flush=True)
|
|
print(
|
|
f"mode={mode} case={case_name} circuit={case.circuit_kind} "
|
|
f"observable={obs_name} nqubits={args.nqubits} nlayers={args.nlayers} "
|
|
f"seed={args.seed} gates={len(circuit.queue)} tree={path}",
|
|
flush=True,
|
|
)
|
|
|
|
if mode == "contract" and not path.exists():
|
|
raise FileNotFoundError(f"Missing tree file: {path}. Run search first.")
|
|
|
|
exact = None
|
|
if args.exact and rank == 0 and mode != "search":
|
|
if args.nqubits > args.exact_max_qubits:
|
|
raise ValueError(
|
|
f"--exact is limited to {args.exact_max_qubits} qubits by default."
|
|
)
|
|
exact = exact_for_observable(circuit, observable, args.nqubits)
|
|
|
|
config = ExpectationConfig(
|
|
ansatz="tn",
|
|
mpi=args.mpi,
|
|
bond=args.bond,
|
|
cut_ratio=args.cut_ratio,
|
|
tensor_module="torch",
|
|
quimb_backend=args.quimb_backend,
|
|
dtype=args.dtype,
|
|
torch_threads=args.torch_threads,
|
|
parallel_opts=build_parallel_opts(
|
|
args,
|
|
tree_file=path,
|
|
search_only=(mode == "search"),
|
|
),
|
|
)
|
|
result = run_cpu_expectation(circuit, observable, config)
|
|
if args.mpi and result.rank != 0:
|
|
return
|
|
|
|
if mode == "search":
|
|
print(f"searched observable={obs_name} tree={path}", flush=True)
|
|
else:
|
|
abs_error = float("nan") if exact is None else abs(result.value - exact)
|
|
rel_error = float("nan") if exact is None else abs_error / max(abs(exact), 1e-15)
|
|
exact_text = "nan" if exact is None else f"{exact:.16e}"
|
|
print(
|
|
f"result observable={obs_name} exact={exact_text} "
|
|
f"value={result.value:.16e} abs_error={abs_error:.6e} "
|
|
f"rel_error={rel_error:.6e} seconds={result.seconds:.3f}",
|
|
flush=True,
|
|
)
|
|
|
|
for stat in result.parallel_stats or ():
|
|
cost = stat["path_cost"]
|
|
search_stats = stat.get("search_stats", {})
|
|
print(
|
|
"tn_term_summary "
|
|
f"observable={obs_name} "
|
|
f"term={stat.get('term_index', 0)} "
|
|
f"search_seconds={stat.get('search_seconds', float('nan')):.3f} "
|
|
f"contract_seconds={stat.get('contract_seconds', float('nan')):.3f} "
|
|
f"completed_trials={search_stats.get('completed_trials', 'na')} "
|
|
f"finite_trials={search_stats.get('finite_trials', 'na')} "
|
|
f"failed_trials={search_stats.get('failed_trials', 'na')} "
|
|
f"requested_trials={search_stats.get('requested_trials', 'na')} "
|
|
f"best_score={search_stats.get('best_score', float('nan')):.6g} "
|
|
f"slices={cost.get('nslices')} "
|
|
f"log10_flops={cost.get('log10_flops', float('nan')):.3f} "
|
|
f"log10_write={cost.get('log10_write', float('nan')):.3f} "
|
|
f"log2_size={cost.get('log2_size', float('nan')):.3f} "
|
|
f"peak_memory_gib={cost.get('peak_memory_gib', float('nan')):.3g} "
|
|
f"rank_slices={stat.get('rank_slices')}",
|
|
flush=True,
|
|
)
|
|
|
|
|
|
def selected_observables(args, case):
|
|
if args.observables:
|
|
return tuple(args.observables)
|
|
if args.obs_filter:
|
|
return tuple(x.strip() for x in args.obs_filter.split(",") if x.strip())
|
|
return case.observables
|
|
|
|
|
|
def apply_case_defaults(args):
|
|
case = CASES[args.case]
|
|
if args.nqubits is None:
|
|
args.nqubits = case.nqubits
|
|
if args.nlayers is None:
|
|
args.nlayers = case.nlayers
|
|
if args.seed is None:
|
|
args.seed = case.seed
|
|
if args.tn_target_slices is None:
|
|
args.tn_target_slices = case.target_slices
|
|
args.observables = selected_observables(args, case)
|
|
|
|
|
|
def stop_dask_cluster(args):
|
|
if args.keep_dask or args.tn_search_backend != "dask" or not args.dask_address:
|
|
return
|
|
if args.mpi:
|
|
from mpi4py import MPI
|
|
|
|
if MPI.COMM_WORLD.Get_rank() != 0:
|
|
return
|
|
script = ROOT / "tools" / "manage_tn_dask_cluster.sh"
|
|
if not script.exists():
|
|
print(f"dask_stop_skipped reason=missing_script path={script}", flush=True)
|
|
return
|
|
|
|
env = os.environ.copy()
|
|
parsed = urlparse(args.dask_address)
|
|
if parsed.hostname:
|
|
env.setdefault("SCHEDULER_HOST", parsed.hostname)
|
|
if parsed.port:
|
|
env.setdefault("SCHEDULER_PORT", str(parsed.port))
|
|
|
|
print("dask_stop_after_search start", flush=True)
|
|
subprocess.run([str(script), "stop"], cwd=str(ROOT), env=env, check=False)
|
|
print("dask_stop_after_search done", flush=True)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("mode", choices=("search", "contract", "all", "validate", "list"))
|
|
parser.add_argument("--case", choices=sorted(CASES), default="main1")
|
|
parser.add_argument("--observables", nargs="+")
|
|
parser.add_argument("--obs-filter", default="")
|
|
parser.add_argument("--tree-dir", default="trees/contest_tn")
|
|
parser.add_argument("--nqubits", type=int)
|
|
parser.add_argument("--nlayers", type=int)
|
|
parser.add_argument("--seed", type=int)
|
|
parser.add_argument("--mpi", action="store_true")
|
|
parser.add_argument("--exact", action="store_true")
|
|
parser.add_argument("--exact-max-qubits", type=int, default=24)
|
|
parser.add_argument("--bond", "--bonds", dest="bond", type=optional_int, default=1024)
|
|
parser.add_argument("--cut-ratio", type=optional_float, default=1e-12)
|
|
parser.add_argument("--torch-threads", type=int, default=8)
|
|
parser.add_argument("--quimb-backend", choices=("numpy", "torch"), default="torch")
|
|
parser.add_argument("--dtype", choices=("complex128", "complex64"), default="complex64")
|
|
parser.add_argument("--tn-target-slices", type=int)
|
|
parser.add_argument("--tn-target-size", type=int, default=2**34)
|
|
parser.add_argument("--tn-search-workers", type=int)
|
|
parser.add_argument("--tn-search-repeats", type=int, default=2048)
|
|
parser.add_argument("--tn-search-time", type=float, default=300.0)
|
|
parser.add_argument(
|
|
"--tn-search-backend",
|
|
choices=("processpool", "dask"),
|
|
default="dask",
|
|
help=(
|
|
"Path-search backend. Defaults to dask. Without --dask-address, "
|
|
"non-MPI search starts a local dask cluster."
|
|
),
|
|
)
|
|
parser.add_argument("--dask-address")
|
|
parser.add_argument("--dask-expected-workers", type=int)
|
|
parser.add_argument("--dask-close-workers", action="store_true")
|
|
parser.add_argument(
|
|
"--keep-dask",
|
|
action="store_true",
|
|
help=(
|
|
"Keep an external dask cluster running after search. By default, "
|
|
"tools/manage_tn_dask_cluster.sh stop is called after search when "
|
|
"--dask-address is used."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--tn-debug-trials",
|
|
action="store_true",
|
|
help="Print dask worker summary and per-trial start/done logs.",
|
|
)
|
|
parser.add_argument("--no-tn-stats", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
if args.mode == "list":
|
|
for name, case in CASES.items():
|
|
print(
|
|
f"{name}: circuit={case.circuit_kind} "
|
|
f"observables={','.join(case.observables)} "
|
|
f"nqubits={case.nqubits} nlayers={case.nlayers} "
|
|
f"seed={case.seed} target_slices={case.target_slices}"
|
|
)
|
|
return
|
|
|
|
apply_case_defaults(args)
|
|
set_torch_threads(args.torch_threads)
|
|
|
|
modes = ("search", "contract") if args.mode == "all" else (args.mode,)
|
|
if args.mode == "validate":
|
|
args.exact = True
|
|
args.nqubits = min(args.nqubits, args.exact_max_qubits)
|
|
modes = ("search", "contract")
|
|
|
|
for mode in modes:
|
|
for obs_name in args.observables:
|
|
run_one(args, args.case, obs_name, mode)
|
|
if mode == "search":
|
|
stop_dask_cluster(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|