fixing grad nan
This commit is contained in:
@@ -30,13 +30,14 @@ def build_circuit(nqubits, nlayers):
|
||||
for q in range(nqubits):
|
||||
circ.add(gates.RY(q=q, theta=0.))
|
||||
circ.add(gates.RZ(q=q, theta=0.))
|
||||
[circ.add(gates.CNOT(q%nqubits, (q+1)%nqubits) for q in range(nqubits))]
|
||||
[circ.add(gates.CZ(q % nqubits, (q + 1) % nqubits)) for q in range(nqubits)]
|
||||
circ.add(gates.M(*range(nqubits)))
|
||||
return circ
|
||||
|
||||
nqubits = 4
|
||||
circuit = build_circuit(nqubits=nqubits, nlayers=3)
|
||||
|
||||
quimb_circuit = quimb_backend._qibo_circuit_to_quimb(circuit)
|
||||
|
||||
def f(params):
|
||||
circuit.set_parameters(params)
|
||||
@@ -46,5 +47,4 @@ def f(params):
|
||||
)
|
||||
|
||||
parameters = np.random.uniform(-np.pi, np.pi, size=len(circuit.get_parameters()))
|
||||
print(f(parameters))
|
||||
print(jax.value_and_grad(f)(parameters))
|
||||
print(jax.value_and_grad(f)(parameters))
|
||||
|
||||
Reference in New Issue
Block a user