15 Commits
main ... tn

Author SHA1 Message Date
5479574502 优化 torch CPU 张量网络收缩路径
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
- torch CPU 收缩默认走 cotengra matmul lowering
  - 复用 mm/bmm/matmul 输出缓冲区,降低中间张量分配压力
  - 仅回收 contiguous tensor,避免非连续 view 进入 workspace
  - 调整 cotengra 中间节点 index 顺序,减少 reshape 触发 clone/copy
  - qibotn MPI 分片收缩显式使用 backend=torch
  - rank 内分片结果先在 torch 中累加,最后再转 numpy 做 Reduce
  - 统一 quimb 后端 torch 数组转换为 CPU contiguous complex128
2026-05-08 11:57:18 +08:00
cec0ba272a 完善脚本功能,添加计时估计 2026-05-08 10:20:03 +08:00
8b71ff96c8 误传profiling删除
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
2026-05-08 00:16:28 +08:00
49b27a5840 完善时间剪枝功能
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
2026-05-08 00:12:32 +08:00
c818ac7a6e mpi完善
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
2026-05-07 23:37:15 +08:00
0a96553bd8 多机测试
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
2026-05-07 23:35:23 +08:00
2f5c863952 并行化支持完善
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
2026-05-07 23:26:53 +08:00
fbae48eb3d 期望值计算支持;更新后端调用
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
2026-05-07 11:19:45 +08:00
f776fbb04f 修复时间剪枝功能
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
2026-05-05 23:31:23 +08:00
5a692033a6 添加MPI并行TN benchmark及辅助脚本,移除旧benchmark
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-05 19:04:09 +08:00
a3f39a1d67 删除tn脚本implementation 2026-05-03 19:06:21 +08:00
dd222587b7 tn脚本更新
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
2026-05-03 18:54:05 +08:00
740828872e 添加tn脚本
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
2026-04-28 23:07:39 +08:00
80d9c1de5a benchmark测试,发现瓶颈:路径搜索
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
2026-04-27 18:59:54 +08:00
2c54840e7b 1.完成mps态脚本,与原始qibojit结果比对确定bond demension和cut off值;2.更新了官方库;3.新大陆
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
2026-04-27 11:03:57 +08:00
19 changed files with 1064 additions and 158 deletions

8
.gitignore vendored
View File

@@ -2,10 +2,14 @@
__pycache__/
*.py[cod]
*$py.class
data/
# C extensions
*.so
bak/
path/
profiles/
vtune_expval/
perf*
# Distribution / packaging
.Python
build/

View File

@@ -0,0 +1,56 @@
"""MPI parallel sliced contraction using pre-sliced tree."""
import time, pickle, os
import numpy as np
from mpi4py import MPI
NQUBITS, NLAYERS, NCORES = 25, 10, 48
comm = MPI.COMM_WORLD
rank, size = comm.Get_rank(), comm.Get_size()
os.environ['OMP_NUM_THREADS'] = str(NCORES)
os.environ['MKL_NUM_THREADS'] = str(NCORES)
import torch
import qibo, quimb as qu
from qibotn.observables import build_random_circuit
torch.set_num_threads(NCORES)
circuit = build_random_circuit(NQUBITS, NLAYERS)
qibo.set_backend("qibotn", platform="quimb")
backend = qibo.get_backend()
backend.configure_tn_simulation(ansatz="tn")
qc = backend._qibo_circuit_to_quimb(circuit, backend.circuit_ansatz)
tn = qc.local_expectation(qu.pauli('x') & qu.pauli('z'), (0, 1), rehearse='tn')
if rank == 0:
with open(f"data/tree_q{NQUBITS}_l{NLAYERS}_sliced.pkl", 'rb') as f:
tree = pickle.load(f)
else:
tree = None
tree = comm.bcast(tree, root=0)
arrays = [torch.from_numpy(np.asarray(t._data)) for t in tn.tensors]
n_slices = tree.multiplicity
if rank == 0:
print(f"Slices: {n_slices}, Ranks: {size}, "
f"Peak: {tree.max_size() * 16 / 1e9:.2f} GB, "
f"Threads/rank: {NCORES}, Backend: torch")
t0 = time.time()
result = None
for i in range(rank, n_slices, size):
val = tree.contract_slice(arrays, i, backend='torch')
val_np = val.cpu().numpy().reshape(-1)
result = val_np if result is None else result + val_np
if result is None:
result = np.zeros(1, dtype=np.complex128)
total = np.zeros_like(result) if rank == 0 else None
comm.Reduce(result, total, root=0)
if rank == 0:
print(f"Contract: {time.time() - t0:.4f}s Expectation: {0.5 * total[0].real:.10f}")

34
benchmark_search.py Normal file
View File

@@ -0,0 +1,34 @@
"""Search contraction path and save."""
import time, os, pickle
from qibotn.parallel import parallel_path_search
from qibotn.observables import build_random_circuit
import qibo, quimb as qu
from mpi4py import MPI
NQUBITS, NLAYERS, WORKERS = 20, 10, 96
comm = MPI.COMM_WORLD
rank, size = comm.Get_rank(), comm.Get_size()
method = 'mpi' if size > 1 else 'processpool'
circuit = build_random_circuit(NQUBITS, NLAYERS)
qibo.set_backend("qibotn", platform="quimb")
backend = qibo.get_backend()
backend.configure_tn_simulation(ansatz="tn")
qc = backend._qibo_circuit_to_quimb(circuit, backend.circuit_ansatz)
tn = qc.local_expectation(qu.pauli('x') & qu.pauli('z'), (0, 1), rehearse='tn')
if rank == 0:
print(f"Searching {NQUBITS}q {NLAYERS}l, method={method}, ranks={size}, workers/rank={WORKERS}...")
t0 = time.time()
tree = parallel_path_search(tn, tn.outer_inds(), method=method,
total_repeats=1024, max_time=300, n_workers=WORKERS,trial_timeout=60)
t_search = time.time() - t0
if rank == 0:
os.makedirs('data', exist_ok=True)
path = f"data/tree_q{NQUBITS}_l{NLAYERS}.pkl"
with open(path, 'wb') as f:
pickle.dump(tree, f)
print(f"Search: {t_search:.2f}s Peak: {tree.max_size() * 16 / 1e9:.2f} GB Saved: {path}")

