Files
qibotn/tests/test_parallel.py
jaunatisblue 72f95599bb
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
完善mps的vidal机制,多节点并行;补充tn搜索时dask集群搜索的方式
2026-05-12 15:44:19 +08:00

47 lines
1.1 KiB
Python

import numpy as np
from qibotn.parallel import _split_repeats, contract_tree_slices, mpi_slice_plan
def test_mpi_slice_plan_block_balances_contiguous_ranges():
plans = [mpi_slice_plan(10, rank, 4, assignment="block") for rank in range(4)]
assert [plan.indices for plan in plans] == [
(0, 1, 2),
(3, 4, 5),
(6, 7),
(8, 9),
]
def test_mpi_slice_plan_cyclic_balances_round_robin():
plans = [mpi_slice_plan(10, rank, 4, assignment="cyclic") for rank in range(4)]
assert [plan.indices for plan in plans] == [
(0, 4, 8),
(1, 5, 9),
(2, 6),
(3, 7),
]
class DummyTree:
def contract_slice(self, arrays, i, backend=None):
return arrays[0] * (i + 1)
def test_contract_tree_slices_sums_numpy_slices():
result = contract_tree_slices(
DummyTree(),
[np.asarray([2.0 + 0.0j])],
(0, 2, 3),
backend="numpy",
)
np.testing.assert_allclose(result, np.asarray([16.0 + 0.0j]))
def test_split_repeats_balances_workers():
assert _split_repeats(10, 4) == [3, 3, 2, 2]
assert _split_repeats(2, 4) == [1, 1]