Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
244 lines
8.5 KiB
Python
244 lines
8.5 KiB
Python
#!/usr/bin/env python
|
|
"""Run TN expectation for a user-provided circuit and observable.
|
|
|
|
The case module should define:
|
|
|
|
def build_circuit(nqubits, nlayers, seed): ...
|
|
def build_observable(nqubits, seed): ...
|
|
|
|
``build_observable`` may return a Qibo SymbolicHamiltonian/form or the qibotn
|
|
dict form:
|
|
|
|
{"terms": [
|
|
{"coefficient": 1.0, "operators": [("X", 0), ("Z", 1)]},
|
|
]}
|
|
|
|
For a single repeated Pauli string, pass ``--pauli-pattern`` instead of
|
|
defining ``build_observable``.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import importlib.util
|
|
import inspect
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SRC = ROOT / "src"
|
|
if str(SRC) not in sys.path:
|
|
sys.path.insert(0, str(SRC))
|
|
|
|
from qibotn.expectation_runner import ( # noqa: E402
|
|
ExpectationConfig,
|
|
exact_for_observable,
|
|
run_cpu_expectation,
|
|
)
|
|
|
|
|
|
def optional_int(text):
|
|
if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}:
|
|
return None
|
|
return int(text)
|
|
|
|
|
|
def optional_float(text):
|
|
if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}:
|
|
return None
|
|
return float(text)
|
|
|
|
|
|
def load_module(path):
|
|
path = Path(path).resolve()
|
|
spec = importlib.util.spec_from_file_location(path.stem, path)
|
|
if spec is None or spec.loader is None:
|
|
raise RuntimeError(f"Cannot import case module from {path}.")
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
|
|
def call_builder(fn, **kwargs):
|
|
sig = inspect.signature(fn)
|
|
if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()):
|
|
return fn(**kwargs)
|
|
accepted = {
|
|
name: value
|
|
for name, value in kwargs.items()
|
|
if name in sig.parameters
|
|
}
|
|
return fn(**accepted)
|
|
|
|
|
|
def load_observable(args, module):
|
|
if args.pauli_pattern:
|
|
return {"pauli_string_pattern": args.pauli_pattern}
|
|
if args.observable_json:
|
|
with Path(args.observable_json).open() as f:
|
|
return json.load(f)
|
|
if hasattr(module, "build_observable"):
|
|
return call_builder(
|
|
module.build_observable,
|
|
nqubits=args.nqubits,
|
|
nlayers=args.nlayers,
|
|
seed=args.seed,
|
|
)
|
|
if hasattr(module, "OBSERVABLE"):
|
|
return module.OBSERVABLE
|
|
raise ValueError(
|
|
"No observable supplied. Define build_observable/OBSERVABLE in the case "
|
|
"module, or pass --pauli-pattern / --observable-json."
|
|
)
|
|
|
|
|
|
def build_parallel_opts(args):
|
|
slicing_opts = {}
|
|
if args.tn_target_slices is not None:
|
|
slicing_opts["target_slices"] = args.tn_target_slices
|
|
if args.tn_target_size is not None:
|
|
slicing_opts["target_size"] = args.tn_target_size
|
|
|
|
opts = {
|
|
"slicing_opts": slicing_opts or None,
|
|
"search_workers": args.tn_search_workers or args.torch_threads,
|
|
"max_repeats": args.tn_search_repeats,
|
|
"max_time": args.tn_search_time,
|
|
"print_stats": not args.no_tn_stats,
|
|
}
|
|
if args.tn_search_backend is not None:
|
|
opts["search_backend"] = args.tn_search_backend
|
|
if args.dask_address is not None:
|
|
opts["dask_address"] = args.dask_address
|
|
if args.dask_close_workers:
|
|
opts["dask_close_workers"] = True
|
|
if args.tn_save_tree is not None:
|
|
opts["save_tree_path"] = args.tn_save_tree
|
|
if args.tn_load_tree is not None:
|
|
opts["load_tree_path"] = args.tn_load_tree
|
|
if args.tn_search_only:
|
|
opts["search_only"] = True
|
|
return opts
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Run CPU TN expectation for a custom qibo circuit module."
|
|
)
|
|
parser.add_argument("case_module", help="Python file defining build_circuit.")
|
|
parser.add_argument("--nqubits", type=int, required=True)
|
|
parser.add_argument("--nlayers", type=int, default=0)
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--mpi", action="store_true")
|
|
parser.add_argument("--exact", action="store_true")
|
|
parser.add_argument("--exact-max-qubits", type=int, default=24)
|
|
parser.add_argument("--bond", "--bonds", dest="bond", type=optional_int, default=1024)
|
|
parser.add_argument("--cut-ratio", type=optional_float, default=1e-12)
|
|
parser.add_argument("--torch-threads", type=int, default=8)
|
|
parser.add_argument("--quimb-backend", choices=("numpy", "torch"), default="torch")
|
|
parser.add_argument("--dtype", choices=("complex128", "complex64"), default="complex128")
|
|
parser.add_argument("--pauli-pattern")
|
|
parser.add_argument("--observable-json")
|
|
parser.add_argument("--tn-target-slices", type=int)
|
|
parser.add_argument("--tn-target-size", type=int, default=2**32)
|
|
parser.add_argument("--tn-search-workers", type=int)
|
|
parser.add_argument("--tn-search-repeats", type=int, default=128)
|
|
parser.add_argument("--tn-search-time", type=float, default=60.0)
|
|
parser.add_argument("--tn-search-backend", choices=("processpool", "dask"))
|
|
parser.add_argument("--dask-address")
|
|
parser.add_argument("--dask-close-workers", action="store_true")
|
|
parser.add_argument("--tn-save-tree")
|
|
parser.add_argument("--tn-load-tree")
|
|
parser.add_argument("--tn-search-only", action="store_true")
|
|
parser.add_argument("--no-tn-stats", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
rank = 0
|
|
if args.mpi:
|
|
from mpi4py import MPI
|
|
|
|
rank = MPI.COMM_WORLD.Get_rank()
|
|
|
|
module = load_module(args.case_module)
|
|
if not hasattr(module, "build_circuit"):
|
|
raise ValueError("case_module must define build_circuit.")
|
|
|
|
circuit = call_builder(
|
|
module.build_circuit,
|
|
nqubits=args.nqubits,
|
|
nlayers=args.nlayers,
|
|
seed=args.seed,
|
|
)
|
|
observable = load_observable(args, module)
|
|
|
|
config = ExpectationConfig(
|
|
ansatz="tn",
|
|
mpi=args.mpi,
|
|
bond=args.bond,
|
|
cut_ratio=args.cut_ratio,
|
|
tensor_module="torch",
|
|
quimb_backend=args.quimb_backend,
|
|
dtype=args.dtype,
|
|
torch_threads=args.torch_threads,
|
|
parallel_opts=build_parallel_opts(args),
|
|
)
|
|
|
|
if rank == 0:
|
|
mode = "MPI" if args.mpi else "serial"
|
|
print(
|
|
f"backend=cpu ansatz=TN mode={mode} case={Path(args.case_module).name} "
|
|
f"nqubits={args.nqubits} nlayers={args.nlayers} seed={args.seed} "
|
|
f"quimb_backend={args.quimb_backend} dtype={args.dtype} "
|
|
f"torch_threads={args.torch_threads}",
|
|
flush=True,
|
|
)
|
|
print("observable exact value abs_error rel_error seconds", flush=True)
|
|
|
|
exact = None
|
|
if args.exact and rank == 0:
|
|
if args.nqubits > args.exact_max_qubits:
|
|
raise ValueError(
|
|
f"--exact is limited to {args.exact_max_qubits} qubits by default."
|
|
)
|
|
exact = exact_for_observable(circuit, observable, args.nqubits)
|
|
|
|
result = run_cpu_expectation(circuit, observable, config)
|
|
if args.mpi and result.rank != 0:
|
|
return
|
|
|
|
abs_error = float("nan") if exact is None else abs(result.value - exact)
|
|
rel_error = float("nan") if exact is None else abs_error / max(abs(exact), 1e-15)
|
|
exact_text = "nan" if exact is None else f"{exact:.16e}"
|
|
print(
|
|
f"custom {exact_text} {result.value:.16e} "
|
|
f"{abs_error:.6e} {rel_error:.6e} {result.seconds:.3f}",
|
|
flush=True,
|
|
)
|
|
|
|
for stat in result.parallel_stats or ():
|
|
cost = stat["path_cost"]
|
|
search_stats = stat.get("search_stats", {})
|
|
print(
|
|
"tn_term_summary "
|
|
f"term={stat.get('term_index', 0)} "
|
|
f"search_seconds={stat.get('search_seconds', float('nan')):.3f} "
|
|
f"contract_seconds={stat.get('contract_seconds', float('nan')):.3f} "
|
|
f"completed_trials={search_stats.get('completed_trials', 'na')} "
|
|
f"finite_trials={search_stats.get('finite_trials', 'na')} "
|
|
f"failed_trials={search_stats.get('failed_trials', 'na')} "
|
|
f"requested_trials={search_stats.get('requested_trials', 'na')} "
|
|
f"best_score={search_stats.get('best_score', float('nan')):.6g} "
|
|
f"slices={cost.get('slices')} "
|
|
f"log10_flops={cost.get('log10_flops', float('nan')):.3f} "
|
|
f"log10_write={cost.get('log10_write', float('nan')):.3f} "
|
|
f"log2_size={cost.get('log2_size', float('nan')):.3f} "
|
|
f"peak_memory_gib={cost.get('peak_memory_gib', float('nan')):.3g} "
|
|
f"rank_slices={stat.get('rank_slices')}",
|
|
flush=True,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|