修复时间剪枝功能
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
@@ -10,8 +10,7 @@ from mpi4py import MPI
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
|
||||
|
||||
def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks):
|
||||
"""Run one serial HyperOptimizer in a subprocess, return (width, tree)."""
|
||||
def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks, max_time=600):
|
||||
import pickle, cotengra as ctg, random
|
||||
random.seed(seed)
|
||||
tn = pickle.loads(tn_bytes)
|
||||
@@ -19,48 +18,49 @@ def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks
|
||||
methods=['kahypar', 'kahypar-agglom', 'spinglass'],
|
||||
max_repeats=repeats,
|
||||
parallel=False,
|
||||
minimize='flops',
|
||||
max_time=600,
|
||||
minimize='combo-256',
|
||||
max_time=max_time,
|
||||
optlib="random",
|
||||
slicing_opts={'target_size': 2**30, 'allow_outer': False},
|
||||
slicing_opts={'target_size': 2**29, 'allow_outer': True},
|
||||
progbar=False,
|
||||
)
|
||||
tree = tn.contraction_tree(optimize=opt, output_inds=output_inds)
|
||||
return tree.contraction_width(), tree
|
||||
return tree.combo_cost(factor=256), tree
|
||||
|
||||
|
||||
def parallel_search(tn, output_inds, total_repeats, n_workers, num_slices, n_ranks,
|
||||
timeout=None):
|
||||
"""Launch n_workers subprocesses each running serial search, return best tree."""
|
||||
timeout=60):
|
||||
import pickle, os, signal
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
tn_bytes = pickle.dumps(tn)
|
||||
repeats_per = max(1, total_repeats // n_workers)
|
||||
best_width, best_tree = float('inf'), None
|
||||
best_cost, best_tree = float('inf'), None
|
||||
|
||||
with ProcessPoolExecutor(max_workers=n_workers) as pool:
|
||||
futures = {
|
||||
pool = ProcessPoolExecutor(max_workers=n_workers)
|
||||
futures = [
|
||||
pool.submit(_run_serial_search, tn_bytes, output_inds,
|
||||
repeats_per, seed, num_slices, n_ranks): seed
|
||||
repeats_per, seed, num_slices, n_ranks, timeout)
|
||||
for seed in range(n_workers)
|
||||
}
|
||||
pids = {f: p.pid for f, p in zip(futures, pool._processes.values())}
|
||||
]
|
||||
try:
|
||||
for fut in as_completed(futures, timeout=timeout):
|
||||
for fut in as_completed(futures, timeout=timeout + 5):
|
||||
try:
|
||||
width, tree = fut.result()
|
||||
if width < best_width:
|
||||
best_width, best_tree = width, tree
|
||||
cost, tree = fut.result()
|
||||
if cost < best_cost:
|
||||
best_cost, best_tree = cost, tree
|
||||
except Exception as e:
|
||||
print(f" [worker failed] {e}")
|
||||
except TimeoutError:
|
||||
pass
|
||||
for fut, pid in pids.items():
|
||||
if not fut.done():
|
||||
finally:
|
||||
for fut in futures:
|
||||
fut.cancel()
|
||||
for pid in list(pool._processes.keys()):
|
||||
try:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
pool.shutdown(wait=False)
|
||||
|
||||
return best_tree
|
||||
|
||||
@@ -107,7 +107,7 @@ def _contract_mpi(tree, arrays, comm, root=0):
|
||||
result_np = x_np if result_np is None else result_np + x_np
|
||||
|
||||
if result_np is None:
|
||||
result_np = np.zeros(1, dtype=np.complex64)
|
||||
result_np = np.zeros(1, dtype=np.complex128)
|
||||
|
||||
result = np.zeros_like(result_np) if rank == root else None
|
||||
comm.Reduce(result_np, result, root=root)
|
||||
@@ -133,7 +133,7 @@ def run_mpi(circuit, nqubits, num_slices, total_repeats=1024,
|
||||
import torch
|
||||
qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz,
|
||||
gate_opts={"max_bond": None, "cutoff": 1e-10})
|
||||
qc.to_backend = lambda x: torch.from_numpy(x).to(torch.complex64)
|
||||
qc.to_backend = lambda x: torch.from_numpy(x).to(torch.complex128)
|
||||
|
||||
# --- path search: each rank serial, gather best to rank 0 ---
|
||||
if load_path:
|
||||
@@ -152,16 +152,16 @@ def run_mpi(circuit, nqubits, num_slices, total_repeats=1024,
|
||||
psi_tn = qc.to_dense(rehearse="tn")
|
||||
local_tree = parallel_search(
|
||||
psi_tn, psi_tn.outer_inds(), rank_repeats, n_workers=48,
|
||||
num_slices=num_slices, n_ranks=size, timeout=630,
|
||||
num_slices=num_slices, n_ranks=size, timeout=60,
|
||||
)
|
||||
t_search = time.time() - t0
|
||||
local_psi = psi_tn
|
||||
|
||||
all_results = comm.gather((local_tree.contraction_width(), local_tree, local_psi), root=0)
|
||||
all_results = comm.gather((local_tree.combo_cost(factor=256), local_tree, local_psi), root=0)
|
||||
if rank == 0:
|
||||
_, tree, psi = min(all_results, key=lambda x: x[0])
|
||||
print(f" [path search] {t_search:.3f}s "
|
||||
f"flops~2^{tree.contraction_cost():.2f} "
|
||||
f"flops~2^{tree.contraction_cost(log=2):.2f} "
|
||||
f"size~2^{tree.contraction_width():.2f} "
|
||||
f"slices={tree.multiplicity}")
|
||||
if save_path:
|
||||
@@ -182,7 +182,7 @@ def run_mpi(circuit, nqubits, num_slices, total_repeats=1024,
|
||||
# --- contraction: all ranks work in parallel ---
|
||||
import torch
|
||||
torch.set_num_threads(max(1, 48 // size))
|
||||
arrays = [torch.from_numpy(np.asarray(a)).to(torch.complex64) for a in psi.arrays]
|
||||
arrays = [torch.from_numpy(np.asarray(a)).to(torch.complex128) for a in psi.arrays]
|
||||
t0 = time.time()
|
||||
sv = _contract_mpi(tree, arrays, comm, root=0)
|
||||
t_contract = time.time() - t0
|
||||
|
||||
Reference in New Issue
Block a user