"""CLI for CPU TN/MPS expectation benchmarks.""" from __future__ import annotations import argparse import os import subprocess from pathlib import Path from urllib.parse import urlparse from qibotn.benchmark_cases import ( CIRCUITS, OBSERVABLES, build_circuit, observable_terms, parse_names, terms_to_dict, ) from qibotn.expectation_runner import ( ExpectationConfig, exact_for_observable, run_cpu_expectation, ) def optional_int(text): if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}: return None return int(text) def optional_float(text): if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}: return None return float(text) def format_optional(value, fmt="g"): return "None" if value is None else format(value, fmt) def should_stop_dask(args): return ( not args.keep_dask and args.tn_search_backend == "dask" and args.dask_address is not None and args.tn_load_tree is None ) def stop_dask_cluster(args, rank): if rank != 0 or not should_stop_dask(args): return script = Path(__file__).resolve().parent / "tools" / "manage_tn_dask_cluster.sh" if not script.exists(): print(f"dask_stop_skipped reason=missing_script path={script}", flush=True) return env = os.environ.copy() parsed = urlparse(args.dask_address) if parsed.hostname: env.setdefault("SCHEDULER_HOST", parsed.hostname) if parsed.port: env.setdefault("SCHEDULER_PORT", str(parsed.port)) print("dask_stop_after_search start", flush=True) subprocess.run([str(script), "stop"], cwd=str(script.parent.parent), env=env, check=False) print("dask_stop_after_search done", flush=True) def build_parallel_opts(args): slicing_opts = {} if args.tn_target_slices is not None: slicing_opts["target_slices"] = args.tn_target_slices if args.tn_target_size is not None: slicing_opts["target_size"] = args.tn_target_size opts = { "slicing_opts": slicing_opts or None, "search_workers": args.tn_search_workers or args.torch_threads, "max_repeats": args.tn_search_repeats, "max_time": args.tn_search_time, "print_stats": not args.no_tn_stats, } if args.tn_search_backend is not None: opts["search_backend"] = args.tn_search_backend if args.dask_address is not None: opts["dask_address"] = args.dask_address if args.tn_save_tree is not None: opts["save_tree_path"] = args.tn_save_tree if args.tn_load_tree is not None: opts["load_tree_path"] = args.tn_load_tree if args.tn_search_only: opts["search_only"] = True if args.tn_debug_trials: opts["debug_trials"] = True if args.tn_contract_implementation is not None: opts["contract_implementation"] = args.tn_contract_implementation if args.dask_close_workers: opts["dask_close_workers"] = True return opts def main(): parser = argparse.ArgumentParser() parser.add_argument("--nqubits", type=int, default=40) parser.add_argument("--nlayers", type=int, default=30) parser.add_argument("--bond", "--bonds", dest="bond", type=optional_int, default=1024) parser.add_argument("--cut-ratio", type=optional_float, default=1e-12) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--torch-threads", type=int, default=8) parser.add_argument("--quimb-backend", choices=("numpy", "torch"), default="torch") parser.add_argument( "--dtype", choices=("complex128", "complex64"), default="complex128", ) parser.add_argument("--ansatz", choices=("tn", "mps"), default=None) parser.add_argument("--mps", action="store_true") parser.add_argument("--mpi", action="store_true") parser.add_argument("--exact", action="store_true") parser.add_argument("--exact-max-qubits", type=int, default=24) parser.add_argument("--circuits", nargs="+", default=["brickwall_cnot"]) parser.add_argument("--observables", nargs="+", default=["ring_xz"]) parser.add_argument("--pauli-pattern") parser.add_argument("--tn-target-slices", type=int) parser.add_argument("--tn-target-size", type=int,default=2**32) parser.add_argument("--tn-search-workers", type=int) parser.add_argument("--tn-search-repeats", type=int, default=128) parser.add_argument("--tn-search-time", type=float, default=60.0) parser.add_argument( "--no-tn-stats", action="store_true", help="Do not print per-term TN search/contraction diagnostics.", ) parser.add_argument( "--tn-search-backend", choices=("processpool", "dask"), default="dask", help="Path-search backend. In MPI mode, dask search runs only on rank 0 and broadcasts the tree.", ) parser.add_argument( "--dask-address", help="Dask scheduler address, for example tcp://host:8786. If omitted with dask search, a local cluster is created.", ) parser.add_argument( "--dask-close-workers", action="store_true", help="After dask path search, ask the scheduler to close all currently connected workers.", ) parser.add_argument( "--keep-dask", action="store_true", help=( "Keep an external dask cluster running after search. By default, " "tools/manage_tn_dask_cluster.sh stop is called after search when " "--dask-address is used." ), ) parser.add_argument( "--tn-save-tree", help="Save searched cotengra contraction tree(s) to this pickle file.", ) parser.add_argument( "--tn-load-tree", help="Load cotengra contraction tree(s) from this pickle file and skip path search.", ) parser.add_argument( "--tn-search-only", action="store_true", help="Only run path search and optional --tn-save-tree; skip contraction.", ) parser.add_argument( "--tn-debug-trials", action="store_true", help="Print dask worker summary and per-trial worker start/done logs.", ) parser.add_argument( "--tn-contract-implementation", choices=("auto", "cotengra", "autoray", "cpp"), help="cotengra contraction implementation for TN contraction.", ) args = parser.parse_args() ansatz = "mps" if args.mps else (args.ansatz or "tn") circuits = parse_names(args.circuits, CIRCUITS, "circuits") observables = [] if args.pauli_pattern else parse_names( args.observables, OBSERVABLES, "observables" ) rank = 0 if args.mpi: from mpi4py import MPI rank = MPI.COMM_WORLD.Get_rank() config = ExpectationConfig( ansatz=ansatz, mpi=args.mpi, bond=args.bond, cut_ratio=args.cut_ratio, tensor_module="torch", quimb_backend=args.quimb_backend, dtype=args.dtype, torch_threads=args.torch_threads, parallel_opts=build_parallel_opts(args), ) if rank == 0: mode = "MPI" if args.mpi else "serial" print( f"backend=cpu ansatz={ansatz.upper()} mode={mode} " f"nqubits={args.nqubits} nlayers={args.nlayers} " f"bond={format_optional(args.bond)} " f"cut_ratio={format_optional(args.cut_ratio)} seed={args.seed} " f"quimb_backend={args.quimb_backend} dtype={args.dtype} " f"torch_threads={args.torch_threads} " f"tn_search_backend={args.tn_search_backend}" ) print("circuit observable exact value abs_error rel_error seconds") try: for circuit_kind in circuits: circuit = build_circuit(circuit_kind, args.nqubits, args.nlayers, args.seed) named_observables = ( [(f"pattern:{args.pauli_pattern}", {"pauli_string_pattern": args.pauli_pattern})] if args.pauli_pattern else [ (obs_kind, terms_to_dict(observable_terms(obs_kind, args.nqubits))) for obs_kind in observables ] ) for obs_name, observable in named_observables: exact = None if args.exact and rank == 0: if args.nqubits > args.exact_max_qubits: raise ValueError( f"--exact is limited to {args.exact_max_qubits} qubits by default." ) exact = exact_for_observable(circuit, observable, args.nqubits) result = run_cpu_expectation(circuit, observable, config) if args.mpi and result.rank != 0: continue abs_error = float("nan") if exact is None else abs(result.value - exact) rel_error = ( float("nan") if exact is None else abs_error / max(abs(exact), 1e-15) ) exact_text = "nan" if exact is None else f"{exact:.16e}" print( f"{circuit_kind} {obs_name} {exact_text} {result.value:.16e} " f"{abs_error:.6e} {rel_error:.6e} {result.seconds:.3f}" ) for stat in result.parallel_stats or (): cost = stat["path_cost"] search_stats = stat.get("search_stats", {}) print( "tn_term_summary " f"term={stat.get('term_index', 0)} " f"search_seconds={stat.get('search_seconds', float('nan')):.3f} " f"contract_seconds={stat.get('contract_seconds', float('nan')):.3f} " f"completed_trials={search_stats.get('completed_trials', 'na')} " f"finite_trials={search_stats.get('finite_trials', 'na')} " f"failed_trials={search_stats.get('failed_trials', 'na')} " f"requested_trials={search_stats.get('requested_trials', 'na')} " f"best_score={search_stats.get('best_score', float('nan')):.6g} " f"slices={cost['nslices']} " f"log10_flops={cost['log10_flops']:.3f} " f"log10_write={cost['log10_write']:.3f} " f"log2_size={cost['log2_size']:.3f} " f"log10_combo={cost['log10_combo']:.3f} " f"peak_memory_gib={cost['peak_memory_gib']:.6g} " f"slicing_overhead={cost['slicing_overhead']:.6g} " f"rank_slices={stat.get('rank_slices', 'na')}" ) finally: stop_dask_cluster(args, rank) if __name__ == "__main__": main()