Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
47 lines
1.1 KiB
Python
47 lines
1.1 KiB
Python
import numpy as np
|
|
|
|
from qibotn.parallel import _split_repeats, contract_tree_slices, mpi_slice_plan
|
|
|
|
|
|
def test_mpi_slice_plan_block_balances_contiguous_ranges():
|
|
plans = [mpi_slice_plan(10, rank, 4, assignment="block") for rank in range(4)]
|
|
|
|
assert [plan.indices for plan in plans] == [
|
|
(0, 1, 2),
|
|
(3, 4, 5),
|
|
(6, 7),
|
|
(8, 9),
|
|
]
|
|
|
|
|
|
def test_mpi_slice_plan_cyclic_balances_round_robin():
|
|
plans = [mpi_slice_plan(10, rank, 4, assignment="cyclic") for rank in range(4)]
|
|
|
|
assert [plan.indices for plan in plans] == [
|
|
(0, 4, 8),
|
|
(1, 5, 9),
|
|
(2, 6),
|
|
(3, 7),
|
|
]
|
|
|
|
|
|
class DummyTree:
|
|
def contract_slice(self, arrays, i, backend=None):
|
|
return arrays[0] * (i + 1)
|
|
|
|
|
|
def test_contract_tree_slices_sums_numpy_slices():
|
|
result = contract_tree_slices(
|
|
DummyTree(),
|
|
[np.asarray([2.0 + 0.0j])],
|
|
(0, 2, 3),
|
|
backend="numpy",
|
|
)
|
|
|
|
np.testing.assert_allclose(result, np.asarray([16.0 + 0.0j]))
|
|
|
|
|
|
def test_split_repeats_balances_workers():
|
|
assert _split_repeats(10, 4) == [3, 3, 2, 2]
|
|
assert _split_repeats(2, 4) == [1, 1]
|