From ff96e36cfcacf89304ecf652f5baca727a5c6419 Mon Sep 17 00:00:00 2001 From: jaunatisblue Date: Sat, 9 May 2026 18:36:23 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E8=84=9A=E6=9C=AC=E5=92=8C?= =?UTF-8?q?=E5=90=8E=E7=AB=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- baseline_mps_expectation.py | 10 +++++++++- src/qibotn/backends/qmatchatea.py | 3 +++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/baseline_mps_expectation.py b/baseline_mps_expectation.py index a72a3f2..8985bf7 100644 --- a/baseline_mps_expectation.py +++ b/baseline_mps_expectation.py @@ -65,12 +65,18 @@ def main(): parser.add_argument("--seed", type=int, default=42) parser.add_argument("--cut-ratio", type=float, default=1e-12) parser.add_argument("--svd-control", default="V") + parser.add_argument("--tensor-module", choices=("numpy", "torch"), default="numpy") + parser.add_argument("--torch-threads", type=int) parser.add_argument("--exact", action="store_true") parser.add_argument("--exact-max-qubits", type=int, default=24) parser.add_argument("--preprocess", action="store_true") args = parser.parse_args() logging.getLogger("qibo.config").setLevel(logging.ERROR) logging.getLogger("qtealeaves").setLevel(logging.ERROR) + if args.torch_threads is not None: + import torch + + torch.set_num_threads(args.torch_threads) circuit = build_circuit(args.nqubits, args.nlayers, args.seed) observable = build_observable(args.nqubits) @@ -84,7 +90,8 @@ def main(): print( f"nqubits={args.nqubits} nlayers={args.nlayers} " - f"seed={args.seed} preprocess={args.preprocess}" + f"seed={args.seed} preprocess={args.preprocess} " + f"tensor_module={args.tensor_module}" ) if exact is not None: print(f"exact={exact:.16e}") @@ -97,6 +104,7 @@ def main(): max_bond_dimension=bond, cut_ratio=args.cut_ratio, svd_control=args.svd_control, + tensor_module=args.tensor_module, ) start = time.perf_counter() value = float( diff --git a/src/qibotn/backends/qmatchatea.py b/src/qibotn/backends/qmatchatea.py index cacc71f..4894f8b 100644 --- a/src/qibotn/backends/qmatchatea.py +++ b/src/qibotn/backends/qmatchatea.py @@ -39,6 +39,7 @@ class QMatchaTeaBackend(QibotnBackend, NumpyBackend): trunc_tracking_mode: str = "C", svd_control: str = "A", ini_bond_dimension: int = 1, + tensor_module: str = "numpy", ): """Configure TN simulation given Quantum Matcha Tea interface. @@ -76,6 +77,7 @@ class QMatchaTeaBackend(QibotnBackend, NumpyBackend): ini_bond_dimension=ini_bond_dimension, ) self.ansatz = ansatz + self.tensor_module = tensor_module if hasattr(self, "qmatchatea_backend"): self._setup_backend_specifics() @@ -96,6 +98,7 @@ class QMatchaTeaBackend(QibotnBackend, NumpyBackend): precision=qmatchatea_precision, device=qmatchatea_device, ansatz=self.ansatz, + tensor_module=self.tensor_module, ) def execute_circuit(