优化 torch CPU 张量网络收缩路径
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
- torch CPU 收缩默认走 cotengra matmul lowering - 复用 mm/bmm/matmul 输出缓冲区,降低中间张量分配压力 - 仅回收 contiguous tensor,避免非连续 view 进入 workspace - 调整 cotengra 中间节点 index 顺序,减少 reshape 触发 clone/copy - qibotn MPI 分片收缩显式使用 backend=torch - rank 内分片结果先在 torch 中累加,最后再转 numpy 做 Reduce - 统一 quimb 后端 torch 数组转换为 CPU contiguous complex128
This commit is contained in:
@@ -38,6 +38,36 @@ GATE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def _torch_cpu_array(data, dtype=None):
|
||||
"""Convert array-like data to a contiguous CPU torch tensor."""
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
x = data
|
||||
else:
|
||||
array = np.asarray(data)
|
||||
if any(stride < 0 for stride in array.strides):
|
||||
array = np.ascontiguousarray(array)
|
||||
x = torch.from_numpy(array)
|
||||
|
||||
if x.device.type != "cpu":
|
||||
x = x.cpu()
|
||||
if dtype is not None and x.dtype != dtype:
|
||||
x = x.to(dtype)
|
||||
if not x.is_contiguous():
|
||||
x = x.contiguous()
|
||||
return x
|
||||
|
||||
|
||||
def _arrays_to_backend(arrays, backend, engine):
|
||||
if backend == "torch":
|
||||
import torch
|
||||
|
||||
return [_torch_cpu_array(array, dtype=torch.complex128) for array in arrays]
|
||||
return [engine.asarray(array) for array in arrays]
|
||||
|
||||
|
||||
def __init__(self, quimb_backend="numpy", contraction_optimizer="auto-hq"):
|
||||
super(self.__class__, self).__init__()
|
||||
|
||||
@@ -480,12 +510,31 @@ def _expectation_parallel(self, circuit, observable, method, opts):
|
||||
continue
|
||||
|
||||
if mpi_contract and comm and size > 1:
|
||||
arrays = [self.engine.asarray(a) for a in tn.arrays]
|
||||
arrays = _arrays_to_backend(tn.arrays, self.backend, self.engine)
|
||||
val = parallel_contract(tree, arrays, method='mpi', comm=comm)
|
||||
else:
|
||||
for tensor in tn.tensors:
|
||||
tensor._data = torch.from_numpy(self.engine.asarray(tensor._data)).to(torch.complex128)
|
||||
val = complex(tn.contract(all, output_inds=(), optimize=tree))
|
||||
if self.backend == "torch":
|
||||
for tensor in tn.tensors:
|
||||
tensor._data = _torch_cpu_array(
|
||||
tensor._data, dtype=torch.complex128
|
||||
)
|
||||
val = complex(
|
||||
tn.contract(
|
||||
all,
|
||||
output_inds=(),
|
||||
optimize=tree,
|
||||
backend="torch",
|
||||
)
|
||||
)
|
||||
else:
|
||||
val = complex(
|
||||
tn.contract(
|
||||
all,
|
||||
output_inds=(),
|
||||
optimize=tree,
|
||||
backend=self.backend,
|
||||
)
|
||||
)
|
||||
|
||||
my_exp += coeff * complex(val)
|
||||
|
||||
|
||||
@@ -170,14 +170,25 @@ def _contract_mpi(tree, arrays, comm, root=0):
|
||||
rank, size = comm.Get_rank(), comm.Get_size()
|
||||
is_torch = type(arrays[0]).__module__.startswith("torch")
|
||||
|
||||
result_np = None
|
||||
for i in range(rank, tree.multiplicity, size):
|
||||
x = tree.contract_slice(arrays, i)
|
||||
x_np = np.asarray(x.detach().cpu().numpy() if is_torch else x).reshape(-1)
|
||||
result_np = x_np if result_np is None else result_np + x_np
|
||||
if is_torch:
|
||||
result_torch = None
|
||||
for i in range(rank, tree.multiplicity, size):
|
||||
x = tree.contract_slice(arrays, i, backend="torch").reshape(-1)
|
||||
result_torch = x if result_torch is None else result_torch + x
|
||||
|
||||
if result_np is None:
|
||||
result_np = np.zeros(1, dtype=np.complex128)
|
||||
if result_torch is None:
|
||||
result_np = np.zeros(1, dtype=np.complex128)
|
||||
else:
|
||||
result_np = result_torch.detach().cpu().numpy()
|
||||
else:
|
||||
result_np = None
|
||||
for i in range(rank, tree.multiplicity, size):
|
||||
x = tree.contract_slice(arrays, i)
|
||||
x_np = np.asarray(x).reshape(-1)
|
||||
result_np = x_np if result_np is None else result_np + x_np
|
||||
|
||||
if result_np is None:
|
||||
result_np = np.zeros(1, dtype=np.complex128)
|
||||
|
||||
result = np.zeros_like(result_np) if rank == root else None
|
||||
comm.Reduce(result_np, result, root=root)
|
||||
|
||||
Reference in New Issue
Block a user