16
benchmark_slice.py Normal file
View File

@@ -0,0 +1,16 @@
"""Slice saved tree and save."""
import pickle
NQUBITS, NLAYERS = 25, 10
with open(f"data/tree_q{NQUBITS}_l{NLAYERS}.pkl", 'rb') as f:
tree = pickle.load(f)
print(f"Original peak: {tree.max_size() * 16 / 1e9:.2f} GB")
tree_sliced = tree.slice_and_reconfigure(target_size=2**28)
with open(f"data/tree_q{NQUBITS}_l{NLAYERS}_sliced.pkl", 'wb') as f:
pickle.dump(tree_sliced, f)
print(f"Sliced peak: {tree_sliced.max_size() * 16 / 1e9:.2f} GB Slices: {tree_sliced.multiplicity}")

378
benchmark_tn_mpi.py Normal file
View File

@@ -0,0 +1,378 @@
"""MPI-parallel TN benchmark: path search + contraction via MPI."""
import json
import pickle
import time
import argparse
import numpy as np
import cotengra as ctg
import qibo
from qibo import Circuit, gates
from mpi4py import MPI
from concurrent.futures import ProcessPoolExecutor, as_completed
from qibotn.observables import check_observable, extract_gates_and_qubits
def _load_observable(observable_file=None, observable_json=None):
if observable_file:
with open(observable_file, "r", encoding="utf8") as f:
return json.load(f)
if observable_json:
return json.loads(observable_json)
return None
def _term_to_quimb_operator(term):
"""Convert one extracted Hamiltonian term to a quimb operator."""
import quimb as qu
coeff = complex(term[0][2]) if term else 1.0
op = None
where = []
for qubit, gate_name, _ in term:
qubit = int(qubit)
gate_name = str(gate_name).upper()
if gate_name == "I":
continue
where.append(qubit)
op = qu.pauli(gate_name.lower()) if op is None else op & qu.pauli(gate_name.lower())
return complex(coeff), op, tuple(where)
def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks, max_time):
import pickle, cotengra as ctg, random
random.seed(seed)
tn = pickle.loads(tn_bytes)
opt = ctg.HyperOptimizer(
methods=['kahypar', 'kahypar-agglom', 'spinglass'],
max_repeats=repeats,
parallel=False,
minimize='combo-256',
max_time=max_time,
optlib="random",
slicing_opts={'target_size': 2**29, 'allow_outer': True},
progbar=False,
)
tree = tn.contraction_tree(optimize=opt, output_inds=output_inds)
return tree.combo_cost(factor=256), tree
def parallel_search(tn, output_inds, total_repeats, n_workers, num_slices, n_ranks,
timeout):
import pickle, os, signal
from concurrent.futures import ProcessPoolExecutor, as_completed
tn_bytes = pickle.dumps(tn)
if n_workers <= 1:
return _run_serial_search(
tn_bytes, output_inds, total_repeats, 0, num_slices, n_ranks, timeout
)[1]
repeats_per = max(1, total_repeats // n_workers)
best_cost, best_tree = float('inf'), None
pool = ProcessPoolExecutor(max_workers=n_workers)
futures = [
pool.submit(_run_serial_search, tn_bytes, output_inds,
repeats_per, seed, num_slices, n_ranks, timeout)
for seed in range(n_workers)
]
try:
for fut in as_completed(futures, timeout=timeout + 5):
try:
cost, tree = fut.result()
if cost < best_cost:
best_cost, best_tree = cost, tree
except Exception as e:
print(f" [worker failed] {e}")
except TimeoutError:
pass
finally:
for fut in futures:
fut.cancel()
for pid in list(pool._processes.keys()):
try:
os.kill(pid, signal.SIGKILL)
except ProcessLookupError:
pass
pool.shutdown(wait=False)
return best_tree
def make_circuit(circuit_type, nqubits, nlayers=1):
c = Circuit(nqubits)
if circuit_type == "qft":
from qibo.models import QFT
return QFT(nqubits)
elif circuit_type == "variational":
for layer in range(nlayers):
for q in range(nqubits):
c.add(gates.RY(q, theta=np.random.uniform(0, 2 * np.pi)))
offset = layer % 2
for q in range(offset, nqubits - 1, 2):
c.add(gates.CZ(q, q + 1))
elif circuit_type == "ghz":
c.add(gates.H(0))
for q in range(nqubits - 1):
c.add(gates.CNOT(q, q + 1))
elif circuit_type == "brickwork":
for q in range(nqubits):
c.add(gates.H(q))
for layer in range(nlayers):
offset = layer % 2
for q in range(offset, nqubits - 1, 2):
c.add(gates.CNOT(q, q + 1))
c.add(gates.RZ(q, theta=np.random.uniform(0, 2 * np.pi)))
c.add(gates.RZ(q + 1, theta=np.random.uniform(0, 2 * np.pi)))
else:
raise ValueError(f"Unknown circuit: {circuit_type}")
return c
def _contract_mpi(tree, arrays, comm, root=0):
rank = comm.Get_rank()
size = comm.Get_size()
is_torch = type(arrays[0]).__module__.startswith("torch")
result_np = None
for i in range(rank, tree.multiplicity, size):
x = tree.contract_slice(arrays, i)
x_np = np.asfortranarray(x.detach().cpu().numpy() if is_torch else np.asarray(x))
result_np = x_np if result_np is None else result_np + x_np
if result_np is None:
result_np = np.zeros(1, dtype=np.complex128)
result = np.zeros_like(result_np) if rank == root else None
comm.Reduce(result_np, result, root=root)
if rank == root:
import torch
return torch.from_numpy(np.asarray(result)) if is_torch else result
return None
def run_mpi(circuit, nqubits, num_slices, total_repeats=1024,
load_path=None, save_path=None):
"""Each MPI rank runs serial path search over total_repeats/size trials,
rank 0 picks the global best, then all ranks contract in parallel."""
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
qibo.set_backend("qibotn", platform="quimb")
b = qibo.get_backend()
b.configure_tn_simulation(ansatz="tn")
import torch
qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz,
gate_opts={"max_bond": None, "cutoff": 1e-10})
qc.to_backend = lambda x: torch.from_numpy(x).to(torch.complex128)
# --- path search: each rank serial, gather best to rank 0 ---
if load_path:
if rank == 0:
with open(load_path, "rb") as f:
saved = pickle.load(f)
tree, psi, t_search = saved["tree"], saved["psi"], 0.0
print(f" [path loaded] {load_path}")
else:
tree = psi = None
t_search = 0.0
else:
rank_repeats = max(1, total_repeats // size)
t0 = time.time()
# get TN object first (no contraction), then run parallel search
psi_tn = qc.to_dense(rehearse="tn")
local_tree = parallel_search(
psi_tn, psi_tn.outer_inds(), rank_repeats, n_workers=48,
num_slices=num_slices, n_ranks=size, timeout=600,
)
t_search = time.time() - t0
local_psi = psi_tn
all_results = comm.gather((local_tree.combo_cost(factor=256), local_tree, local_psi), root=0)
if rank == 0:
_, tree, psi = min(all_results, key=lambda x: x[0])
print(f" [path search] {t_search:.3f}s "
f"flops~2^{tree.contraction_cost(log=2):.2f} "
f"size~2^{tree.contraction_width():.2f} "
f"slices={tree.multiplicity}")
if save_path:
with open(save_path, "wb") as f:
pickle.dump({"tree": tree, "psi": psi}, f)
print(f" [path saved] {save_path}")
else:
tree = psi = None
if save_path:
t_search = comm.bcast(t_search, root=0)
return None, t_search
tree = comm.bcast(tree, root=0)
psi = comm.bcast(psi, root=0)
t_search = comm.bcast(t_search, root=0)
# --- contraction: all ranks work in parallel ---
import torch
torch.set_num_threads(max(1, 96 // size))
arrays = [torch.from_numpy(np.asarray(a)).to(torch.complex128) for a in psi.arrays]
t0 = time.time()
sv = _contract_mpi(tree, arrays, comm, root=0)
t_contract = time.time() - t0
if rank == 0:
print(f" [contraction] {t_contract:.3f}s")
return np.array(sv).reshape(-1), t_search + t_contract
return None, t_search + t_contract
def run_mpi_expval(
circuit,
nqubits,
observable=None,
total_repeats=1024,
search_workers=1,
search_timeout=300,
):
"""Compute a Hamiltonian expectation value directly from TN via MPI.
MPI parallelizes over Hamiltonian terms; ProcessPool optionally helps
path search for each term."""
import torch
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
qibo.set_backend("qibotn", platform="quimb")
b = qibo.get_backend()
b.configure_tn_simulation(ansatz="tn")
observable = check_observable(observable, nqubits)
ham_gate_map = extract_gates_and_qubits(observable)
qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz,
gate_opts={"max_bond": None, "cutoff": 1e-10})
my_terms = ham_gate_map[rank::size]
torch.set_num_threads(max(1, 96 // size))
t0 = time.time()
my_exp = 0.0 + 0.0j
for term in my_terms:
coeff, op, where = _term_to_quimb_operator(term)
if op is None:
my_exp += coeff
continue
tn = qc.local_expectation_tn(op, where=where)
if len(tn.outer_inds()) == 0:
val = complex(tn.contract())
else:
tree = parallel_search(
tn,
tn.outer_inds(),
total_repeats,
n_workers=search_workers,
num_slices=1,
n_ranks=size,
timeout=search_timeout,
)
if tree is None:
raise RuntimeError("Failed to find a contraction tree for expectation TN.")
arrays = [torch.from_numpy(np.asarray(a)).to(torch.complex128) for a in tn.arrays]
acc = sum(tree.contract_slice(arrays, i) for i in range(tree.multiplicity))
val = complex(acc.item() if hasattr(acc, 'item') else acc)
my_exp += coeff * val
t_total = time.time() - t0
all_results = comm.gather(my_exp, root=0)
if rank == 0:
total_exp = sum(all_results)
print(f"\n[TN expval] time={t_total:.4f}s expval={total_exp.real:.12f}")
return np.real_if_close(total_exp), t_total
return None, t_total
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--nqubits", type=int, default=30)
parser.add_argument("--circuit", type=str, default="qft",
choices=["qft", "variational", "ghz", "brickwork"])
parser.add_argument("--nlayers", type=int, default=3)
parser.add_argument("--num-slices", type=int, default=1)
parser.add_argument("--total-repeats", type=int, default=1024)
parser.add_argument("--search-workers", type=int, default=1)
parser.add_argument("--search-timeout", type=int, default=300)
parser.add_argument("--observable-file", type=str, default=None)
parser.add_argument("--observable-json", type=str, default=None)
parser.add_argument("--save-path", type=str, default=None)
parser.add_argument("--load-path", type=str, default=None)
parser.add_argument("--no-compare", action="store_true")
parser.add_argument("--mode", type=str, default="sv", choices=["sv", "expval"])
args = parser.parse_args()
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
if rank == 0:
print(f"Circuit: {args.circuit}, nqubits={args.nqubits}, "
f"nlayers={args.nlayers}, ranks={comm.Get_size()}")
np.random.seed(42)
circuit = make_circuit(args.circuit, args.nqubits, args.nlayers)
observable = _load_observable(args.observable_file, args.observable_json)
if args.mode == "expval":
try:
expval, t_total = run_mpi_expval(
circuit,
args.nqubits,
observable=observable,
total_repeats=args.total_repeats,
search_workers=args.search_workers,
search_timeout=args.search_timeout,
)
except Exception as e:
if rank == 0:
print(f"[FAILED] {e}")
raise
if rank == 0:
np.save(f"data/expval_tn_{args.circuit}{args.nqubits}.npy", np.asarray(expval))
if not args.no_compare:
print("No built-in reference comparison for arbitrary observables.")
return
try:
sv, t_total = run_mpi(circuit, args.nqubits, args.num_slices,
total_repeats=args.total_repeats,
load_path=args.load_path, save_path=args.save_path)
except Exception as e:
if rank == 0:
print(f"[FAILED] {e}")
raise
if rank == 0 and sv is not None:
print(f"\n[quimb TN MPI] time={t_total:.4f}s shape={sv.shape}")
np.save(f"data/sv_tn_{args.circuit}{args.nqubits}_mpi.npy", sv)
if not args.no_compare:
from qibotn.bak.benchmark_tn import run_qibojit
import gc
np.random.seed(42)
circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers)
sv_ref, t_ref = run_qibojit(circuit_ref)
np.save(f"data/sv_qibojit_{args.circuit}{args.nqubits}.npy", sv_ref)
print(f"[qibojit] time={t_ref:.4f}s")
# free memory before loading via mmap for expval comparison
del sv, sv_ref
gc.collect()
from compare_jit_tn_quimb import check_results
ref_path = f"data/sv_qibojit_{args.circuit}{args.nqubits}.npy"
tn_path = f"data/sv_tn_{args.circuit}{args.nqubits}_mpi.npy"
check_results(ref_path, tn_path, args.nqubits)
if t_total > 0:
print(f"Speedup : {t_ref/t_total:.2f}x")
if __name__ == "__main__":
main()

25
check_tree.py Normal file
View File

@@ -0,0 +1,25 @@
"""Check contraction tree statistics."""
import pickle, sys
path = sys.argv[1] if len(sys.argv) > 1 else "data/tree_q25_l10.pkl"
with open(path, 'rb') as f:
tree = pickle.load(f)
# Intel 8558P: 96 cores, 2.1GHz, AVX-512 (16 FP64/cycle), FMA x2
# complex128 multiply-add = 6 real FLOPs
CORES = 96
FREQ = 2.1e9
AVX512_FP64 = 16
TFLOPS = CORES * FREQ * AVX512_FP64 * 2 / 1e12 # ~6.45 TFLOPS real FP64
COMPLEX_FLOPS = TFLOPS / 6 # complex128 effective
flops = tree.total_flops()
slices = tree.multiplicity
est_seconds = flops * slices / (COMPLEX_FLOPS * 1e12)
print(f"File: {path}")
print(f"Peak memory (GB): {tree.max_size() * 16 / 1e9:.2f}")
print(f"Total FLOPs: {flops:.2e} x{slices} slices = {flops*slices:.2e}")
print(f"Contraction width: {tree.contraction_width()}")
print(f"Multiplicity (slices): {slices}")
print(f"Estimated time (96 cores): {est_seconds:.1f}s ({est_seconds/3600:.2f}h)")

BIN
data/tree_q25_l10.pkl Normal file

Binary file not shown.

Binary file not shown.

BIN
data/tree_q30_l10.pkl Normal file

Binary file not shown.

Binary file not shown.

View File

@@ -1,35 +1,35 @@
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=source
set BUILDDIR=build
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.https://www.sphinx-doc.org/
exit /b 1
)
if "%1" == "" goto help
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd

2
hostfile Normal file
View File

@@ -0,0 +1,2 @@
10.20.6.74
#10.20.6.102

6
poetry.lock generated
View File

@@ -1733,14 +1733,14 @@ files = [
[[package]]
name = "mako"
version = "1.3.10"
version = "1.3.11"
description = "A super-fast templating language that borrows the best ideas from the existing templating languages."
optional = false
python-versions = ">=3.8"
groups = ["main"]
files = [
{file = "mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59"},
{file = "mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28"},
{file = "mako-1.3.11-py3-none-any.whl", hash = "sha256:e372c6e333cf004aa736a15f425087ec977e1fcbd2966aae7f17c8dc1da27a77"},
{file = "mako-1.3.11.tar.gz", hash = "sha256:071eb4ab4c5010443152255d77db7faa6ce5916f35226eb02dc34479b6858069"},
]
[package.dependencies]

View File

@@ -38,6 +38,36 @@ GATE_MAP = {
}
def _torch_cpu_array(data, dtype=None):
"""Convert array-like data to a contiguous CPU torch tensor."""
import numpy as np
import torch
if isinstance(data, torch.Tensor):
x = data
else:
array = np.asarray(data)
if any(stride < 0 for stride in array.strides):
array = np.ascontiguousarray(array)
x = torch.from_numpy(array)
if x.device.type != "cpu":
x = x.cpu()
if dtype is not None and x.dtype != dtype:
x = x.to(dtype)
if not x.is_contiguous():
x = x.contiguous()
return x
def _arrays_to_backend(arrays, backend, engine):
if backend == "torch":
import torch
return [_torch_cpu_array(array, dtype=torch.complex128) for array in arrays]
return [engine.asarray(array) for array in arrays]
def __init__(self, quimb_backend="numpy", contraction_optimizer="auto-hq"):
super(self.__class__, self).__init__()
@@ -167,7 +197,7 @@ def execute_circuit(
raise_error(ValueError, "Initial state not None supported only for MPS ansatz.")
circ_quimb = self.circuit_ansatz.from_openqasm2_str(
circuit.to_qasm(), psi0=initial_state
circuit.to_qasm(), psi0=initial_state, gate_opts={"max_bond": self.max_bond_dimension, "cutoff": self.svd_cutoff}
)
if nshots:
@@ -186,7 +216,16 @@ def execute_circuit(
else:
frequencies = None
measured_probabilities = None
'''
if return_array:
if self.ansatz == "mps":
psi = circ_quimb.psi
statevector = psi.to_dense().reshape(-1)
else:
statevector = circ_quimb.to_dense(backend=self.backend, optimize=self.contractions_optimizer)
else:
statevector = None
'''
statevector = (
circ_quimb.to_dense(backend=self.backend, optimize=self.contractions_optimizer)
if return_array
@@ -291,6 +330,15 @@ def _qibo_circuit_to_quimb(
quimb_gate_name = GATE_MAP.get(gate_name, None)
if quimb_gate_name == "measure":
continue
if gate_name == "cu1":
theta = gate.parameters[0]
c, t = gate.qubits
circ.apply_gate("RZ", theta / 2, c)
circ.apply_gate("RZ", theta / 2, t)
circ.apply_gate("CNOT", c, t)
circ.apply_gate("RZ", -theta / 2, t)
circ.apply_gate("CNOT", c, t)
continue
if quimb_gate_name is None:
raise_error(ValueError, f"Gate {gate_name} not supported in Quimb backend.")
@@ -334,6 +382,172 @@ def _string_to_quimb_operator(self, op_str):
return op
def expectation(self, circuit, observable, parallel=None, parallel_opts=None):
"""
Compute expectation value with optional parallel acceleration.
Parameters
----------
circuit : qibo.models.Circuit
The quantum circuit.
observable : qibo.hamiltonians.SymbolicHamiltonian or form
The observable to measure.
parallel : str, optional
Parallelization method: 'mpi', 'processpool', or None (default).
parallel_opts : dict, optional
Options for parallel execution:
- max_repeats: int (default 1024)
- max_time: int (default 300)
- search_workers: int (default 48, processpool only)
- mpi_contract: bool (default False, use MPI for contraction)
Returns
-------
float
The expectation value.
"""
from qibotn.observables import check_observable, extract_gates_and_qubits
if parallel_opts is None:
parallel_opts = {}
observable = check_observable(observable, circuit.nqubits)
if parallel is None:
# Use original implementation
from qibotn.observables import extract_gates_and_qubits
all_terms = extract_gates_and_qubits(observable)
qc = self._qibo_circuit_to_quimb(
circuit,
quimb_circuit_type=self.circuit_ansatz,
gate_opts={"max_bond": self.max_bond_dimension, "cutoff": self.svd_cutoff},
)
exp_val = 0.0
for coeff, factors in all_terms:
op = None
where = []
for qubit, gate_name in factors:
p = qu.pauli(gate_name.lower())
op = p if op is None else op & p
where.append(qubit)
val = qc.local_expectation(
op, tuple(where),
backend=self.backend,
optimize=self.contractions_optimizer,
simplify_sequence="ADCRS",
simplify_atol=1e-12,
)
exp_val += coeff * val
return self.real(exp_val)
else:
# Use parallel implementation
return self._expectation_parallel(circuit, observable, parallel, parallel_opts)
def _expectation_parallel(self, circuit, observable, method, opts):
"""Parallel expectation value computation."""
from qibotn.observables import extract_gates_and_qubits
from qibotn.parallel import parallel_path_search, parallel_contract
import torch
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD if method == 'mpi' else None
rank = comm.Get_rank() if comm else 0
size = comm.Get_size() if comm else 1
except ImportError:
comm, rank, size = None, 0, 1
max_repeats = opts.get('max_repeats', 1024)
max_time = opts.get('max_time', 300)
search_workers = opts.get('search_workers', 48)
mpi_contract = opts.get('mpi_contract', False)
torch_threads = opts.get('torch_threads', None)
slicing_opts = opts.get('slicing_opts', None)
trial_timeout = opts.get('trial_timeout', None)
qc = self._qibo_circuit_to_quimb(
circuit,
quimb_circuit_type=self.circuit_ansatz,
gate_opts={"max_bond": self.max_bond_dimension, "cutoff": self.svd_cutoff},
)
all_terms = extract_gates_and_qubits(observable)
my_terms = all_terms[rank::size]
if method == 'mpi' and comm:
torch.set_num_threads(max(1, 96 // size))
elif torch_threads:
torch.set_num_threads(torch_threads)
my_exp = 0.0
for coeff, factors in my_terms:
op = None
where = []
for qubit, gate_name in factors:
p = qu.pauli(gate_name.lower())
op = p if op is None else op & p
where.append(qubit)
tn = qc.local_expectation(op, tuple(where), rehearse='tn')
tree = parallel_path_search(
tn, tn.outer_inds(),
method=method,
total_repeats=max_repeats,
max_time=max_time,
n_workers=search_workers,
slicing_opts=slicing_opts,
trial_timeout=trial_timeout,
)
if tree is None:
continue
if mpi_contract and comm and size > 1:
arrays = _arrays_to_backend(tn.arrays, self.backend, self.engine)
val = parallel_contract(tree, arrays, method='mpi', comm=comm)
else:
if self.backend == "torch":
for tensor in tn.tensors:
tensor._data = _torch_cpu_array(
tensor._data, dtype=torch.complex128
)
val = complex(
tn.contract(
all,
output_inds=(),
optimize=tree,
backend="torch",
)
)
else:
val = complex(
tn.contract(
all,
output_inds=(),
optimize=tree,
backend=self.backend,
)
)
my_exp += coeff * complex(val)
if comm:
all_exp = comm.gather(my_exp, root=0)
if rank == 0:
total_exp = sum(all_exp)
return self.real(total_exp)
return 0.0
return self.real(my_exp)
CLASSES_ROOTS = {"numpy": "Numpy", "torch": "PyTorch", "jax": "Jax"}
METHODS = {
@@ -344,6 +558,8 @@ METHODS = {
"exp_value_observable_symbolic": exp_value_observable_symbolic,
"_qibo_circuit_to_quimb": _qibo_circuit_to_quimb,
"_string_to_quimb_operator": _string_to_quimb_operator,
"expectation": expectation,
"_expectation_parallel": _expectation_parallel,
"circuit_ansatz": circuit_ansatz,
}

View File

@@ -4,83 +4,16 @@ from cupy.cuda import nccl
from cupy.cuda.runtime import getDeviceCount
from cuquantum.tensornet import Network, contract
from mpi4py import MPI
from qibo import hamiltonians
from qibo.symbols import I, X, Y, Z
from qibotn.circuit_convertor import QiboCircuitToEinsum
from qibotn.circuit_to_mps import QiboCircuitToMPS
from qibotn.mps_contraction_helper import MPSContractionHelper
def check_observable(observable, circuit_nqubit):
"""Checks the type of observable and returns the appropriate Hamiltonian."""
if observable is None:
return build_observable(circuit_nqubit)
elif isinstance(observable, dict):
return create_hamiltonian_from_dict(observable, circuit_nqubit)
elif isinstance(observable, hamiltonians.SymbolicHamiltonian):
# TODO: check if the observable is compatible with the circuit
return observable
else:
raise TypeError("Invalid observable type.")
def build_observable(circuit_nqubit):
"""Helper function to construct a target observable."""
hamiltonian_form = 0
for i in range(circuit_nqubit):
hamiltonian_form += 0.5 * X(i % circuit_nqubit) * Z((i + 1) % circuit_nqubit)
hamiltonian = hamiltonians.SymbolicHamiltonian(form=hamiltonian_form)
return hamiltonian
def create_hamiltonian_from_dict(data, circuit_nqubit):
"""Create a Qibo SymbolicHamiltonian from a dictionary representation.
Ensures that each Hamiltonian term explicitly acts on all circuit qubits
by adding identity (`I`) gates where needed.
Args:
data (dict): Dictionary containing Hamiltonian terms.
circuit_nqubit (int): Total number of qubits in the quantum circuit.
Returns:
hamiltonians.SymbolicHamiltonian: The constructed Hamiltonian.
"""
PAULI_GATES = {"X": X, "Y": Y, "Z": Z}
terms = []
for term in data["terms"]:
coeff = term["coefficient"]
operators = term["operators"] # List of tuples like [("Z", 0), ("X", 1)]
# Convert the operator list into a dictionary {qubit_index: gate}
operator_dict = {q: PAULI_GATES[g] for g, q in operators}
# Build the full term ensuring all qubits are covered
full_term_expr = [
operator_dict[q](q) if q in operator_dict else I(q)
for q in range(circuit_nqubit)
]
# Multiply all operators together to form a single term
term_expr = full_term_expr[0]
for op in full_term_expr[1:]:
term_expr *= op
# Scale by the coefficient
final_term = coeff * term_expr
terms.append(final_term)
if not terms:
raise ValueError("No valid Hamiltonian terms were added.")
# Combine all terms
hamiltonian_form = sum(terms)
return hamiltonians.SymbolicHamiltonian(hamiltonian_form)
from qibotn.observables import (
build_observable,
check_observable,
create_hamiltonian_from_dict,
extract_gates_and_qubits,
)
def get_ham_gates(pauli_map, dtype="complex128", backend=cp):
@@ -111,45 +44,6 @@ def get_ham_gates(pauli_map, dtype="complex128", backend=cp):
return gates
def extract_gates_and_qubits(hamiltonian):
"""
Extracts the gates and their corresponding qubits from a Qibo Hamiltonian.
Parameters:
hamiltonian (qibo.hamiltonians.Hamiltonian or qibo.hamiltonians.SymbolicHamiltonian):
A Qibo Hamiltonian object.
Returns:
list of tuples: [(coefficient, [(gate, qubit), ...]), ...]
- coefficient: The prefactor of the term.
- list of (gate, qubit): Each term's gates and the qubits they act on.
"""
extracted_terms = []
if isinstance(hamiltonian, hamiltonians.SymbolicHamiltonian):
for term in hamiltonian.terms:
coeff = term.coefficient # Extract coefficient
gate_qubit_list = []
# Extract gate and qubit information
for factor in term.factors:
gate_name = str(factor)[
0
] # Extract the gate type (X, Y, Z) from 'X0', 'Z1'
qubit = int(str(factor)[1:]) # Extract the qubit index
gate_qubit_list.append((qubit, gate_name, coeff))
coeff = 1.0
extracted_terms.append(gate_qubit_list)
else:
raise ValueError(
"Unsupported Hamiltonian type. Must be SymbolicHamiltonian or Hamiltonian."
)
return extracted_terms
def initialize_mpi():
"""Initialize MPI communication and device selection."""
comm = MPI.COMM_WORLD

86
src/qibotn/observables.py Normal file
View File

@@ -0,0 +1,86 @@
"""Observable helpers shared by tensor-network backends and benchmarks."""
from qibo import hamiltonians
from qibo.symbols import I, X, Y, Z
def check_observable(observable, circuit_nqubit):
"""Checks the type of observable and returns the appropriate Hamiltonian."""
if observable is None:
return build_observable(circuit_nqubit)
if isinstance(observable, dict):
return create_hamiltonian_from_dict(observable, circuit_nqubit)
if isinstance(observable, hamiltonians.SymbolicHamiltonian):
return observable
raise TypeError("Invalid observable type.")
def build_observable(circuit_nqubit):
"""Construct the default benchmark observable used by qibotn."""
hamiltonian_form = 0
for i in range(circuit_nqubit):
hamiltonian_form += 0.5 * X(i % circuit_nqubit) * Z((i + 1) % circuit_nqubit)
return hamiltonians.SymbolicHamiltonian(form=hamiltonian_form)
def create_hamiltonian_from_dict(data, circuit_nqubit):
"""Create a Qibo SymbolicHamiltonian from the qibotn dict representation."""
pauli_gates = {"X": X, "Y": Y, "Z": Z}
terms = []
for term in data["terms"]:
coeff = term["coefficient"]
operators = term["operators"]
operator_dict = {q: pauli_gates[g] for g, q in operators}
full_term_expr = [
operator_dict[q](q) if q in operator_dict else I(q)
for q in range(circuit_nqubit)
]
term_expr = full_term_expr[0]
for op in full_term_expr[1:]:
term_expr *= op
terms.append(coeff * term_expr)
if not terms:
raise ValueError("No valid Hamiltonian terms were added.")
return hamiltonians.SymbolicHamiltonian(sum(terms))
def build_random_circuit(nqubits, nlayers, seed=42):
"""Build a random circuit with RY+RZ+CNOT layers for benchmarks."""
import numpy as np
from qibo import Circuit, gates
np.random.seed(seed)
c = Circuit(nqubits)
for _ in range(nlayers):
for q in range(nqubits):
c.add(gates.RY(q, theta=np.random.uniform(0, 2*np.pi)))
c.add(gates.RZ(q, theta=np.random.uniform(0, 2*np.pi)))
for q in range(nqubits):
c.add(gates.CNOT(q % nqubits, (q + 1) % nqubits))
return c
def extract_gates_and_qubits(hamiltonian):
"""Extract per-term Pauli factors from a Qibo SymbolicHamiltonian.
Returns list of terms, where each term is (coefficient, [(qubit, gate_name), ...]).
"""
extracted_terms = []
if not isinstance(hamiltonian, hamiltonians.SymbolicHamiltonian):
raise ValueError(
"Unsupported Hamiltonian type. Must be SymbolicHamiltonian or Hamiltonian."
)
for term in hamiltonian.terms:
coeff = term.coefficient
factors = [(int(str(f)[1:]), str(f)[0]) for f in term.factors]
extracted_terms.append((coeff, factors))
return extracted_terms

195
src/qibotn/parallel.py Normal file
View File

@@ -0,0 +1,195 @@
"""Parallel path search and contraction utilities for tensor networks."""
import os
import pickle
import signal
import numpy as np
from concurrent.futures import ProcessPoolExecutor, as_completed
try:
from mpi4py import MPI
_HAVE_MPI = True
except ImportError:
_HAVE_MPI = False
MPI = None
def _run_single_trial(tn_bytes, output_inds, seed, slicing_opts):
import random, cotengra as ctg
random.seed(seed)
tn = pickle.loads(tn_bytes)
opt = ctg.HyperOptimizer(
methods=["kahypar", "kahypar-agglom", "spinglass"],
max_repeats=1,
parallel=False,
minimize="combo-256",
optlib="random",
slicing_opts=slicing_opts,
progbar=False,
)
tree = tn.contraction_tree(optimize=opt, output_inds=output_inds)
return tree.combo_cost(factor=256), tree
def _kill_pool(pool):
for pid in list(pool._processes.keys()):
try:
os.kill(pid, signal.SIGKILL)
except ProcessLookupError:
pass
pool.shutdown(wait=False)
def _serial_search(tn_bytes, output_inds, repeats, seed, max_time, slicing_opts=None, trial_timeout=None):
import time
if trial_timeout is None:
import random, cotengra as ctg
random.seed(seed)
tn = pickle.loads(tn_bytes)
opt = ctg.HyperOptimizer(
methods=["kahypar", "kahypar-agglom", "spinglass"],
max_repeats=repeats,
parallel=False,
minimize="combo-256",
max_time=max_time,
optlib="random",
slicing_opts=slicing_opts,
progbar=False,
)
tree = tn.contraction_tree(optimize=opt, output_inds=output_inds)
return tree.combo_cost(factor=256), tree
deadline = time.time() + max_time
best_cost, best_tree = float("inf"), None
for i in range(repeats):
if time.time() >= deadline:
break
timeout = min(trial_timeout, deadline - time.time())
pool = ProcessPoolExecutor(max_workers=1)
fut = pool.submit(_run_single_trial, tn_bytes, output_inds, seed * 10000 + i, slicing_opts)
try:
cost, tree = fut.result(timeout=timeout)
if cost < best_cost:
best_cost, best_tree = cost, tree
except Exception:
pass
finally:
_kill_pool(pool)
return best_cost, best_tree
def _processpool_search(tn, output_inds, total_repeats, n_workers, max_time, slicing_opts=None, trial_timeout=None):
tn_bytes = pickle.dumps(tn)
repeats_per = max(1, total_repeats // n_workers)
pool = ProcessPoolExecutor(max_workers=n_workers)
futures = [
pool.submit(_serial_search, tn_bytes, output_inds, repeats_per, seed, max_time, slicing_opts, trial_timeout)
for seed in range(n_workers)
]
best_cost, best_tree = float("inf"), None
try:
for fut in as_completed(futures, timeout=max_time + 5):
try:
cost, tree = fut.result()
if cost < best_cost:
best_cost, best_tree = cost, tree
except Exception:
pass
except TimeoutError:
pass
finally:
for fut in futures:
fut.cancel()
_kill_pool(pool)
return best_tree
def _mpi_search(tn, output_inds, total_repeats, max_time, n_workers=None, slicing_opts=None, trial_timeout=None):
comm = MPI.COMM_WORLD
rank, size = comm.Get_rank(), comm.Get_size()
tn_bytes = pickle.dumps(tn)
repeats_per = max(1, total_repeats // size)
if n_workers and n_workers > 1:
local_tree = _processpool_search(
tn, output_inds, repeats_per, n_workers, max_time, slicing_opts, trial_timeout
)
local_cost = local_tree.combo_cost(factor=256) if local_tree else float("inf")
else:
local_cost, local_tree = _serial_search(
tn_bytes, output_inds, repeats_per, rank, max_time, slicing_opts, trial_timeout
)
all_results = comm.gather((local_cost, local_tree), root=0)
best_tree = None
if rank == 0:
best_cost = float("inf")
for cost, tree in all_results:
if tree is not None and cost < best_cost:
best_cost, best_tree = cost, tree
return comm.bcast(best_tree, root=0)
def parallel_path_search(tn, output_inds, method='processpool', total_repeats=1024,
max_time=300, n_workers=48, slicing_opts=None, trial_timeout=None):
"""Parallel contraction path search.
Args:
method: 'processpool' | 'mpi' | 'serial'
total_repeats: Total optimization repeats across all workers
max_time: Global timeout per worker (seconds)
n_workers: Workers per MPI rank (or total for processpool)
slicing_opts: cotengra slicing options for memory control
trial_timeout: Per-trial timeout (seconds); kills and skips hung trials
"""
if method == 'serial':
tn_bytes = pickle.dumps(tn)
_, tree = _serial_search(tn_bytes, output_inds, total_repeats, 0, max_time, slicing_opts, trial_timeout)
return tree
elif method == 'mpi':
if not _HAVE_MPI:
raise ImportError("mpi4py not available")
return _mpi_search(tn, output_inds, total_repeats, max_time, n_workers, slicing_opts, trial_timeout)
elif method == 'processpool':
return _processpool_search(tn, output_inds, total_repeats, n_workers, max_time, slicing_opts, trial_timeout)
else:
raise ValueError(f"Unknown method: {method}")
def parallel_contract(tree, arrays, method='mpi', comm=None):
if method == 'mpi':
if not _HAVE_MPI or comm is None:
raise ValueError("MPI method requires mpi4py and comm")
return _contract_mpi(tree, arrays, comm)
raise ValueError(f"Unknown method: {method}")
def _contract_mpi(tree, arrays, comm, root=0):
rank, size = comm.Get_rank(), comm.Get_size()
is_torch = type(arrays[0]).__module__.startswith("torch")
if is_torch:
result_torch = None
for i in range(rank, tree.multiplicity, size):
x = tree.contract_slice(arrays, i, backend="torch").reshape(-1)
result_torch = x if result_torch is None else result_torch + x
if result_torch is None:
result_np = np.zeros(1, dtype=np.complex128)
else:
result_np = result_torch.detach().cpu().numpy()
else:
result_np = None
for i in range(rank, tree.multiplicity, size):
x = tree.contract_slice(arrays, i)
x_np = np.asarray(x).reshape(-1)
result_np = x_np if result_np is None else result_np + x_np
if result_np is None:
result_np = np.zeros(1, dtype=np.complex128)
result = np.zeros_like(result_np) if rank == root else None
comm.Reduce(result_np, result, root=root)
return result

View File

@@ -57,10 +57,10 @@ class TensorNetworkResult:
return self.measures
def state(self):
"""Return the statevector if the number of qubits is less than 20."""
if self.nqubits < 20:
"""Return the statevector if the number of qubits is less than 35."""
if self.nqubits < 35:
return self.statevector
raise_error(
NotImplementedError,
f"Tensor network simulation cannot be used to reconstruct statevector for >= 20 .",
f"Tensor network simulation cannot be used to reconstruct statevector for >= 35 .",
)

View File

@@ -35,7 +35,7 @@ def test_observable_expval(backend, nqubits):
numpy_backend = construct_backend("numpy")
ham, ham_form = build_observable(nqubits)
circ = build_circuit(nqubits=nqubits, nlayers=1)
exact_expval = numpy_backend.calculate_expectation_state(
hamiltonian=ham,
state=circ().state(),