Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
247 lines
8.8 KiB
Python
247 lines
8.8 KiB
Python
"""MPI-parallel TN benchmark: path search + contraction via MPI."""
|
|
import pickle
|
|
import time
|
|
import argparse
|
|
import numpy as np
|
|
import cotengra as ctg
|
|
import qibo
|
|
from qibo import Circuit, gates
|
|
from mpi4py import MPI
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
|
|
|
|
def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks, max_time=600):
|
|
import pickle, cotengra as ctg, random
|
|
random.seed(seed)
|
|
tn = pickle.loads(tn_bytes)
|
|
opt = ctg.HyperOptimizer(
|
|
methods=['kahypar', 'kahypar-agglom', 'spinglass'],
|
|
max_repeats=repeats,
|
|
parallel=False,
|
|
minimize='combo-256',
|
|
max_time=max_time,
|
|
optlib="random",
|
|
slicing_opts={'target_size': 2**29, 'allow_outer': True},
|
|
progbar=False,
|
|
)
|
|
tree = tn.contraction_tree(optimize=opt, output_inds=output_inds)
|
|
return tree.combo_cost(factor=256), tree
|
|
|
|
|
|
def parallel_search(tn, output_inds, total_repeats, n_workers, num_slices, n_ranks,
|
|
timeout=60):
|
|
import pickle, os, signal
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
tn_bytes = pickle.dumps(tn)
|
|
repeats_per = max(1, total_repeats // n_workers)
|
|
best_cost, best_tree = float('inf'), None
|
|
|
|
pool = ProcessPoolExecutor(max_workers=n_workers)
|
|
futures = [
|
|
pool.submit(_run_serial_search, tn_bytes, output_inds,
|
|
repeats_per, seed, num_slices, n_ranks, timeout)
|
|
for seed in range(n_workers)
|
|
]
|
|
try:
|
|
for fut in as_completed(futures, timeout=timeout + 5):
|
|
try:
|
|
cost, tree = fut.result()
|
|
if cost < best_cost:
|
|
best_cost, best_tree = cost, tree
|
|
except Exception as e:
|
|
print(f" [worker failed] {e}")
|
|
except TimeoutError:
|
|
pass
|
|
finally:
|
|
for fut in futures:
|
|
fut.cancel()
|
|
for pid in list(pool._processes.keys()):
|
|
try:
|
|
os.kill(pid, signal.SIGKILL)
|
|
except ProcessLookupError:
|
|
pass
|
|
pool.shutdown(wait=False)
|
|
|
|
return best_tree
|
|
|
|
|
|
def make_circuit(circuit_type, nqubits, nlayers=1):
|
|
c = Circuit(nqubits)
|
|
if circuit_type == "qft":
|
|
from qibo.models import QFT
|
|
return QFT(nqubits)
|
|
elif circuit_type == "variational":
|
|
for layer in range(nlayers):
|
|
for q in range(nqubits):
|
|
c.add(gates.RY(q, theta=np.random.uniform(0, 2 * np.pi)))
|
|
offset = layer % 2
|
|
for q in range(offset, nqubits - 1, 2):
|
|
c.add(gates.CZ(q, q + 1))
|
|
elif circuit_type == "ghz":
|
|
c.add(gates.H(0))
|
|
for q in range(nqubits - 1):
|
|
c.add(gates.CNOT(q, q + 1))
|
|
elif circuit_type == "brickwork":
|
|
for q in range(nqubits):
|
|
c.add(gates.H(q))
|
|
for layer in range(nlayers):
|
|
offset = layer % 2
|
|
for q in range(offset, nqubits - 1, 2):
|
|
c.add(gates.CNOT(q, q + 1))
|
|
c.add(gates.RZ(q, theta=np.random.uniform(0, 2 * np.pi)))
|
|
c.add(gates.RZ(q + 1, theta=np.random.uniform(0, 2 * np.pi)))
|
|
else:
|
|
raise ValueError(f"Unknown circuit: {circuit_type}")
|
|
return c
|
|
|
|
|
|
def _contract_mpi(tree, arrays, comm, root=0):
|
|
rank = comm.Get_rank()
|
|
size = comm.Get_size()
|
|
is_torch = type(arrays[0]).__module__.startswith("torch")
|
|
|
|
result_np = None
|
|
for i in range(rank, tree.multiplicity, size):
|
|
x = tree.contract_slice(arrays, i)
|
|
x_np = np.asfortranarray(x.detach().cpu().numpy() if is_torch else np.asarray(x))
|
|
result_np = x_np if result_np is None else result_np + x_np
|
|
|
|
if result_np is None:
|
|
result_np = np.zeros(1, dtype=np.complex128)
|
|
|
|
result = np.zeros_like(result_np) if rank == root else None
|
|
comm.Reduce(result_np, result, root=root)
|
|
|
|
if rank == root:
|
|
import torch
|
|
return torch.from_numpy(np.asarray(result)) if is_torch else result
|
|
return None
|
|
|
|
|
|
def run_mpi(circuit, nqubits, num_slices, total_repeats=1024,
|
|
load_path=None, save_path=None):
|
|
"""Each MPI rank runs serial path search over total_repeats/size trials,
|
|
rank 0 picks the global best, then all ranks contract in parallel."""
|
|
comm = MPI.COMM_WORLD
|
|
rank = comm.Get_rank()
|
|
size = comm.Get_size()
|
|
|
|
qibo.set_backend("qibotn", platform="quimb")
|
|
b = qibo.get_backend()
|
|
b.configure_tn_simulation(ansatz="tn")
|
|
|
|
import torch
|
|
qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz,
|
|
gate_opts={"max_bond": None, "cutoff": 1e-10})
|
|
qc.to_backend = lambda x: torch.from_numpy(x).to(torch.complex128)
|
|
|
|
# --- path search: each rank serial, gather best to rank 0 ---
|
|
if load_path:
|
|
if rank == 0:
|
|
with open(load_path, "rb") as f:
|
|
saved = pickle.load(f)
|
|
tree, psi, t_search = saved["tree"], saved["psi"], 0.0
|
|
print(f" [path loaded] {load_path}")
|
|
else:
|
|
tree = psi = None
|
|
t_search = 0.0
|
|
else:
|
|
rank_repeats = max(1, total_repeats // size)
|
|
t0 = time.time()
|
|
# get TN object first (no contraction), then run parallel search
|
|
psi_tn = qc.to_dense(rehearse="tn")
|
|
local_tree = parallel_search(
|
|
psi_tn, psi_tn.outer_inds(), rank_repeats, n_workers=48,
|
|
num_slices=num_slices, n_ranks=size, timeout=60,
|
|
)
|
|
t_search = time.time() - t0
|
|
local_psi = psi_tn
|
|
|
|
all_results = comm.gather((local_tree.combo_cost(factor=256), local_tree, local_psi), root=0)
|
|
if rank == 0:
|
|
_, tree, psi = min(all_results, key=lambda x: x[0])
|
|
print(f" [path search] {t_search:.3f}s "
|
|
f"flops~2^{tree.contraction_cost(log=2):.2f} "
|
|
f"size~2^{tree.contraction_width():.2f} "
|
|
f"slices={tree.multiplicity}")
|
|
if save_path:
|
|
with open(save_path, "wb") as f:
|
|
pickle.dump({"tree": tree, "psi": psi}, f)
|
|
print(f" [path saved] {save_path}")
|
|
else:
|
|
tree = psi = None
|
|
|
|
if save_path:
|
|
t_search = comm.bcast(t_search, root=0)
|
|
return None, t_search
|
|
|
|
tree = comm.bcast(tree, root=0)
|
|
psi = comm.bcast(psi, root=0)
|
|
t_search = comm.bcast(t_search, root=0)
|
|
|
|
# --- contraction: all ranks work in parallel ---
|
|
import torch
|
|
torch.set_num_threads(max(1, 48 // size))
|
|
arrays = [torch.from_numpy(np.asarray(a)).to(torch.complex128) for a in psi.arrays]
|
|
t0 = time.time()
|
|
sv = _contract_mpi(tree, arrays, comm, root=0)
|
|
t_contract = time.time() - t0
|
|
|
|
if rank == 0:
|
|
print(f" [contraction] {t_contract:.3f}s")
|
|
return np.array(sv).reshape(-1), t_search + t_contract
|
|
return None, t_search + t_contract
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--nqubits", type=int, default=30)
|
|
parser.add_argument("--circuit", type=str, default="qft",
|
|
choices=["qft", "variational", "ghz", "brickwork"])
|
|
parser.add_argument("--nlayers", type=int, default=3)
|
|
parser.add_argument("--num-slices", type=int, default=1)
|
|
parser.add_argument("--total-repeats", type=int, default=1024)
|
|
parser.add_argument("--save-path", type=str, default=None)
|
|
parser.add_argument("--load-path", type=str, default=None)
|
|
parser.add_argument("--no-compare", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
comm = MPI.COMM_WORLD
|
|
rank = comm.Get_rank()
|
|
|
|
if rank == 0:
|
|
print(f"Circuit: {args.circuit}, nqubits={args.nqubits}, "
|
|
f"nlayers={args.nlayers}, ranks={comm.Get_size()}")
|
|
|
|
np.random.seed(42)
|
|
circuit = make_circuit(args.circuit, args.nqubits, args.nlayers)
|
|
|
|
try:
|
|
sv, t_total = run_mpi(circuit, args.nqubits, args.num_slices,
|
|
total_repeats=args.total_repeats,
|
|
load_path=args.load_path, save_path=args.save_path)
|
|
except Exception as e:
|
|
if rank == 0:
|
|
print(f"[FAILED] {e}")
|
|
raise
|
|
|
|
if rank == 0 and sv is not None:
|
|
print(f"\n[quimb TN MPI] time={t_total:.4f}s shape={sv.shape}")
|
|
np.save(f"data/sv_tn_{args.circuit}{args.nqubits}_mpi.npy", sv)
|
|
|
|
if not args.no_compare:
|
|
from benchmark_tn import run_qibojit
|
|
np.random.seed(42)
|
|
circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers)
|
|
sv_ref, t_ref = run_qibojit(circuit_ref)
|
|
fid = abs(np.dot(sv_ref.conj(), sv)) ** 2
|
|
print(f"[qibojit] time={t_ref:.4f}s")
|
|
print(f"Fidelity : {fid:.8f}")
|
|
print(f"L2 error : {np.linalg.norm(sv_ref - sv):.2e}")
|
|
if t_total > 0:
|
|
print(f"Speedup : {t_ref/t_total:.2f}x")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|