"""MPI parallel sliced contraction using pre-sliced tree.""" import time, pickle, os import numpy as np from mpi4py import MPI NQUBITS, NLAYERS, NCORES = 25, 10, 48 comm = MPI.COMM_WORLD rank, size = comm.Get_rank(), comm.Get_size() os.environ['OMP_NUM_THREADS'] = str(NCORES) os.environ['MKL_NUM_THREADS'] = str(NCORES) import torch import qibo, quimb as qu from qibotn.observables import build_random_circuit torch.set_num_threads(NCORES) circuit = build_random_circuit(NQUBITS, NLAYERS) qibo.set_backend("qibotn", platform="quimb") backend = qibo.get_backend() backend.configure_tn_simulation(ansatz="tn") qc = backend._qibo_circuit_to_quimb(circuit, backend.circuit_ansatz) tn = qc.local_expectation(qu.pauli('x') & qu.pauli('z'), (0, 1), rehearse='tn') if rank == 0: with open(f"data/tree_q{NQUBITS}_l{NLAYERS}_sliced.pkl", 'rb') as f: tree = pickle.load(f) else: tree = None tree = comm.bcast(tree, root=0) arrays = [torch.from_numpy(np.asarray(t._data)) for t in tn.tensors] n_slices = tree.multiplicity if rank == 0: print(f"Slices: {n_slices}, Ranks: {size}, " f"Peak: {tree.max_size() * 16 / 1e9:.2f} GB, " f"Threads/rank: {NCORES}, Backend: torch") t0 = time.time() result = None for i in range(rank, n_slices, size): val = tree.contract_slice(arrays, i, backend='torch') val_np = val.cpu().numpy().reshape(-1) result = val_np if result is None else result + val_np if result is None: result = np.zeros(1, dtype=np.complex128) total = np.zeros_like(result) if rank == 0 else None comm.Reduce(result, total, root=0) if rank == 0: print(f"Contract: {time.time() - t0:.4f}s Expectation: {0.5 * total[0].real:.10f}")