diff --git a/benchmark_tn_mpi.py b/benchmark_tn_mpi.py index 400eff9..ae76eca 100644 --- a/benchmark_tn_mpi.py +++ b/benchmark_tn_mpi.py @@ -10,8 +10,7 @@ from mpi4py import MPI from concurrent.futures import ProcessPoolExecutor, as_completed -def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks): - """Run one serial HyperOptimizer in a subprocess, return (width, tree).""" +def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks, max_time=600): import pickle, cotengra as ctg, random random.seed(seed) tn = pickle.loads(tn_bytes) @@ -19,48 +18,49 @@ def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks methods=['kahypar', 'kahypar-agglom', 'spinglass'], max_repeats=repeats, parallel=False, - minimize='flops', - max_time=600, + minimize='combo-256', + max_time=max_time, optlib="random", - slicing_opts={'target_size': 2**30, 'allow_outer': False}, + slicing_opts={'target_size': 2**29, 'allow_outer': True}, progbar=False, ) tree = tn.contraction_tree(optimize=opt, output_inds=output_inds) - return tree.contraction_width(), tree + return tree.combo_cost(factor=256), tree def parallel_search(tn, output_inds, total_repeats, n_workers, num_slices, n_ranks, - timeout=None): - """Launch n_workers subprocesses each running serial search, return best tree.""" + timeout=60): import pickle, os, signal from concurrent.futures import ProcessPoolExecutor, as_completed tn_bytes = pickle.dumps(tn) repeats_per = max(1, total_repeats // n_workers) - best_width, best_tree = float('inf'), None + best_cost, best_tree = float('inf'), None - with ProcessPoolExecutor(max_workers=n_workers) as pool: - futures = { - pool.submit(_run_serial_search, tn_bytes, output_inds, - repeats_per, seed, num_slices, n_ranks): seed - for seed in range(n_workers) - } - pids = {f: p.pid for f, p in zip(futures, pool._processes.values())} - try: - for fut in as_completed(futures, timeout=timeout): - try: - width, tree = fut.result() - if width < best_width: - best_width, best_tree = width, tree - except Exception as e: - print(f" [worker failed] {e}") - except TimeoutError: - pass - for fut, pid in pids.items(): - if not fut.done(): - try: - os.kill(pid, signal.SIGKILL) - except ProcessLookupError: - pass + pool = ProcessPoolExecutor(max_workers=n_workers) + futures = [ + pool.submit(_run_serial_search, tn_bytes, output_inds, + repeats_per, seed, num_slices, n_ranks, timeout) + for seed in range(n_workers) + ] + try: + for fut in as_completed(futures, timeout=timeout + 5): + try: + cost, tree = fut.result() + if cost < best_cost: + best_cost, best_tree = cost, tree + except Exception as e: + print(f" [worker failed] {e}") + except TimeoutError: + pass + finally: + for fut in futures: + fut.cancel() + for pid in list(pool._processes.keys()): + try: + os.kill(pid, signal.SIGKILL) + except ProcessLookupError: + pass + pool.shutdown(wait=False) return best_tree @@ -107,7 +107,7 @@ def _contract_mpi(tree, arrays, comm, root=0): result_np = x_np if result_np is None else result_np + x_np if result_np is None: - result_np = np.zeros(1, dtype=np.complex64) + result_np = np.zeros(1, dtype=np.complex128) result = np.zeros_like(result_np) if rank == root else None comm.Reduce(result_np, result, root=root) @@ -133,7 +133,7 @@ def run_mpi(circuit, nqubits, num_slices, total_repeats=1024, import torch qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz, gate_opts={"max_bond": None, "cutoff": 1e-10}) - qc.to_backend = lambda x: torch.from_numpy(x).to(torch.complex64) + qc.to_backend = lambda x: torch.from_numpy(x).to(torch.complex128) # --- path search: each rank serial, gather best to rank 0 --- if load_path: @@ -152,16 +152,16 @@ def run_mpi(circuit, nqubits, num_slices, total_repeats=1024, psi_tn = qc.to_dense(rehearse="tn") local_tree = parallel_search( psi_tn, psi_tn.outer_inds(), rank_repeats, n_workers=48, - num_slices=num_slices, n_ranks=size, timeout=630, + num_slices=num_slices, n_ranks=size, timeout=60, ) t_search = time.time() - t0 local_psi = psi_tn - all_results = comm.gather((local_tree.contraction_width(), local_tree, local_psi), root=0) + all_results = comm.gather((local_tree.combo_cost(factor=256), local_tree, local_psi), root=0) if rank == 0: _, tree, psi = min(all_results, key=lambda x: x[0]) print(f" [path search] {t_search:.3f}s " - f"flops~2^{tree.contraction_cost():.2f} " + f"flops~2^{tree.contraction_cost(log=2):.2f} " f"size~2^{tree.contraction_width():.2f} " f"slices={tree.multiplicity}") if save_path: @@ -182,7 +182,7 @@ def run_mpi(circuit, nqubits, num_slices, total_repeats=1024, # --- contraction: all ranks work in parallel --- import torch torch.set_num_threads(max(1, 48 // size)) - arrays = [torch.from_numpy(np.asarray(a)).to(torch.complex64) for a in psi.arrays] + arrays = [torch.from_numpy(np.asarray(a)).to(torch.complex128) for a in psi.arrays] t0 = time.time() sv = _contract_mpi(tree, arrays, comm, root=0) t_contract = time.time() - t0