From de92060180a3da24d39b03636cf3a2f3dc27e956 Mon Sep 17 00:00:00 2001 From: vinitha-balachandran Date: Fri, 23 Feb 2024 14:48:30 +0800 Subject: [PATCH] Adding feature to pass MPS parameters in quimb --- src/qibotn/backends/quimb.py | 8 +++++--- src/qibotn/eval_qu.py | 11 +++-------- tests/test_quimb_backend.py | 15 ++++++++++++--- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/qibotn/backends/quimb.py b/src/qibotn/backends/quimb.py index 31e24c7..0a34eef 100644 --- a/src/qibotn/backends/quimb.py +++ b/src/qibotn/backends/quimb.py @@ -16,9 +16,11 @@ class QuimbBackend(NumpyBackend): mps_enabled_value = runcard.get("MPS_enabled") if mps_enabled_value is True: - self.MPS_enabled = True + self.mps_opts = {"method": "svd", "cutoff": 1e-6, "cutoff_mod": "abs"} elif mps_enabled_value is False: - self.MPS_enabled = False + self.mps_opts = False + elif isinstance(mps_enabled_value, dict): + self.mps_opts = mps_enabled_value else: raise TypeError("MPS_enabled has an unexpected type") @@ -74,7 +76,7 @@ class QuimbBackend(NumpyBackend): ) state = eval.dense_vector_tn_qu( - circuit.to_qasm(), initial_state, is_mps=self.MPS_enabled, backend="numpy" + circuit.to_qasm(), initial_state, self.mps_opts, backend="numpy" ) if return_array: diff --git a/src/qibotn/eval_qu.py b/src/qibotn/eval_qu.py index 6124152..0a32781 100644 --- a/src/qibotn/eval_qu.py +++ b/src/qibotn/eval_qu.py @@ -10,7 +10,7 @@ def init_state_tn(nqubits, init_state_sv): return qtn.tensor_1d.MatrixProductState.from_dense(init_state_sv, dims) -def dense_vector_tn_qu(qasm: str, initial_state, is_mps, backend="numpy"): +def dense_vector_tn_qu(qasm: str, initial_state, mps_opts, backend="numpy"): """Evaluate QASM with Quimb. backend (quimb): numpy, cupy, jax. Passed to ``opt_einsum``. @@ -20,14 +20,9 @@ def dense_vector_tn_qu(qasm: str, initial_state, is_mps, backend="numpy"): nqubits = int(np.log2(len(initial_state))) initial_state = init_state_tn(nqubits, initial_state) - if is_mps: - gate_opt = {} - gate_opt["method"] = "svd" - gate_opt["cutoff"] = 1e-6 - gate_opt["cutoff_mode"] = "abs" - + if mps_opts: circ_quimb = qtn.circuit.CircuitMPS.from_openqasm2_str( - qasm, psi0=initial_state, gate_opts=gate_opt + qasm, psi0=initial_state, gate_opts=mps_opts ) else: diff --git a/tests/test_quimb_backend.py b/tests/test_quimb_backend.py index 15ba652..2b77ab6 100644 --- a/tests/test_quimb_backend.py +++ b/tests/test_quimb_backend.py @@ -50,9 +50,18 @@ def test_eval(nqubits: int, tolerance: float, is_mps: bool): qasm_circ = qibo_circ.to_qasm() # Test quimb - result_tn = qibotn.eval_qu.dense_vector_tn_qu( - qasm_circ, init_state_tn, is_mps, backend=config.quimb.backend - ).flatten() + if is_mps: + gate_opt = {} + gate_opt["method"] = "svd" + gate_opt["cutoff"] = 1e-6 + gate_opt["cutoff_mode"] = "abs" + result_tn = qibotn.eval_qu.dense_vector_tn_qu( + qasm_circ, init_state_tn, gate_opt, backend=config.quimb.backend + ).flatten() + else: + result_tn = qibotn.eval_qu.dense_vector_tn_qu( + qasm_circ, init_state_tn, is_mps, backend=config.quimb.backend + ).flatten() assert np.allclose( result_sv, result_tn, atol=tolerance