Files
qibotn/baseline_mps_expectation.py
jaunatisblue ff96e36cfc
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
更新脚本和后端
2026-05-09 18:36:23 +08:00

121 lines
4.1 KiB
Python

"""Baseline MPS expectation scan with the qmatchatea backend."""
import argparse
import logging
import math
import time
from qibo import Circuit, gates, hamiltonians
from qibo.symbols import X, Z
from qibotn.backends.qmatchatea import QMatchaTeaBackend
def parse_bonds(value):
return [int(item) for item in value.split(",") if item.strip()]
def build_circuit(nqubits, nlayers, seed):
import numpy as np
rng = np.random.default_rng(seed)
circuit = Circuit(nqubits)
for _ in range(nlayers):
for qubit in range(nqubits):
circuit.add(gates.RY(qubit, theta=rng.uniform(-math.pi, math.pi)))
circuit.add(gates.RZ(qubit, theta=rng.uniform(-math.pi, math.pi)))
for qubit in range(0, nqubits - 1, 2):
circuit.add(gates.CNOT(qubit, qubit + 1))
for qubit in range(1, nqubits - 1, 2):
circuit.add(gates.CNOT(qubit, qubit + 1))
return circuit
def build_observable(nqubits):
form = 0
for qubit in range(nqubits - 1):
form += 0.5 * Z(qubit) * Z(qubit + 1)
form += 0.25 * X(0)
return hamiltonians.SymbolicHamiltonian(form=form)
def exact_expectation(circuit, nqubits):
import numpy as np
state = circuit().state(numpy=True).reshape(-1)
probabilities = np.abs(state) ** 2
indices = np.arange(state.size)
value = 0.0
for qubit in range(nqubits - 1):
left = (indices >> (nqubits - 1 - qubit)) & 1
right = (indices >> (nqubits - 2 - qubit)) & 1
value += 0.5 * np.sum(probabilities * (1 - 2 * left) * (1 - 2 * right))
flip_q0 = 1 << (nqubits - 1)
value += 0.25 * np.vdot(state[indices ^ flip_q0], state).real
return float(value)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--nqubits", type=int, default=20)
parser.add_argument("--nlayers", type=int, default=8)
parser.add_argument("--bonds", type=parse_bonds, default=parse_bonds("2,4,8,16,32"))
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--cut-ratio", type=float, default=1e-12)
parser.add_argument("--svd-control", default="V")
parser.add_argument("--tensor-module", choices=("numpy", "torch"), default="numpy")
parser.add_argument("--torch-threads", type=int)
parser.add_argument("--exact", action="store_true")
parser.add_argument("--exact-max-qubits", type=int, default=24)
parser.add_argument("--preprocess", action="store_true")
args = parser.parse_args()
logging.getLogger("qibo.config").setLevel(logging.ERROR)
logging.getLogger("qtealeaves").setLevel(logging.ERROR)
if args.torch_threads is not None:
import torch
torch.set_num_threads(args.torch_threads)
circuit = build_circuit(args.nqubits, args.nlayers, args.seed)
observable = build_observable(args.nqubits)
exact = None
if args.exact:
if args.nqubits > args.exact_max_qubits:
raise ValueError(
f"--exact is limited to {args.exact_max_qubits} qubits by default."
)
exact = exact_expectation(circuit, args.nqubits)
print(
f"nqubits={args.nqubits} nlayers={args.nlayers} "
f"seed={args.seed} preprocess={args.preprocess} "
f"tensor_module={args.tensor_module}"
)
if exact is not None:
print(f"exact={exact:.16e}")
print("bond_dim expval abs_error rel_error seconds")
backend = QMatchaTeaBackend()
for bond in args.bonds:
backend.configure_tn_simulation(
ansatz="MPS",
max_bond_dimension=bond,
cut_ratio=args.cut_ratio,
svd_control=args.svd_control,
tensor_module=args.tensor_module,
)
start = time.perf_counter()
value = float(
backend.expectation(circuit, observable, preprocess=args.preprocess).real
)
elapsed = time.perf_counter() - start
abs_error = float("nan") if exact is None else abs(value - exact)
rel_error = float("nan") if exact is None else abs_error / max(abs(exact), 1e-15)
print(f"{bond:d} {value:.16e} {abs_error:.6e} {rel_error:.6e} {elapsed:.3f}")
if __name__ == "__main__":
main()