Lifted shape function to class scope

This commit is contained in:
tankya2
2023-04-21 11:16:39 +08:00
parent dedf1dd6fd
commit 3cb799ddf1

View File

@@ -17,12 +17,6 @@ class QiboCircuitToEinsum:
"""
def __init__(self, circuit, dtype="complex128"):
def op_shape_from_qubits(nqubits):
"""This function is to modify the shape of the tensor to the required format by cuQuantum
(qubit_states,) * input_output * qubits_involved
"""
return (2, 2) * nqubits
self.backend = cp
self.dtype = getattr(self.backend, dtype)
@@ -35,7 +29,7 @@ class QiboCircuitToEinsum:
# self.gate_tensors is to extract into a list the gate matrix together with the qubit id that it is acting on
# https://github.com/NVIDIA/cuQuantum/blob/6b6339358f859ea930907b79854b90b2db71ab92/python/cuquantum/cutensornet/_internal/circuit_parser_utils_cirq.py#L32
required_shape = op_shape_from_qubits(len(gate_qubits))
required_shape = self.op_shape_from_qubits(len(gate_qubits))
self.gate_tensors.append(
(
cp.asarray(gate.matrix).reshape(required_shape),
@@ -108,3 +102,9 @@ class QiboCircuitToEinsum:
next_frontier += 1
mode_labels.append(output_mode_labels + input_mode_labels)
return mode_labels, operands
def op_shape_from_qubits(self, nqubits):
"""This function is to modify the shape of the tensor to the required format by cuQuantum
(qubit_states,) * input_output * qubits_involved
"""
return (2, 2) * nqubits