diff --git a/.venv/lib/python3.12/site-packages/cotengra/contract.py b/.venv/lib/python3.12/site-packages/cotengra/contract.py new file mode 100644 index 0000000..0e9a46c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/cotengra/contract.py @@ -0,0 +1,1288 @@ +"""Functionality relating to actually contracting.""" + +import functools +import itertools +import operator +import contextlib +import os + +from autoray import do, shape, infer_backend_multi, get_lib_fn + +from .utils import node_from_single + + +DEFAULT_IMPLEMENTATION = "auto" + + +def _torch_workspace_enabled(): + value = os.environ.get("QIBOTN_TORCH_WORKSPACE", "0").strip().lower() + return value in {"1", "true", "yes", "on", "enable", "enabled"} + + +def _torch_arena_enabled(): + value = os.environ.get("QIBOTN_TORCH_ARENA", "0").strip().lower() + return value in {"1", "true", "yes", "on", "enable", "enabled"} + + +def _parse_size_bytes(value, default): + if value is None: + return default + + text = str(value).strip().lower() + scale = 1 + for suffix, multiplier in ( + ("gib", 1024**3), + ("gb", 1000**3), + ("gi", 1024**3), + ("g", 1024**3), + ("mib", 1024**2), + ("mb", 1000**2), + ("mi", 1024**2), + ("m", 1024**2), + ): + if text.endswith(suffix): + text = text[: -len(suffix)] + scale = multiplier + break + + return int(float(text) * scale) + + +class _TorchArena: + """Simple first-fit arena for torch CPU contraction intermediates.""" + + def __init__(self, size_bytes=None): + self.size_bytes = _parse_size_bytes( + size_bytes or os.environ.get("QIBOTN_TORCH_ARENA_BYTES"), + 70 * 1024**3, + ) + self.buffer = None + self.dtype = None + self.device = None + self.free = [] + self.allocated = {} + + def _setup(self, dtype, device): + import torch + + element_size = torch.empty((), dtype=dtype, device=device).element_size() + nelements = self.size_bytes // element_size + if nelements <= 0: + raise MemoryError("QIBOTN_TORCH_ARENA_BYTES is too small.") + self.buffer = torch.empty(nelements, dtype=dtype, device=device) + self.dtype = dtype + self.device = device + self.free = [(0, nelements)] + self.allocated = {} + + def alloc(self, shape, dtype, device): + import torch + + if self.buffer is None: + self._setup(dtype, device) + elif (dtype != self.dtype) or (device != self.device): + return torch.empty(shape, dtype=dtype, device=device) + + numel = functools.reduce(operator.mul, shape, 1) + element_size = self.buffer.element_size() + align = max(1, 64 // element_size) + + for i, (offset, length) in enumerate(self.free): + aligned = ((offset + align - 1) // align) * align + padding = aligned - offset + if length - padding < numel: + continue + + new_blocks = [] + if padding: + new_blocks.append((offset, padding)) + tail_offset = aligned + numel + tail_length = (offset + length) - tail_offset + if tail_length: + new_blocks.append((tail_offset, tail_length)) + self.free[i : i + 1] = new_blocks + self.allocated[aligned] = numel + return self.buffer.narrow(0, aligned, numel).view(shape) + + return torch.empty(shape, dtype=dtype, device=device) + + def release(self, x): + if self.buffer is None: + return + if not hasattr(x, "untyped_storage"): + return + try: + if x.untyped_storage().data_ptr() != self.buffer.untyped_storage().data_ptr(): + return + offset = int(x.storage_offset()) + except Exception: + return + + length = self.allocated.pop(offset, None) + if length is None: + return + + self.free.append((offset, length)) + self.free.sort() + merged = [] + for block_offset, block_length in self.free: + if merged and merged[-1][0] + merged[-1][1] == block_offset: + prev_offset, prev_length = merged[-1] + merged[-1] = (prev_offset, prev_length + block_length) + else: + merged.append((block_offset, block_length)) + self.free = merged + + +def set_default_implementation(impl): + global DEFAULT_IMPLEMENTATION + DEFAULT_IMPLEMENTATION = impl + + +def get_default_implementation(): + return DEFAULT_IMPLEMENTATION + + +@contextlib.contextmanager +def default_implementation(impl): + """Context manager for temporarily setting the default implementation.""" + global DEFAULT_IMPLEMENTATION + old_impl = DEFAULT_IMPLEMENTATION + DEFAULT_IMPLEMENTATION = impl + try: + yield + finally: + DEFAULT_IMPLEMENTATION = old_impl + + +@functools.lru_cache(2**12) +def _sanitize_equation(eq): + """Get the input and output indices of an equation, computing the output + implicitly as the sorted sequence of every index that appears exactly once + if it is not provided. + """ + # remove spaces + eq = eq.replace(" ", "") + + if "..." in eq: + raise NotImplementedError("Ellipsis not supported.") + + if "->" not in eq: + lhs = eq + tmp_subscripts = lhs.replace(",", "") + out = "".join( + # sorted sequence of indices + s + for s in sorted(set(tmp_subscripts)) + # that appear exactly once + if tmp_subscripts.count(s) == 1 + ) + else: + lhs, out = eq.split("->") + return lhs, out + + +@functools.lru_cache(2**12) +def _parse_einsum_single(eq, shape): + """Cached parsing of a single term einsum equation into the necessary + sequence of arguments for axes diagonals, sums, and transposes. + """ + lhs, out = _sanitize_equation(eq) + + # parse each index + need_to_diag = [] + need_to_sum = [] + seen = set() + for ix in lhs: + if ix in need_to_diag: + continue + if ix in seen: + need_to_diag.append(ix) + continue + seen.add(ix) + if ix not in out: + need_to_sum.append(ix) + + # first handle diagonal reductions + if need_to_diag: + diag_sels = [] + sizes = dict(zip(lhs, shape)) + while need_to_diag: + ixd = need_to_diag.pop() + dinds = tuple(range(sizes[ixd])) + + # construct advanced indexing object + selector = tuple(dinds if ix == ixd else slice(None) for ix in lhs) + diag_sels.append(selector) + + # after taking the diagonal what are new indices? + ixd_contig = ixd * lhs.count(ixd) + if ixd_contig in lhs: + # contig axes, new axis is at same position + lhs = lhs.replace(ixd_contig, ixd) + else: + # non-contig, new axis is at beginning + lhs = ixd + lhs.replace(ixd, "") + else: + diag_sels = None + + # then sum reductions + if need_to_sum: + sum_axes = tuple(map(lhs.index, need_to_sum)) + for ix in need_to_sum: + lhs = lhs.replace(ix, "") + else: + sum_axes = None + + # then transposition + if lhs == out: + perm = None + else: + perm = tuple(lhs.index(ix) for ix in out) + + return diag_sels, sum_axes, perm + + +def _parse_eq_to_pure_multiplication(a_term, shape_a, b_term, shape_b, out): + """If there are no contracted indices, then we can directly transpose and + insert singleton dimensions into ``a`` and ``b`` such that (broadcast) + elementwise multiplication performs the einsum. + + No need to cache this as it is within the cached + ``_parse_eq_to_batch_matmul``. + + """ + desired_a = "" + desired_b = "" + new_shape_a = [] + new_shape_b = [] + for ix in out: + if ix in a_term: + desired_a += ix + new_shape_a.append(shape_a[a_term.index(ix)]) + else: + new_shape_a.append(1) + if ix in b_term: + desired_b += ix + new_shape_b.append(shape_b[b_term.index(ix)]) + else: + new_shape_b.append(1) + + if desired_a != a_term: + eq_a = f"{a_term}->{desired_a}" + else: + eq_a = None + if desired_b != b_term: + eq_b = f"{b_term}->{desired_b}" + else: + eq_b = None + + return ( + eq_a, + eq_b, + new_shape_a, + new_shape_b, + None, # new_shape_ab, not needed since not fusing + None, # perm_ab, not needed as we transpose a and b first + True, # pure_multiplication=True + ) + + +@functools.lru_cache(2**12) +def _parse_eq_to_batch_matmul(eq, shape_a, shape_b): + """Cached parsing of a two term einsum equation into the necessary + sequence of arguments for contracttion via batched matrix multiplication. + The steps we need to specify are: + + 1. Remove repeated and trivial indices from the left and right terms, + and transpose them, done as a single einsum. + 2. Fuse the remaining indices so we have two 3D tensors. + 3. Perform the batched matrix multiplication. + 4. Unfuse the output to get the desired final index order. + + """ + lhs, out = eq.split("->") + a_term, b_term = lhs.split(",") + + if len(a_term) != len(shape_a): + raise ValueError(f"Term '{a_term}' does not match shape {shape_a}.") + if len(b_term) != len(shape_b): + raise ValueError(f"Term '{b_term}' does not match shape {shape_b}.") + + bat_inds = [] # appears on A, B, O + con_inds = [] # appears on A, B, . + a_keep = [] # appears on A, ., O + b_keep = [] # appears on ., B, O + sizes = {} + singletons = set() + + # parse left term + seen = set() + for ix, d in zip(a_term, shape_a): + if d == 1: + # everything (including broadcasting) works nicely if simply ignore + # such dimensions, but we do need to track if they appear in output + # and thus should be reintroduced later + singletons.add(ix) + continue + + # set or check size + if sizes.setdefault(ix, d) != d: + raise ValueError( + f"Index {ix} has mismatched sizes {sizes[ix]} and {d}." + ) + + if ix in seen: + continue + seen.add(ix) + + if ix in b_term: + if ix in out: + bat_inds.append(ix) + else: + con_inds.append(ix) + elif ix in out: + a_keep.append(ix) + + # parse right term + seen.clear() + for ix, d in zip(b_term, shape_b): + if d == 1: + singletons.add(ix) + continue + # broadcast indices don't appear as singletons in output + singletons.discard(ix) + + # set or check size + if sizes.setdefault(ix, d) != d: + raise ValueError( + f"Index {ix} has mismatched sizes {sizes[ix]} and {d}." + ) + + if ix in seen: + continue + seen.add(ix) + + if ix not in a_term: + if ix in out: + b_keep.append(ix) + + if not con_inds: + # contraction is pure multiplication, prepare inputs differently + return _parse_eq_to_pure_multiplication( + a_term, shape_a, b_term, shape_b, out + ) + + # only need the size one indices that appear in the output + singletons = [ix for ix in out if ix in singletons] + + # take diagonal, remove any trivial axes and transpose left + desired_a = "".join((*bat_inds, *a_keep, *con_inds)) + if a_term != desired_a: + if set(a_term) == set(desired_a): + # only need to transpose, don't invoke einsum + eq_a = tuple(a_term.index(ix) for ix in desired_a) + else: + eq_a = f"{a_term}->{desired_a}" + else: + eq_a = None + + # take diagonal, remove any trivial axes and transpose right + desired_b = "".join((*bat_inds, *con_inds, *b_keep)) + if b_term != desired_b: + if set(b_term) == set(desired_b): + # only need to transpose, don't invoke einsum + eq_b = tuple(b_term.index(ix) for ix in desired_b) + else: + eq_b = f"{b_term}->{desired_b}" + else: + eq_b = None + + # then we want to reshape + if bat_inds: + lgroups = (bat_inds, a_keep, con_inds) + rgroups = (bat_inds, con_inds, b_keep) + ogroups = (bat_inds, a_keep, b_keep) + else: + # avoid size 1 batch dimension if no batch indices + lgroups = (a_keep, con_inds) + rgroups = (con_inds, b_keep) + ogroups = (a_keep, b_keep) + + if any(len(group) != 1 for group in lgroups): + # need to fuse 'kept' and contracted indices + # (though could allow batch indices to be broadcast) + new_shape_a = tuple( + functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1) + for ix_group in lgroups + ) + else: + new_shape_a = None + + if any(len(group) != 1 for group in rgroups): + # need to fuse 'kept' and contracted indices + # (though could allow batch indices to be broadcast) + new_shape_b = tuple( + functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1) + for ix_group in rgroups + ) + else: + new_shape_b = None + + if any(len(group) != 1 for group in ogroups) or singletons: + new_shape_ab = (1,) * len(singletons) + tuple( + sizes[ix] for ix_group in ogroups for ix in ix_group + ) + else: + new_shape_ab = None + + # then we want to permute the matmul produced output: + out_produced = "".join((*singletons, *bat_inds, *a_keep, *b_keep)) + perm_ab = tuple(out_produced.index(ix) for ix in out) + if perm_ab == tuple(range(len(perm_ab))): + perm_ab = None + + return ( + eq_a, + eq_b, + new_shape_a, + new_shape_b, + new_shape_ab, + perm_ab, + False, # pure_multiplication=False + ) + + +def _einsum_single(eq, x, backend=None): + """Einsum on a single tensor, via three steps: diagonal selection + (via advanced indexing), axes summations, transposition. The logic for each + is cached based on the equation and array shape, and each step is only + performed if necessary. + """ + try: + return do("einsum", eq, x, like=backend) + except ImportError: + pass + + diag_sels, sum_axes, perm = _parse_einsum_single(eq, shape(x)) + + if diag_sels is not None: + # diagonal reduction via advanced indexing + # e.g ababbac->abc + for selector in diag_sels: + x = x[selector] + + if sum_axes is not None: + # trivial removal of axes via summation + # e.g. abc->c + x = do("sum", x, sum_axes, like=backend) + + if perm is not None: + # transpose to desired output + # e.g. abc->cba + x = do("transpose", x, perm, like=backend) + + return x + + +def _do_contraction_via_bmm( + a, + b, + eq_a, + eq_b, + new_shape_a, + new_shape_b, + new_shape_ab, + perm_ab, + pure_multiplication, + backend, +): + # prepare left + if eq_a is not None: + if isinstance(eq_a, tuple): + # only transpose + a = do("transpose", a, eq_a, like=backend) + else: + # diagonals, sums, and tranpose + a = _einsum_single(eq_a, a) + if new_shape_a is not None: + a = do("reshape", a, new_shape_a, like=backend) + + # prepare right + if eq_b is not None: + if isinstance(eq_b, tuple): + # only transpose + b = do("transpose", b, eq_b, like=backend) + else: + # diagonals, sums, and tranpose + b = _einsum_single(eq_b, b) + if new_shape_b is not None: + b = do("reshape", b, new_shape_b, like=backend) + + if pure_multiplication: + # no contracted indices + return do("multiply", a, b) + + # do the contraction! + ab = do("matmul", a, b, like=backend) + + # prepare the output + if new_shape_ab is not None: + ab = do("reshape", ab, new_shape_ab, like=backend) + if perm_ab is not None: + ab = do("transpose", ab, perm_ab, like=backend) + + return ab + + +def _torch_workspace_pop(shape, dtype, device): + try: + pool = _TORCH_WORKSPACE_POOL + except NameError: + return None + try: + return pool[(tuple(shape), dtype, device)].pop() + except (KeyError, IndexError): + return None + + +def _torch_workspace_push(x): + try: + pool = _TORCH_WORKSPACE_POOL + except NameError: + return + if not x.is_contiguous(): + return + pool.setdefault((tuple(x.shape), x.dtype, x.device), []).append(x) + + +def _torch_matmul_workspace(a, b): + import torch + + shape = torch.broadcast_shapes(a.shape[:-2], b.shape[:-2]) + ( + a.shape[-2], + b.shape[-1], + ) + try: + arena = _TORCH_ARENA + except NameError: + arena = None + + if arena is not None: + out = arena.alloc(shape, a.dtype, a.device) + else: + out = _torch_workspace_pop(shape, a.dtype, a.device) + if out is None: + out = torch.empty(shape, dtype=a.dtype, device=a.device) + + if a.ndim == 2 and b.ndim == 2: + torch.mm(a, b, out=out) + elif a.ndim == 3 and b.ndim == 3: + torch.bmm(a, b, out=out) + else: + torch.matmul(a, b, out=out) + + return out + + +def _torch_multiply_workspace(a, b): + import torch + + shape = torch.broadcast_shapes(a.shape, b.shape) + try: + arena = _TORCH_ARENA + except NameError: + arena = None + + if arena is not None: + out = arena.alloc(shape, a.dtype, a.device) + torch.mul(a, b, out=out) + return out + + return do("multiply", a, b) + + +def _torch_reshape_workspace(x, new_shape, backend): + try: + arena = _TORCH_ARENA + except NameError: + arena = None + if arena is None: + return do("reshape", x, new_shape, like=backend) + + if not hasattr(x, "view"): + return do("reshape", x, new_shape, like=backend) + + try: + return x.view(new_shape) + except RuntimeError: + pass + + try: + out = arena.alloc(tuple(new_shape), x.dtype, x.device) + out.view(tuple(x.shape)).copy_(x) + return out + except Exception: + return do("reshape", x, new_shape, like=backend) + + +def _do_contraction_via_bmm_torch_workspace( + a, + b, + eq_a, + eq_b, + new_shape_a, + new_shape_b, + new_shape_ab, + perm_ab, + pure_multiplication, + backend, +): + import torch + + if eq_a is not None: + if isinstance(eq_a, tuple): + a = do("transpose", a, eq_a, like=backend) + else: + a = _einsum_single(eq_a, a) + if new_shape_a is not None: + a = _torch_reshape_workspace(a, new_shape_a, backend) + + if eq_b is not None: + if isinstance(eq_b, tuple): + b = do("transpose", b, eq_b, like=backend) + else: + b = _einsum_single(eq_b, b) + if new_shape_b is not None: + b = _torch_reshape_workspace(b, new_shape_b, backend) + + if pure_multiplication: + return _torch_multiply_workspace(a, b) + + ab = _torch_matmul_workspace(a, b) + + if new_shape_ab is not None: + ab = _torch_reshape_workspace(ab, new_shape_ab, backend) + if perm_ab is not None: + ab = do("transpose", ab, perm_ab, like=backend) + + return ab + + +def einsum(eq, a, b=None, *, backend=None): + """Perform arbitrary single and pairwise einsums using only `matmul`, + `transpose`, `reshape` and `sum`. The logic for each is cached based on + the equation and array shape, and each step is only performed if necessary. + + Parameters + ---------- + eq : str + The einsum equation. + a : array_like + The first array to contract. + b : array_like, optional + The second array to contract. + backend : str, optional + The backend to use for array operations. If ``None``, dispatch + automatically based on ``a`` and ``b``. + + Returns + ------- + array_like + """ + if b is None: + return _einsum_single(eq, a, backend=backend) + + ( + eq_a, + eq_b, + new_shape_a, + new_shape_b, + new_shape_ab, + perm_ab, + pure_multiplication, + ) = _parse_eq_to_batch_matmul(eq, shape(a), shape(b)) + + do_contraction = ( + _do_contraction_via_bmm_torch_workspace + if backend == "torch" + and (_torch_workspace_enabled() or _torch_arena_enabled()) + else _do_contraction_via_bmm + ) + + return do_contraction( + a, + b, + eq_a, + eq_b, + new_shape_a, + new_shape_b, + new_shape_ab, + perm_ab, + pure_multiplication, + backend, + ) + + +def gen_nice_inds(): + """Generate the indices from [a-z, A-Z, reasonable unicode...].""" + for i in range(26): + yield chr(ord("a") + i) + for i in range(26): + yield chr(ord("A") + i) + for i in itertools.count(192): + yield chr(i) + + +@functools.lru_cache(2**12) +def _parse_tensordot_axes_to_matmul(axes, shape_a, shape_b): + """Parse a tensordot specification into the necessary sequence of arguments + for contracttion via matrix multiplication. This just converts ``axes`` + into an ``einsum`` eq string then calls ``_parse_eq_to_batch_matmul``. + """ + ndim_a = len(shape_a) + ndim_b = len(shape_b) + + if isinstance(axes, int): + axes_a = tuple(range(ndim_a - axes, ndim_a)) + axes_b = tuple(range(axes)) + else: + axes_a, axes_b = axes + + num_con = len(axes_a) + if num_con != len(axes_b): + raise ValueError( + f"Axes should have the same length, got {axes_a} and {axes_b}." + ) + + possible_inds = gen_nice_inds() + inds_a = [next(possible_inds) for _ in range(ndim_a)] + inds_b = [] + inds_out = inds_a.copy() + + for axb in range(ndim_b): + if axb not in axes_b: + # right uncontracted index + ind = next(possible_inds) + inds_out.append(ind) + else: + # contracted index + axa = axes_a[axes_b.index(axb)] + # check that the shapes match + if shape_a[axa] != shape_b[axb]: + raise ValueError( + f"Dimension mismatch between axes {axa} of {shape_a} and " + f"{axb} of {shape_b}: {shape_a[axa]} != {shape_b[axb]}." + ) + ind = inds_a[axa] + inds_out.remove(ind) + inds_b.append(ind) + + eq = f"{''.join(inds_a)},{''.join(inds_b)}->{''.join(inds_out)}" + + return _parse_eq_to_batch_matmul(eq, shape_a, shape_b) + + +def tensordot(a, b, axes=2, *, backend=None): + """Perform a tensordot using only `matmul`, `transpose`, `reshape`. The + logic for each is cached based on the equation and array shape, and each + step is only performed if necessary. + + Parameters + ---------- + a, b : array_like + The arrays to contract. + axes : int or tuple of (sequence[int], sequence[int]) + The number of axes to contract, or the axes to contract. If an int, + the last ``axes`` axes of ``a`` and the first ``axes`` axes of ``b`` + are contracted. If a tuple, the axes to contract for ``a`` and ``b`` + respectively. + backend : str or None, optional + The backend to use for array operations. If ``None``, dispatch + automatically based on ``a`` and ``b``. + + Returns + ------- + array_like + """ + try: + # ensure hashable + axes = tuple(map(int, axes[0])), tuple(map(int, axes[1])) + except IndexError: + axes = int(axes) + + ( + eq_a, + eq_b, + new_shape_a, + new_shape_b, + new_shape_ab, + perm_ab, + pure_multiplication, + ) = _parse_tensordot_axes_to_matmul(axes, shape(a), shape(b)) + + do_contraction = ( + _do_contraction_via_bmm_torch_workspace + if backend == "torch" + and (_torch_workspace_enabled() or _torch_arena_enabled()) + else _do_contraction_via_bmm + ) + + return do_contraction( + a, + b, + eq_a, + eq_b, + new_shape_a, + new_shape_b, + new_shape_ab, + perm_ab, + pure_multiplication, + backend, + ) + + +def extract_contractions( + tree, + order=None, + prefer_einsum=False, +): + """Extract just the information needed to perform the contraction. + + Parameters + ---------- + order : str or callable, optional + Supplied to :meth:`ContractionTree.traverse`. + prefer_einsum : bool, optional + Prefer to use ``einsum`` for pairwise contractions, even if + ``tensordot`` can perform the contraction. + + Returns + ------- + contractions : tuple + A tuple of tuples, each containing the information needed to + perform a pairwise contraction. Each tuple contains: + + - ``p``: the parent node, + - ``l``: the left child node, + - ``r``: the right child node, + - ``tdot``: whether to use ``tensordot`` or ``einsum``, + - ``arg``: the argument to pass to ``tensordot`` or ``einsum`` + i.e. ``axes`` or ``eq``, + - ``perm``: the permutation required after the contraction, if + any (only applies to tensordot). + + If both ``l`` and ``r`` are ``None``, the the operation is a single + term simplification performed with ``einsum``. + """ + contractions = [] + + # pairwise contractions + contractions.extend( + (p, l, r, False, tree.get_einsum_eq(p), None) + if (prefer_einsum or not tree.get_can_dot(p)) + else ( + p, + l, + r, + True, + tree.get_tensordot_axes(p), + tree.get_tensordot_perm(p), + ) + for p, l, r in tree.traverse(order=order) + ) + + if tree.preprocessing: + # inplace single term simplifications + # n.b. these are populated lazily when the other information is + # computed above, so we do it after + pre_contractions = ( + (node_from_single(i), None, None, False, eq, None) + for i, eq in tree.preprocessing.items() + ) + return (*pre_contractions, *contractions) + + return tuple(contractions) + + +class Contractor: + """Default cotengra network contractor. + + Parameters + ---------- + contractions : tuple[tuple] + The sequence of contractions to perform. Each contraction should be a + tuple containing: + + - ``p``: the parent node, + - ``l``: the left child node, + - ``r``: the right child node, + - ``tdot``: whether to use ``tensordot`` or ``einsum``, + - ``arg``: the argument to pass to ``tensordot`` or ``einsum`` + i.e. ``axes`` or ``eq``, + - ``perm``: the permutation required after the contraction, if + any (only applies to tensordot). + + e.g. built by calling ``extract_contractions(tree)``. + + strip_exponent : bool, optional + If ``True``, eagerly strip the exponent (in log10) from + intermediate tensors to control numerical problems from leaving the + range of the datatype. This method then returns the scaled + 'mantissa' output array and the exponent separately. + check_zero : bool, optional + If ``True``, when ``strip_exponent=True``, explicitly check for + zero-valued intermediates that would otherwise produce ``nan``, + instead terminating early if encounteredand returning + ``(0.0, 0.0)``. + backend : str, optional + What library to use for ``tensordot``, ``einsum`` and + ``transpose``, it will be automatically inferred from the input + arrays if not given. + progbar : bool, optional + Whether to show a progress bar. + """ + + __slots__ = ( + "contractions", + "strip_exponent", + "check_zero", + "implementation", + "backend", + "progbar", + "__weakref__", + ) + + def __init__( + self, + contractions, + strip_exponent=False, + check_zero=False, + implementation="auto", + backend=None, + progbar=False, + ): + self.contractions = contractions + self.strip_exponent = strip_exponent + self.check_zero = check_zero + self.implementation = implementation + self.backend = backend + self.progbar = progbar + + def __call__(self, *arrays, **kwargs): + """Contract ``arrays`` using operations listed in ``contractions``. + + Parameters + ---------- + arrays : sequence of array-like + The arrays to contract. + kwargs : dict + Override the default settings for this contraction only. + + Returns + ------- + output : array + The contracted output, it will be scaled if ``strip_exponent==True``. + exponent : float + The exponent of the output in base 10, returned only if + ``strip_exponent==True``. + """ + backend = kwargs.pop("backend", self.backend) + progbar = kwargs.pop("progbar", self.progbar) + check_zero = kwargs.pop("check_zero", self.check_zero) + strip_exponent = kwargs.pop("strip_exponent", self.strip_exponent) + implementation = kwargs.pop("implementation", self.implementation) + if kwargs: + raise TypeError(f"Unknown keyword arguments: {kwargs}.") + + if backend is None: + backend = infer_backend_multi(*arrays) + + if implementation == "auto": + if (backend == "numpy") or ( + backend == "torch" + and all( + getattr(getattr(x, "device", None), "type", "cpu") == "cpu" + for x in arrays + if hasattr(x, "device") + ) + ): + # by default replace numpy's einsum/tensordot, and do the + # same for torch CPU to control bmm outputs and workspace reuse + implementation = "cotengra" + else: + implementation = "autoray" + + if implementation == "cotengra": + _einsum, _tensordot = einsum, tensordot + elif implementation == "autoray": + try: + _einsum = get_lib_fn(backend, "einsum") + except ImportError: + # fallback to cotengra (matmul) implementation + _einsum = einsum + + try: + _tensordot = get_lib_fn(backend, "tensordot") + except ImportError: + # fallback to cotengra (matmul) implementation + _tensordot = tensordot + else: + # manually supplied + _einsum, _tensordot = implementation + + using_torch_arena = (backend == "torch") and _torch_arena_enabled() + if using_torch_arena: + global _TORCH_ARENA + _TORCH_ARENA = _TorchArena() + + using_torch_workspace = ( + (backend == "torch") + and (_einsum is einsum) + and _torch_workspace_enabled() + and not using_torch_arena + ) + if using_torch_workspace: + global _TORCH_WORKSPACE_POOL + _TORCH_WORKSPACE_POOL = {} + + # temporary storage for intermediates + N = len(arrays) + temps = { + leaf: array + for leaf, array in zip(map(node_from_single, range(N)), arrays) + } + + exponent = 0.0 if (strip_exponent is not False) else None + + if progbar: + import tqdm + + contractions = tqdm.tqdm(self.contractions, total=N - 1) + else: + contractions = self.contractions + + p_array = next(iter(temps.values())) + for p, l, r, tdot, arg, perm in contractions: + if (l is None) and (r is None): + # single term simplification, perform inplace with einsum + temps[p] = _einsum(arg, temps[p]) + p_array = temps[p] + continue + + # get input arrays for this contraction + l_array = temps.pop(l) + r_array = temps.pop(r) + + if tdot: + p_array = _tensordot(l_array, r_array, arg) + if perm: + p_array = do("transpose", p_array, perm, like=backend) + else: + p_array = _einsum(arg, l_array, r_array) + + if exponent is not None: + factor = do( + "max", do("abs", p_array, like=backend), like=backend + ) + if check_zero and float(factor) == 0.0: + if using_torch_arena: + _TORCH_ARENA = None + return 0.0, float("-inf") + exponent = exponent + do("log10", factor, like=backend) + p_array = p_array / factor + + # insert the new intermediate array + temps[p] = p_array + + if using_torch_workspace: + if (len(l) != 1) and hasattr(l_array, "device"): + _torch_workspace_push(l_array) + if (len(r) != 1) and hasattr(r_array, "device"): + _torch_workspace_push(r_array) + + if using_torch_arena: + _TORCH_ARENA.release(l_array) + _TORCH_ARENA.release(r_array) + + if using_torch_arena: + # The final output may be a view into the arena. Clone it before + # dropping the arena so a scalar result doesn't keep the whole + # workspace storage alive. + if hasattr(p_array, "clone"): + p_array = p_array.clone() + _TORCH_ARENA = None + + if exponent is not None: + return p_array, exponent + + return p_array + + +class CuQuantumContractor: + def __init__( + self, + tree, + handle_slicing=False, + autotune=False, + **kwargs, + ): + if kwargs.pop("strip_exponent", None): + raise ValueError( + "strip_exponent=True not supported with cuQuantum" + ) + + if tree.has_preprocessing(): + raise ValueError("Preprocessing not supported with cuQuantum yet.") + + if kwargs.pop("progbar", None): + import warnings + + warnings.warn("Progress bar not supported with cuQuantum yet.") + + if handle_slicing: + self.eq = tree.get_eq() + self.shapes = tree.get_shapes() + else: + self.eq = tree.get_eq_sliced() + self.shapes = tree.get_shapes_sliced() + + if tree.is_complete(): + kwargs.setdefault("optimize", {}) + kwargs["optimize"].setdefault("path", tree.get_path()) + + if handle_slicing and tree.sliced_inds: + kwargs["optimize"].setdefault( + "slicing", + [(ix, tree.size_dict[ix] - 1) for ix in tree.sliced_inds], + ) + + self.kwargs = kwargs + self.autotune = 3 if autotune is True else autotune + self.handle = None + self.network = None + + def setup(self, *arrays): + import cuquantum + + if hasattr(cuquantum, "bindings"): + # cuquantum-python >= 25.03 + from cuquantum.tensornet import Network + else: + # for cuquantum < 25.03 + from cuquantum import Network + + self.network = Network( + self.eq, + *arrays, + ) + self.network.contract_path(**self.kwargs) + if self.autotune: + self.network.autotune(iterations=self.autotune) + + def __call__( + self, + *arrays, + check_zero=False, + backend=None, + progbar=False, + ): + # can't handle these yet + assert not check_zero + assert not progbar + assert backend is None + + if self.network is None: + self.setup(*arrays) + else: + self.network.reset_operands(*arrays) + + return self.network.contract() + + def __del__(self): + if self.network is not None: + self.network.free() + + +def make_contractor( + tree, + order=None, + prefer_einsum=False, + strip_exponent=False, + check_zero=False, + implementation=None, + autojit=False, + progbar=False, +): + """Get a reusable function which performs the contraction corresponding + to ``tree``. The various options provide defaults that can also be overrode + when calling the standard contractor. + + Parameters + ---------- + tree : ContractionTree + The contraction tree. + order : str or callable, optional + Supplied to :meth:`ContractionTree.traverse`, the order in which + to perform the pairwise contractions given by the tree. + prefer_einsum : bool, optional + Prefer to use ``einsum`` for pairwise contractions, even if + ``tensordot`` can perform the contraction. + strip_exponent : bool, optional + If ``True``, the function will strip the exponent from the output + array and return it separately. + check_zero : bool, optional + If ``True``, when ``strip_exponent=True``, explicitly check for + zero-valued intermediates that would otherwise produce ``nan``, + instead terminating early if encountered and returning + ``(0.0, 0.0)``. + implementation : str or tuple[callable, callable], optional + What library to use to actually perform the contractions. Options are + + - "auto": let cotengra choose + - "autoray": dispatch with autoray, using the ``tensordot`` and + ``einsum`` implementation of the backend + - "cotengra": use the ``tensordot`` and ``einsum`` implementation of + cotengra, which is based on batch matrix multiplication. This is + faster for some backends like numpy, and also enables libraries + which don't yet provide ``tensordot`` and ``einsum`` to be used. + - "cuquantum": use the cuquantum library to perform the whole + contraction (not just individual contractions). + - tuple[callable, callable]: manually supply the ``tensordot`` and + ``einsum`` implementations to use. + + autojit : bool, optional + If ``True``, use :func:`autoray.autojit` to compile the contraction + function. + progbar : bool, optional + Whether to show progress through the contraction by default. + + Returns + ------- + fn : callable + The contraction function, with signature ``fn(*arrays)``. + """ + if implementation is None: + implementation = get_default_implementation() + + if implementation == "cuquantum": + fn = CuQuantumContractor( + tree, + strip_exponent=strip_exponent, + check_zero=check_zero, + progbar=progbar, + ) + else: + fn = Contractor( + contractions=extract_contractions(tree, order, prefer_einsum), + strip_exponent=strip_exponent, + check_zero=check_zero, + implementation=implementation, + progbar=progbar, + ) + if autojit: + from autoray import autojit as _autojit + + fn = _autojit(fn) + + return fn diff --git a/.venv/lib/python3.12/site-packages/cotengra/core.py b/.venv/lib/python3.12/site-packages/cotengra/core.py new file mode 100644 index 0000000..59ce75d --- /dev/null +++ b/.venv/lib/python3.12/site-packages/cotengra/core.py @@ -0,0 +1,4130 @@ +"""Core contraction tree data structure and methods.""" + +import collections +import functools +import itertools +import math +import warnings +from dataclasses import dataclass +from typing import Optional + +from autoray import do + +from .contract import make_contractor +from .hypergraph import get_hypergraph +from .parallel import ( + can_scatter, + maybe_leave_pool, + maybe_rejoin_pool, + parse_parallel_arg, + scatter, + submit, +) +from .pathfinders.path_simulated_annealing import ( + parallel_temper_tree, + simulated_anneal_tree, +) +from .plot import ( + plot_contractions, + plot_contractions_alt, + plot_hypergraph, + plot_tree_circuit, + plot_tree_flat, + plot_tree_ring, + plot_tree_rubberband, + plot_tree_span, + plot_tree_tent, +) +from .scoring import ( + DEFAULT_COMBO_FACTOR, + CompressedStatsTracker, + get_score_fn, +) +from .utils import ( + MaxCounter, + compute_size_by_dict, + deprecated, + get_rng, + get_symbol, + groupby, + inputs_output_to_eq, + interleave, + is_valid_node, + node_from_seq, + node_from_single, + node_get_single_el, + node_supremum, + oset, + prod, + unique, +) + + +def cached_node_property(name): + """Decorator for caching information about nodes.""" + + def wrapper(meth): + @functools.wraps(meth) + def getter(self, node): + try: + return self.info[node][name] + except KeyError: + self.info[node][name] = value = meth(self, node) + return value + + return getter + + return wrapper + + +def union_it(bs): + """Non-variadic version of various set type unions.""" + b0, *bs = bs + return b0.union(*bs) + + +def legs_union(legs_seq): + """Combine a sequence of legs into a single set of legs, summing their + appearances. + """ + new_legs, *rem_legs = legs_seq + new_legs = new_legs.copy() + for legs in rem_legs: + for ix, ix_count in legs.items(): + new_legs[ix] = new_legs.get(ix, 0) + ix_count + return new_legs + + +def legs_without(legs, ind): + """Discard ``ind`` from legs to create a new set of legs.""" + new_legs = legs.copy() + new_legs.pop(ind, None) + return new_legs + + +def get_with_default(k, obj, default): + return obj.get(k, default) + + +@dataclass(order=True, frozen=True) +class SliceInfo: + inner: bool + ind: str + size: int + project: Optional[int] + + @property + def sliced_range(self): + if self.project is None: + return range(self.size) + else: + return [self.project] + + +def get_slice_strides(sliced_inds): + """Compute the 'strides' given the (ordered) dictionary of sliced indices.""" + slice_infos = list(sliced_inds.values()) + nsliced = len(slice_infos) + strides = [1] * nsliced + # backwards cumulative product + for i in range(nsliced - 2, -1, -1): + strides[i] = strides[i + 1] * slice_infos[i + 1].size + return strides + + +def add_maybe_exponent_stripped(x, y): + """Add two arrays, or tuples of (array, exponent) together in a stable + and branchless way. + """ + xistup = isinstance(x, tuple) + yistup = isinstance(y, tuple) + if not (xistup or yistup): + # simple sum without exponent + return x + y + + if xistup: + xm, xe = x + else: + xm = x + xe = 0.0 + + if yistup: + ym, ye = y + else: + ym = y + ye = 0.0 + + # perform branchless for jit etc. + e = max(xe, ye) + m = xm * 10 ** (xe - e) + ym * 10 ** (ye - e) + + return (m, e) + + +class ContractionTree: + """Binary tree representing a tensor network contraction. + + Parameters + ---------- + inputs : sequence of str + The list of input tensor's indices. + output : str + The output indices. + size_dict : dict[str, int] + The size of each index. + track_childless : bool, optional + Whether to dynamically keep track of which nodes are childless. Useful + if you are 'divisively' building the tree. + track_flops : bool, optional + Whether to dynamically keep track of the total number of flops. If + ``False`` You can still compute this once the tree is complete. + track_write : bool, optional + Whether to dynamically keep track of the total number of elements + written. If ``False`` You can still compute this once the tree is + complete. + track_size : bool, optional + Whether to dynamically keep track of the largest tensor so far. If + ``False`` You can still compute this once the tree is complete. + objective : str or Objective, optional + An default objective function to use for further optimization and + scoring, for example reconfiguring or computing the combo cost. If not + supplied the default is to create a flops objective when needed. + + Attributes + ---------- + children : dict[node, tuple[node]] + Mapping of each node to two children. + info : dict[node, dict] + Information about the tree nodes. The key is the set of inputs (a + set of inputs indices) the node contains. Or in other words, the + subgraph of the node. The value is a dictionary to cache information + about effective 'leg' indices, size, flops of formation etc. + """ + + def __init__( + self, + inputs, + output, + size_dict, + track_childless=False, + track_flops=False, + track_write=False, + track_size=False, + objective=None, + ): + self.inputs = inputs + self.output = output + + if isinstance(self.inputs[0], set) or isinstance(self.output, set): + warnings.warn( + "The inputs or output of this tree are not ordered." + "Costs will be accurate but actually contracting requires " + "ordered indices corresponding to array axes." + ) + + if not isinstance(next(iter(size_dict.values()), 1), int): + # make sure we are working with python integers to avoid overflow + # comparison errors with inf etc. + self.size_dict = {k: int(v) for k, v in size_dict.items()} + else: + self.size_dict = size_dict + + self.N = len(self.inputs) + + # the index representation for each input is an ordered mapping of + # each index to the number of times it has appeared on children. By + # also tracking the total number of appearances one can efficiently + # and locally compute which indices should be kept or contracted + self.appearances = {} + for term in self.inputs: + for ix in term: + self.appearances[ix] = self.appearances.get(ix, 0) + 1 + # adding output appearances ensures these are never contracted away, + # N.B. if after this step every appearance count is exactly 2, + # then there are no 'hyper' indices in the contraction + for ix in self.output: + self.appearances[ix] = self.appearances.get(ix, 0) + 1 + + # this stores potentialy preprocessing steps that are not part of the + # main contraction tree, but assumed to have been applied, for example + # tracing or summing over indices that appear only once + self.preprocessing = {} + + # mapping of parents to children - the core binary tree object + self.children = {} + + # information about all the nodes + self.info = {} + + # add constant nodes: the leaves + for leaf in self.gen_leaves(): + self._add_node(leaf) + # and the root or top node + self.root = node_supremum(self.N) + self._add_node(self.root) + + # whether to keep track of dangling nodes/subgraphs + self.track_childless = track_childless + if self.track_childless: + # the set of dangling nodes + self.childless = oset([self.root]) + + # running largest_intermediate and total flops + self._track_flops = track_flops + if track_flops: + self._flops = 0 + + self._track_write = track_write + if track_write: + self._write = 0 + + self._track_size = track_size + if track_size: + self._sizes = MaxCounter() + + # container for caching subtree reconfiguration condidates + self.already_optimized = dict() + + # info relating to slicing (base constructor is always unsliced) + self.multiplicity = 1 + self.sliced_inds = {} + self.sliced_inputs = frozenset() + + # cache for compiled contraction cores + self.contraction_cores = {} + + # a default objective function useful for + # further optimization and scoring + self._default_objective = objective + + def set_state_from(self, other): + """Set the internal state of this tree to that of ``other``.""" + # immutable or never mutated properties + for attr in ( + "appearances", + "inputs", + "multiplicity", + "N", + "output", + "root", + "size_dict", + "sliced_inputs", + "_default_objective", + ): + setattr(self, attr, getattr(other, attr)) + + # mutable properties + for attr in ( + "children", + "contraction_cores", + "sliced_inds", + "preprocessing", + ): + setattr(self, attr, getattr(other, attr).copy()) + + # dicts of mutable + for attr in ("info", "already_optimized"): + setattr( + self, + attr, + {k: v.copy() for k, v in getattr(other, attr).items()}, + ) + + self.track_childless = other.track_childless + if other.track_childless: + self.childless = other.childless.copy() + + self._track_flops = other._track_flops + if other._track_flops: + self._flops = other._flops + + self._track_write = other._track_write + if other._track_write: + self._write = other._write + + self._track_size = other._track_size + if other._track_size: + self._sizes = other._sizes.copy() + + def copy(self): + """Create a copy of this ``ContractionTree``.""" + tree = object.__new__(self.__class__) + tree.set_state_from(self) + return tree + + def set_default_objective(self, objective): + """Set the objective function for this tree.""" + self._default_objective = get_score_fn(objective) + + def get_default_objective(self): + """Get the objective function for this tree.""" + if self._default_objective is None: + self._default_objective = get_score_fn("flops") + return self._default_objective + + def get_default_combo_factor(self): + """Get the default combo factor for this tree.""" + objective = self.get_default_objective() + try: + return objective.factor + except AttributeError: + return DEFAULT_COMBO_FACTOR + + def get_score(self, objective=None): + """Score this tree using the default objective function.""" + from .scoring import get_score_fn + + if objective is None: + objective = self.get_default_objective() + + objective = get_score_fn(objective) + + return objective({"tree": self}) + + @property + def nslices(self): + """Simple alias for how many independent contractions this tree + represents overall. + """ + return self.multiplicity + + @property + def nchunks(self): + """The number of 'chunks' - determined by the number of sliced output + indices. + """ + return prod( + si.size for si in self.sliced_inds.values() if not si.inner + ) + + def node_to_terms(self, node): + """Turn a node -- a frozen set of ints -- into the corresponding terms + -- a sequence of sets of str corresponding to input indices. + """ + return (self.get_legs(node_from_single(i)) for i in node) + + def gen_leaves(self): + """Generate the nodes representing leaves of the contraction tree, i.e. + of size 1 each corresponding to a single input tensor. + """ + return map(node_from_single, range(self.N)) + + def get_incomplete_nodes(self): + """Get the set of current nodes that have no children and the set of + nodes that have no parents. These are the 'childless' and 'parentless' + nodes respectively, that need to be contracted to complete the tree. + The parentless nodes are grouped into the childless nodes that contain + them as subgraphs. + + Returns + ------- + groups : dict[frozenet[int], list[frozenset[int]]] + A mapping of childless nodes to the list of parentless nodes are + beneath them. + + See Also + -------- + autocomplete + """ + childless = dict.fromkeys( + node + for node in self.info + # start wth all but leaves + if len(node) != 1 + ) + parentless = dict.fromkeys( + node + for node in self.info + # start with all but root + if len(node) != self.N + ) + for p, (l, r) in self.children.items(): + parentless.pop(l) + parentless.pop(r) + childless.pop(p) + + groups = {node: [] for node in childless} + for node in parentless: + # get the smallest node that contains this node + ancestor = min( + filter(node.issubset, childless), + key=len, + ) + groups[ancestor].append(node) + + return groups + + def autocomplete(self, **contract_opts): + """Contract all remaining node groups (as computed by + ``tree.get_incomplete_nodes``) in the tree to complete it. + + Parameters + ---------- + contract_opts + Options to pass to ``tree.contract_nodes``. + + See Also + -------- + get_incomplete_nodes, contract_nodes + """ + groups = self.get_incomplete_nodes() + for _, parentless_subnodes in groups.items(): + self.contract_nodes(parentless_subnodes, **contract_opts) + + @classmethod + def from_path( + cls, + inputs, + output, + size_dict, + *, + path=None, + ssa_path=None, + edge_path=None, + optimize="auto-hq", + autocomplete="auto", + check=False, + **kwargs, + ): + """Create a (completed) ``ContractionTree`` from the usual inputs plus + a standard contraction path or 'ssa_path' - you need to supply one. + + Parameters + ---------- + inputs : Sequence[Sequence[str]] + The input indices of each tensor, as single unicode characters. + output : Sequence[str] + The output indices. + size_dict : dict[str, int] + The size of each index. + path : Sequence[Sequence[int]], optional + The contraction path, a sequence of pairs of tensor ids to + contract. The ids are linear indices into the list of temporary + tensors, which are recycled as each contraction pops a pair and + appends the result. One of ``path``, ``ssa_path`` or ``edge_path`` + must be supplied. + ssa_path : Sequence[Sequence[int]], optional + The contraction path, a sequence of pairs of indices to contract. + The indices are single use, as if the result of each contraction is + appended to the end of the list of temporary tensors without + popping. One of ``path``, ``ssa_path`` or ``edge_path`` must be + supplied. + edge_path : Sequence[str], optional + The contraction path, a sequence of indices to contract in order. + One of ``path``, ``ssa_path`` or ``edge_path`` must be supplied. + optimize : str, optional + If a contraction within the path contains 3 or more tensors, how to + optimize this subcontraction into a binary tree. + autocomplete : "auto" or bool, optional + Whether to automatically complete the path, i.e. contract all + remaining nodes. If "auto" then a warning is issued if the path is + not complete. + check : bool, optional + Whether to perform some basic checks while creating the contraction + nodes. + + Returns + ------- + ContractionTree + """ + if (path is None) + (ssa_path is None) + (edge_path is None) != 2: + raise ValueError( + "Exactly one of ``path`` or ``ssa_path`` must be supplied." + ) + + contract_opts = {"optimize": optimize, "check": check} + + if edge_path is not None: + from .pathfinders.path_basic import edge_path_to_ssa + + ssa_path = edge_path_to_ssa(edge_path, inputs) + + if ssa_path is not None: + path = ssa_path + + tree = cls(inputs, output, size_dict, **kwargs) + + if ssa_path is not None: + # ssa path (single use ids) + nodes = dict(enumerate(tree.gen_leaves())) + ssa = len(nodes) + for p in path: + merge = [nodes.pop(i) for i in p] + nodes[ssa] = tree.contract_nodes(merge, **contract_opts) + ssa += 1 + nodes = nodes.values() + else: + # regular path ('recycled' ids) + nodes = list(tree.gen_leaves()) + for p in path: + merge = [nodes.pop(i) for i in sorted(p, reverse=True)] + nodes.append(tree.contract_nodes(merge, **contract_opts)) + + if len(nodes) > 1 and autocomplete: + if autocomplete == "auto": + # warn that we are completing + warnings.warn( + "Path was not complete - contracting all remaining. " + "You can silence this warning with `autocomplete=True`." + "Or produce an incomplete tree with `autocomplete=False`." + ) + + tree.contract_nodes(nodes, **contract_opts) + + return tree + + @classmethod + def from_info(cls, info, **kwargs): + """Create a ``ContractionTree`` from an ``opt_einsum.PathInfo`` object.""" + return cls.from_path( + inputs=info.input_subscripts.split(","), + output=info.output_subscript, + size_dict=info.size_dict, + path=info.path, + **kwargs, + ) + + @classmethod + def from_eq(cls, eq, size_dict, **kwargs): + """Create a empty ``ContractionTree`` directly from an equation and set + of shapes. + + Parameters + ---------- + eq : str + The einsum string equation. + size_dict : dict[str, int] + The size of each index. + """ + lhs, output = eq.split("->") + inputs = lhs.split(",") + return cls(inputs, output, size_dict, **kwargs) + + def get_eq(self): + """Get the einsum equation corresponding to this tree. Note that this + is the total (or original) equation, so includes indices which have + been sliced. + + Returns + ------- + eq : str + """ + return inputs_output_to_eq(self.inputs, self.output) + + def get_shapes(self): + """Get the shapes of the input tensors corresponding to this tree. + + Returns + ------- + shapes : tuple[tuple[int]] + """ + return tuple( + tuple(self.size_dict[ix] for ix in term) for term in self.inputs + ) + + def get_inputs_sliced(self): + """Get the input indices corresponding to a single slice of this tree, + i.e. with sliced indices removed. + + Returns + ------- + inputs : tuple[tuple[str]] + """ + return tuple( + tuple(ix for ix in term if ix not in self.sliced_inds) + for term in self.inputs + ) + + def get_output_sliced(self): + """Get the output indices corresponding to a single slice of this tree, + i.e. with sliced indices removed. + + Returns + ------- + output : tuple[str] + """ + return tuple(ix for ix in self.output if ix not in self.sliced_inds) + + def get_eq_sliced(self): + """Get the einsum equation corresponding to a single slice of this + tree, i.e. with sliced indices removed. + + Returns + ------- + eq : str + """ + return inputs_output_to_eq( + self.get_inputs_sliced(), self.get_output_sliced() + ) + + def get_shapes_sliced(self): + """Get the shapes of the input tensors corresponding to a single slice + of this tree, i.e. with sliced indices removed. + + Returns + ------- + shapes : tuple[tuple[int]] + """ + return tuple( + tuple( + self.size_dict[ix] for ix in term if ix not in self.sliced_inds + ) + for term in self.inputs + ) + + @classmethod + def from_edge_path( + cls, + edge_path, + inputs, + output, + size_dict, + optimize="auto-hq", + autocomplete="auto", + check=False, + **kwargs, + ): + """Create a ``ContractionTree`` from an edge elimination ordering.""" + warnings.warn( + "ContractionTree.from_edge_path(edge_path, ...) is deprecated. Use" + " ContractionTree.from_path(edge_path=edge_path, ...) instead." + ) + return cls.from_path( + inputs, + output, + size_dict, + edge_path=edge_path, + optimize=optimize, + autocomplete=autocomplete, + check=check, + **kwargs, + ) + + def _add_node(self, node, check=False): + if check: + if len(self.info) > 2 * self.N - 1: + raise ValueError("There are too many children already.") + if len(self.children) > self.N - 1: + raise ValueError("There are too many branches already.") + if not is_valid_node(node): + raise ValueError("{} is not a valid node.".format(node)) + + self.info.setdefault(node, dict()) + + def _remove_node(self, node): + """Remove ``node`` from this tree and update the flops and maximum size + if tracking them respectively, as well as input pre-processing. + """ + node_extent = len(node) + + if node_extent == 1: + # leaf nodes should always exist + self.info[node].clear() + # input: remove any associated preprocessing + self.preprocessing.pop(node_get_single_el(node), None) + else: + # only non-leaf nodes contribute to size, flops and write + if self._track_size: + self._sizes.discard(self.get_size(node)) + + if self._track_flops: + self._flops -= self.get_flops(node) + + if self._track_write: + self._write -= self.get_size(node) + + del self.children[node] + if node_extent == self.N: + # root node should always exist + self.info[node].clear() + else: + del self.info[node] + + def compute_leaf_legs(self, i): + """Compute the effective 'outer' indices for the ith input tensor. This + is not always simply the ith input indices, due to A) potential slicing + and B) potential preprocessing. + """ + # indices of input tensor (after slicing which is done immediately) + if self.sliced_inds: + term = tuple( + ix for ix in self.inputs[i] if ix not in self.sliced_inds + ) + else: + term = self.inputs[i] + + legs = {} + for ix in term: + legs[ix] = legs.get(ix, 0) + 1 + + # check for single term simplifications, these are treated as a simple + # preprocessing step that only is taken into account during actual + # contraction, and are not represented in the binary tree + # N.B. need to compute simplifiability *after* slicing + is_simplifiable = ( + # repeated indices (diag or traces) + (len(term) != len(legs)) + or + # reduced indices (are summed immediately) + any( + ix_count == self.appearances[ix] + for ix, ix_count in legs.items() + ) + ) + + if is_simplifiable: + # compute the simplified legs -> the new effective input legs + legs = { + ix: ix_count + for ix, ix_count in legs.items() + if ix_count != self.appearances[ix] + } + # add a preprocessing step to the list of contractions + eq = inputs_output_to_eq((term,), legs, canonicalize=True) + self.preprocessing[i] = eq + + return legs + + def has_preprocessing(self): + # touch all inputs legs, since preprocessing is lazily computed + for node in self.gen_leaves(): + self.get_legs(node) + return bool(self.preprocessing) + + def has_hyper_indices(self): + """Check if there are any 'hyper' indices in the contraction, i.e. + indices that don't appear exactly twice, when considering the inputs + and output. + """ + return any(ix_count != 2 for ix_count in self.appearances.values()) + + @cached_node_property("legs") + def get_legs(self, node): + """Get the effective 'outer' indices for the collection of tensors + in ``node``. + """ + node_extent = len(node) + + if node_extent == 1: + # leaf legs are inputs + return self.compute_leaf_legs(node_get_single_el(node)) + elif node_extent == self.N: + # root legs are output, after slicing + # n.b. the index counts are irrelevant for the output + return {ix: 0 for ix in self.output if ix not in self.sliced_inds} + + try: + involved = self.get_involved(node) + except KeyError: + involved = legs_union(self.node_to_terms(node)) + + return { + ix: ix_count + for ix, ix_count in involved.items() + if ix_count < self.appearances[ix] + } + + @cached_node_property("involved") + def get_involved(self, node): + """Get all the indices involved in the formation of subgraph ``node``.""" + if len(node) == 1: + return {} + sub_legs = map(self.get_legs, self.children[node]) + return legs_union(sub_legs) + + @cached_node_property("size") + def get_size(self, node): + """Get the tensor size of ``node``.""" + return compute_size_by_dict(self.get_legs(node), self.size_dict) + + @cached_node_property("flops") + def get_flops(self, node): + """Get the FLOPs for the pairwise contraction that will create + ``node``. + """ + if len(node) == 1: + return 0 + involved = self.get_involved(node) + return compute_size_by_dict(involved, self.size_dict) + + @cached_node_property("can_dot") + def get_can_dot(self, node): + """Get whether this contraction can be performed as a dot product (i.e. + with ``tensordot``), or else requires ``einsum``, as it has indices + that don't appear exactly twice in either the inputs or the output. + """ + l, r = self.children[node] + sp, sl, sr = map(self.get_legs, (node, l, r)) + return set(sp) == set(sl).symmetric_difference(sr) + + @cached_node_property("inds") + def get_inds(self, node): + """Get the indices of this node - an ordered string version of + ``get_legs`` that starts with ``tree.inputs`` and maintains the order + they appear in each contraction 'ABC,abc->ABCabc', to match tensordot. + """ + # NB: self.inputs and self.output contain the full (unsliced) indices + # thus we filter even the input legs and output legs + + if len(node) in (1, self.N): + return "".join(self.get_legs(node)) + + legs = self.get_legs(node) + l_inds, r_inds = map(self.get_inds, self.children[node]) + default_inds = "".join( + unique(filter(legs.__contains__, itertools.chain(l_inds, r_inds))) + ) + + # The default ordering can put right-only output indices between + # batches of left-only indices. On torch CPU this tends to create + # high-rank permuted views that have to be cloned before GEMM/BMM. + # Group output indices by how this contraction naturally produces + # them: shared kept indices first, then left-only, then right-only. + r_inds_set = set(r_inds) + l_inds_set = set(l_inds) + + batch = [ + ix for ix in l_inds + if (ix in legs) and (ix in r_inds_set) + ] + left_keep = [ + ix for ix in l_inds + if (ix in legs) and (ix not in r_inds_set) + ] + right_keep = [ + ix for ix in r_inds + if (ix in legs) and (ix not in l_inds_set) + ] + + ordered = list(unique(itertools.chain(batch, left_keep, right_keep))) + seen = set(ordered) + ordered.extend(ix for ix in default_inds if ix not in seen) + return "".join(ordered) + + @cached_node_property("tensordot_axes") + def get_tensordot_axes(self, node): + """Get the ``axes`` arg for a tensordot ocontraction that produces + ``node``. The pairs are sorted in order of appearance on the left + input. + """ + l_inds, r_inds = map(self.get_inds, self.children[node]) + l_axes, r_axes = [], [] + for i, ind in enumerate(l_inds): + j = r_inds.find(ind) + if j != -1: + l_axes.append(i) + r_axes.append(j) + return tuple(l_axes), tuple(r_axes) + + @cached_node_property("tensordot_perm") + def get_tensordot_perm(self, node): + """Get the permutation required, if any, to bring the tensordot output + of this nodes contraction into line with ``self.get_inds(node)``. + """ + l_inds, r_inds = map(self.get_inds, self.children[node]) + # the target output inds + p_inds = self.get_inds(node) + # the tensordot output inds + td_inds = "".join(sorted(p_inds, key=f"{l_inds}{r_inds}".find)) + if td_inds == p_inds: + return None + return tuple(map(td_inds.find, p_inds)) + + @cached_node_property("einsum_eq") + def get_einsum_eq(self, node): + """Get the einsum string describing the contraction that produces + ``node``, unlike ``get_inds`` the characters are mapped into [a-zA-Z], + for compatibility with ``numpy.einsum`` for example. + """ + l, r = self.children[node] + l_inds, r_inds, p_inds = map(self.get_inds, (l, r, node)) + # we need to map any extended unicode characters into ascii + char_mapping = { + ord(ix): get_symbol(i) + for i, ix in enumerate(unique(itertools.chain(l_inds, r_inds))) + } + return f"{l_inds},{r_inds}->{p_inds}".translate(char_mapping) + + def get_centrality(self, node): + try: + return self.info[node]["centrality"] + except KeyError: + self.compute_centralities() + return self.info[node]["centrality"] + + def total_flops(self, dtype=None, log=None): + """Sum the flops contribution from every node in the tree. + + Parameters + ---------- + dtype : {'float', 'complex', None}, optional + Scale the answer depending on the assumed data type. + """ + if self._track_flops: + C = self.multiplicity * self._flops + + else: + self._flops = 0 + for node, _, _ in self.traverse(): + self._flops += self.get_flops(node) + + self._track_flops = True + C = self.multiplicity * self._flops + + if dtype is None: + pass + elif "float" in dtype: + C *= 2 + elif "complex" in dtype: + C *= 4 + else: + raise ValueError(f"Unknown dtype {dtype}") + + if log is not None: + C = math.log(C, log) + + return C + + def total_write(self): + """Sum the total amount of memory that will be created and operated on.""" + if not self._track_write: + self._write = 0 + for node, _, _ in self.traverse(): + self._write += self.get_size(node) + + self._track_write = True + + return self.multiplicity * self._write + + def combo_cost(self, factor=DEFAULT_COMBO_FACTOR, combine=sum, log=None): + t = 0 + for p in self.children: + f = self.get_flops(p) + w = self.get_size(p) + t += combine((f, factor * w)) + + t *= self.multiplicity + + if log is not None: + t = math.log(t, log) + + return t + + total_cost = combo_cost + + def max_size(self, log=None): + """The size of the largest intermediate tensor.""" + if self.N == 1: + return self.get_size(self.root) + + if not self._track_size: + self._sizes = MaxCounter() + for node, _, _ in self.traverse(): + self._sizes.add(self.get_size(node)) + self._track_size = True + + size = self._sizes.max() + + if log is not None: + size = math.log(size, log) + + return size + + def peak_size(self, order=None, log=None): + """Get the peak concurrent size of tensors needed - this depends on the + traversal order, i.e. the exact contraction path, not just the + contraction tree. + """ + tot_size = sum(self.get_size(node) for node in self.gen_leaves()) + peak = tot_size + for p, l, r in self.traverse(order=order): + tot_size += self.get_size(p) + # measure peak assuming we need both inputs and output + peak = max(peak, tot_size) + tot_size -= self.get_size(l) + tot_size -= self.get_size(r) + + if log is not None: + peak = math.log(peak, log) + + return peak + + def contract_stats(self, force=False): + """Simulteneously compute the total flops, write and size of the + contraction tree. This is more efficient than calling each of the + individual methods separately. Once computed, each quantity is then + automatically tracked. + + Returns + ------- + stats : dict[str, int] + The total flops, write and size. + """ + if force or not ( + self._track_flops and self._track_write and self._track_size + ): + self._flops = self._write = 0 + self._sizes = MaxCounter() + + for node, _, _ in self.traverse(): + self._flops += self.get_flops(node) + node_size = self.get_size(node) + self._write += node_size + self._sizes.add(node_size) + + self._track_flops = self._track_write = self._track_size = True + + return { + "flops": self.multiplicity * self._flops, + "write": self.multiplicity * self._write, + "size": self._sizes.max(), + } + + def arithmetic_intensity(self): + """The ratio of total flops to total write - the higher the better for + extracting good computational performance. + """ + return self.total_flops(dtype=None) / self.total_write() + + def contraction_scaling(self): + """This is computed simply as the maximum number of indices involved + in any single contraction, which will match the scaling assuming that + all dimensions are equal. + """ + return max(len(self.get_involved(node)) for node in self.info) + + def contraction_cost(self, log=None): + """Get the total number of scalar operations ~ time complexity.""" + return self.total_flops(dtype=None, log=log) + + def contraction_width(self, log=2): + """Get log2 of the size of the largest tensor.""" + return self.max_size(log=log) + + def compressed_contract_stats( + self, + chi=None, + order="surface_order", + compress_late=None, + ): + if chi is None: + chi = self.get_default_chi() + + if compress_late is None: + compress_late = self.get_default_compress_late() + + hg = self.get_hypergraph(accel="auto") + + # conversion between tree nodes <-> hypergraph nodes during contraction + tree_map = dict(zip(self.gen_leaves(), range(hg.get_num_nodes()))) + + tracker = CompressedStatsTracker(hg, chi) + + for p, l, r in self.traverse(order): + li = tree_map[l] + ri = tree_map[r] + + tracker.update_pre_step() + + if compress_late: + tracker.update_pre_compress(hg, li, ri) + # compress just before we contract tensors + hg.compress(chi=chi, edges=hg.get_node(li)) + hg.compress(chi=chi, edges=hg.get_node(ri)) + tracker.update_post_compress(hg, li, ri) + + tracker.update_pre_contract(hg, li, ri) + pi = tree_map[p] = hg.contract(li, ri) + tracker.update_post_contract(hg, pi) + + if not compress_late: + # compress as soon as we can after contracting tensors + tracker.update_pre_compress(hg, pi) + hg.compress(chi=chi, edges=hg.get_node(pi)) + tracker.update_post_compress(hg, pi) + + tracker.update_post_step() + + return tracker + + def total_flops_compressed( + self, + chi=None, + order="surface_order", + compress_late=None, + dtype=None, + log=None, + ): + """Estimate the total flops for a compressed contraction of this tree + with maximum bond size ``chi``. This includes basic estimates of the + ops to perform contractions, QRs and SVDs. + """ + if dtype is not None: + raise ValueError( + "Can only estimate cost in terms of " + "number of abstract scalar ops." + ) + + F = self.compressed_contract_stats( + chi=chi, + order=order, + compress_late=compress_late, + ).flops + + if log is not None: + F = math.log(F, log) + + return F + + contraction_cost_compressed = total_flops_compressed + + def total_write_compressed( + self, + chi=None, + order="surface_order", + compress_late=None, + accel="auto", + log=None, + ): + """Compute the total size of all intermediate tensors when a + compressed contraction is performed with maximum bond size ``chi``, + ordered by ``order``. This is relevant maybe for time complexity and + e.g. autodiff space complexity (since every intermediate is kept). + """ + W = self.compressed_contract_stats( + chi=chi, + order=order, + compress_late=compress_late, + ).write + + if log is not None: + W = math.log(W, log) + + return W + + def combo_cost_compressed( + self, + chi=None, + order="surface_order", + compress_late=None, + factor=None, + log=None, + ): + if factor is None: + factor = self.get_default_combo_factor() + + C = self.total_flops_compressed( + chi=chi, order=order, compress_late=compress_late + ) + factor * self.total_write_compressed( + chi=chi, order=order, compress_late=compress_late + ) + + if log is not None: + C = math.log(C, log) + + return C + + total_cost_compressed = combo_cost_compressed + + def max_size_compressed( + self, chi=None, order="surface_order", compress_late=None, log=None + ): + """Compute the maximum sized tensor produced when a compressed + contraction is performed with maximum bond size ``chi``, ordered by + ``order``. This is close to the ideal space complexity if only + tensors that are being directly operated on are kept in memory. + """ + S = self.compressed_contract_stats( + chi=chi, + order=order, + compress_late=compress_late, + ).max_size + + if log is not None: + S = math.log(S, log) + + return S + + def peak_size_compressed( + self, + chi=None, + order="surface_order", + compress_late=None, + accel="auto", + log=None, + ): + """Compute the peak size of combined intermediate tensors when a + compressed contraction is performed with maximum bond size ``chi``, + ordered by ``order``. This is the practical space complexity if one is + not swapping intermediates in and out of memory. + """ + P = self.compressed_contract_stats( + chi=chi, + order=order, + compress_late=compress_late, + ).peak_size + + if log is not None: + P = math.log(P, log) + + return P + + def contraction_width_compressed( + self, chi=None, order="surface_order", compress_late=None, log=2 + ): + """Compute log2 of the maximum sized tensor produced when a compressed + contraction is performed with maximum bond size ``chi``, ordered by + ``order``. + """ + return self.max_size_compressed(chi, order, compress_late, log=log) + + def _update_tracked(self, node): + if self._track_flops: + self._flops += self.get_flops(node) + if self._track_write: + self._write += self.get_size(node) + if self._track_size: + self._sizes.add(self.get_size(node)) + + def contract_nodes_pair( + self, + x, + y, + legs=None, + cost=None, + size=None, + check=False, + ): + """Contract node ``x`` with node ``y`` in the tree to create a new + parent node, which is returned. + + Parameters + ---------- + x : frozenset[int] + The first node to contract. + y : frozenset[int] + The second node to contract. + legs : dict[str, int], optional + The effective 'legs' of the new node if already known. If not + given, this is computed from the inputs of ``x`` and ``y``. + cost : int, optional + The cost of the contraction if already known. If not given, this is + computed from the inputs of ``x`` and ``y``. + size : int, optional + The size of the new node if already known. If not given, this is + computed from the inputs of ``x`` and ``y``. + check : bool, optional + Whether to check the inputs are valid. + + Returns + ------- + parent : frozenset[int] + The new parent node of ``x`` and ``y``. + """ + parent = x.union(y) + + # make sure info entries exist for all (default dict) + for node in (x, y, parent): + self._add_node(node, check=check) + + # enforce left ordering of 'heaviest' subtrees + nx, ny = len(x), len(y) + if nx == ny: + # deterministically break ties + sortx = -min(x) + sorty = -min(y) + else: + sortx = nx + sorty = ny + + if sortx > sorty: + lr = (x, y) + else: + lr = (y, x) + + self.children[parent] = lr + + if self.track_childless: + self.childless.discard(parent) + if x not in self.children and nx > 1: + self.childless.add(x) + if y not in self.children and ny > 1: + self.childless.add(y) + + # pre-computed information + if legs is not None: + self.info[parent]["legs"] = legs + if cost is not None: + self.info[parent]["flops"] = cost + if size is not None: + self.info[parent]["size"] = size + + self._update_tracked(parent) + + return parent + + def contract_nodes( + self, + nodes, + optimize="auto-hq", + check=False, + extra_opts=None, + ): + """Contract an arbitrary number of ``nodes`` in the tree to build up a + subtree. The root of this subtree (a new intermediate) is returned. + """ + if len(nodes) == 1: + return next(iter(nodes)) + + if len(nodes) == 2: + return self.contract_nodes_pair(*nodes, check=check) + + from .interface import find_path + + # create the bottom and top nodes + grandparent = union_it(nodes) + self._add_node(grandparent, check=check) + for node in nodes: + self._add_node(node, check=check) + + # if more than two nodes need to find the path to fill in between + # \ + # GN <- 'grandparent' + # / \ + # ????????? + # ????????????? <- to be filled with 'temp nodes' + # / \ / / \ + # N0 N1 N2 N3 N4 <- ``nodes``, or, subgraphs + # / \ / / \ + path_inputs = [tuple(self.get_legs(x)) for x in nodes] + path_output = tuple(self.get_legs(grandparent)) + + path = find_path( + path_inputs, + path_output, + self.size_dict, + optimize=optimize, + **(extra_opts or {}), + ) + + # now we have path create the nodes in between + temp_nodes = list(nodes) + for p in path: + to_contract = [temp_nodes.pop(i) for i in sorted(p, reverse=True)] + temp_nodes.append(self.contract_nodes(to_contract, check=check)) + + (parent,) = temp_nodes + + if check: + # final remaining temp input should be the 'grandparent' + assert parent == grandparent + + return parent + + def is_complete(self): + """Check every node has two children, unless it is a leaf.""" + too_many_nodes = len(self.info) > 2 * self.N - 1 + too_many_branches = len(self.children) > self.N - 1 + + if too_many_nodes or too_many_branches: + raise ValueError("Contraction tree seems to be over complete!") + + queue = [self.root] + while queue: + x = queue.pop() + if len(x) == 1: + continue + try: + queue.extend(self.children[x]) + except KeyError: + return False + + return True + + def get_default_order(self): + return "dfs" + + def _traverse_dfs(self): + """Traverse the tree in a depth first, non-recursive, order.""" + ready = set(self.gen_leaves()) + queue = [self.root] + + while queue: + node = queue[-1] + l, r = self.children[node] + + # both node's children are ready -> we can yield this contraction + if (l in ready) and (r in ready): + ready.add(queue.pop()) + yield node, l, r + continue + + if r not in ready: + queue.append(r) + if l not in ready: + queue.append(l) + + def _traverse_ordered(self, order): + """Traverse the tree in the order that minimizes ``order(node)``, but + still constrained to produce children before parents. + """ + from bisect import bisect + + if order == "surface_order": + order = self.surface_order + + seen = set() + queue = [self.root] + scores = [order(self.root)] + + while len(seen) != len(self.children): + i = 0 + while i < len(queue): + node = queue[i] + if node not in seen: + for child in self.children[node]: + if len(child) > 1: + # insert child into queue by score + before parent + score = order(child) + ci = bisect(scores[:i], score) + scores.insert(ci, score) + queue.insert(ci, child) + # parent moves extra place to right + i += 1 + seen.add(node) + i += 1 + + for node in queue: + yield (node, *self.children[node]) + + def traverse(self, order=None): + """Generate, in order, all the node merges in this tree. Non-recursive! + This ensures children are always visited before their parent. + + Parameters + ---------- + order : None, "dfs", or callable, optional + How to order the contractions within the tree. If a callable is + given (which should take a node as its argument), try to contract + nodes that minimize this function first. + + Returns + ------- + generator[tuple[node]] + The bottom up ordered sequence of tree merges, each a + tuple of ``(parent, left_child, right_child)``. + + See Also + -------- + descend + """ + if self.N == 1: + return + + if order is None: + order = self.get_default_order() + + if order == "dfs": + yield from self._traverse_dfs() + else: + yield from self._traverse_ordered(order=order) + + def descend(self, mode="dfs"): + """Generate, from root to leaves, all the node merges in this tree. + Non-recursive! This ensures parents are visited before their children. + + Parameters + ---------- + mode : {'dfs', bfs}, optional + How expand from a parent. + + Returns + ------- + generator[tuple[node] + The top down ordered sequence of tree merges, each a + tuple of ``(parent, left_child, right_child)``. + + See Also + -------- + traverse + """ + queue = [self.root] + while queue: + if mode == "dfs": + parent = queue.pop(-1) + elif mode == "bfs": + parent = queue.pop(0) + l, r = self.children[parent] + yield parent, l, r + if len(l) > 1: + queue.append(l) + if len(r) > 1: + queue.append(r) + + def get_subtree(self, node, size, search="bfs", seed=None): + """Get a subtree spanning down from ``node`` which will have ``size`` + leaves (themselves not necessarily leaves of the actual tree). + + Parameters + ---------- + node : node + The node of the tree to start with. + size : int + How many subtree leaves to aim for. + search : {'bfs', 'dfs', 'random'}, optional + How to build the tree: + + - 'bfs': breadth first expansion + - 'dfs': depth first expansion (largest nodes first) + - 'random': random expansion + + seed : None, int or random.Random, optional + Random number generator seed, if ``search`` is 'random'. + + Returns + ------- + sub_leaves : tuple[node] + Nodes which are subtree leaves. + branches : tuple[node] + Nodes which are between the subtree leaves and root. + """ + # nodes which are subtree leaves + branches = [] + + # actual tree leaves - can't expand + real_leaves = [] + + # nodes to expand + queue = [node] + + if search == "random": + rng = get_rng(seed) + else: + rng = None + if search == "bfs": + i = 0 + elif search == "dfs": + i = -1 + + while (len(queue) + len(real_leaves) < size) and queue: + if rng is not None: + i = rng.randint(0, len(queue) - 1) + + p = queue.pop(i) + if len(p) == 1: + real_leaves.append(p) + continue + + # the left child is always >= in weight that right child + # if we append it last then ``.pop(-1)`` above perform the + # depth first search sorting by node subgraph size + l, r = self.children[p] + + queue.append(r) + queue.append(l) + branches.append(p) + + # nodes at the bottom of the subtree + sub_leaves = queue + real_leaves + + return tuple(sub_leaves), tuple(branches) + + def remove_ind(self, ind, project=None, inplace=False): + """Remove (i.e. by default slice) index ``ind`` from this contraction + tree, taking care to update all relevant information about each node. + """ + tree = self if inplace else self.copy() + + if ind in tree.sliced_inds: + raise ValueError(f"Index {ind} already sliced.") + + # make sure all flops and size information has been populated + tree.contract_stats() + + d = tree.size_dict[ind] + if project is None: + # we are slicing the index + si = SliceInfo(ind not in tree.output, ind, d, None) + tree.multiplicity = tree.multiplicity * d + else: + si = SliceInfo(ind not in tree.output, ind, 1, project) + + # update the ordered slice information dictionary, but maintain the + # order such that output sliced indices always appear first -> + # enforced by the dataclass SliceInfo ordering + tree.sliced_inds = { + si.ind: si for si in sorted((*tree.sliced_inds.values(), si)) + } + + for node, node_info in tree.info.items(): + if len(node) == 1: + # handle leaves separately + i = node_get_single_el(node) + term = tree.inputs[i] + if ind in term: + # n.b. leaves don't contribute to size, flops or write + # simply recalculate all information, incl. preprocessing + tree._remove_node(node) + tree.sliced_inputs = tree.sliced_inputs | frozenset([i]) + else: + involved = tree.get_involved(node) + if ind not in involved: + # if ind doesn't feature in this node (contraction) + # -> nothing to do + continue + + # else update all the relevant information about this node + # -> flops changes for all involved indices + node_info["involved"] = legs_without(involved, ind) + old_flops = tree.get_flops(node) + new_flops = old_flops // d + node_info["flops"] = new_flops + tree._flops += new_flops - old_flops + + # -> size and write only changes for node legs (output) indices + legs = tree.get_legs(node) + if ind in legs: + node_info["legs"] = legs_without(legs, ind) + old_size = tree.get_size(node) + tree._sizes.discard(old_size) + new_size = old_size // d + tree._sizes.add(new_size) + node_info["size"] = new_size + tree._write += new_size - old_size + + # delete info we can't change + for k in ( + "inds", + "einsum_eq", + "can_dot", + "tensordot_axes", + "tensordot_perm", + ): + tree.info[node].pop(k, None) + + tree.already_optimized.clear() + tree.contraction_cores.clear() + + return tree + + remove_ind_ = functools.partialmethod(remove_ind, inplace=True) + + def restore_ind(self, ind, inplace=False): + """Restore (unslice or un-project) index ``ind`` to this contraction + tree, taking care to update all relevant information about each node. + + Parameters + ---------- + ind : str + The index to restore. + inplace : bool, optional + Whether to perform the restoration inplace or not. + + Returns + ------- + ContractionTree + """ + tree = self if inplace else self.copy() + + # pop sliced index info + si = tree.sliced_inds.pop(ind) + + # make sure all flops and size information has been populated + tree.contract_stats() + tree.multiplicity //= si.size + + # handle inputs + for i, term in enumerate(tree.inputs): + # this is the original term with all indices + if ind in term: + tree._remove_node(node_from_single(i)) + if all(ix not in tree.sliced_inds for ix in term): + # mark this input as not sliced + tree.sliced_inputs = tree.sliced_inputs - frozenset([i]) + + # delete and re-add dependent intermediates + for p, l, r in tree.traverse(): + if ind in tree.get_legs(l) or ind in tree.get_legs(r): + tree._remove_node(p) + tree.contract_nodes_pair(l, r) + + # reset caches + tree.already_optimized.clear() + tree.contraction_cores.clear() + + return tree + + restore_ind_ = functools.partialmethod(restore_ind, inplace=True) + + def unslice_rand(self, seed=None, inplace=False): + """Unslice (restore) a random index from this contraction tree. + + Parameters + ---------- + seed : None, int or random.Random, optional + Random number generator seed. + inplace : bool, optional + Whether to perform the unslicing inplace or not. + + Returns + ------- + ContractionTree + """ + rng = get_rng(seed) + ix = rng.choice(tuple(self.sliced_inds)) + return self.restore_ind(ix, inplace=inplace) + + unslice_rand_ = functools.partialmethod(unslice_rand, inplace=True) + + def unslice_all(self, inplace=False): + """Unslice (restore) all sliced indices from this contraction tree. + + Parameters + ---------- + inplace : bool, optional + Whether to perform the unslicing inplace or not. + + Returns + ------- + ContractionTree + """ + tree = self if inplace else self.copy() + + for ind in tuple(tree.sliced_inds): + tree.restore_ind_(ind) + + return tree + + unslice_all_ = functools.partialmethod(unslice_all, inplace=True) + + def calc_subtree_candidates(self, pwr=2, what="flops"): + candidates = list(self.children) + + if what == "size": + weights = [self.get_size(x) for x in candidates] + + elif what == "flops": + weights = [self.get_flops(x) for x in candidates] + + max_weight = max(weights) + + # can be bigger than numpy int/float allows + weights = [float(w / max_weight) ** (1 / pwr) for w in weights] + + # sort by descending score + candidates, weights = zip( + *sorted(zip(candidates, weights), key=lambda x: -x[1]) + ) + + return list(candidates), list(weights) + + def subtree_reconfigure( + self, + subtree_size=8, + subtree_search="bfs", + weight_what="flops", + weight_pwr=2, + select="max", + maxiter=500, + seed=None, + minimize=None, + optimize=None, + inplace=False, + progbar=False, + ): + """Reconfigure subtrees of this tree with locally optimal paths. + + Parameters + ---------- + subtree_size : int, optional + The size of subtree to consider. Cost is exponential in this. + subtree_search : {'bfs', 'dfs', 'random'}, optional + How to build the subtrees: + + - 'bfs': breadth-first-search creating balanced subtrees + - 'dfs': depth-first-search creating imbalanced subtrees + - 'random': random subtree building + + weight_what : {'flops', 'size'}, optional + When assessing nodes to build and optimize subtrees from whether to + score them by the (local) contraction cost, or tensor size. + weight_pwr : int, optional + When assessing nodes to build and optimize subtrees from, how to + scale their score into a probability: ``score**(1 / weight_pwr)``. + The larger this is the more explorative the algorithm is when + ``select='random'``. + select : {'max', 'min', 'random'}, optional + What order to select node subtrees to optimize: + + - 'max': choose the highest score first + - 'min': choose the lowest score first + - 'random': choose randomly weighted on score -- see + ``weight_pwr``. + + maxiter : int, optional + How many subtree optimizations to perform, the algorithm can + terminate before this if all subtrees have been optimized. + seed : int, optional + A random seed (seeds python system random module). + minimize : {'flops', 'size'}, optional + Whether to minimize with respect to contraction flops or size. + inplace : bool, optional + Whether to perform the reconfiguration inplace or not. + progbar : bool, optional + Whether to show live progress of the reconfiguration. + + Returns + ------- + ContractionTree + """ + tree = self if inplace else self.copy() + + # ensure these have been computed and thus are being tracked + tree.contract_stats() + + if minimize is None: + minimize = self.get_default_objective() + scorer = get_score_fn(minimize) + + if optimize is None: + from .pathfinders.path_basic import OptimalOptimizer + + opt = OptimalOptimizer( + minimize=scorer.get_dynamic_programming_minimize() + ) + else: + opt = optimize + + node_cost = getattr(scorer, "cost_local_tree_node", lambda _: 2) + + # different caches as we might want to reconfigure one before other + tree.already_optimized.setdefault(minimize, set()) + already_optimized = tree.already_optimized[minimize] + + if select == "random": + rng = get_rng(seed) + else: + if select == "max": + i = 0 + elif select == "min": + i = -1 + rng = None + + candidates, weights = tree.calc_subtree_candidates( + pwr=weight_pwr, what=weight_what + ) + + if progbar: + import tqdm + + pbar = tqdm.tqdm() + pbar.set_description(_describe_tree(tree), refresh=False) + + r = 0 + try: + while candidates and r < maxiter: + if rng is not None: + (i,) = rng.choices(range(len(candidates)), weights=weights) + + weights.pop(i) + sub_root = candidates.pop(i) + + # get a subtree to possibly reconfigure + sub_leaves, sub_branches = tree.get_subtree( + sub_root, size=subtree_size, search=subtree_search + ) + + sub_leaves = frozenset(sub_leaves) + + # check if its already been optimized + if sub_leaves in already_optimized: + continue + + # else remove the branches, keeping track of current cost + current_cost = node_cost(tree, sub_root) + for node in sub_branches: + if minimize == "size": + current_cost = max(current_cost, node_cost(tree, node)) + else: + current_cost += node_cost(tree, node) + tree._remove_node(node) + + # make the optimizer more efficient by supplying accurate cap + opt.cost_cap = max(2, current_cost) + + # and reoptimize the leaves + tree.contract_nodes(sub_leaves, optimize=opt) + already_optimized.add(sub_leaves) + + r += 1 + + if progbar: + pbar.update() + pbar.set_description(_describe_tree(tree), refresh=False) + + # if we have reconfigured simply re-add all candidates + candidates, weights = tree.calc_subtree_candidates( + pwr=weight_pwr, what=weight_what + ) + finally: + if progbar: + pbar.close() + + # invalidate any compiled contractions + tree.contraction_cores.clear() + + return tree + + subtree_reconfigure_ = functools.partialmethod( + subtree_reconfigure, inplace=True + ) + + def subtree_reconfigure_forest( + self, + num_trees=8, + num_restarts=10, + restart_fraction=0.5, + subtree_maxiter=100, + subtree_size=10, + subtree_search=("random", "bfs"), + subtree_select=("random",), + subtree_weight_what=("flops", "size"), + subtree_weight_pwr=(2,), + parallel="auto", + parallel_maxiter_steps=4, + minimize=None, + seed=None, + progbar=False, + inplace=False, + ): + """'Forested' version of ``subtree_reconfigure`` which is more + explorative and can be parallelized. It stochastically generates + a 'forest' reconfigured trees, then only keeps some fraction of these + to generate the next forest. + + Parameters + ---------- + num_trees : int, optional + The number of trees to reconfigure at each stage. + num_restarts : int, optional + The number of times to halt, prune and then restart the + tree reconfigurations. + restart_fraction : float, optional + The fraction of trees to keep at each stage and generate the next + forest from. + subtree_maxiter : int, optional + Number of subtree reconfigurations per step. + ``num_restarts * subtree_maxiter`` is the max number of total + subtree reconfigurations for the final tree produced. + subtree_size : int, optional + The size of subtrees to search for and reconfigure. + subtree_search : tuple[{'random', 'bfs', 'dfs'}], optional + Tuple of options for the ``search`` kwarg of + :meth:`ContractionTree.subtree_reconfigure` to randomly sample. + subtree_select : tuple[{'random', 'max', 'min'}], optional + Tuple of options for the ``select`` kwarg of + :meth:`ContractionTree.subtree_reconfigure` to randomly sample. + subtree_weight_what : tuple[{'flops', 'size'}], optional + Tuple of options for the ``weight_what`` kwarg of + :meth:`ContractionTree.subtree_reconfigure` to randomly sample. + subtree_weight_pwr : tuple[int], optional + Tuple of options for the ``weight_pwr`` kwarg of + :meth:`ContractionTree.subtree_reconfigure` to randomly sample. + parallel : 'auto', False, True, int, or distributed.Client + Whether to parallelize the search. + parallel_maxiter_steps : int, optional + If parallelizing, how many steps to break each reconfiguration into + in order to evenly saturate many processes. + minimize : {'flops', 'size', ..., Objective}, optional + Whether to minimize the total flops or maximum size of the + contraction tree. + seed : None, int or random.Random, optional + A random seed to use. + progbar : bool, optional + Whether to show live progress. + inplace : bool, optional + Whether to perform the subtree reconfiguration inplace. + + Returns + ------- + ContractionTree + """ + tree = self if inplace else self.copy() + + # some of these might be unpicklable + tree.contraction_cores.clear() + + # candidate trees + num_keep = max(1, int(num_trees * restart_fraction)) + + # how to rank the trees + if minimize is None: + minimize = self.get_default_objective() + score = get_score_fn(minimize) + + rng = get_rng(seed) + + # set up the initial 'forest' and parallel machinery + pool = parse_parallel_arg(parallel) + is_scatter_pool = can_scatter(pool) + if is_scatter_pool: + is_worker = maybe_leave_pool(pool) + # store the trees as futures for the entire process + forest = [scatter(pool, tree)] + maxiter = subtree_maxiter // parallel_maxiter_steps + else: + forest = [tree] + maxiter = subtree_maxiter + + if progbar: + import tqdm + + pbar = tqdm.tqdm(total=num_restarts) + pbar.set_description(_describe_tree(tree), refresh=False) + + try: + for _ in range(num_restarts): + # on the next round take only the best trees + forest = itertools.cycle(forest[:num_keep]) + + # select some random configurations + saplings = [ + { + "tree": next(forest), + "maxiter": maxiter, + "minimize": minimize, + "subtree_size": subtree_size, + "subtree_search": rng.choice(subtree_search), + "select": rng.choice(subtree_select), + "weight_pwr": rng.choice(subtree_weight_pwr), + "weight_what": rng.choice(subtree_weight_what), + } + for _ in range(num_trees) + ] + + if pool is None: + forest = [_reconfigure_tree(**s) for s in saplings] + res = [{"tree": t, **_get_tree_info(t)} for t in forest] + elif not is_scatter_pool: + forest_futures = [ + submit(pool, _reconfigure_tree, **s) for s in saplings + ] + forest = [f.result() for f in forest_futures] + res = [{"tree": t, **_get_tree_info(t)} for t in forest] + else: + # submit in smaller steps to saturate processes + for _ in range(parallel_maxiter_steps): + for s in saplings: + s["tree"] = submit(pool, _reconfigure_tree, **s) + + # compute scores remotely then gather + forest_futures = [s["tree"] for s in saplings] + res_futures = [ + submit(pool, _get_tree_info, t) for t in forest_futures + ] + res = [ + {"tree": tree_future, **res_future.result()} + for tree_future, res_future in zip( + forest_futures, res_futures + ) + ] + + # update the order of the new forest + res.sort(key=score) + forest = [r["tree"] for r in res] + + if progbar: + pbar.update() + if pool is None: + d = _describe_tree(forest[0]) + else: + d = submit(pool, _describe_tree, forest[0]).result() + pbar.set_description(d, refresh=False) + + finally: + if progbar: + pbar.close() + + if is_scatter_pool: + tree.set_state_from(forest[0].result()) + maybe_rejoin_pool(is_worker, pool) + else: + tree.set_state_from(forest[0]) + + return tree + + subtree_reconfigure_forest_ = functools.partialmethod( + subtree_reconfigure_forest, inplace=True + ) + + simulated_anneal = simulated_anneal_tree + simulated_anneal_ = functools.partialmethod(simulated_anneal, inplace=True) + parallel_temper = parallel_temper_tree + parallel_temper_ = functools.partialmethod(parallel_temper, inplace=True) + + def slice( + self, + target_size=None, + target_overhead=None, + target_slices=None, + temperature=0.01, + minimize=None, + allow_outer=True, + max_repeats=16, + reslice=False, + seed=None, + inplace=False, + ): + """Slice this tree (turn some indices into indices which are explicitly + summed over rather than being part of contractions). The indices are + stored in ``tree.sliced_inds``, and the contraction width updated to + take account of the slicing. Calling ``tree.contract(arrays)`` moreover + which automatically perform the slicing and summation. + + Parameters + ---------- + target_size : int, optional + The target number of entries in the largest tensor of the sliced + contraction. The search algorithm will terminate after this is + reached. + target_slices : int, optional + The target or minimum number of 'slices' to consider - individual + contractions after slicing indices. The search algorithm will + terminate after this is breached. This is on top of the current + number of slices. + target_overhead : float, optional + The target increase in total number of floating point operations. + For example, a value of ``2.0`` will terminate the search just + before the cost of computing all the slices individually breaches + twice that of computing the original contraction all at once. + temperature : float, optional + How much to randomize the repeated search. + minimize : {'flops', 'size', ..., Objective}, optional + Which metric to score the overhead increase against. + allow_outer : bool, optional + Whether to allow slicing of outer indices. + max_repeats : int, optional + How many times to repeat the search with a slight randomization. + reslice : bool, optional + Whether to reslice the tree, i.e. first remove all currently + sliced indices and start the search again. Generally any 'good' + sliced indices will be easily found again. + seed : None, int or random.Random, optional + A random seed or generator to use for the search. + inplace : bool, optional + Whether the remove the indices from this tree inplace or not. + + Returns + ------- + ContractionTree + + See Also + -------- + SliceFinder, ContractionTree.slice_and_reconfigure + """ + from .slicer import SliceFinder + + if minimize is None: + minimize = self.get_default_objective() + + tree = self if inplace else self.copy() + + if reslice: + if target_slices is not None: + target_slices *= tree.nslices + tree.unslice_all_() + + sf = SliceFinder( + tree, + target_size=target_size, + target_overhead=target_overhead, + target_slices=target_slices, + temperature=temperature, + minimize=minimize, + allow_outer=allow_outer, + seed=seed, + ) + + ix_sl, _ = sf.search(max_repeats) + for ix in ix_sl: + tree.remove_ind_(ix) + + return tree + + slice_ = functools.partialmethod(slice, inplace=True) + + def slice_and_reconfigure( + self, + target_size, + step_size=2, + temperature=0.01, + minimize=None, + allow_outer=True, + max_repeats=16, + reslice=False, + reconf_opts=None, + progbar=False, + inplace=False, + ): + """Interleave slicing (removing indices into an exterior sum) with + subtree reconfiguration to minimize the overhead induced by this + slicing. + + Parameters + ---------- + target_size : int + Slice the tree until the maximum intermediate size is this or + smaller. + step_size : int, optional + The minimum size reduction to try and achieve before switching to a + round of subtree reconfiguration. + temperature : float, optional + The temperature to supply to ``SliceFinder`` for searching for + indices. + minimize : {'flops', 'size', ..., Objective}, optional + The metric to minimize when slicing and reconfiguring subtrees. + max_repeats : int, optional + The number of slicing attempts to perform per search. + progbar : bool, optional + Whether to show live progress. + inplace : bool, optional + Whether to perform the slicing and reconfiguration inplace. + reconf_opts : None or dict, optional + Supplied to + :meth:`ContractionTree.subtree_reconfigure` or + :meth:`ContractionTree.subtree_reconfigure_forest`, depending on + `'forested'` key value. + """ + tree = self if inplace else self.copy() + + reconf_opts = {} if reconf_opts is None else dict(reconf_opts) + + if minimize is None: + minimize = self.get_default_objective() + minimize = get_score_fn(minimize) + + reconf_opts.setdefault("minimize", minimize) + forested_reconf = reconf_opts.pop("forested", False) + + if progbar: + import tqdm + + pbar = tqdm.tqdm() + pbar.set_description(_describe_tree(tree), refresh=False) + + try: + while tree.max_size() > target_size: + tree.slice_( + temperature=temperature, + target_slices=step_size, + minimize=minimize, + allow_outer=allow_outer, + max_repeats=max_repeats, + reslice=reslice, + ) + if forested_reconf: + tree.subtree_reconfigure_forest_(**reconf_opts) + else: + tree.subtree_reconfigure_(**reconf_opts) + + if progbar: + pbar.update() + pbar.set_description(_describe_tree(tree), refresh=False) + finally: + if progbar: + pbar.close() + + return tree + + slice_and_reconfigure_ = functools.partialmethod( + slice_and_reconfigure, inplace=True + ) + + def slice_and_reconfigure_forest( + self, + target_size, + step_size=2, + num_trees=8, + restart_fraction=0.5, + temperature=0.02, + max_repeats=32, + reslice=False, + minimize=None, + allow_outer=True, + parallel="auto", + progbar=False, + inplace=False, + reconf_opts=None, + ): + """'Forested' version of :meth:`ContractionTree.slice_and_reconfigure`. + This maintains a 'forest' of trees with different slicing and subtree + reconfiguration attempts, pruning the worst at each step and generating + a new forest from the best. + + Parameters + ---------- + target_size : int + Slice the tree until the maximum intermediate size is this or + smaller. + step_size : int, optional + The minimum size reduction to try and achieve before switching to a + round of subtree reconfiguration. + num_restarts : int, optional + The number of times to halt, prune and then restart the + tree reconfigurations. + restart_fraction : float, optional + The fraction of trees to keep at each stage and generate the next + forest from. + temperature : float, optional + The temperature at which to randomize the sliced index search. + max_repeats : int, optional + The number of slicing attempts to perform per search. + parallel : 'auto', False, True, int, or distributed.Client + Whether to parallelize the search. + progbar : bool, optional + Whether to show live progress. + inplace : bool, optional + Whether to perform the slicing and reconfiguration inplace. + reconf_opts : None or dict, optional + Supplied to + :meth:`ContractionTree.slice_and_reconfigure`. + + Returns + ------- + ContractionTree + """ + tree = self if inplace else self.copy() + + # some of these might be unpicklable + tree.contraction_cores.clear() + + # candidate trees + num_keep = max(1, int(num_trees * restart_fraction)) + + # how to rank the trees + if minimize is None: + minimize = self.get_default_objective() + score = get_score_fn(minimize) + + # set up the initial 'forest' and parallel machinery + pool = parse_parallel_arg(parallel) + is_scatter_pool = can_scatter(pool) + if is_scatter_pool: + is_worker = maybe_leave_pool(pool) + # store the trees as futures for the entire process + forest = [scatter(pool, tree)] + else: + forest = [tree] + + if progbar: + import tqdm + + pbar = tqdm.tqdm() + pbar.set_description(_describe_tree(tree), refresh=False) + + next_size = tree.max_size() + + try: + while True: + next_size //= step_size + + # on the next round take only the best trees + forest = itertools.cycle(forest[:num_keep]) + + saplings = [ + { + "tree": next(forest), + "target_size": next_size, + "step_size": step_size, + "temperature": temperature, + "max_repeats": max_repeats, + "reconf_opts": reconf_opts, + "allow_outer": allow_outer, + "reslice": reslice, + } + for _ in range(num_trees) + ] + + if pool is None: + forest = [ + _slice_and_reconfigure_tree(**s) for s in saplings + ] + res = [{"tree": t, **_get_tree_info(t)} for t in forest] + + elif not is_scatter_pool: + # simple pool with no pass by reference + forest_futures = [ + submit(pool, _slice_and_reconfigure_tree, **s) + for s in saplings + ] + forest = [f.result() for f in forest_futures] + res = [{"tree": t, **_get_tree_info(t)} for t in forest] + + else: + forest_futures = [ + submit(pool, _slice_and_reconfigure_tree, **s) + for s in saplings + ] + + # compute scores remotely then gather + res_futures = [ + submit(pool, _get_tree_info, t) for t in forest_futures + ] + res = [ + {"tree": tree_future, **res_future.result()} + for tree_future, res_future in zip( + forest_futures, res_futures + ) + ] + + # we want to sort by flops, but also favour sampling as + # many different sliced index combos as possible + # ~ [1, 1, 1, 2, 2, 3] -> [1, 2, 3, 1, 2, 1] + res.sort(key=score) + res = list( + interleave( + groupby(lambda r: r["sliced_ind_set"], res).values() + ) + ) + + # update the order of the new forest + forest = [r["tree"] for r in res] + + if progbar: + pbar.update() + if pool is None: + d = _describe_tree(forest[0]) + else: + d = submit(pool, _describe_tree, forest[0]).result() + pbar.set_description(d, refresh=False) + + if res[0]["size"] <= target_size: + break + + finally: + if progbar: + pbar.close() + + if is_scatter_pool: + tree.set_state_from(forest[0].result()) + maybe_rejoin_pool(is_worker, pool) + else: + tree.set_state_from(forest[0]) + + return tree + + slice_and_reconfigure_forest_ = functools.partialmethod( + slice_and_reconfigure_forest, inplace=True + ) + + def compressed_reconfigure( + self, + minimize=None, + order_only=False, + max_nodes="auto", + max_time=None, + local_score=None, + exploration_power=0, + best_score=None, + progbar=False, + inplace=False, + ): + """Reconfigure this tree according to ``peak_size_compressed``. + + Parameters + ---------- + chi : int + The maximum bond dimension to consider. + order_only : bool, optional + Whether to only consider the ordering of the current tree + contractions, or all possible contractions, starting with the + current. + max_nodes : int, optional + Set the maximum number of contraction steps to consider. + max_time : float, optional + Set the maximum time to spend on the search. + local_score : callable, optional + A function that assigns a score to a potential contraction, with a + lower score giving more priority to explore that contraction + earlier. It should have signature:: + + local_score(step, new_score, dsize, new_size) + + where ``step`` is the number of steps so far, ``new_score`` is the + score of the contraction so far, ``dsize`` is the change in memory + by the current step, and ``new_size`` is the new memory size after + contraction. + exploration_power : float, optional + If not ``0.0``, the inverse power to which the step is raised in + the default local score function. Higher values favor exploring + more promising branches early on - at the cost of increased memory. + Ignored if ``local_score`` is supplied. + best_score : float, optional + Manually specify an upper bound for best score found so far. + progbar : bool, optional + If ``True``, display a progress bar. + inplace : bool, optional + Whether to perform the reconfiguration inplace on this tree. + + Returns + ------- + ContractionTree + """ + from .experimental.path_compressed_branchbound import ( + CompressedExhaustive, + ) + + if minimize is None: + minimize = self.get_default_objective() + + if max_nodes == "auto": + if max_time is None: + max_nodes = max(10_000, self.N**2) + else: + max_nodes = float("inf") + + opt = CompressedExhaustive( + minimize=minimize, + local_score=local_score, + max_nodes=max_nodes, + max_time=max_time, + exploration_power=exploration_power, + best_score=best_score, + progbar=progbar, + ) + opt.setup(self.inputs, self.output, self.size_dict) + opt.explore_path(self.get_path_surface(), restrict=order_only) + + # rtree = opt.search(self.inputs, self.output, self.size_dict) + + opt.run(self.inputs, self.output, self.size_dict) + ssa_path = opt.ssa_path + # ssa_path = opt(self.inputs, self.output, self.size_dict) + rtree = self.__class__.from_path( + self.inputs, + self.output, + self.size_dict, + ssa_path=ssa_path, + objective=minimize, + ) + if inplace: + self.set_state_from(rtree) + rtree = self + + rtree.contraction_cores.clear() + return rtree + + compressed_reconfigure_ = functools.partialmethod( + compressed_reconfigure, inplace=True + ) + + def windowed_reconfigure( + self, + minimize=None, + order_only=False, + window_size=20, + max_iterations=100, + max_window_tries=1000, + score_temperature=0.0, + queue_temperature=1.0, + scorer=None, + queue_scorer=None, + seed=None, + inplace=False, + progbar=False, + **kwargs, + ): + from .pathfinders.path_compressed import WindowedOptimizer + + if minimize is None: + minimize = self.get_default_objective() + + wo = WindowedOptimizer( + self.inputs, + self.output, + self.size_dict, + minimize=minimize, + ssa_path=self.get_ssa_path(), + seed=seed, + ) + + wo.refine( + window_size=window_size, + max_iterations=max_iterations, + order_only=order_only, + max_window_tries=max_window_tries, + score_temperature=score_temperature, + queue_temperature=queue_temperature, + scorer=scorer, + queue_scorer=queue_scorer, + progbar=progbar, + **kwargs, + ) + ssa_path = wo.get_ssa_path() + + rtree = self.__class__.from_path( + self.inputs, + self.output, + self.size_dict, + ssa_path=ssa_path, + objective=minimize, + ) + + if inplace: + self.set_state_from(rtree) + rtree = self + + rtree.contraction_cores.clear() + return rtree + + windowed_reconfigure_ = functools.partialmethod( + windowed_reconfigure, inplace=True + ) + + def flat_tree(self, order=None): + """Create a nested tuple representation of the contraction tree like:: + + ((0, (1, 2)), ((3, 4), ((5, (6, 7)), (8, 9)))) + + Such that the contraction will progress like:: + + ((0, (1, 2)), ((3, 4), ((5, (6, 7)), (8, 9)))) + ((0, 12), (34, ((5, 67), 89))) + (012, (34, (567, 89))) + (012, (34, 56789)) + (012, 3456789) + 0123456789 + + Where each integer represents a leaf (i.e. single element node). + """ + tups = dict(zip(self.gen_leaves(), range(self.N))) + + for parent, l, r in self.traverse(order=order): + tups[parent] = tups[l], tups[r] + + return tups[self.root] + + def get_leaves_ordered(self): + """Return the list of leaves as ordered by the contraction tree. + + Returns + ------- + tuple[frozenset[str]] + """ + if not self.is_complete(): + raise ValueError("Can't order the leaves until tree is complete.") + + return tuple( + nd + for nd in itertools.chain.from_iterable(self.traverse()) + if len(nd) == 1 + ) + + def get_path(self, order=None): + """Generate a standard path (with linear recycled ids) from the + contraction tree. + + Parameters + ---------- + order : None, "dfs", or callable, optional + How to order the contractions within the tree. If a callable is + given (which should take a node as its argument), try to contract + nodes that minimize this function first. + + Returns + ------- + path: tuple[tuple[int, int]] + """ + from bisect import bisect_left + + ssa = self.N + ssas = list(range(ssa)) + node_to_ssa = dict(zip(self.gen_leaves(), ssas)) + path = [] + + for parent, left, right in self.traverse(order=order): + # map nodes to ssas + lssa = node_to_ssa[left] + rssa = node_to_ssa[right] + # map ssas to linear indices, using bisection + i, j = sorted((bisect_left(ssas, lssa), bisect_left(ssas, rssa))) + # 'contract' nodes + ssas.pop(j) + ssas.pop(i) + path.append((i, j)) + ssas.append(ssa) + # update mapping + node_to_ssa[parent] = ssa + ssa += 1 + + return tuple(path) + + path = deprecated(get_path, "path", "get_path") + + def get_numpy_path(self, order=None): + """Generate a path compatible with the `optimize` kwarg of + `numpy.einsum`. + """ + return ["einsum_path", *self.get_path(order=order)] + + def get_ssa_path(self, order=None): + """Generate a single static assignment path from the contraction tree. + + Parameters + ---------- + order : None, "dfs", or callable, optional + How to order the contractions within the tree. If a callable is + given (which should take a node as its argument), try to contract + nodes that minimize this function first. + + Returns + ------- + ssa_path: tuple[tuple[int, int]] + """ + ssa_path = [] + pos = dict(zip(self.gen_leaves(), range(self.N))) + + for parent, l, r in self.traverse(order=order): + i, j = sorted((pos[l], pos[r])) + ssa_path.append((i, j)) + pos[parent] = len(ssa_path) + self.N - 1 + + return tuple(ssa_path) + + ssa_path = deprecated(get_ssa_path, "ssa_path", "get_ssa_path") + + def surface_order(self, node): + return (len(node), self.get_centrality(node)) + + def set_surface_order_from_path(self, ssa_path): + o = {} + nodes = list(self.gen_leaves()) + for j, p in enumerate(ssa_path): + l, r = (nodes[i] for i in p) + p = l.union(r) + nodes.append(p) + o[p] = j + + self.surface_order = functools.partial( + get_with_default, obj=o, default=float("inf") + ) + + def get_path_surface(self): + return self.get_path(order=self.surface_order) + + path_surface = deprecated( + get_path_surface, "path_surface", "get_path_surface" + ) + + def get_ssa_path_surface(self): + return self.get_ssa_path(order=self.surface_order) + + ssa_path_surface = deprecated( + get_ssa_path_surface, "ssa_path_surface", "get_ssa_path_surface" + ) + + def get_spans(self): + """Get all (which could mean none) potential embeddings of this + contraction tree into a spanning tree of the original graph. + + Returns + ------- + tuple[dict[frozenset[int], frozenset[int]]] + """ + ind_to_term = collections.defaultdict(set) + for i, term in enumerate(self.inputs): + for ix in term: + ind_to_term[ix].add(i) + + def boundary_pairs(node): + """Get nodes along the boundary of the bipartition represented by + ``node``. + """ + pairs = set() + involved = self.get_involved(node) + legs = self.get_legs(node) + removed = [ix for ix in involved if ix not in legs] + for ix in removed: + # for every index across the contraction + l1, l2 = ind_to_term[ix] + + # can either span from left to right or right to left + pairs.add((l1, l2)) + pairs.add((l2, l1)) + + return pairs + + # first span choice is any nodes across the top level bipart + candidates = [ + { + # which intermedate nodes map to which leaf nodes + "map": {self.root: node_from_single(l2)}, + # the leaf nodes in the spanning tree + "spine": {l1, l2}, + } + for l1, l2 in boundary_pairs(self.root) + ] + + for _, l, r in self.descend(): + for child in (r, l): + # for each current candidate check all the possible extensions + for _ in range(len(candidates)): + cand = candidates.pop(0) + + # don't need to do anything for + if len(child) == 1: + candidates.append( + { + "map": {child: child, **cand["map"]}, + "spine": cand["spine"].copy(), + } + ) + + for l1, l2 in boundary_pairs(child): + if (l1 in cand["spine"]) or (l2 not in cand["spine"]): + # pair does not merge inwards into spine + continue + + # valid extension of spanning tree + candidates.append( + { + "map": { + child: node_from_single(l2), + **cand["map"], + }, + "spine": cand["spine"] | {l1, l2}, + } + ) + + return tuple(c["map"] for c in candidates) + + def compute_centralities(self, combine="mean"): + """Compute a centrality for every node in this contraction tree.""" + hg = self.get_hypergraph(accel="auto") + cents = hg.simple_centrality() + + for i, leaf in enumerate(self.gen_leaves()): + self.info[leaf]["centrality"] = cents[i] + + combine = { + "mean": lambda x, y: (x + y) / 2, + "sum": lambda x, y: (x + y), + "max": max, + "min": min, + }.get(combine, combine) + + for p, l, r in self.traverse("dfs"): + self.info[p]["centrality"] = combine( + self.info[l]["centrality"], self.info[r]["centrality"] + ) + + def get_hypergraph(self, accel=False): + """Get a hypergraph representing the uncontracted network (i.e. the + leaves). + """ + return get_hypergraph(self.inputs, self.output, self.size_dict, accel) + + def reset_contraction_indices(self): + """Reset all information regarding the explicit contraction indices + ordering. + """ + # delete all derived information + for node in self.children: + for k in ( + "inds", + "einsum_eq", + "can_dot", + "tensordot_axes", + "tensordot_perm", + ): + self.info[node].pop(k, None) + + # invalidate any compiled contractions + self.contraction_cores.clear() + + def sort_contraction_indices( + self, + priority="flops", + make_output_contig=True, + make_contracted_contig=True, + reset=True, + ): + """Set explicit orders for the contraction indices of this self to + optimize for one of two things: contiguity in contracted ('k') indices, + or contiguity of left and right output ('m' and 'n') indices. + + Parameters + ---------- + priority : {'flops', 'size', 'root', 'leaves'}, optional + Which order to process the intermediate nodes in. Later nodes + re-sort previous nodes so are more likely to keep their ordering. + E.g. for 'flops' the mostly costly contracton will be process last + and thus will be guaranteed to have its indices exactly sorted. + make_output_contig : bool, optional + When processing a pairwise contraction, sort the parent contraction + indices so that the order of indices is the order they appear + from left to right in the two child (input) tensors. + make_contracted_contig : bool, optional + When processing a pairwise contraction, sort the child (input) + tensor indices so that all contracted indices appear contiguously. + reset : bool, optional + Reset all indices to the default order before sorting. + """ + if reset: + self.reset_contraction_indices() + + if priority == "flops": + nodes = sorted( + self.children.items(), key=lambda x: self.get_flops(x[0]) + ) + elif priority == "size": + nodes = sorted( + self.children.items(), key=lambda x: self.get_size(x[0]) + ) + elif priority == "root": + nodes = ((p, (l, r)) for p, l, r in self.traverse()) + elif priority == "leaves": + nodes = ((p, (l, r)) for p, l, r in self.descend()) + else: + raise ValueError(priority) + + for p, (l, r) in nodes: + p_inds, l_inds, r_inds = map(self.get_inds, (p, l, r)) + + if make_output_contig and len(p) != self.N: + # sort indices by whether they appear in the left or right + # whether this happens before or after the sort below depends + # on the order we are processing the nodes + # (avoid root as don't want to modify output) + + def psort(ix): + # group by whether in left or right input + return (r_inds.find(ix), l_inds.find(ix)) + + p_inds = "".join(sorted(p_inds, key=psort)) + self.info[p]["inds"] = p_inds + + if make_contracted_contig: + # sort indices by: + # 1. if they are going to be contracted + # 2. what order they appear in the parent indices + # (but ignore leaf indices) + if len(l) != 1: + + def lsort(ix): + return (r_inds.find(ix), p_inds.find(ix)) + + l_inds = "".join(sorted(self.get_legs(l), key=lsort)) + self.info[l]["inds"] = l_inds + + if len(r) != 1: + + def rsort(ix): + return (p_inds.find(ix), l_inds.find(ix)) + + r_inds = "".join(sorted(self.get_legs(r), key=rsort)) + self.info[r]["inds"] = r_inds + + # invalidate any compiled contractions + self.contraction_cores.clear() + + def print_contractions(self, sort=None, show_brackets=True): + """Print each pairwise contraction, with colorized indices (if + `colorama` is installed), and other information. The color codes are: + + - blue: index appears on left and is kept + - green: index appears on right and is kept + - red: contracted index: appears on both sides and is removed + - pink: batch index: appears on both sides and is kept + + Any trivial indices that appear only on one term and not in the output + are removed and shown by the preprocessing steps. + + Parameters + ---------- + sort : {'flops', 'size'}, optional + Sort the contractions by either the number of floating point + operations or the size of the intermediate tensor. By default the + contraction are show in the order they are performed. + show_brackets : bool, optional + Whether to show the brackets around contiguous sections of the same + type of indices. + """ + try: + from colorama import Fore + + RESET = Fore.RESET + GREY = Fore.WHITE + PINK = Fore.MAGENTA + RED = Fore.RED + BLUE = Fore.BLUE + GREEN = Fore.GREEN + except ImportError: + RESET = GREY = PINK = RED = BLUE = GREEN = "" + + entries = [] + + if self.has_preprocessing(): + for pi, eq in self.preprocessing.items(): + # eq is with canonical indices, reinsert original inputs + replacer = dict(zip(eq.split("->")[0], self.inputs[pi])) + eq = "".join(replacer.get(c, c) for c in eq) + print(f"{GREY}preprocess input {pi}: {RESET}{eq}") + print() + + for i, (p, l, r) in enumerate(self.traverse()): + p_legs, l_legs, r_legs = map(self.get_legs, [p, l, r]) + p_inds, l_inds, r_inds = map(self.get_inds, [p, l, r]) + + # print sizes and flops + p_flops = self.get_flops(p) + p_sz, l_sz, r_sz = ( + math.log2(self.get_size(node)) for node in [p, l, r] + ) + # print whether tensordottable + if self.get_can_dot(p): + type_msg = "tensordot" + perm = self.get_tensordot_perm(p) + if perm is not None: + # and whether indices match tensordot + type_msg += "+perm" + else: + type_msg = "einsum" + + kpt_brck_l = "(" if show_brackets else "" + kpt_brck_r = ")" if show_brackets else "" + con_brck_l = "[" if show_brackets else "" + con_brck_r = "]" if show_brackets else "" + hyp_brck_l = "{" if show_brackets else "" + hyp_brck_r = "}" if show_brackets else "" + + pa = ( + "".join( + PINK + f"{hyp_brck_l}{ix}{hyp_brck_r}" + if (ix in l_legs) and (ix in r_legs) + else GREEN + f"{kpt_brck_l}{ix}{kpt_brck_r}" + if ix in r_legs + else BLUE + ix + for ix in p_inds + ) + .replace(f"){GREEN}(", "") + .replace(f"}}{PINK}{{", "") + ) + la = ( + "".join( + PINK + f"{hyp_brck_l}{ix}{hyp_brck_r}" + if (ix in p_legs) and (ix in r_legs) + else RED + f"{con_brck_l}{ix}{con_brck_r}" + if ix in r_legs + else BLUE + ix + for ix in l_inds + ) + .replace(f"]{RED}[", "") + .replace(f"}}{PINK}{{", "") + ) + ra = ( + "".join( + PINK + f"{hyp_brck_l}{ix}{hyp_brck_r}" + if (ix in p_legs) and (ix in l_legs) + else RED + f"{con_brck_l}{ix}{con_brck_r}" + if ix in l_legs + else GREEN + ix + for ix in r_inds + ) + .replace(f"]{RED}[", "") + .replace(f"}}{PINK}{{", "") + ) + + entries.append( + ( + p, + f"{GREY}({i}) cost: {RESET}{p_flops:.1e} " + f"{GREY}widths: {RESET}{l_sz:.1f},{r_sz:.1f}->{p_sz:.1f} " + f"{GREY}type: {RESET}{type_msg}\n" + f"{GREY}inputs: {la},{ra}{RESET}->\n" + f"{GREY}output: {pa}\n", + ) + ) + + if sort == "flops": + entries.sort(key=lambda x: self.get_flops(x[0]), reverse=True) + if sort == "size": + entries.sort(key=lambda x: self.get_size(x[0]), reverse=True) + + entries.append((None, f"{RESET}")) + + o = "\n".join(entry for _, entry in entries) + print(o) + + # --------------------- Performing the Contraction ---------------------- # + + def get_contractor( + self, + order=None, + prefer_einsum=False, + strip_exponent=False, + check_zero=False, + implementation=None, + autojit=False, + progbar=False, + ): + """Get a reusable function which performs the contraction corresponding + to this tree, cached. + + Parameters + ---------- + tree : ContractionTree + The contraction tree. + order : str or callable, optional + Supplied to :meth:`ContractionTree.traverse`, the order in which + to perform the pairwise contractions given by the tree. + prefer_einsum : bool, optional + Prefer to use ``einsum`` for pairwise contractions, even if + ``tensordot`` can perform the contraction. + strip_exponent : bool, optional + If ``True``, the function will eagerly strip the exponent (in + log10) from intermediate tensors to control numerical problems from + leaving the range of the datatype. This method then returns the + scaled 'mantissa' output array and the exponent separately. + check_zero : bool, optional + If ``True``, when ``strip_exponent=True``, explicitly check for + zero-valued intermediates that would otherwise produce ``nan``, + instead terminating early if encountered and returning + ``(0.0, 0.0)``. + implementation : str or tuple[callable, callable], optional + What library to use to actually perform the contractions. Options + are: + + - None: let cotengra choose. + - "autoray": dispatch with autoray, using the ``tensordot`` and + ``einsum`` implementation of the backend. + - "cotengra": use the ``tensordot`` and ``einsum`` implementation + of cotengra, which is based on batch matrix multiplication. This + is faster for some backends like numpy, and also enables + libraries which don't yet provide ``tensordot`` and ``einsum`` to + be used. + - "cuquantum": use the cuquantum library to perform the whole + contraction (not just individual contractions). + - tuple[callable, callable]: manually supply the ``tensordot`` and + ``einsum`` implementations to use. + + autojit : bool, optional + If ``True``, use :func:`autoray.autojit` to compile the contraction + function. + progbar : bool, optional + Whether to show progress through the contraction by default. + + Returns + ------- + fn : callable + The contraction function, with signature ``fn(*arrays)``. + """ + key = ( + autojit, + order, + prefer_einsum, + strip_exponent, + check_zero, + implementation, + progbar, + ) + try: + fn = self.contraction_cores[key] + except KeyError: + fn = self.contraction_cores[key] = make_contractor( + tree=self, + order=order, + prefer_einsum=prefer_einsum, + strip_exponent=strip_exponent, + check_zero=check_zero, + implementation=implementation, + autojit=autojit, + progbar=progbar, + ) + + return fn + + def contract_core( + self, + arrays, + order=None, + prefer_einsum=False, + strip_exponent=False, + check_zero=False, + backend=None, + implementation=None, + autojit="auto", + progbar=False, + ): + """Contract ``arrays`` with this tree. The order of the axes and + output is assumed to be that of ``tree.inputs`` and ``tree.output``, + but with sliced indices removed. This functon contracts the core tree + and thus if indices have been sliced the arrays supplied need to be + sliced as well. + + Parameters + ---------- + arrays : sequence of array + The arrays to contract. + order : str or callable, optional + Supplied to :meth:`ContractionTree.traverse`. + prefer_einsum : bool, optional + Prefer to use ``einsum`` for pairwise contractions, even if + ``tensordot`` can perform the contraction. + backend : str, optional + What library to use for ``einsum`` and ``transpose``, will be + automatically inferred from the arrays if not given. + autojit : "auto" or bool, optional + Whether to use ``autoray.autojit`` to jit compile the expression. + If "auto", then let ``cotengra`` choose. + progbar : bool, optional + Show progress through the contraction. + """ + if autojit == "auto": + # choose for the user + autojit = backend == "jax" + + fn = self.get_contractor( + order=order, + prefer_einsum=prefer_einsum, + strip_exponent=strip_exponent is not False, + implementation=implementation, + autojit=autojit, + check_zero=check_zero, + progbar=progbar, + ) + return fn(*arrays, backend=backend) + + def slice_key(self, i, strides=None): + """Get the combination of sliced index values for overall slice ``i``. + + Parameters + ---------- + i : int + The overall slice index. + + Returns + ------- + key : dict[str, int] + The value each sliced index takes for slice ``i``. + """ + if strides is None: + strides = get_slice_strides(self.sliced_inds) + + key = {} + for (ind, info), stride in zip(self.sliced_inds.items(), strides): + if info.project is None: + key[ind] = i // stride + i %= stride + else: + # size is 1 and i doesn't change + key[ind] = info.project + + return key + + def slice_arrays(self, arrays, i): + """Take ``arrays`` and slice the relevant inputs according to + ``tree.sliced_inds`` and the dynary representation of ``i``. + """ + temp_arrays = list(arrays) + + # e.g. {'a': 2, 'd': 7, 'z': 0} + locations = self.slice_key(i) + + for c in self.sliced_inputs: + # the indexing object, e.g. [:, :, 7, :, 2, :, :, 0] + selector = tuple( + locations.get(ix, slice(None)) for ix in self.inputs[c] + ) + # re-insert the sliced array + temp_arrays[c] = temp_arrays[c][selector] + + return temp_arrays + + def contract_slice(self, arrays, i, **kwargs): + """Get slices ``i`` of ``arrays`` and then contract them.""" + return self.contract_core(self.slice_arrays(arrays, i), **kwargs) + + def gather_slices(self, slices, backend=None, progbar=False): + """Gather all the output contracted slices into a single full result. + If none of the sliced indices appear in the output, then this is a + simple sum - otherwise the slices need to be partially summed and + partially stacked. + """ + if progbar: + import tqdm + + slices = tqdm.tqdm(slices, total=self.multiplicity) + + output_pos = { + ix: i for i, ix in enumerate(self.output) if ix in self.sliced_inds + } + + if not output_pos: + # we can just sum everything + return functools.reduce(add_maybe_exponent_stripped, slices) + + # first we sum over non-output sliced indices + chunks = {} + for i, s in enumerate(slices): + key_slice = self.slice_key(i) + key = tuple(key_slice[ix] for ix in output_pos) + try: + chunks[key] = add_maybe_exponent_stripped(chunks[key], s) + except KeyError: + chunks[key] = s + + if isinstance(next(iter(chunks.values())), tuple): + # have stripped exponents, need to scale to largest + emax = max(v[1] for v in chunks.values()) + chunks = { + k: mi * 10 ** (ei - emax) for k, (mi, ei) in chunks.items() + } + else: + emax = None + + # then we stack these summed chunks over output sliced indices + def recursively_stack_chunks(loc, remaining): + if not remaining: + return chunks[loc] + arrays = [ + recursively_stack_chunks(loc + (d,), remaining[1:]) + for d in self.sliced_inds[remaining[0]].sliced_range + ] + axes = output_pos[remaining[0]] - len(loc) + return do("stack", arrays, axes, like=backend) + + result = recursively_stack_chunks((), tuple(output_pos)) + + if emax is not None: + # strip_exponent was True, return the exponent separately + return result, emax + + return result + + def gen_output_chunks( + self, arrays, with_key=False, progbar=False, **contract_opts + ): + """Generate each output chunk of the contraction - i.e. take care of + summing internally sliced indices only first. This assumes that the + ``sliced_inds`` are sorted by whether they appear in the output or not + (the default order). Useful for performing some kind of reduction over + the final tensor object like ``fn(x).sum()`` without constructing the + entire thing. + + Parameters + ---------- + arrays : sequence of array + The arrays to contract. + with_key : bool, optional + Whether to yield the output index configuration key along with the + chunk. + progbar : bool, optional + Show progress through the contraction chunks. + + Yields + ------ + chunk : array + A chunk of the contracted result. + key : dict[str, int] + The value each sliced output index takes for this chunk. + """ + # consecutive slices of size ``stepsize`` all belong to the same output + # block because the sliced indices are sorted output first + stepsize = prod( + si.size for si in self.sliced_inds.values() if si.inner + ) + + if progbar: + import tqdm + + it = tqdm.trange(self.nslices // stepsize) + else: + it = range(self.nslices // stepsize) + + for o in it: + chunk = self.contract_slice(arrays, o * stepsize, **contract_opts) + + if with_key: + output_key = { + ix: x + for ix, x in self.slice_key(o * stepsize).items() + if ix in self.output + } + + for j in range(1, stepsize): + i = o * stepsize + j + chunk = chunk + self.contract_slice(arrays, i, **contract_opts) + + if with_key: + yield chunk, output_key + else: + yield chunk + + def contract( + self, + arrays, + order=None, + prefer_einsum=False, + strip_exponent=False, + check_zero=False, + backend=None, + implementation=None, + autojit="auto", + progbar=False, + ): + """Contract ``arrays`` with this tree. This function takes *unsliced* + arrays and handles the slicing, contractions and gathering. The order + of the axes and output is assumed to match that of ``tree.inputs`` and + ``tree.output``. + + Parameters + ---------- + arrays : sequence of array + The arrays to contract. + order : str or callable, optional + Supplied to :meth:`ContractionTree.traverse`. + prefer_einsum : bool, optional + Prefer to use ``einsum`` for pairwise contractions, even if + ``tensordot`` can perform the contraction. + strip_exponent : bool, optional + If ``True``, eagerly strip the exponent (in log10) from + intermediate tensors to control numerical problems from leaving the + range of the datatype. This method then returns the scaled + 'mantissa' output array and the exponent separately. + check_zero : bool, optional + If ``True``, when ``strip_exponent=True``, explicitly check for + zero-valued intermediates that would otherwise produce ``nan``, + instead terminating early if encountered and returning + ``(0.0, 0.0)``. + backend : str, optional + What library to use for ``tensordot``, ``einsum`` and + ``transpose``, it will be automatically inferred from the input + arrays if not given. + autojit : bool, optional + Whether to use the 'autojit' feature of `autoray` to compile the + contraction expression. + progbar : bool, optional + Whether to show a progress bar. + + Returns + ------- + output : array + The contracted output, it will be scaled if + ``strip_exponent==True``. + exponent : float + The exponent of the output in base 10, returned only if + ``strip_exponent==True``. + + See Also + -------- + contract_core, contract_slice, slice_arrays, gather_slices + """ + if not self.sliced_inds: + return self.contract_core( + arrays, + order=order, + prefer_einsum=prefer_einsum, + strip_exponent=strip_exponent, + check_zero=check_zero, + backend=backend, + implementation=implementation, + autojit=autojit, + progbar=progbar, + ) + + slices = ( + self.contract_slice( + arrays, + i, + order=order, + prefer_einsum=prefer_einsum, + strip_exponent=strip_exponent, + check_zero=check_zero, + backend=backend, + implementation=implementation, + autojit=autojit, + ) + for i in range(self.multiplicity) + ) + + return self.gather_slices(slices, backend=backend, progbar=progbar) + + def contract_mpi(self, arrays, comm=None, root=None, **kwargs): + """Contract the slices of this tree and sum them in parallel - + *assuming* we are already running under MPI. + + Parameters + ---------- + arrays : sequence of array + The input (unsliced arrays) + comm : None or mpi4py communicator + Defaults to ``mpi4py.MPI.COMM_WORLD`` if not given. + root : None or int, optional + If ``root=None``, an ``Allreduce`` will be performed such that + every process has the resulting tensor, else if an integer e.g. + ``root=0``, the result will be exclusively gathered to that + process using ``Reduce``, with every other process returning + ``None``. + kwargs + Supplied to :meth:`~cotengra.ContractionTree.contract_slice`. + """ + if not set(self.sliced_inds).isdisjoint(set(self.output)): + raise NotImplementedError( + "Sliced and output indices overlap - currently only a simple " + "sum of result slices is supported currently." + ) + + if comm is None: + from mpi4py import MPI + + comm = MPI.COMM_WORLD + + if self.multiplicity < comm.size: + raise ValueError( + f"Need to have more slices than MPI processes, but have " + f"{self.multiplicity} and {comm.size} respectively." + ) + + # round robin compute each slice, eagerly summing + result_i = None + for i in range(comm.rank, self.multiplicity, comm.size): + # note: fortran ordering is needed for the MPI reduce + x = do("asfortranarray", self.contract_slice(arrays, i, **kwargs)) + if result_i is None: + result_i = x + else: + result_i += x + + if root is None: + # everyone gets the summed result + result = do("empty_like", result_i) + comm.Allreduce(result_i, result) + return result + + # else we only sum reduce the result to process ``root`` + if comm.rank == root: + result = do("empty_like", result_i) + else: + result = None + comm.Reduce(result_i, result, root=root) + return result + + def benchmark( + self, + dtype, + max_time=60, + min_reps=3, + max_reps=100, + warmup=True, + **contract_opts, + ): + """Benchmark the contraction of this tree. + + Parameters + ---------- + dtype : {"float32", "float64", "complex64", "complex128"} + The datatype to use. + max_time : float, optional + The maximum time to spend benchmarking in seconds. + min_reps : int, optional + The minimum number of repetitions to perform, regardless of time. + max_reps : int, optional + The maximum number of repetitions to perform, regardless of time. + warmup : bool or int, optional + Whether to perform a warmup run before the benchmark. If an int, + the number of warmup runs to perform. + contract_opts + Supplied to :meth:`~cotengra.ContractionTree.contract_slice`. + + Returns + ------- + dict + A dictionary of benchmarking results. The keys are: + + - "time_per_slice" : float + The average time to contract a single slice. + - "est_time_total" : float + The estimated total time to contract all slices. + - "est_gigaflops" : float + The estimated gigaflops of the contraction. + + See Also + -------- + contract_slice + """ + import time + + from .utils import make_arrays_from_inputs + + arrays = make_arrays_from_inputs( + self.inputs, self.size_dict, dtype=dtype + ) + + for i in range(int(warmup)): + self.contract_slice(arrays, i % self.nslices, **contract_opts) + + t0 = time.time() + ti = t0 + i = 0 + while (ti - t0 < max_time) or (i < min_reps): + self.contract_slice(arrays, i % self.nslices, **contract_opts) + ti = time.time() + i += 1 + if i >= max_reps: + break + + time_per_slice = (ti - t0) / i + est_time_total = time_per_slice * self.nslices + est_gigaflops = self.total_flops(dtype=dtype) / (1e9 * est_time_total) + + return { + "time_per_slice": time_per_slice, + "est_time_total": est_time_total, + "est_gigaflops": est_gigaflops, + } + + plot_ring = plot_tree_ring + plot_tent = plot_tree_tent + plot_span = plot_tree_span + plot_flat = plot_tree_flat + plot_circuit = plot_tree_circuit + plot_rubberband = plot_tree_rubberband + plot_contractions = plot_contractions + plot_contractions_alt = plot_contractions_alt + + @functools.wraps(plot_hypergraph) + def plot_hypergraph(self, **kwargs): + hg = self.get_hypergraph(accel=False) + hg.plot(**kwargs) + + def describe(self, info="normal", join=" "): + """Return a string describing the contraction tree.""" + self.contract_stats() + if info == "normal": + return join.join( + ( + f"log10[FLOPs]={self.total_flops(log=10):.2f}", + f"log2[SIZE]={self.max_size(log=2):.2f}", + ) + ) + + elif info == "full": + s = [ + f"log10[FLOPS]={self.total_flops(log=10):.2f}", + f"log10[COMBO]={self.combo_cost(log=10):.2f}", + f"log2[SIZE]={self.max_size(log=2):.2f}", + f"log2[PEAK]={self.peak_size(log=2):.2f}", + ] + if self.sliced_inds: + s.append(f"NSLICES={self.multiplicity:.2f}") + return join.join(s) + + elif info == "concise": + s = [ + f"F={self.total_flops(log=10):.2f}", + f"C={self.combo_cost(log=10):.2f}", + f"S={self.max_size(log=2):.2f}", + f"P={self.peak_size(log=2):.2f}", + ] + if self.sliced_inds: + s.append(f"$={self.multiplicity:.2f}") + return join.join(s) + + def __repr__(self): + if self.is_complete(): + return f"<{self.__class__.__name__}(N={self.N})>" + else: + s = "<{}(N={}, branches={}, complete={})>" + return s.format( + self.__class__.__name__, + self.N, + len(self.children), + self.is_complete(), + ) + + def __str__(self): + if not self.is_complete(): + return self.__repr__() + else: + d = self.describe("concise", join=", ") + return f"<{self.__class__.__name__}(N={self.N}, {d})>" + + +def _reconfigure_tree(tree, *args, **kwargs): + return tree.subtree_reconfigure(*args, **kwargs) + + +def _slice_and_reconfigure_tree(tree, *args, **kwargs): + return tree.slice_and_reconfigure(*args, **kwargs) + + +def _get_tree_info(tree): + stats = tree.contract_stats() + stats["sliced_ind_set"] = frozenset(tree.sliced_inds) + return stats + + +def _describe_tree(tree, info="normal"): + return tree.describe(info=info) + + +class ContractionTreeCompressed(ContractionTree): + """A contraction tree for compressed contractions. Currently the only + difference is that this defaults to the 'surface' traversal ordering. + """ + + def set_state_from(self, other): + super().set_state_from(other) + self.set_surface_order_from_path(other.get_ssa_path()) + + @classmethod + def from_path( + cls, + inputs, + output, + size_dict, + *, + path=None, + ssa_path=None, + autocomplete="auto", + check=False, + **kwargs, + ): + """Create a (completed) ``ContractionTreeCompressed`` from the usual + inputs plus a standard contraction path or 'ssa_path' - you need to + supply one. This also set the default 'surface' traversal ordering to + be the initial path. + """ + if int(path is None) + int(ssa_path is None) != 1: + raise ValueError( + "Exactly one of ``path`` or ``ssa_path`` must be supplied." + ) + + if path is not None: + from .pathfinders.path_basic import linear_to_ssa + + ssa_path = linear_to_ssa(path) + + tree = cls(inputs, output, size_dict, **kwargs) + terms = list(tree.gen_leaves()) + + for p in ssa_path: + merge = [terms[i] for i in p] + terms.append(tree.contract_nodes(merge, check=check)) + + tree.set_surface_order_from_path(ssa_path) + + if (len(tree.children) < tree.N - 1) and autocomplete: + if autocomplete == "auto": + # warn that we are completing + warnings.warn( + "Path was not complete - contracting all remaining. " + "You can silence this warning with `autocomplete=True`." + "Or produce an incomplete tree with `autocomplete=False`." + ) + + tree.autocomplete(optimize="greedy-compressed") + + return tree + + def get_default_order(self): + return "surface_order" + + def get_default_objective(self): + if self._default_objective is None: + self._default_objective = get_score_fn("peak-compressed") + return self._default_objective + + def get_default_chi(self): + objective = self.get_default_objective() + try: + chi = objective.chi + except AttributeError: + chi = "auto" + + if chi == "auto": + chi = max(self.size_dict.values()) ** 2 + + return chi + + def get_default_compress_late(self): + objective = self.get_default_objective() + try: + return objective.compress_late + except AttributeError: + return False + + total_flops = ContractionTree.total_flops_compressed + total_write = ContractionTree.total_write_compressed + combo_cost = ContractionTree.combo_cost_compressed + total_cost = ContractionTree.total_cost_compressed + max_size = ContractionTree.max_size_compressed + peak_size = ContractionTree.peak_size_compressed + contraction_cost = ContractionTree.contraction_cost_compressed + contraction_width = ContractionTree.contraction_width_compressed + + total_flops_exact = ContractionTree.total_flops + total_write_exact = ContractionTree.total_write + combo_cost_exact = ContractionTree.combo_cost + total_cost_exact = ContractionTree.total_cost + max_size_exact = ContractionTree.max_size + peak_size_exact = ContractionTree.peak_size + + def get_contractor(self, *_, **__): + raise NotImplementedError( + "`cotengra` doesn't implement compressed contraction itself. " + "If you want to use compressed contractions, you need to use " + "`quimb` and the `TensorNetwork.contract_compressed` method, " + "with e.g. `optimize=tree.get_path()`." + ) + + def simulated_anneal( + self, + minimize=None, + tfinal=0.0001, + tstart=0.01, + tsteps=50, + numiter=50, + seed=None, + inplace=False, + progbar=False, + **kwargs, + ): + """Perform simulated annealing refinement of this *compressed* + contraction tree. + """ + from .pathfinders.path_compressed import WindowedOptimizer + + if minimize is None: + minimize = self.get_default_objective() + + wo = WindowedOptimizer( + self.inputs, + self.output, + self.size_dict, + minimize=minimize, + ssa_path=self.get_ssa_path(), + seed=seed, + ) + + wo.simulated_anneal( + tfinal=tfinal, + tstart=tstart, + tsteps=tsteps, + numiter=numiter, + progbar=progbar, + **kwargs, + ) + ssa_path = wo.get_ssa_path() + + rtree = self.__class__.from_path( + self.inputs, + self.output, + self.size_dict, + ssa_path=ssa_path, + objective=minimize, + ) + + if inplace: + self.set_state_from(rtree) + rtree = self + + rtree.contraction_cores.clear() + return rtree + + simulated_anneal_ = functools.partialmethod(simulated_anneal, inplace=True) + + +class PartitionTreeBuilder: + """Function wrapper that takes a function that partitions graphs and + uses it to build a contraction tree. ``partition_fn`` should have + signature: + + def partition_fn(inputs, output, size_dict, + weight_nodes, weight_edges, **kwargs): + ... + return membership + + Where ``weight_nodes`` and ``weight_edges`` decsribe how to weight the + nodes and edges of the graph respectively and ``membership`` should be a + list of integers of length ``len(inputs)`` labelling which partition + each input node should be put it. + """ + + def __init__(self, partition_fn): + self.partition_fn = partition_fn + + def build_divide( + self, + inputs, + output, + size_dict, + random_strength=0.01, + cutoff=10, + parts=2, + parts_decay=0.5, + sub_optimize="greedy", + super_optimize="auto-hq", + check=False, + seed=None, + **partition_opts, + ): + tree = ContractionTree(inputs, output, size_dict, track_childless=True) + + rng = get_rng(seed) + rand_size_dict = jitter_dict(size_dict, random_strength, rng) + + dynamic_imbalance = ("imbalance" in partition_opts) and ( + "imbalance_decay" in partition_opts + ) + if dynamic_imbalance: + imbalance = partition_opts.pop("imbalance") + imbalance_decay = partition_opts.pop("imbalance_decay") + else: + imbalance = imbalance_decay = None + + dynamic_fix = partition_opts.get("fix_output_nodes", None) == "auto" + + while tree.childless: + tree_node = next(iter(tree.childless)) + subgraph = tuple(tree_node) + subsize = len(subgraph) + + # skip straight to better method + if subsize <= cutoff: + tree.contract_nodes( + [node_from_single(x) for x in subgraph], + optimize=sub_optimize, + check=check, + ) + continue + + # relative subgraph size + s = subsize / tree.N + + # let the target number of communities depend on subgraph size + parts_s = max(int(s**parts_decay * parts), 2) + + # let the imbalance either rise or fall + if dynamic_imbalance: + if imbalance_decay >= 0: + imbalance_s = s**imbalance_decay * imbalance + else: + imbalance_s = 1 - s**-imbalance_decay * (1 - imbalance) + partition_opts["imbalance"] = imbalance_s + + if dynamic_fix: + # for the top level subtree (s==1.0) we partition the outputs + # nodes first into their own bi-partition + parts_s = 2 + partition_opts["fix_output_nodes"] = s == 1.0 + + # partition! get community membership list e.g. + # [0, 0, 1, 0, 1, 0, 0, 2, 2, ...] + inputs = tuple(map(tuple, tree.node_to_terms(subgraph))) + output = tuple(tree.get_legs(tree_node)) + membership = self.partition_fn( + inputs, + output, + rand_size_dict, + parts=parts_s, + seed=rng, + **partition_opts, + ) + + # divide subgraph up e.g. if we enumerate the subgraph index sets + # (0, 1, 2, 3, 4, 5, 6, 7, 8, ...) -> + # ({0, 1, 3, 5, 6}, {2, 4}, {7, 8}) + new_subgs = tuple( + map(node_from_seq, separate(subgraph, membership)) + ) + + if len(new_subgs) == 1: + # no communities found - contract all remaining + tree.contract_nodes( + tuple(map(node_from_single, subgraph)), + optimize=sub_optimize, + check=check, + ) + continue + + # update tree structure with newly contracted subgraphs + tree.contract_nodes( + new_subgs, optimize=super_optimize, check=check + ) + + if check: + assert tree.is_complete() + + return tree + + def build_agglom( + self, + inputs, + output, + size_dict, + random_strength=0.01, + groupsize=4, + check=False, + sub_optimize="greedy", + seed=None, + **partition_opts, + ): + tree = ContractionTree(inputs, output, size_dict, track_childless=True) + rand_size_dict = jitter_dict(size_dict, random_strength, seed) + leaves = tuple(tree.gen_leaves()) + for node in leaves: + tree._add_node(node, check=check) + output = tuple(tree.output) + + while len(leaves) > groupsize: + parts = max(2, len(leaves) // groupsize) + + inputs = [tuple(tree.get_legs(node)) for node in leaves] + membership = self.partition_fn( + inputs, + output, + rand_size_dict, + parts=parts, + **partition_opts, + ) + leaves = [ + tree.contract_nodes(group, check=check, optimize=sub_optimize) + for group in separate(leaves, membership) + ] + + if len(leaves) > 1: + tree.contract_nodes(leaves, check=check, optimize=sub_optimize) + + if check: + assert tree.is_complete() + + return tree + + def trial_fn(self, inputs, output, size_dict, **partition_opts): + return self.build_divide(inputs, output, size_dict, **partition_opts) + + def trial_fn_agglom(self, inputs, output, size_dict, **partition_opts): + return self.build_agglom(inputs, output, size_dict, **partition_opts) + + +def jitter(x, strength, rng): + return x * (1 + strength * rng.expovariate(1.0)) + + +def jitter_dict(d, strength, seed=None): + rng = get_rng(seed) + return {k: jitter(v, strength, rng) for k, v in d.items()} + + +def separate(xs, blocks): + """Partition ``xs`` into ``n`` different list based on the corresponding + labels in ``blocks``. + """ + sorter = collections.defaultdict(list) + for x, b in zip(xs, blocks): + sorter[b].append(x) + x_b = list(sorter.items()) + x_b.sort() + return [x[1] for x in x_b] diff --git a/.venv/lib/python3.12/site-packages/cotengra/hyperoptimizers/hyper.py b/.venv/lib/python3.12/site-packages/cotengra/hyperoptimizers/hyper.py new file mode 100644 index 0000000..08c880e --- /dev/null +++ b/.venv/lib/python3.12/site-packages/cotengra/hyperoptimizers/hyper.py @@ -0,0 +1,1168 @@ +"""Base hyper optimization functionality.""" + +import functools +import importlib +import re +import time +import warnings +from math import log2, log10 + +from ..core import ( + ContractionTree, + ContractionTreeCompressed, +) +from ..core_multi import ContractionTreeMulti +from ..oe import PathOptimizer +from ..parallel import get_n_workers, parse_parallel_arg, should_nest, submit +from ..plot import ( + plot_parameters_parallel, + plot_scatter, + plot_scatter_alt, + plot_trials, + plot_trials_alt, +) +from ..reusable import ReusableOptimizer +from ..scoring import get_score_fn +from ..utils import BadTrial, get_rng + + +@functools.lru_cache(maxsize=None) +def get_default_hq_methods(): + methods = ["greedy"] + if importlib.util.find_spec("kahypar"): + methods.append("kahypar") + else: + methods.append("labels") + warnings.warn( + "Couldn't import `kahypar` - skipping from default hyper optimizer" + " and using basic `labels` method instead. `kahypar` is highly " + "recommended for the best quality contraction paths." + ) + return tuple(methods) + + +@functools.lru_cache(maxsize=None) +def get_default_optlib_eco(): + """Get the default optimizer favoring speed.""" + if importlib.util.find_spec("cmaes"): + optlib = "cmaes" + elif importlib.util.find_spec("nevergrad"): + optlib = "nevergrad" + elif importlib.util.find_spec("optuna"): + optlib = "optuna" + else: + optlib = "random" + warnings.warn( + "Couldn't find `optuna`, `cmaes`, or `nevergrad` so will use " + "completely random sampling in place of hyper-optimization. " + "It is recommended to install one of these libraries for higher " + "quality contraction paths." + ) + return optlib + + +@functools.lru_cache(maxsize=None) +def get_default_optlib(): + """Get the default optimizer balancing quality and speed.""" + if importlib.util.find_spec("optuna"): + optlib = "optuna" + elif importlib.util.find_spec("cmaes"): + optlib = "cmaes" + elif importlib.util.find_spec("nevergrad"): + optlib = "nevergrad" + else: + optlib = "random" + warnings.warn( + "Couldn't find `optuna`, `cmaes`, or `nevergrad` so will use " + "completely random sampling in place of hyper-optimization. " + "It is recommended to install one of these libraries for higher " + "quality hyper-optimization." + ) + return optlib + + +_PATH_FNS = {} +_OPTLIB_FNS = {} +_HYPER_SEARCH_SPACE = {} +_HYPER_CONSTANTS = {} + + +def get_hyper_space(): + return _HYPER_SEARCH_SPACE + + +def get_hyper_constants(): + return _HYPER_CONSTANTS + + +def register_hyper_optlib(name, init_optimizers, get_setting, report_result): + _OPTLIB_FNS[name] = (init_optimizers, get_setting, report_result) + + +def register_hyper_function(name, ssa_func, space, constants=None): + """Register a contraction path finder to be used by the hyper-optimizer. + + Parameters + ---------- + name : str + The name to call the method. + ssa_func : callable + The raw function that returns a 'ContractionTree', with signature + ``(inputs, output, size_dict, **kwargs)``. + space : dict[str, dict] + The space of hyper-parameters to search. + """ + if constants is None: + constants = {} + + _PATH_FNS[name] = ssa_func + _HYPER_SEARCH_SPACE[name] = space + _HYPER_CONSTANTS[name] = constants + + +def list_hyper_functions(): + """Return a list of currently registered hyper contraction finders.""" + return sorted(_PATH_FNS) + + +def base_trial_fn(inputs, output, size_dict, method, **kwargs): + tree = _PATH_FNS[method](inputs, output, size_dict, **kwargs) + return {"tree": tree} + + +class TrialSetObjective: + def __init__(self, trial_fn, objective): + self.trial_fn = trial_fn + self.objective = objective + + def __call__(self, *args, **kwargs): + trial = self.trial_fn(*args, **kwargs) + trial["tree"].set_default_objective(self.objective) + return trial + + +class TrialConvertTree: + def __init__(self, trial_fn, cls): + self.trial_fn = trial_fn + self.cls = cls + + def __call__(self, *args, **kwargs): + trial = self.trial_fn(*args, **kwargs) + + tree = trial["tree"] + if not isinstance(tree, self.cls): + tree.__class__ = self.cls + + return trial + + +class TrialTreeMulti: + def __init__(self, trial_fn, varmults, numconfigs): + self.trial_fn = trial_fn + self.varmults = varmults + self.numconfigs = numconfigs + + def __call__(self, *args, **kwargs): + trial = self.trial_fn(*args, **kwargs) + + tree = trial["tree"] + if not isinstance(tree, ContractionTreeMulti): + tree.__class__ = ContractionTreeMulti + + tree.set_varmults(self.varmults) + tree.set_numconfigs(self.numconfigs) + + return trial + + +class SlicedTrialFn: + def __init__(self, trial_fn, **opts): + self.trial_fn = trial_fn + self.opts = opts + + def __call__(self, *args, **kwargs): + trial = self.trial_fn(*args, **kwargs) + tree = trial["tree"] + + stats = tree.contract_stats() + trial.setdefault("original_flops", stats["flops"]) + trial.setdefault("original_write", stats["write"]) + trial.setdefault("original_size", stats["size"]) + + tree.slice_(**self.opts) + trial.update(tree.contract_stats()) + + return trial + + +class SimulatedAnnealingTrialFn: + def __init__(self, trial_fn, **opts): + self.trial_fn = trial_fn + self.opts = opts + + def __call__(self, *args, **kwargs): + trial = self.trial_fn(*args, **kwargs) + tree = trial["tree"] + stats = tree.contract_stats() + trial.setdefault("original_flops", stats["flops"]) + trial.setdefault("original_write", stats["write"]) + trial.setdefault("original_size", stats["size"]) + tree.simulated_anneal_(**self.opts) + trial.update(tree.contract_stats()) + return trial + + +class ReconfTrialFn: + def __init__(self, trial_fn, forested=False, parallel=False, **opts): + self.trial_fn = trial_fn + self.forested = forested + self.parallel = parallel + self.opts = opts + + def __call__(self, *args, **kwargs): + trial = self.trial_fn(*args, **kwargs) + tree = trial["tree"] + + stats = tree.contract_stats() + trial.setdefault("original_flops", stats["flops"]) + trial.setdefault("original_write", stats["write"]) + trial.setdefault("original_size", stats["size"]) + + if self.forested: + tree.subtree_reconfigure_forest_( + parallel=self.parallel, **self.opts + ) + else: + tree.subtree_reconfigure_(**self.opts) + + tree.already_optimized.clear() + trial.update(tree.contract_stats()) + + return trial + + +class SlicedReconfTrialFn: + def __init__(self, trial_fn, forested=False, parallel=False, **opts): + self.trial_fn = trial_fn + self.forested = forested + self.parallel = parallel + self.opts = opts + + def __call__(self, *args, **kwargs): + trial = self.trial_fn(*args, **kwargs) + tree = trial["tree"] + + stats = tree.contract_stats() + trial.setdefault("original_flops", stats["flops"]) + trial.setdefault("original_write", stats["write"]) + trial.setdefault("original_size", stats["size"]) + + if self.forested: + tree.slice_and_reconfigure_forest_( + parallel=self.parallel, **self.opts + ) + else: + tree.slice_and_reconfigure_(**self.opts) + + tree.already_optimized.clear() + trial.update(tree.contract_stats()) + + return trial + + +class CompressedReconfTrial: + def __init__(self, trial_fn, minimize=None, **opts): + self.trial_fn = trial_fn + self.minimize = minimize + self.opts = opts + + def __call__(self, *args, **kwargs): + trial = self.trial_fn(*args, **kwargs) + tree = trial["tree"] + tree.windowed_reconfigure_(minimize=self.minimize, **self.opts) + return trial + + +class ComputeScore: + """The final score wrapper, that performs some simple arithmetic on the + trial score to make it more suitable for hyper-optimization. + """ + + def __init__( + self, + fn, + score_fn, + score_compression=0.75, + score_smudge=1e-6, + on_trial_error="warn", + seed=0, + ): + self.fn = fn + self.score_fn = score_fn + self.score_compression = score_compression + self.score_smudge = score_smudge + self.on_trial_error = { + "raise": "raise", + "warn": "warn", + "ignore": "ignore", + }[on_trial_error] + self.rng = get_rng(seed) + + def __call__(self, *args, **kwargs): + ti = time.time() + try: + trial = self.fn(*args, **kwargs) + trial["score"] = self.score_fn(trial) ** self.score_compression + # random smudge is for baytune/scikit-learn nan/inf bug + trial["score"] += self.rng.gauss(0.0, self.score_smudge) + except BadTrial: + trial = { + "score": float("inf"), + "flops": float("inf"), + "write": float("inf"), + "size": float("inf"), + } + except Exception as e: + if self.on_trial_error == "raise": + raise e + elif self.on_trial_error == "warn": + warnings.warn( + f"Trial error: {e}. Set `HyperOptimizer` kwarg " + "`on_trial_error='raise'` to raise this error, or " + "`on_trial_error='ignore'` to silence." + ) + trial = { + "score": float("inf"), + "flops": float("inf"), + "write": float("inf"), + "size": float("inf"), + } + + tf = time.time() + trial["time"] = tf - ti + # replace heavy ContractionTree with lightweight path info to avoid + # OOM when many workers serialize large trees into the IPC queue + tree = trial.get("tree") + if tree is not None: + trial["_path"] = tree.get_path() + trial["_sliced_inds"] = dict(tree.sliced_inds) + del trial["tree"] + return trial + + +def progress_description(best, info="concise"): + try: + return best["tree"].describe(info=info) + except KeyError: + return ( + f"log10[FLOPs]={log10(best['flops']):.2f} " + f"log2[SIZE]={log2(best['size']):.2f}" + ) + + +class HyperOptimizer(PathOptimizer): + """A path optimizer that samples a series of contraction trees + while optimizing the hyper parameters used to generate them. + + Parameters + ---------- + methods : None or sequence[str] or str, optional + Which method(s) to use from ``list_hyper_functions()``. + minimize : str, Objective or callable, optional + How to score each trial, used to train the optimizer and rank the + results. If a custom callable, it should take a ``trial`` dict as its + argument and return a single float. + max_repeats : int, optional + The maximum number of trial contraction trees to generate. + Default: 128. + max_time : None or float, optional + The maximum amount of time to run for. Use ``None`` for no limit. You + can also set an estimated execution 'rate' here like ``'rate:1e9'`` + that will terminate the search when the estimated FLOPs of the best + contraction found divided by the rate is greater than the time spent + searching, allowing quick termination on easy contractions. + parallel : 'auto', False, True, int, or distributed.Client + Whether to parallelize the search. + slicing_opts : dict, optional + If supplied, once a trial contraction path is found, try slicing with + the given options, and then update the flops and size of the trial with + the sliced versions. + slicing_reconf_opts : dict, optional + If supplied, once a trial contraction path is found, try slicing + interleaved with subtree reconfiguation with the given options, and + then update the flops and size of the trial with the sliced and + reconfigured versions. + reconf_opts : dict, optional + If supplied, once a trial contraction path is found, try subtree + reconfiguation with the given options, and then update the flops and + size of the trial with the reconfigured versions. + optlib : {'optuna', 'cmaes', 'nevergrad', 'skopt', ...}, optional + Which optimizer to sample and train with. + space : dict, optional + The hyper space to search, see ``get_hyper_space`` for the default. + score_compression : float, optional + Raise scores to this power in order to compress or accentuate the + differences. The lower this is, the more the selector will sample from + various optimizers rather than quickly specializing. + on_trial_error : {'warn', 'raise', 'ignore'}, optional + What to do if a trial fails. If ``'warn'`` (default), a warning will be + printed and the trial will be given a score of ``inf``. If ``'raise'`` + the error will be raised. If ``'ignore'`` the trial will be given a + score of ``inf`` silently. + max_training_steps : int, optional + The maximum number of trials to train the optimizer with. Setting this + can be helpful when the optimizer itself becomes costly to train (e.g. + for Gaussian Processes). + progbar : bool, optional + Show live progress of the best contraction found so far. + optlib_opts + Supplied to the hyper-optimizer library initialization. + """ + + compressed = False + multicontraction = False + + def __init__( + self, + methods=None, + minimize="flops", + max_repeats=128, + max_time=None, + parallel="auto", + simulated_annealing_opts=None, + slicing_opts=None, + slicing_reconf_opts=None, + reconf_opts=None, + optlib=None, + space=None, + score_compression=0.75, + on_trial_error="warn", + max_training_steps=None, + progbar=False, + **optlib_opts, + ): + self.max_repeats = max_repeats + self._repeats_start = 0 + self.max_time = max_time + self.parallel = parallel + + self.method_choices = [] + self.param_choices = [] + self.scores = [] + self.times = [] + self.costs_flops = [] + self.costs_write = [] + self.costs_size = [] + + if methods is None: + self._methods = get_default_hq_methods() + elif isinstance(methods, str): + self._methods = [methods] + else: + self._methods = list(methods) + + if optlib is None: + optlib = get_default_optlib() + + # which score to feed to the hyper optimizer (setter below handles) + self.minimize = minimize + self.score_compression = score_compression + self.on_trial_error = on_trial_error + self.best_score = float("inf") + self.max_training_steps = max_training_steps + + inf = float("inf") + self.best = {"score": inf, "size": inf, "flops": inf} + self.trials_since_best = 0 + + self.simulated_annealing_opts = ( + None + if simulated_annealing_opts is None + else dict(simulated_annealing_opts) + ) + self.slicing_opts = ( + None if slicing_opts is None else dict(slicing_opts) + ) + self.reconf_opts = None if reconf_opts is None else dict(reconf_opts) + self.slicing_reconf_opts = ( + None if slicing_reconf_opts is None else dict(slicing_reconf_opts) + ) + self.progbar = progbar + + if space is None: + space = get_hyper_space() + + self._optimizer = dict( + zip(["init", "get_setting", "report_result"], _OPTLIB_FNS[optlib]) + ) + + self._optimizer["init"](self, self._methods, space, **optlib_opts) + + @property + def minimize(self): + return self._minimize + + @minimize.setter + def minimize(self, minimize): + self._minimize = minimize + if callable(minimize): + self.objective = minimize + else: + self.objective = get_score_fn(minimize) + + @property + def parallel(self): + return self._parallel + + @parallel.setter + def parallel(self, parallel): + self._parallel = parallel + self._pool = parse_parallel_arg(parallel) + if self._pool is not None: + self._num_workers = get_n_workers(self._pool) + + self.pre_dispatch = self._num_workers + @property + def tree(self): + return self.best["tree"] + + @property + def path(self): + return self.tree.get_path() + + def setup(self, inputs, output, size_dict): + trial_fn = TrialSetObjective(base_trial_fn, self.objective) + + if self.compressed: + assert not self.multicontraction + trial_fn = TrialConvertTree(trial_fn, ContractionTreeCompressed) + + if self.multicontraction: + assert not self.compressed + trial_fn = TrialTreeMulti(trial_fn, self.varmults, self.numconfigs) + + nested_parallel = should_nest(self._pool) + + if self.simulated_annealing_opts is not None: + trial_fn = SimulatedAnnealingTrialFn( + trial_fn, **self.simulated_annealing_opts + ) + + if self.slicing_opts is not None: + trial_fn = SlicedTrialFn(trial_fn, **self.slicing_opts) + + if self.slicing_reconf_opts is not None: + self.slicing_reconf_opts.setdefault("parallel", nested_parallel) + trial_fn = SlicedReconfTrialFn( + trial_fn, **self.slicing_reconf_opts + ) + + if self.reconf_opts is not None: + if self.compressed: + trial_fn = CompressedReconfTrial(trial_fn, **self.reconf_opts) + else: + self.reconf_opts.setdefault("parallel", nested_parallel) + trial_fn = ReconfTrialFn(trial_fn, **self.reconf_opts) + + # make sure score computation is performed worker side + trial_fn = ComputeScore( + trial_fn, + score_fn=self.objective, + score_compression=self.score_compression, + on_trial_error=self.on_trial_error, + ) + + return trial_fn, (inputs, output, size_dict) + + def _maybe_cancel_futures(self): + if self._pool is not None: + while self._futures: + f = self._futures.pop()[-1] + f.cancel() + + def _maybe_report_result(self, setting, trial): + score = trial["score"] + + new_best = score < self.best_score + if new_best: + self.best_score = score + + # only fit optimizers after the training epoch if the score is best + should_report = ( + (self.max_training_steps is None) + or (len(self.scores) < self.max_training_steps) + or new_best + ) and ( + # don't report bad trials + # XXX: should we map to some high value? + trial["score"] < float("inf") + ) + + if should_report: + self._optimizer["report_result"](self, setting, trial, score) + + self.method_choices.append(setting["method"]) + self.param_choices.append(setting["params"]) + # keep track of all costs and sizes + self.costs_flops.append(trial["flops"]) + self.costs_write.append(trial["write"]) + self.costs_size.append(trial["size"]) + self.scores.append(trial["score"]) + self.times.append(trial["time"]) + + def _gen_results(self, repeats, trial_fn, trial_args): + constants = get_hyper_constants() + + for _ in repeats: + setting = self._optimizer["get_setting"](self) + method = setting["method"] + + trial = trial_fn( + *trial_args, + method=method, + **setting["params"], + **constants[method], + ) + + self._maybe_report_result(setting, trial) + + yield trial + + def _get_and_report_next_future(self): + futures_map = {future: setting for setting, future in self._futures} + + if not futures_map: + return { + "score": float("inf"), + "flops": float("inf"), + "write": float("inf"), + "size": float("inf"), + "time": 0.0, + } + + future0 = next(iter(futures_map)) + if future0.__class__.__module__.split(".", 1)[0] == "distributed": + from distributed import as_completed + + deadline = getattr(self, "_qibotn_deadline", None) + timeout = None if deadline is None else max(0.0, deadline - time.time()) + try: + future = next(iter(as_completed(futures_map, timeout=timeout))) + except TimeoutError: + for future in futures_map: + future.cancel() + self._futures = [] + return { + "score": float("inf"), + "flops": float("inf"), + "write": float("inf"), + "size": float("inf"), + "time": 0.0, + } + else: + # use as_completed to block efficiently instead of busy-polling + import concurrent.futures as _cf + + future = next(_cf.as_completed(futures_map)) + + setting = futures_map[future] + self._futures = [(s, f) for s, f in self._futures if f is not future] + try: + trial = future.result() + except Exception: + trial = { + "score": float("inf"), + "flops": float("inf"), + "write": float("inf"), + "size": float("inf"), + "time": 0.0, + } + self._maybe_report_result(setting, trial) + return trial + + def _gen_results_parallel(self, repeats, trial_fn, trial_args): + constants = get_hyper_constants() + self._futures = [] + + for _ in repeats: + setting = self._optimizer["get_setting"](self) + method = setting["method"] + + try: + future = submit( + self._pool, + trial_fn, + *trial_args, + method=method, + **setting["params"], + **constants[method], + ) + except Exception: + # pool broken — drain remaining futures and stop submitting + break + self._futures.append((setting, future)) + + if len(self._futures) >= self.pre_dispatch: + yield self._get_and_report_next_future() + + while self._futures: + yield self._get_and_report_next_future() + + def _search(self, inputs, output, size_dict): + # start a timer? + if self.max_time is not None: + t0 = time.time() + if isinstance(self.max_time, str): + which, amount = re.match( + r"(rate|equil):(.+)", self.max_time + ).groups() + + if which == "rate": + rate = float(amount) + + def should_stop(): + return (time.time() - t0) > (self.best["flops"] / rate) + + elif which == "equil": + amount = int(amount) + + def should_stop(): + return self.trials_since_best > amount + + else: + + def should_stop(): + return (time.time() - t0) > self.max_time + + else: + + def should_stop(): + return False + + trial_fn, trial_args = self.setup(inputs, output, size_dict) + + r_start = self._repeats_start + len(self.scores) + r_stop = r_start + self.max_repeats + repeats = range(r_start, r_stop) + + # create the trials lazily + if self._pool is not None: + trials = self._gen_results_parallel(repeats, trial_fn, trial_args) + else: + trials = self._gen_results(repeats, trial_fn, trial_args) + + if self.progbar: + import tqdm + + pbar = tqdm.tqdm(trials, total=self.max_repeats) + pbar.set_description( + progress_description(self.best), refresh=False + ) + trials = pbar + + # assess the trials + for trial in trials: + # check if we have found a new best + if trial["score"] < self.best["score"]: + self.trials_since_best = 0 + self.best = trial + self.best["params"] = dict(self.param_choices[-1]) + self.best["params"]["method"] = self.method_choices[-1] + + if self.progbar: + pbar.set_description( + progress_description(self.best), refresh=False + ) + + else: + self.trials_since_best += 1 + + # check if we have run out of time + if should_stop(): + break + + if self.progbar: + pbar.close() + + self._maybe_cancel_futures() + + # rebuild the best ContractionTree once from lightweight path info + if "_path" in self.best: + tree = ContractionTree.from_path( + inputs, output, size_dict, path=self.best["_path"] + ) + for ind, si in self.best["_sliced_inds"].items(): + tree.remove_ind_(ind, project=si.project, inplace=True) + self.best["tree"] = tree + del self.best["_path"], self.best["_sliced_inds"] + + def search(self, inputs, output, size_dict): + """Run this optimizer and return the ``ContractionTree`` for the best + path it finds. + """ + self._search( + inputs, + output, + size_dict, + ) + return self.tree + + def get_tree(self): + """Return the ``ContractionTree`` for the best path found.""" + return self.tree + + def __call__(self, inputs, output, size_dict, memory_limit=None): + """``opt_einsum`` interface, returns direct ``path``.""" + self._search( + inputs, + output, + size_dict, + ) + return tuple(self.path) + + def get_trials(self, sort=None): + trials = list( + zip( + self.method_choices, + self.costs_size, + self.costs_flops, + self.costs_write, + self.param_choices, + ) + ) + + if sort == "method": + trials.sort(key=lambda t: t[0]) + if sort == "combo": + trials.sort( + key=lambda t: log2(t[1]) / 1e3 + log2(t[2] + 256 * t[3]) + ) + if sort == "size": + trials.sort( + key=lambda t: log2(t[1]) + log2(t[2]) / 1e3 + log2(t[3]) / 1e3 + ) + if sort == "flops": + trials.sort( + key=lambda t: log2(t[1]) / 1e3 + log2(t[2]) + log2(t[3]) / 1e3 + ) + if sort == "write": + trials.sort( + key=lambda t: log2(t[1]) / 1e3 + log2(t[2]) / 1e3 + log2(t[3]) + ) + + return trials + + def print_trials(self, sort=None): + header = "{:>11} {:>11} {:>11} {}" + print( + header.format( + "METHOD", + "log2[SIZE]", + "log10[FLOPS]", + "log10[WRITE]", + "PARAMS", + ) + ) + row = "{:>11} {:>11.2f} {:>11.2f} {:>11.2f} {}" + for choice, size, flops, write, params in self.get_trials(sort): + print( + row.format( + choice, log2(size), log10(flops), log10(write), params + ) + ) + + def to_df(self): + """Create a single ``pandas.DataFrame`` with all trials and scores.""" + import pandas + + return pandas.DataFrame( + data={ + "run": list(range(len(self.costs_size))), + "time": self.times, + "method": self.method_choices, + "size": list(map(log2, self.costs_size)), + "flops": list(map(log10, self.costs_flops)), + "write": list(map(log10, self.costs_write)), + "random_strength": [ + p.get("random_strength", 1e-6) for p in self.param_choices + ], + "score": self.scores, + } + ).sort_values(by="method") + + def to_dfs_parametrized(self): + """Create a ``pandas.DataFrame`` for each method, with all parameters + and scores for each trial. + """ + import pandas as pd + + rows = {} + for i in range(len(self.scores)): + row = { + "run": i, + "time": self.times[i], + **self.param_choices[i], + "flops": log10(self.costs_flops[i]), + "write": log2(self.costs_write[i]), + "size": log2(self.costs_size[i]), + "score": self.scores[i], + } + method = self.method_choices[i] + rows.setdefault(method, []).append(row) + + return { + method: pd.DataFrame(rows[method]).sort_values(by="score") + for method in rows + } + + plot_trials = plot_trials + plot_trials_alt = plot_trials_alt + plot_scatter = plot_scatter + plot_scatter_alt = plot_scatter_alt + plot_parameters_parallel = plot_parameters_parallel + + +class ReusableHyperOptimizer(ReusableOptimizer): + """Like ``HyperOptimizer`` but it will re-instantiate the optimizer + whenever a new contraction is detected, and also cache the paths (and + sliced indices) found. + + Parameters + ---------- + directory : None, True, or str, optional + If specified use this directory as a persistent cache. If ``True`` auto + generate a directory in the current working directory based on the + options which are most likely to affect the path (see + `ReusableHyperOptimizer._get_path_relevant_opts`). + overwrite : bool or 'improved', optional + If ``True``, the optimizer will always run, overwriting old results in + the cache. This can be used to update paths without deleting the whole + cache. If ``'improved'`` then only overwrite if the new path is better. + hash_method : {'a', 'b', ...}, optional + The method used to hash the contraction tree. The default, ``'a'``, is + faster hashwise but doesn't recognize when indices are permuted. + cache_only : bool, optional + If ``True``, the optimizer will only use the cache, and will raise + ``KeyError`` if a contraction is not found. + directory_split : "auto" or bool, optional + If specified, the hash will be split into two parts, the first part + will be used as a subdirectory, and the second part will be used as the + filename. This is useful for avoiding a very large flat diretory. If + "auto" it will check the current cache if any and guess from that. + opt_kwargs + Supplied to ``HyperOptimizer``. + """ + + def _get_path_relevant_opts(self): + """Get a frozenset of the options that are most likely to affect the + path. These are the options that we use when the directory name is not + manually specified. + """ + return [ + ("methods", None), + ("minimize", "flops"), + ("max_repeats", 128), + ("max_time", None), + ("slicing_opts", None), + ("slicing_reconf_opts", None), + ("simulated_annealing_opts", None), + ("reconf_opts", None), + ("compressed", False), + ("multicontraction", False), + ] + + def _get_suboptimizer(self): + return HyperOptimizer(**self._suboptimizer_kwargs) + + def _deconstruct_tree(self, opt, tree): + return { + "path": tree.get_path(), + "score": tree.get_score(), + # dont' need to store all slice info, just which indices + "sliced_inds": tuple(tree.sliced_inds), + } + + def _reconstruct_tree(self, inputs, output, size_dict, con): + tree = ContractionTree.from_path( + inputs, + output, + size_dict, + path=con["path"], + objective=self.minimize, + ) + + for ix in con["sliced_inds"]: + tree.remove_ind_(ix) + + return tree + + +class HyperCompressedOptimizer(HyperOptimizer): + """A compressed contraction path optimizer that samples a series of ordered + contraction trees while optimizing the hyper parameters used to generate + them. + + Parameters + ---------- + chi : None or int, optional + The maximum bond dimension to compress to. If ``None`` then use the + square of the largest existing dimension. If ``minimize`` is specified + as a full scoring function, this is ignored. + methods : None or sequence[str] or str, optional + Which method(s) to use from ``list_hyper_functions()``. + minimize : str, Objective or callable, optional + How to score each trial, used to train the optimizer and rank the + results. If a custom callable, it should take a ``trial`` dict as its + argument and return a single float. + max_repeats : int, optional + The maximum number of trial contraction trees to generate. + Default: 128. + max_time : None or float, optional + The maximum amount of time to run for. Use ``None`` for no limit. You + can also set an estimated execution 'rate' here like ``'rate:1e9'`` + that will terminate the search when the estimated FLOPs of the best + contraction found divided by the rate is greater than the time spent + searching, allowing quick termination on easy contractions. + parallel : 'auto', False, True, int, or distributed.Client + Whether to parallelize the search. + slicing_opts : dict, optional + If supplied, once a trial contraction path is found, try slicing with + the given options, and then update the flops and size of the trial with + the sliced versions. + slicing_reconf_opts : dict, optional + If supplied, once a trial contraction path is found, try slicing + interleaved with subtree reconfiguation with the given options, and + then update the flops and size of the trial with the sliced and + reconfigured versions. + reconf_opts : dict, optional + If supplied, once a trial contraction path is found, try subtree + reconfiguation with the given options, and then update the flops and + size of the trial with the reconfigured versions. + optlib : {'baytune', 'nevergrad', 'chocolate', 'skopt'}, optional + Which optimizer to sample and train with. + space : dict, optional + The hyper space to search, see ``get_hyper_space`` for the default. + score_compression : float, optional + Raise scores to this power in order to compress or accentuate the + differences. The lower this is, the more the selector will sample from + various optimizers rather than quickly specializing. + max_training_steps : int, optional + The maximum number of trials to train the optimizer with. Setting this + can be helpful when the optimizer itself becomes costly to train (e.g. + for Gaussian Processes). + progbar : bool, optional + Show live progress of the best contraction found so far. + optlib_opts + Supplied to the hyper-optimizer library initialization. + """ + + compressed = True + multicontraction = False + + def __init__( + self, + chi=None, + methods=("greedy-compressed", "greedy-span", "kahypar-agglom"), + minimize="peak-compressed", + **kwargs, + ): + if (chi is not None) and not callable(minimize): + minimize += f"-{chi}" + + kwargs["methods"] = methods + kwargs["minimize"] = minimize + + if kwargs.pop("slicing_opts", None) is not None: + raise ValueError( + "Cannot use slicing_opts with compressed contraction." + ) + if kwargs.pop("slicing_reconf_opts", None) is not None: + raise ValueError( + "Cannot use slicing_reconf_opts with compressed contraction." + ) + + super().__init__(**kwargs) + + +class ReusableHyperCompressedOptimizer(ReusableHyperOptimizer): + """Like ``HyperCompressedOptimizer`` but it will re-instantiate the + optimizer whenever a new contraction is detected, and also cache the paths + found. + + Parameters + ---------- + chi : None or int, optional + The maximum bond dimension to compress to. If ``None`` then use the + square of the largest existing dimension. If ``minimize`` is specified + as a full scoring function, this is ignored. + directory : None, True, or str, optional + If specified use this directory as a persistent cache. If ``True`` auto + generate a directory in the current working directory based on the + options which are most likely to affect the path (see + `ReusableHyperOptimizer._get_path_relevant_opts`). + overwrite : bool, optional + If ``True``, the optimizer will always run, overwriting old results in + the cache. This can be used to update paths without deleting the whole + cache. + hash_method : {'a', 'b', ...}, optional + The method used to hash the contraction tree. The default, ``'a'``, is + faster hashwise but doesn't recognize when indices are permuted. + cache_only : bool, optional + If ``True``, the optimizer will only use the cache, and will raise + ``KeyError`` if a contraction is not found. + opt_kwargs + Supplied to ``HyperCompressedOptimizer``. + """ + + def __init__( + self, + chi=None, + methods=("greedy-compressed", "greedy-span", "kahypar-agglom"), + minimize="peak-compressed", + **kwargs, + ): + if (chi is not None) and not callable(minimize): + minimize += f"-{chi}" + + kwargs["methods"] = methods + kwargs["minimize"] = minimize + + if kwargs.pop("slicing_opts", None) is not None: + raise ValueError( + "Cannot use slicing_opts with compressed contraction." + ) + if kwargs.pop("slicing_reconf_opts", None) is not None: + raise ValueError( + "Cannot use slicing_reconf_opts with compressed contraction." + ) + if kwargs.pop("simulated_annealing_opts", None) is not None: + raise ValueError( + "Cannot use simulated_annealing_opts " + "with compressed contraction." + ) + + super().__init__(**kwargs) + + def _get_suboptimizer(self): + return HyperCompressedOptimizer(**self._suboptimizer_kwargs) + + def _deconstruct_tree(self, opt, tree): + return { + "path": tree.get_path(), + "score": tree.get_score(), + "sliced_inds": tuple(tree.sliced_inds), + } + + def _reconstruct_tree(self, inputs, output, size_dict, con): + return ContractionTreeCompressed.from_path( + inputs, + output, + size_dict, + path=con["path"], + objective=self.minimize, + ) + + +class HyperMultiOptimizer(HyperOptimizer): + compressed = False + multicontraction = True diff --git a/.venv/lib/python3.12/site-packages/cotengra/parallel.py b/.venv/lib/python3.12/site-packages/cotengra/parallel.py new file mode 100644 index 0000000..e08d5d7 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/cotengra/parallel.py @@ -0,0 +1,583 @@ +"""Interface for parallelism.""" + +import atexit +import collections +import functools +import importlib +import inspect +import numbers +import operator +import warnings + +_AUTO_BACKEND = None + +# check for loky, joblib (vendors loky), then default to concurrent.futures +have_loky = importlib.util.find_spec("loky") is not None +have_joblib = importlib.util.find_spec("joblib") is not None +if have_loky or have_joblib: + _DEFAULT_BACKEND = "loky" +else: + _DEFAULT_BACKEND = "concurrent.futures" + + +@functools.lru_cache(None) +def choose_default_num_workers(): + import os + + if "COTENGRA_NUM_WORKERS" in os.environ: + return int(os.environ["COTENGRA_NUM_WORKERS"]) + + if "OMP_NUM_THREADS" in os.environ: + return int(os.environ["OMP_NUM_THREADS"]) + + return os.cpu_count() + + +def get_pool(n_workers=None, maybe_create=False, backend=None): + """Get a parallel pool.""" + if backend is None: + backend = _DEFAULT_BACKEND + + if backend == "dask": + return _get_pool_dask(n_workers=n_workers, maybe_create=maybe_create) + + if backend == "ray": + return _get_pool_ray(n_workers=n_workers, maybe_create=maybe_create) + + # above backends are distributed, don't specify n_workers + if n_workers is None: + n_workers = choose_default_num_workers() + + if backend == "loky": + get_reusable_executor = get_loky_get_reusable_executor() + return get_reusable_executor(max_workers=n_workers) + + if backend == "concurrent.futures": + return _get_process_pool_cf(n_workers=n_workers) + + if backend == "threads": + return _get_thread_pool_cf(n_workers=n_workers) + + +@functools.lru_cache(None) +def _infer_backed_cached(pool_class): + if pool_class.__name__ == "RayExecutor": + return "ray" + + path = pool_class.__module__.split(".") + + if path[0] == "concurrent": + return "concurrent.futures" + + if path[0] == "joblib": + return "loky" + + if path[0] == "distributed": + return "dask" + + return path[0] + + +def _infer_backend(pool): + """Return the backend type of ``pool`` - cached for speed.""" + return _infer_backed_cached(pool.__class__) + + +def get_n_workers(pool=None): + """Extract how many workers our pool has (mostly for working out how many + tasks to pre-dispatch). + """ + if pool is None: + pool = get_pool() + + try: + return pool._max_workers + except AttributeError: + pass + + backend = _infer_backend(pool) + + if backend == "dask": + workers = pool.scheduler_info(n_workers=-1)["workers"] + return sum(int(w.get("nthreads", 1) or 1) for w in workers.values()) + + if backend == "ray": + while True: + try: + return int(get_ray().available_resources()["CPU"]) + except KeyError: + import time + + time.sleep(1e-3) + + if backend == "mpi4py": + from mpi4py import MPI + + return MPI.COMM_WORLD.size + + raise ValueError(f"Can't find number of workers in pool {pool}.") + + +def parse_parallel_arg(parallel): + """ """ + global _AUTO_BACKEND + + if parallel == "auto": + return get_pool(maybe_create=False, backend=_AUTO_BACKEND) + + if parallel is False: + return None + + if parallel is True: + if _AUTO_BACKEND is None: + _AUTO_BACKEND = _DEFAULT_BACKEND + parallel = _AUTO_BACKEND + + if isinstance(parallel, numbers.Integral): + _AUTO_BACKEND = _DEFAULT_BACKEND + return get_pool( + n_workers=parallel, maybe_create=True, backend=_DEFAULT_BACKEND + ) + + if parallel == "loky": + return get_pool(maybe_create=True, backend="loky") + + if parallel == "concurrent.futures": + return get_pool(maybe_create=True, backend="concurrent.futures") + + if parallel == "threads": + return get_pool(maybe_create=True, backend="threads") + + if parallel == "dask": + _AUTO_BACKEND = "dask" + return get_pool(maybe_create=True, backend="dask") + + if parallel == "ray": + _AUTO_BACKEND = "ray" + return get_pool(maybe_create=True, backend="ray") + + return parallel + + +def set_parallel_backend(backend): + """Create a parallel pool of type ``backend`` which registers it as the + default for ``'auto'`` parallel. + """ + return parse_parallel_arg(backend) + + +def maybe_leave_pool(pool): + """Logic required for nested parallelism in dask.distributed.""" + if _infer_backend(pool) == "dask": + return _maybe_leave_pool_dask() + + +def maybe_rejoin_pool(is_worker, pool): + """Logic required for nested parallelism in dask.distributed.""" + if is_worker and _infer_backend(pool) == "dask": + _rejoin_pool_dask() + + +def submit(pool, fn, *args, **kwargs): + """Interface for submitting ``fn(*args, **kwargs)`` to ``pool``.""" + if _infer_backend(pool) == "dask": + kwargs.setdefault("pure", False) + return pool.submit(fn, *args, **kwargs) + + +def scatter(pool, data): + """Interface for maybe turning ``data`` into a remote object or reference.""" + if _infer_backend(pool) in ("dask", "ray"): + return pool.scatter(data) + return data + + +def can_scatter(pool): + """Whether ``pool`` can make objects remote.""" + return _infer_backend(pool) in ("dask", "ray") + + +def should_nest(pool): + """Given argument ``pool`` should we try nested parallelism.""" + if pool is None: + return False + backend = _infer_backend(pool) + if backend in ("ray", "dask"): + return backend + return False + + +# ---------------------------------- loky ----------------------------------- # + + +@functools.lru_cache(1) +def get_loky_get_reusable_executor(): + try: + from loky import get_reusable_executor + except ImportError: + from joblib.externals.loky import get_reusable_executor + return get_reusable_executor + + +# --------------------------- concurrent.futures ---------------------------- # + + +class CachedProcessPoolExecutor: + def __init__(self): + self._pool = None + self._n_workers = -1 + atexit.register(self.shutdown) + + def __call__(self, n_workers=None): + if n_workers != self._n_workers: + from concurrent.futures import ProcessPoolExecutor + + self.shutdown() + self._pool = ProcessPoolExecutor(n_workers) + self._n_workers = n_workers + return self._pool + + def is_initialized(self): + return self._pool is not None + + def shutdown(self): + if self._pool is not None: + self._pool.shutdown() + self._pool = None + + def __del__(self): + self.shutdown() + + +ProcessPoolHandler = CachedProcessPoolExecutor() + + +def _get_process_pool_cf(n_workers=None): + return ProcessPoolHandler(n_workers) + + +class CachedThreadPoolExecutor: + def __init__(self): + self._pool = None + self._n_workers = -1 + atexit.register(self.shutdown) + + def __call__(self, n_workers=None): + if n_workers != self._n_workers: + from concurrent.futures import ThreadPoolExecutor + + self.shutdown() + self._pool = ThreadPoolExecutor(n_workers) + self._n_workers = n_workers + return self._pool + + def is_initialized(self): + return self._pool is not None + + def shutdown(self): + if self._pool is not None: + self._pool.shutdown() + self._pool = None + + def __del__(self): + self.shutdown() + + +ThreadPoolHandler = CachedThreadPoolExecutor() + + +def _get_thread_pool_cf(n_workers=None): + return ThreadPoolHandler(n_workers) + + +# ---------------------------------- DASK ----------------------------------- # + + +def _get_pool_dask(n_workers=None, maybe_create=False): + """Maybe get an existing or create a new dask.distrbuted client. + + Parameters + ---------- + n_workers : None or int, optional + The number of workers to request if creating a new client. + maybe_create : bool, optional + Whether to create an new local cluster and client if no existing client + is found. + + Returns + ------- + None or dask.distributed.Client + """ + try: + from dask.distributed import get_client + except ImportError: + if not maybe_create: + return None + else: + raise + + try: + client = get_client() + except ValueError: + if not maybe_create: + return None + + import shutil + import tempfile + + from dask.distributed import Client, LocalCluster + + local_directory = tempfile.mkdtemp() + lc = LocalCluster( + n_workers=n_workers, + threads_per_worker=1, + local_directory=local_directory, + memory_limit=0, + ) + client = Client(lc) + + warnings.warn( + "Parallel specified but no existing global dask client found... " + "created one (with {} workers).".format(get_n_workers(client)) + ) + + @atexit.register + def delete_local_dask_directory(): + shutil.rmtree(local_directory, ignore_errors=True) + + if n_workers is not None: + current_n_workers = get_n_workers(client) + if n_workers != current_n_workers: + warnings.warn( + "Found existing client (with {} workers which) doesn't match " + "the requested {}... using it instead.".format( + current_n_workers, n_workers + ) + ) + + return client + + +def _maybe_leave_pool_dask(): + try: + from dask.distributed import secede + + secede() # for nested parallelism + is_dask_worker = True + except (ImportError, ValueError): + is_dask_worker = False + return is_dask_worker + + +def _rejoin_pool_dask(): + from dask.distributed import rejoin + + rejoin() + + +# ----------------------------------- RAY ----------------------------------- # + + +@functools.lru_cache(None) +def get_ray(): + """ """ + import ray + + return ray + + +class RayFuture: + """Basic ``concurrent.futures`` like future wrapping a ray ``ObjectRef``.""" + + __slots__ = ("_obj", "_cancelled") + + def __init__(self, obj): + self._obj = obj + self._cancelled = False + + def result(self, timeout=None): + return get_ray().get(self._obj, timeout=timeout) + + def done(self): + return self._cancelled or bool( + get_ray().wait([self._obj], timeout=0)[0] + ) + + def cancel(self): + get_ray().cancel(self._obj) + self._cancelled = True + + +def _unpack_futures_tuple(x): + return tuple(map(_unpack_futures, x)) + + +def _unpack_futures_list(x): + return list(map(_unpack_futures, x)) + + +def _unpack_futures_dict(x): + return {k: _unpack_futures(v) for k, v in x.items()} + + +def _unpack_futures_identity(x): + return x + + +_unpack_dispatch = collections.defaultdict( + lambda: _unpack_futures_identity, + { + RayFuture: operator.attrgetter("_obj"), + tuple: _unpack_futures_tuple, + list: _unpack_futures_list, + dict: _unpack_futures_dict, + }, +) + + +def _unpack_futures(x): + """Allows passing futures by reference - takes e.g. args and kwargs and + replaces all ``RayFuture`` objects with their underyling ``ObjectRef`` + within all nested tuples, lists and dicts. + + [Subclassing ``ObjectRef`` might avoid needing this.] + """ + return _unpack_dispatch[x.__class__](x) + + +@functools.lru_cache(2**14) +def get_remote_fn(fn, **remote_opts): + """Cached retrieval of remote function.""" + ray = get_ray() + if remote_opts: + return ray.remote(**remote_opts)(fn) + return ray.remote(fn) + + +@functools.lru_cache(2**14) +def get_fn_as_remote_object(fn): + ray = get_ray() + return ray.put(fn) + + +@functools.lru_cache(None) +def get_deploy(**remote_opts): + """Alternative for 'non-function' callables - e.g. partial + functions - pass the callable object too. + """ + ray = get_ray() + + def deploy(fn, *args, **kwargs): + return fn(*args, **kwargs) + + if remote_opts: + return ray.remote(**remote_opts)(deploy) + return ray.remote(deploy) + + +class RayExecutor: + """Basic ``concurrent.futures`` like interface using ``ray``.""" + + def __init__(self, *args, default_remote_opts=None, **kwargs): + ray = get_ray() + if not ray.is_initialized(): + ray.init(*args, **kwargs) + + self.default_remote_opts = ( + {} if default_remote_opts is None else dict(default_remote_opts) + ) + + def _maybe_inject_remote_opts(self, remote_opts=None): + """Return the default remote options, possibly overriding some with + those supplied by a ``submit call``. + """ + ropts = self.default_remote_opts + if remote_opts is not None: + ropts = {**ropts, **remote_opts} + return ropts + + def submit(self, fn, *args, pure=False, remote_opts=None, **kwargs): + """Remotely run ``fn(*args, **kwargs)``, returning a ``RayFuture``.""" + # want to pass futures by reference + args = _unpack_futures_tuple(args) + kwargs = _unpack_futures_dict(kwargs) + + ropts = self._maybe_inject_remote_opts(remote_opts) + + # this is the same test ray uses to accept functions + if inspect.isfunction(fn): + # can use the faster cached remote function + obj = get_remote_fn(fn, **ropts).remote(*args, **kwargs) + else: + fn_obj = get_fn_as_remote_object(fn) + obj = get_deploy(**ropts).remote(fn_obj, *args, **kwargs) + + return RayFuture(obj) + + def map(self, func, *iterables, remote_opts=None): + """Remote map ``func`` over arguments ``iterables``.""" + ropts = self._maybe_inject_remote_opts(remote_opts) + remote_fn = get_remote_fn(func, **ropts) + objs = tuple(map(remote_fn.remote, *iterables)) + ray = get_ray() + return map(ray.get, objs) + + def scatter(self, data): + """Push ``data`` into the distributed store, returning an ``ObjectRef`` + that can be supplied to ``submit`` calls for example. + """ + ray = get_ray() + return ray.put(data) + + def shutdown(self): + """Shutdown the parent ray cluster, this ``RayExecutor`` instance + itself does not need any cleanup. + """ + get_ray().shutdown() + + +_RAY_EXECUTOR = None + + +def _get_pool_ray(n_workers=None, maybe_create=False): + """Maybe get an existing or create a new RayExecutor, thus initializing, + ray. + + Parameters + ---------- + n_workers : None or int, optional + The number of workers to request if creating a new client. + maybe_create : bool, optional + Whether to create initialize ray and return a RayExecutor if not + initialized already. + + Returns + ------- + None or RayExecutor + """ + try: + import ray + except ImportError: + if not maybe_create: + return None + else: + raise + + global _RAY_EXECUTOR + + if (_RAY_EXECUTOR is None) or (not ray.is_initialized()): + if not maybe_create: + return None + _RAY_EXECUTOR = RayExecutor(num_cpus=n_workers) + + if n_workers is not None: + current_n_workers = get_n_workers(_RAY_EXECUTOR) + if n_workers != current_n_workers: + warnings.warn( + "Found initialized ray (with {} workers which) doesn't match " + "the requested {}... sticking with old number.".format( + current_n_workers, n_workers + ) + ) + + return _RAY_EXECUTOR diff --git a/.venv/lib/python3.12/site-packages/qmatchatea/py_emulator.py b/.venv/lib/python3.12/site-packages/qmatchatea/py_emulator.py new file mode 100644 index 0000000..85215a8 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/qmatchatea/py_emulator.py @@ -0,0 +1,1009 @@ +# This code is part of qmatchatea. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" +The :py:class:`QCEmulator` class enables full-python simulations. + +Functions and classes +~~~~~~~~~~~~~~~~~~~~~ + +""" + +# pylint: disable=protected-access, bare-except + +import os +import time +import warnings +from copy import deepcopy + +import numpy as np +import psutil +import qredtea as qrt +import qtealeaves.tensors as qtt +from qiskit import QuantumCircuit +from qtealeaves.abstracttns.abstract_tn import _AbstractTN +from qtealeaves.convergence_parameters import TNConvergenceParameters +from qtealeaves.emulator import MPIMPS, MPS, TTN +from qtealeaves.mpos import DenseMPO +from qtealeaves.observables import TNObservables +from qtealeaves.simulation.tn_simulation import run_tn_measurements + +from .circuit import Qcircuit +from .circuit.observables import QCObservableStep +from .utils import QCBackend, SimulationResults +from .utils.tn_utils import QCOperators +from .utils.utils import QCCheckpoints, QCConvergenceParameters + +__all__ = ["QCEmulator", "run_py_simulation"] + + +class QCEmulator: + """ + Emulator class to run quantum circuits, powered by either + TTNs or MPS. + + + Parameters + ---------- + + num_sites: int + Number of sites + convergence_parameters: :py:class:`QCConvergenceParameters` + Class for handling convergence parameters. In particular, in the MPS simulator we are + interested in: + - the *maximum bond dimension* :math:`\\chi`; + - the *cut ratio* :math:`\\epsilon` after which the singular values are neglected, i.e. + if :math:`\\lambda_1` is the bigger singular values then after an SVD we neglect all the + singular values such that :math:`\\frac{\\lambda_i}{\\lambda_1}\\leq\\epsilon` + local_dim: int, optional + Local dimension of the degrees of freedom. Default to 2. + tensor_backend: TensorBackend, optional + Contains all the information on the tensors, such as dtype and device. + Default to TensorBackend() (dtype=np.complex128, device="cpu"). + qc_backend: QCBackend, optional + Backend for the qmatchatea emulation, containing the backend and other important infos. + Default to QCBackend() (ansatz="MPS", precision="A", device="cpu") + initialize: str, optional + Initialization procedure. + Default to "vacuum", the 0000...0 state. + Available: "random", "vacuum", path_to_file + """ + + ansatzes = {"MPS": MPS, "TTN": TTN, "MPIMPS": MPIMPS} + + # pylint: disable-next=too-many-arguments + def __init__( + self, + num_sites, + convergence_parameters=QCConvergenceParameters(), + local_dim=2, + tensor_backend=qtt.TensorBackend(), + qc_backend=QCBackend(), + initialize="vacuum", + ): + if not isinstance(convergence_parameters, TNConvergenceParameters): + raise TypeError( + "convergence_parameters must be of the QCConvergenceParameters class" + ) + if qc_backend.device != tensor_backend.device: + raise ValueError( + "Tensor backend and QCBackend have different devices, " + + f"{tensor_backend.device} and {qc_backend.device} respectively." + ) + + self._trunc_tracking_mode = convergence_parameters.trunc_tracking_mode + self._qc_backend = qc_backend + + # Classical registers to hold qiskit informations + self.cl_regs = {} + + # Observables measured + self.is_measured = [ + True, + False, + False, + True, + True, + True, + True, + True, + True, + False, + False, + ] + + # If a TTN, pad with empty sites until you get to a power of 2 sites + if self.ansatz == "TTN": + if num_sites & (num_sites - 1) == 0: + exponent = np.ceil(np.log2(num_sites)) + num_sites = int(2**exponent) + + # Initialize based on the intialized keyword + if os.path.isfile(initialize): + if initialize.endswith( + "pkl" + self.ansatzes[qc_backend.ansatz.upper()].extension + ): + self.emulator = self.ansatzes[qc_backend.ansatz.upper()].read_pickle( + filename=initialize + ) + elif initialize.endswith( + self.ansatzes[qc_backend.ansatz.upper()].extension + ): + self.emulator = self.ansatzes[qc_backend.ansatz.upper()].read( + filename=initialize, + tensor_backend=tensor_backend, + cmplx=np.iscomplex(np.empty(1, dtype=tensor_backend.dtype))[0], + order="F", + ) + else: + raise IOError(f"Extension {initialize} not supported by QCEmulator") + + self.emulator._tensor_backend = tensor_backend + self.emulator._convergence_parameters = convergence_parameters + else: + self.emulator = self.ansatzes[qc_backend.ansatz.upper()]( + num_sites=num_sites, + convergence_parameters=convergence_parameters, + local_dim=local_dim, + initialize=initialize, + tensor_backend=tensor_backend, + ) + + @property + def tensor_backend(self): + """Tensor backend of the simulation""" + return self.emulator._tensor_backend + + @property + def ansatz(self): + """Ansatz of the emulator""" + return self._qc_backend.ansatz + + def __getattr__(self, __name: str): + """ + Check for the attribute in emulator, i.e. the QCEmulator inherits all + the emulator calls. + This call is for convenience and for retrocompatibility + + .. warning:: + The method `__getattr__` is called when `__getattribute__` fails, + so it already covers the possibility of the attribute being in the + base class + """ + return self.emulator.__getattribute__(__name) + + @classmethod + def from_emulator( + cls, emulator, conv_params=None, tensor_backend=None, qc_backend=QCBackend() + ): + """ + Initialize the QCEmulator class starting from an emulator class, i.e. either + MPS or TTN + + Parameters + ---------- + emulator : :class:`_AbstractTN` + Either an MPS or TTN emulator + conv_params : :class:`TNConvergenceParameters`, optional + Convergence parameters. If None, the convergence parameters of the emulator + are used + tensor_backend: TensorBackend, optional + Contains all the information on the tensors, such as dtype and device. + Default to TensorBackend() (dtype=np.complex128, device="cpu"). + qc_backend : QCBackend(), optional + Backend of the qmatchatea simulation + + Return + ------ + QCEmulator + The quantum circuit emulator class + """ + if not isinstance(emulator, _AbstractTN): + raise TypeError("The emulator should be a TN emulator class") + if conv_params is None: + conv_params = emulator._convergence_parameters + if tensor_backend is None: + tensor_backend = emulator._tensor_backend + + simulator = cls( + emulator.num_sites, + conv_params, + emulator.local_dim, + tensor_backend=tensor_backend, + qc_backend=qc_backend, + ) + emulator._convergence_parameters = conv_params + emulator._tensor_backend = tensor_backend + simulator.emulator = emulator + simulator.emulator.convert( + device=tensor_backend.device, dtype=tensor_backend.dtype + ) + + return simulator + + @classmethod + def from_tensor_list( + cls, tensor_list, conv_params=None, tensor_backend=None, qc_backend=QCBackend() + ): + """ + Initialize the QCEmulator class starting from a tensor list, i.e. either + MPS or TTN + + Parameters + ---------- + tensor_list : list of tensors + Either an MPS or TTN list of tensors + conv_params : :class:`TNConvergenceParameters`, optional + Convergence parameters. If None, the convergence parameters of the emulator + are used + tensor_backend: TensorBackend, optional + Contains all the information on the tensors, such as dtype and device. + Default to TensorBackend() (dtype=np.complex128, device="cpu"). + qc_backend : QCBackend(), optional + Backend of the qmatchatea simulation + + Return + ------ + QCEmulator + The quantum circuit emulator class + """ + # A list of lists is a TTN, while a list of tensors is an MPS + initial_state = cls.ansatzes[qc_backend.ansatz].from_tensor_list( + tensor_list, conv_params=conv_params, tensor_backend=tensor_backend + ) + + simulator = cls.from_emulator( + initial_state, + conv_params=conv_params, + tensor_backend=tensor_backend, + qc_backend=qc_backend, + ) + + return simulator + + def meas_projective( + self, nmeas=1024, qiskit_convention=True, seed=None, unitary_setup=None + ): + """See the parent method""" + return self.emulator.meas_projective( + nmeas=nmeas, + qiskit_convention=qiskit_convention, + seed=seed, + unitary_setup=unitary_setup, + ) + + def to_statevector(self, qiskit_order=True, max_qubit_equivalent=20): + """See the parent method""" + return self.emulator.to_statevector(qiskit_order, max_qubit_equivalent) + + def apply_two_site_gate(self, operator, control, target): + """Apply a two-site gate, regardless of the position on the chain + + Parameters + ---------- + operator : QTeaTensor + Gate to be applied + control : int + control qubit index + target : int + target qubit index + + Returns + ------- + singvals_cut + singular values cut in the process + """ + local_dim = self.local_dim[0] + if operator.shape == (local_dim**2, local_dim**2): + operator = operator.reshape([local_dim] * 4) + # Reorder for qiskit convention on the two-qubits gates + if control < target or self.ansatz == "TTN": + operator = operator.transpose([1, 0, 3, 2]) + + singvals_cut = self.apply_two_site_operator(operator, [control, target]) + + # Trunc tracking mode is stored in self.emulator._convergence_parameters + singvals_cut = self.emulator._postprocess_singvals_cut(singvals_cut) + + # Bring to CPU/host if attribute available via some example tensor; must be + # iso center in case of mixed device + if isinstance(self.emulator, MPIMPS): + # Will have problems with mixed-device MPI-MPS, but we can live + # with this for now. Overwriting `get_tensor_of_site` in MPIMPS + # is definitely necessary + tensor = self.emulator[0] + else: + # in MPS, iso moves to the right and stays on device, TTN is less + # obvious + idx = max([control, target]) + tensor = self.emulator.get_tensor_of_site(idx) + + singvals_cut = tensor.get_of(singvals_cut) + + return [singvals_cut] + + def apply_multi_site_gate(self, operator, sites): + """ + Apply a n-sites gate, regardless of the position on the chain + + Parameters + ---------- + operator : QTeaTensor | List[QTeaTensor] + If a single QTeaTensor, it is the unitary matrix of the + n-qubits gate. If a List[QTeaTensor] it is already + written in the MPO form + sites : List[int] + Sites to which the operator should be applied + + Returns + ------- + singvals_cut + singular values cut in the process + """ + # This site order could be reversed for the qiskit convention + site_order = np.argsort(sites) + local_dim = self.local_dim[sites[0]] + if isinstance(operator, self.tensor_backend.tensor_cls): + operator = operator.reshape([local_dim] * len(sites) * 2) + transpose_idxs = np.arange(operator.ndim).reshape(2, -1) + transpose_idxs[0, :] = transpose_idxs[0, site_order] + transpose_idxs[1, :] = transpose_idxs[1, site_order] + operator.transpose_update(transpose_idxs.reshape(-1)) + operator = DenseMPO.from_matrix( + operator, sites, local_dim, self._convergence_parameters + ) + + singvals_cut = self.apply_mpo(operator) + + # Avoid errors due to no singv cut + singvals_cut = np.append(singvals_cut, 0) + if self._trunc_tracking_mode == "M": + singvals_cut = max(0, singvals_cut.max()) + elif self._trunc_tracking_mode == "C": + singvals_cut = (singvals_cut**2).sum() + + if hasattr(singvals_cut, "get"): + singvals_cut = singvals_cut.get() + + return [singvals_cut] + + def meas_observables(self, observables, operators): + """Measure all the observables + + Parameters + ---------- + observables : :py:class:`TNObservables` + All the observables to be measured + oeprators : :py:class:`TNOperators` + List of operators that form the circuit stored in THE CORRECT DEVICE. + If you are running on GPU the operators should be on the GPU. + + Returns + ------- + TNObservables + Observables with the results in results_buffer + """ + if not isinstance(observables, TNObservables): + raise TypeError("observables must be TNObservables") + + with warnings.catch_warnings(): + # We use a function that raises a warning for a specific thing we are not interested in. + # So we filter it out. + warnings.filterwarnings( + "ignore", + message="Tried to compute energy with no effective operators. Returning nan", + ) + # At the moment, observables are only measured serially + if self.ansatz == "MPIMPS": + if self._qc_backend.mpi_settings[-1] < 0: + self.emulator.reinstall_isometry_serial() + else: + self.emulator.reinstall_isometry_parallel( + self._qc_backend.mpi_settings[-1] + ) + rank = self.emulator.rank + tensor_list = self.emulator.mpi_gather_tn() + if rank != 0: + return observables + emulator = MPS.from_tensor_list( + tensor_list, + self.emulator._convergence_parameters, + self.tensor_backend, + ) + else: + rank = 0 + emulator = self.emulator + + if rank == 0: + emulator.normalize() + observables = run_tn_measurements( + state=emulator, + observables=observables, + operators=operators, + params={}, + tensor_backend=self.tensor_backend, + tn_type=6 if self.ansatz in ("MPS", "MPSMPI") else 5, + ) + + return observables + + def run_circuit_from_instruction(self, op_list, instr_list): + """ + Run a circuit from the istructions. + + Parameters + ---------- + op_list : list of tensors + List of operators that form the circuit + instr_list : list of instructions + Instruction for the circuit, i.e. [op_name, op_idx, [sites] ] + + Return + ------ + singvals_cut : list of float + Singular values cutted, selected through the _trunc_tracking_mode + """ + singvals_cut = [] + for instr in instr_list: + sites = instr[2] + num_sites = len(sites) + idx = instr[1] + if instr[0] == "barrier": + continue + + if num_sites == 1: + self.emulator.apply_one_site_operator(op_list[idx], *sites) + + elif num_sites == 2: + singv_cut = self.apply_two_site_gate(op_list[idx], sites[0], sites[1]) + + # Avoid errors due to no singv cut + singv_cut = np.append(singv_cut, 0) + if self._trunc_tracking_mode == "M": + singvals_cut.append(np.max(singv_cut, initial=0.0)) + elif self._trunc_tracking_mode == "C": + singvals_cut.append(np.sum(singv_cut**2)) + + else: + raise ValueError("Only one and two-site operations are implemented") + return singvals_cut + + # pylint: disable-next=too-many-statements, too-many-branches, too-many-locals + def run_from_qk(self, circuit, operators=None, checkpoints=QCCheckpoints()): + """ + Run a qiskit quantum circuit on the simulator + + Parameters + ---------- + circuit : :py:class:`QuantumCircuit` + qiskit quantum circuit + operators : TNOperators + Operators class + checkpoints : QCCheckpoints + Checkpoints class + + Returns + ------- + List[float] + singular values cutted in the simulation + Dictionary[TNObservables] + The dictionary with the observables measured mid circuit + List[float] + Memory used in the simulation in bytes + """ + # data structure of the quantum circuit + data = circuit.data[checkpoints._initial_line :] + process = psutil.Process() + memory = np.zeros(len(data)) + obs_dict = {} + singvals_cut = [] + for creg in circuit.cregs: + self.cl_regs[creg.name] = np.zeros(creg.size) + + start_time = time.time() + barrier_cnt = 0 + cache_gate_tensors = getattr(self._qc_backend, "cache_gate_tensors", False) + track_memory = getattr(self._qc_backend, "track_memory", True) + gate_tensor_cache = {} + + def cached_gate_tensor(operation, gate_name, num_qubits): + cache_key = None + if cache_gate_tensors: + try: + params = tuple(str(param) for param in operation.params) + cache_key = ( + gate_name, + num_qubits, + params, + str(self.tensor_backend.dtype), + str(self.tensor_backend.device), + ) + except (AttributeError, TypeError): + cache_key = None + + if cache_key is not None and cache_key in gate_tensor_cache: + return gate_tensor_cache[cache_key] + + gate_mat = operation.to_matrix() + gate = self.tensor_backend.tensor_cls.from_elem_array( + gate_mat, self.tensor_backend.dtype, self.tensor_backend.device + ) + if cache_key is not None: + gate_tensor_cache[cache_key] = gate + return gate + + # Run over instances + for idx, instance in enumerate(data): + operation = instance.operation + qubits = instance.qubits + clbits = instance.clbits + gate_name = operation.name + num_qubits = len(qubits) + qubits = [circuit.find_bit(qub).index for qub in qubits] + + # Checking for classical conditions on this gate. + # + # NOTE: Gate.condition will be deprecated in Qiskit 2.0.0 + # so we need to find an alternative way to make this work. + # (https://docs.quantum.ibm.com/api/qiskit/qiskit.circuit.Gate#condition) + + if operation.condition is None: + apply_gate = True + else: + # NOTE: condition should be a tuple (classical_bit, bit_value) + bit_idx = [clbit.index for clbit in operation.condition[0]] + bit_value = self.cl_regs[operation.condition[0].name][bit_idx[0]] + # ^^ possible warning here: we are checking only the first bit_idx + + # Apply the gate only if condition is met: + apply_gate = bit_value == operation.condition[1] + + # Handling special circuit elements. + if gate_name == "barrier": + if self._qc_backend.mpi_settings[barrier_cnt] < 0: + self.emulator.reinstall_isometry_serial() + else: + self.emulator.reinstall_isometry_parallel( + self._qc_backend.mpi_settings[barrier_cnt] + ) + barrier_cnt += 1 + continue + if gate_name == "measure": + meas_state, _ = self.apply_projective_operator(*qubits) + self.cl_regs[clbits[0].register.name][0] = meas_state + apply_gate = False + elif gate_name == "reset": + self.reset(qubits) + apply_gate = False + elif gate_name == "MeasureObservables": + tic = time.time() + obs = self.meas_observables(operation.observables, operators) + toc = time.time() + obs.results_buffer["time"] = tic - start_time + obs.results_buffer["energy"] = None + obs.results_buffer["norm"] = self.norm() + obs.results_buffer["measurement_time"] = toc - tic + obs_dict[operation.label] = obs + continue + if gate_name in ("id", "identity"): + apply_gate = False + # possible bug warning: + # Check that the previous if/elif return either `apply_gate=False` + # or `continue`. Otherwise, it is expected that `operation` has a + # method to_matrix(), which is used to apply the gate if apply_gate==True. + + if apply_gate: + # Grab the operator matrix and move it to the correct device + gate = cached_gate_tensor(operation, gate_name, num_qubits) + if num_qubits == 1: + self.emulator.apply_one_site_operator(gate, *qubits) + elif num_qubits == 2: + singv_cut = self.apply_two_site_gate(gate, *qubits) + singvals_cut += singv_cut + else: + singv_cut = self.apply_multi_site_gate(gate, qubits) + singvals_cut += singv_cut + + if track_memory: + memory[idx] = process.memory_info().rss + # Check if you can change settings every n iterations + self._runtime_checks_updates( + idx + checkpoints._initial_line, self.num_sites, singvals_cut + ) + # Save checkpoints if needed + checkpoints.save_checkpoint(idx + checkpoints._initial_line, self.emulator) + + return singvals_cut, obs_dict, memory + + # pylint: disable-next=too-many-statements, too-many-branches, too-many-locals + def run_from_qcirc(self, qcirc, starting_idx=0, checkpoints=QCCheckpoints()): + """ + Run a simulation starting from a Qcircuit on a portion of the TN state + + Parameters + ---------- + qcirc : :class:`Qcircuit` + Quantum circuit + starting_idx : int, optional + MPS index that correspond to the index 0 of the Qcircuit. Default to 0. + checkpoints : QCCheckpoints, optional + Checkpoints in the simulation + + Returns + ------- + List[float] + singular values cutted in the simulation + Dictionary[TNObservables] + The dictionary with the observables measured mid circuit + List[float] + Memory used in the simulation in bytes + """ + if not isinstance(qcirc, Qcircuit): + raise TypeError(f"qcirc must be of type Qcircuit, not {type(qcirc)}") + + process = psutil.Process() + memory = np.zeros(len(qcirc)) + obs_dict = {} + singvals_cut = [] + start_time = time.time() + cnt = -1 + cache_gate_tensors = getattr(self._qc_backend, "cache_gate_tensors", False) + track_memory = getattr(self._qc_backend, "track_memory", True) + gate_tensor_cache = {} + + def cached_operator_tensor(operation): + cache_key = None + if cache_gate_tensors: + try: + cache_key = ( + operation.name, + tuple(str(param) for param in operation.operator.reshape(-1)), + str(self.tensor_backend.dtype), + str(self.tensor_backend.device), + ) + except (AttributeError, TypeError): + cache_key = None + + if cache_key is not None and cache_key in gate_tensor_cache: + return gate_tensor_cache[cache_key] + + gate = self.tensor_backend.tensor_cls.from_elem_array( + operation.operator, + self.tensor_backend.dtype, + self.tensor_backend.device, + ) + if cache_key is not None: + gate_tensor_cache[cache_key] = gate + return gate + + for layer in qcirc: + for instruction in layer: + cnt += 1 + if cnt < checkpoints._initial_line: + continue + sites = [ss + starting_idx for ss in instruction[1]] + operation = instruction[0] + + # Check for classical conditioning + appy_operation = operation.c_if.is_satisfied(qcirc) + if appy_operation: + # First, check for particular keywords + if isinstance(operation, QCObservableStep): + operators = ( + self.tensor_backend.base_tensor_cls.convert_operator_dict( + operation.operators, + params={}, + symmetries=[], + generators=[], + base_tensor_cls=self.tensor_backend.base_tensor_cls, + dtype=self.tensor_backend.dtype, + device=self.tensor_backend.device, + ) + ) + tic = time.time() + obs = self.meas_observables( + operation.observables, + operators, + ) + toc = time.time() + obs.results_buffer["time"] = tic - start_time + obs.results_buffer["norm"] = self.norm() + obs.results_buffer["measurement_time"] = toc - tic + operation.observables = obs + operation.postprocess_obs_indexing() # Postprocess for qregisters + for elem in obs.obs_list: + obs.results_buffer.update(obs.obs_list[elem].results_buffer) + obs_dict[operation.name] = deepcopy( + operation.observables.results_buffer + ) + del obs + + # Check for particular keywords + elif operation.name == "renormalize": + self.normalize() + elif operation.name == "measure": + res = self.emulator.apply_projective_operator( + *sites, operation.selected_output + ) + # Update measured value + qcirc.modify_cregister( + res, operation.cregister, operation.cl_idx + ) + elif operation.name == "add_site": + self.emulator.add_site(operation.position) + elif operation.name == "remove_site": + self.apply_projective_operator(operation.position, remove=True) + + # Apply gates + elif len(sites) == 1: + gate = cached_operator_tensor(operation) + self.site_canonize(*sites, keep_singvals=True) + self.apply_one_site_operator(gate, *sites) + elif len(sites) == 2: + gate = cached_operator_tensor(operation) + svd_cut = self.apply_two_site_gate(gate, *sites) + singvals_cut += svd_cut + else: + gate = cached_operator_tensor(operation) + svd_cut = self.apply_multi_site_gate(gate, sites) + singvals_cut += svd_cut + + # Check if you can change settings every n iterations + self._runtime_checks_updates(cnt, self.num_sites, singvals_cut) + # Save checkpoints if needed + checkpoints.save_checkpoint(cnt, self.emulator) + if track_memory: + memory[cnt] = process.memory_info().rss + + return singvals_cut, obs_dict, memory + + def _runtime_checks_updates(self, idx, frequency, norm_cut): + """ + Perform the checks to change the device and the precision if + idx%frequency is 0. + + Parameters + ---------- + idx : int + Index of the current operation of the quantum circuit + frequency: int + The checks are done every frequency operations + norm_cut: float + The norm cut in the last simulation + """ + if idx % frequency == 0: + device = self._qc_backend.resolve_device( + self.emulator.current_max_bond_dim, self.tensor_backend.device + ) + precision = self._qc_backend.resolve_precision( + (1 - np.array(norm_cut)).prod() + ) + self.emulator.convert(device=device, dtype=precision) + + +# pylint: disable-next=too-many-statements, too-many-branches, too-many-locals, too-many-arguments +def run_py_simulation( + circ, + local_dim=2, + convergence_parameters=QCConvergenceParameters(), + operators=QCOperators(), + observables=TNObservables(), + initial_state=None, + backend=QCBackend(), + checkpoints=QCCheckpoints(), +): + """ + Transpile the circuit to adapt it to the linear structure of the MPS and run the circuit, + obtaining in output the measurements. + + Parameters + ---------- + circ: QuantumCircuit + qiskit quantum circuit object to simulate + local_dim: int, optional + Local dimension of the single degree of freedom. Default is 2, for qubits + convergence_parameters: :py:class:`QCConvergenceParameters`, optional + Maximum bond dimension and cut ratio. Default to max_bond_dim=10, cut_ratio=1e-9. + operators: :py:class:`QCOperators`, optional + Operator class with the observables operators ALREADY THERE. If None, then it is + initialized empty. Default to None. + observables: :py:class:`TNObservables`, optional + The observables to be measured at the end of the simulation. Default to TNObservables(), + which contains no observables to measure. + initial_state : list of ndarray, optional + Initial state of the simulation. If None, ``|00...0>`` is considered. Default to None. + backend: :py:class:`QCBackend`, optional + Backend containing all the information for where to run the simulation + checkpoints: :py:class:`QCCheckpoints`, optional + Class to handle checkpoints in the simulation + + Returns + ------- + result: qmatchatea.SimulationResults + Results of the simulation, containing the following data: + - Measures + - Statevector + - Computational time + - Singular values cut + - Entanglement + - Measure probabilities + - MPS state + - MPS file size + - Observables measurements + """ + if isinstance(circ, (QuantumCircuit, Qcircuit)): + num_qubits = circ.num_qubits + else: + raise TypeError( + "Only qiskit Quantum Circuits and Qcircuit are implemented for" + + f" simulation, not {type(circ)}" + ) + start = time.time() + tensor_backend = _resolve_tensor_backend( + tensor_module=backend.tensor_module, + device=backend.resolve_device(1, "cpu"), + dtype=backend.resolve_precision(1), + ) + + if backend.mpi_approach != "SR" and backend.ansatz == "MPS": + backend._ansatz = "MPIMPS" + + operators = tensor_backend.base_tensor_cls.convert_operator_dict( + operators, + params={}, + symmetries=[], + generators=[], + base_tensor_cls=tensor_backend.base_tensor_cls, + dtype=tensor_backend.dtype, + device=tensor_backend.device, + ) + # Check if you selected restart from a checkpoint + initial_state = checkpoints.restart_from_checkpoint(initial_state) + + # The scalar check is to avoid a warning + if np.isscalar(initial_state): + if initial_state is None: + initial_state = "vacuum" + simulator = QCEmulator( + num_qubits, + convergence_parameters, + local_dim=local_dim, + tensor_backend=tensor_backend, + qc_backend=backend, + initialize=initial_state.lower(), + ) + elif isinstance(initial_state, _AbstractTN): + simulator = QCEmulator.from_emulator( + initial_state, + conv_params=convergence_parameters, + tensor_backend=tensor_backend, + qc_backend=backend, + ) + else: + simulator = QCEmulator.from_tensor_list( + initial_state, + conv_params=convergence_parameters, + tensor_backend=tensor_backend, + qc_backend=backend, + ) + if isinstance(circ, QuantumCircuit): + singvals_cut, obs_dict, memory = simulator.run_from_qk( + circ, operators, checkpoints=checkpoints + ) + elif isinstance(circ, Qcircuit): + singvals_cut, obs_dict, memory = simulator.run_from_qcirc( + circ, checkpoints=checkpoints + ) + else: + # Duplicate from above, but makes linter happy + raise TypeError( + "Only qiskit Quantum Circuits and Qcircuit are implemented for pure python" + + f" simulation, not {type(circ)}" + ) + + end = time.time() + + tic = time.time() + observables = simulator.meas_observables(observables, operators) + toc = time.time() + + observables.results_buffer["time"] = end - start + observables.results_buffer["energy"] = None + observables.results_buffer["norm"] = simulator.norm() + observables.results_buffer["measurement_time"] = toc - tic + observables.results_buffer["memory"] = memory / (1024**3) + + result_dict = observables.results_buffer + + # Observables postprocessing + postprocess = False + if simulator.ansatz == "MPIMPS": + if simulator.rank == 0: + postprocess = True + else: + postprocess = True + + if postprocess: + for elem in observables.obs_list: + result_dict.update(observables.obs_list[elem].results_buffer) + # Special treatment for TNState2file + if str(elem) == "TNState2File": + for value in observables.obs_list[elem].name: + result_dict["tn_state_path"] = observables.obs_list[ + elem + ].results_buffer[value] + + # Storing the results of measurement happened mid-circuit + # under their label + # pylint: disable-next=consider-using-dict-items + for label in obs_dict: + obs_values = obs_dict[label] + tmp = obs_values.results_buffer + for elem in obs_values.obs_list: + tmp.update(obs_values.obs_list[elem].results_buffer) + # Special treatment for TNState2file + if str(elem) == "TNState2File": + for value in obs_values.obs_list[elem].name: + tmp["tn_state_path"] = observables.obs_list[ + elem + ].results_buffer[value] + + result_dict[label] = tmp + + results = SimulationResults() + results.set_results(result_dict, singvals_cut) + + return results + + +def _resolve_tensor_backend(tensor_module, device, dtype): + """ + Resolve the string name of the module used for the tensor + operations. + + Parameters + ---------- + tensor_module : str + Name of the module used for the tensor operations + + Returns + ------- + qtealeaves.tensors._AbstractTensor + """ + + # First fake initialization, to have access to the tensor_cls for + # the correct dtype + if tensor_module == "numpy": + tensor_backend = qtt.TensorBackend() + elif tensor_module == "torch": + tensor_backend = qrt.torchapi.default_pytorch_backend() + elif tensor_module == "tensorflow": + tensor_backend = qrt.tensorflowapi.default_tensorflow_backend() + elif tensor_module == "jax": + tensor_backend = qrt.jaxapi.default_jax_backend() + else: + raise ValueError(f"Tensor class with {tensor_module} is not available.") + + # Get the correct dtype + tmp_tensor = tensor_backend([1, 1]) + dtype = tmp_tensor.dtype_from_char(dtype) + + # Return real tensor backend with correct dtype + if tensor_module == "numpy": + return qtt.TensorBackend(device=device, dtype=dtype) + if tensor_module == "torch": + return qrt.torchapi.default_pytorch_backend(device=device, dtype=dtype) + if tensor_module == "tensorflow": + return qrt.tensorflowapi.default_tensorflow_backend(device=device, dtype=dtype) + if tensor_module == "jax": + return qrt.jaxapi.default_jax_backend(device=device, dtype=dtype) + + # Makes linter happy + raise ValueError(f"Tensor class with {tensor_module} is not available.") diff --git a/.venv/lib/python3.12/site-packages/qredtea/torchapi/qteatorchtensor.py b/.venv/lib/python3.12/site-packages/qredtea/torchapi/qteatorchtensor.py new file mode 100644 index 0000000..753f778 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/qredtea/torchapi/qteatorchtensor.py @@ -0,0 +1,2997 @@ +# This code is part of qredtea. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" +Tensor class based on pytorch; pytorch supports both CPU and GPU in one framework. +""" + +# pylint: disable=too-many-arguments +# pylint: disable=too-many-branches +# pylint: disable=too-many-lines +# pylint: disable=too-many-locals +# pylint: disable=too-many-public-methods +# pylint: disable=too-many-statements +# pylint: disable=wrong-import-position + +import gc +import itertools +import logging +import math +from contextlib import nullcontext + +import numpy as np + +# All imports cause problems for the build server at the moment +# * torch is simply not install +# * qtealeaves is lost in the conan generator somewhere +from qtealeaves.tooling.devices import _CPU_DEVICE, _GPU_DEVICE, _XLA_DEVICE, DeviceList +from qtealeaves.tooling.mpisupport import MPI, TN_MPI_TYPES + +# pylint: disable-next=wrong-import-order,ungrouped-imports +from qredtea.tooling import QRedTeaBackendLibraryImportError + +try: + import torch as to +except ImportError as exc: + raise QRedTeaBackendLibraryImportError() from exc +GPU_AVAILABLE = to.cuda.is_available() +try: + # pylint: disable-next=unused-import + import torch_xla + + # Okay, this is a bold assumption, but how to get the device count? + XLA_AVAILABLE = True + import torch_xla.core.xla_model as xm + + # xla_device is global variable, similar to GPU, we cannot select + # a specific device if multiple are avaiable. + # pylint: disable-next=invalid-name + xla_device = xm.xla_device() +except ImportError: + XLA_AVAILABLE = False + # pylint: disable-next=invalid-name + xla_device = None + +# pylint: disable=import-error,ungrouped-imports,no-name-in-module +from qtealeaves.convergence_parameters import TNConvergenceParameters +from qtealeaves.operators import TNOperators + +# pylint: disable-next=unused-import +from qtealeaves.solvers import DenseTensorEigenSolverH, EigenSolverH +from qtealeaves.tensors import ( + QteaTensor, + TensorBackend, + _AbstractDataMover, + _AbstractQteaBaseTensor, + _parse_block_size, + _process_svd_ctrl, +) +from qtealeaves.tooling import write_tensor + +from qredtea.symmetries import AbelianSymmetryInjector, QteaAbelianTensor +from qredtea.tooling import ( + QRedTeaError, + QRedTeaLinAlgError, + QRedTeaLinkError, + QRedTeaRankError, +) + +# pylint: enable=import-error,ungrouped-imports,no-name-in-module + +ACCELERATOR_DEVICES = DeviceList([_GPU_DEVICE, _XLA_DEVICE]) + +# pylint: disable-next=invalid-name +_BLOCK_SIZE_BOND_DIMENSION, _BLOCK_SIZE_BYTE = _parse_block_size() + +# pylint: disable-next=invalid-name +_USE_STREAMS = False + + +__all__ = [ + "QteaTorchTensor", + "default_pytorch_backend", + "default_abelian_pytorch_backend", + "set_block_size_qteatorchtensors", + "get_gpu_available", + "DataMoverPytorch", +] + +logger = logging.getLogger(__name__) + + +def get_gpu_available(): + """Returns boolean on availability of GPU.""" + return GPU_AVAILABLE + + +def set_block_size_qteatorchtensors( + block_size_bond_dimension=None, block_size_byte=None +): + """ + Allows to overwrite bond dimension decisions to enhance performance + on hardware by keeping "better" or "consistent" bond dimensions. + Only one of the two can be used right now. + + **Arguments** + + block_size_bond_dimension : int + Direct handling of bond dimension. + + block_size_byte : int + Control dimension of tensors (in SVD cuts) via blocks of bytes. + For example, nvidia docs suggest multiples of sixteen float64 + or 32 float32 for A100, i.e., 128 bytes. + """ + # pylint: disable-next=invalid-name,global-statement + global _BLOCK_SIZE_BOND_DIMENSION + # pylint: disable-next=invalid-name,global-statement + global _BLOCK_SIZE_BYTE + + _BLOCK_SIZE_BOND_DIMENSION = block_size_bond_dimension + _BLOCK_SIZE_BYTE = block_size_byte + + if (block_size_bond_dimension is not None) and (block_size_byte is not None): + # We do not want to handle both of them, will be ignored lateron, + # but raise warning as early as possible + logger.warning( + "Ignoring BLOCK_SIZE_BOND_DIMENSION in favor of BLOCK_SIZE_BYTE." + ) + + +def set_streams_qteatorchtensors(use_streams): + """ + Allow to decide if streams are used. + + **Arguments** + + use_streams : bool + If True, streams will be used, otherwise we return a nullcontext even + if streams would be possible. + + """ + # pylint: disable-next=invalid-name,global-statement + global _USE_STREAMS + + _USE_STREAMS = use_streams + + +# class set_block_size_qteatorchtensors once to resolve if both variables +# are set +set_block_size_qteatorchtensors(_BLOCK_SIZE_BOND_DIMENSION, _BLOCK_SIZE_BYTE) + + +class QteaTorchStream(to.cuda.StreamContext): + # pylint: disable=too-few-public-methods + """ + Wrapper for stream with torch which provides access to synchronize. + """ + + def __init__(self): + self.stream = to.cuda.Stream() + super().__init__(self.stream) + + def synchronize(self, *args, **kwargs): + """Synchronize the stream used within the instance.""" + self.stream.synchronize() + + +class QteaTorchTensor(_AbstractQteaBaseTensor): + """ + Tensor for Quantum TEA based on the pytorch tensors. + """ + + implemented_devices = DeviceList([_CPU_DEVICE, _GPU_DEVICE, _XLA_DEVICE]) + + def __init__( + self, + links, + ctrl="Z", + are_links_outgoing=None, # pylint: disable=unused-argument + base_tensor_cls=None, # pylint: disable=unused-argument + dtype=to.complex128, + device=None, + requires_grad=None, + ): + """ + + links : list of ints with shape (links works towards generalization) + """ + super().__init__(links) + + if ctrl is None: + self._elem: to.Tensor = None + return + + if requires_grad is None: + requires_grad = False + if ctrl in ["N"]: + self._elem = to.empty(links, dtype=dtype, requires_grad=requires_grad) + elif ctrl in ["O"]: + self._elem = to.ones(links, dtype=dtype, requires_grad=requires_grad) + elif ctrl in ["Z"]: + self._elem = to.zeros(links, dtype=dtype, requires_grad=requires_grad) + elif ctrl in ["1", "eye"]: + if len(links) != 2: + raise ValueError("Initialization with identity only for rank-2.") + if links[0] != links[1]: + raise ValueError("Initialization with identity only for square matrix.") + self._elem = to.eye(links[0], dtype=dtype, requires_grad=requires_grad) + elif ctrl in ["R", "random"]: + if dtype in [to.complex64, to.complex128]: + self._elem = to.rand( + *links, requires_grad=requires_grad + ) + 1j * to.rand(*links, requires_grad=requires_grad) + else: + self._elem = to.rand(*links, requires_grad=requires_grad) + elif ctrl in ["ground"]: + dim = int(to.prod(to.Tensor(links)).item()) + self._elem = to.zeros([dim], dtype=dtype) + self._elem[0] = 1.0 + self._elem = to.reshape(self._elem, links) + self._elem.requires_grad_(requires_grad) + elif np.isscalar(ctrl) and np.isreal(ctrl): + # This prohibits initialization with complex numbers. + # In case of adding complex numbers, enforce a complex dtype! + self._elem = ctrl * to.ones(links, dtype=dtype, requires_grad=requires_grad) + else: + raise QRedTeaError(f"Unknown initialization {ctrl} of type {type(ctrl)}.") + + self.convert(dtype, device) + + if (not self._elem.is_leaf) and (requires_grad): + self._elem = self.elem.detach().clone().requires_grad_(True) + + # -------------------------------------------------------------------------- + # Properties + # -------------------------------------------------------------------------- + + @property + def are_links_outgoing(self): + """Define property of outgoing links as property (always False).""" + return [False] * self.ndim + + @property + def device(self): + """Device where the tensor is stored.""" + return self.device_str(self._elem) + + @property + def elem(self): + """Elements of the tensor.""" + return self._elem + + @property + def dtype(self): + """Data type of the underlying arrays.""" + return self._elem.dtype + + @property + def dtype_eps(self): + """Data type's machine precision.""" + eps_dict = { + "torch.float16": 1e-3, + "torch.float32": 1e-7, + "torch.float64": 1e-14, + "torch.complex64": 1e-7, + "torch.complex128": 1e-14, + } + + return eps_dict[str(self.dtype)] + + @property + def linear_algebra_library(self): + """Specification of the linear algebra library used as string `torch``.""" + return "torch" + + @property + def links(self): + """Here, as well dimension of tensor along each dimension.""" + return self.shape + + @property + def ndim(self): + """Rank of the tensor.""" + return self._elem.ndim + + @property + def shape(self): + """Dimension of tensor along each dimension.""" + return tuple(self._elem.shape) + + # -------------------------------------------------------------------------- + # Data type tooling beyond properties + # -------------------------------------------------------------------------- + + def dtype_from_char(self, dtype): + """Resolve data type from chars C, D, S, Z and optionally H.""" + data_types = { + "A": to.complex128, + "C": to.complex64, + "D": to.float64, + "H": to.float16, + "S": to.float32, + "Z": to.complex128, + "I": to.int32, + } + + return data_types[dtype] + + # inherits `dtype_from_mpi` + # inherits `dtype_real` + # inherits `dtype_to_char` + + # -------------------------------------------------------------------------- + # Overwritten operators + # -------------------------------------------------------------------------- + # + # inherit def __eq__ + # inherit def __ne__ + + def __add__(self, other): + """ + Addition of a scalar to a tensor adds it to all the entries. + If other is another tensor, elementwise addition if they have the same shape + """ + new_tensor = self.copy() + if isinstance(other, QteaTorchTensor): + if self.elem.requires_grad: + # torch backpropagation dislikes in-place operations. + new_tensor._elem = new_tensor.elem + other.elem + else: + new_tensor._elem += other.elem + elif not to.is_tensor(other): + # Assume it is scalar then + new_tensor._elem += other + else: + raise TypeError( + "Addition for QteaTensor is defined only for scalars and QteaTensor," + + f" not {type(other)}" + ) + return new_tensor + + def __iadd__(self, other): + """In-place addition of tensor with tensor or scalar (update).""" + if isinstance(other, QteaTorchTensor): + self._elem += other.elem + elif not to.is_tensor(other): + # Assume it is scalar then + self._elem += other + else: + raise TypeError( + "Addition for QteaTorchTensor is defined only for scalars" + + f" and QteaTorchTensor, not {type(other)}" + ) + return self + + def __mul__(self, factor): + """Multiplication of tensor with scalar returning new tensor as result.""" + return self.from_elem_array( + factor * self._elem, dtype=self.dtype, device=self.device + ) + + def __matmul__(self, other): + """Matrix multiplication as contraction over last and first index of self and other.""" + idx = self.ndim - 1 + return self.tensordot(other, ([idx], [0])) + + def __imul__(self, factor): + """In-place multiplication of tensor with scalar (update).""" + self._elem *= factor + return self + + def __itruediv__(self, factor): + """In-place division of tensor with scalar (update).""" + if factor == 0: + raise ZeroDivisionError("Trying to divide by zero.") + self._elem /= factor + return self + + def __sub__(self, other): + """ + Subtraction of a scalar to a tensor subtracts it to all the entries. + If other is another tensor, elementwise subtraction if they have the same shape + """ + new_tensor = self.copy() + if isinstance(other, QteaTorchTensor): + new_tensor._elem -= other.elem + elif not to.is_tensor(other): + # Assume it is scalar then + new_tensor._elem -= other + else: + raise TypeError( + "Subtraction for QteaTorchTensor is defined only for scalars" + + f" and QteaTorchTensor, not {type(other)}" + ) + return new_tensor + + def __truediv__(self, factor): + """Division of tensor with scalar.""" + if factor == 0: + raise ZeroDivisionError("Trying to divide by zero.") + elem = self._elem / factor + return self.from_elem_array(elem, dtype=self.dtype, device=self.device) + + def __neg__(self): + """Negative of a tensor returned as a new tensor.""" + # pylint: disable-next=invalid-unary-operand-type + neg_elem = -self._elem + return self.from_elem_array(neg_elem, dtype=self.dtype, device=self.device) + + # -------------------------------------------------------------------------- + # Printing functions + # -------------------------------------------------------------------------- + + def __str__(self): + """ + Output of print() function. + """ + elem_str = str(self._elem) + elem_str = elem_str[elem_str.find("[") : elem_str.rfind("]") + 1] + + return ( + f"{self.__class__.__name__}(" + elem_str + ", " + f"shape={self.shape}, dtype={self.dtype}, device={self.device})" + ) + + # -------------------------------------------------------------------------- + # classmethod, classmethod like + # -------------------------------------------------------------------------- + + @staticmethod + def convert_operator_dict( + op_dict, + params=None, + symmetries=None, + generators=None, + base_tensor_cls=None, + dtype=to.complex128, + device=_CPU_DEVICE, + ): + """ + Iterate through an operator dict and convert the entries. Converts as well + to rank-4 tensors. + + **Arguments** + + op_dict : instance of :class:`TNOperators` + Contains the operators as xp.ndarray. + + params : dict, optional + To resolve operators being passed as callable. + + symmetries: list, optional, for compatability with symmetric tensors. + Must be empty list. + + generators : list, optional, for compatability with symmetric tensors. + Must be empty list. + + base_tensor_cls : None, optional, for compatability with symmetric tensors. + No checks on this one here. + + dtype : data type for xp, optional + Specify data type. + Default to `to.complex128` + + device : str + Device for the simulation. Available "cpu" and "gpu" + Default to "cpu" + + **Details** + + The conversion to rank-4 tensors is useful for future implementations, + either to support adding interactions with a bond dimension greater than + one between them or for symmetries. We add dummy links of dimension one. + The order is (dummy link to the left, old link-1, old link-2, dummy link + to the right). + """ + qteatensor_dict = QteaTensor.convert_operator_dict( + op_dict, + params=params, + symmetries=symmetries, + generators=generators, + base_tensor_cls=base_tensor_cls, + dtype=np.complex128, + device=_CPU_DEVICE, + ) + + new_op_dict = TNOperators( + set_names=qteatensor_dict.set_names, + mapping_func=qteatensor_dict.mapping_func, + ) + for key, value in qteatensor_dict.items(): + new_op_dict[key] = QteaTorchTensor.from_qteatensor(value) + new_op_dict[key].convert(dtype, device) + + return new_op_dict + + def copy(self, dtype=None, device=None): + """Make a copy of a tensor; using detach and clone and keeping the requires grad option""" + if dtype is None: + dtype = self.dtype + if device is None: + device = self.device + + if self.elem.requires_grad or self.elem.grad_fn is not None: + return self.from_elem_array(self._elem.clone(), dtype=dtype, device=device) + + return self.from_elem_array( + self._elem.clone().detach(), dtype=dtype, device=device + ) + + def eye_like(self, link): + """ + Generate identity matrix. + + **Arguments** + + self : instance of :class:`QteaTensor` + Extract data type etc from this one here. + + link : same as returned by `links` property, here integer. + Dimension of the square, identity matrix. + """ + elem = to.eye(link) + return self.from_elem_array(elem, dtype=self.dtype, device=self.device) + + @classmethod + def from_qteatensor(cls, qteatensor, dtype=None, device=None): + """Convert QteaTensor based on numpy/cupy into QteaTorchTensor.""" + elem = to.from_numpy(qteatensor.elem) + return cls.from_elem_array(elem, dtype=dtype, device=device) + + def random_unitary(self, link): + """ + Generate a random unitary matrix via performing a SVD on a + random tensor. + + **Arguments** + + self : instance of :class:`QteaTensor` + Extract data type etc from this one here. + + link : same as returned by `links` property, here integer. + Dimension of the square, random unitary matrix. + """ + elem = to.rand(link, link) + elem, _, _ = to.linalg.svd(elem, full_matrices=False) + + return self.from_elem_array(elem, dtype=self.dtype, device=self.device) + + @classmethod + def read(cls, filehandle, dtype, device, base_tensor_cls, cmplx=True, order="F"): + """Read a tensor from file via QteaTensor.""" + qteatensor = QteaTensor.read( + filehandle, + np.complex128, + _CPU_DEVICE, + base_tensor_cls, + cmplx=cmplx, + order=order, + ) + obj = cls.from_qteatensor(qteatensor) + obj.convert(dtype, device) + return obj + + # Overwrite method due to requires_grad + def zeros_like(self, requires_grad=False): + """Get a tensor same as `self` but filled with zeros.""" + return type(self)( + self.shape, + ctrl="Z", + dtype=self.dtype, + device=self.device, + requires_grad=requires_grad, + ) + + def identity_like(self, fuse_point, requires_grad=False): + """Get an identity for the legs fused as: (0, fuse_point),(fuse_point+1,...). + Same shape as `self`.""" + mat = self.copy() + + # fuse legs + mat.fuse_links_update(0, fuse_point) + mat.fuse_links_update(1, mat.ndim - 1) + # make identity + identity = type(self)( + mat.shape, + ctrl="1", + dtype=self.dtype, + device=self.device, + requires_grad=requires_grad, + ) + # reshape back into the original shape + identity.reshape_update(self.shape) + return identity + + # -------------------------------------------------------------------------- + # Checks and asserts + # -------------------------------------------------------------------------- + # + # inherit def assert_normalized + # inherit def assert_unitary + # inherit def sanity_check + + def are_equal(self, other, tol=1e-7): + """Check if two tensors are equal.""" + if self.ndim != other.ndim: + return False + + if np.any(self.shape != other.shape): + return False + + return to.isclose(self._elem, other.elem, atol=tol, rtol=tol).all().item() + + def assert_identical_irrep(self, link_idx): + """Assert that specified link is identical irreps.""" + if self.shape[link_idx] != 1: + raise QRedTeaLinkError("Link dim is greater one in identical irrep check.") + + def assert_identity(self, tol=1e-7): + """Check if tensor is an identity matrix.""" + if not self.is_close_identity(tol=tol): + logger.error("Error information tensor:\n%s", self._elem) + raise QRedTeaError("Tensor not diagonal with ones.") + + def is_close_identity(self, tol=1e-7): + """Check if rank-2 tensor is close to identity.""" + if self.ndim != 2: + return False + + if self.shape[0] != self.shape[1]: + return False + + eye = to.eye(self.shape[0], device=self._elem.device) + eps = (to.abs(eye - self._elem)).max().item() + + return eps < tol + + def is_implemented_device(self, query): + """ + Check if argument query is an implemented device. + + Parameters + ---------- + + query : str + String to be tested if it corresponds to a device + implemented with this tensor. + + Returns + ------- + + is_implemented : bool + True if string is available as device. + """ + return query in self.implemented_devices + + @staticmethod + def is_xla_static(device_str): + """ + Check if device is XLA or not. + + Parameters + ---------- + + device_str : str + Check for this string if it is a XLA device. + + Returns + ------- + + is_xla : bool + True if device is a XLA. + """ + return device_str.startswith(_XLA_DEVICE) + + def is_xla(self, query=None): + """ + Check if device is XLA or not. + + Parameters + ---------- + + query : str | None, optional + If given, check for this string. If `None`, self.device + will be checked. + Default to None. + + Returns + ------- + + is_xla : bool + True if device is a XLA. + + """ + return self.is_xla_static(self.device if query is None else query) + + def is_dtype_complex(self): + """Check if data type is complex.""" + return self._elem.is_complex() + + # -------------------------------------------------------------------------- + # Single-tensor operations + # -------------------------------------------------------------------------- + # + # inherit def flip_links_update + + # pylint: disable-next=unused-argument + def attach_dummy_link(self, position, is_outgoing=True): + """Attach dummy link at given position (inplace update).""" + self.reshape_update(self._attach_dummy_link_shape(position)) + return self + + def conj(self): + """Return the complex conjugated in a new tensor.""" + # For both real and complex tensors, conj() does not (always?) + # return a true copy, but the "same memory address" which means + # inplace-updates on the conjugate tensor modify the original self. + + # Educated guess: Assuming gradients shoud be preserved in such a + # case, return a clone() and not a detach().clone() + if self._elem.is_complex(): + return self.from_elem_array( + self._elem.clone().conj(), dtype=self.dtype, device=self.device + ) + + return self.from_elem_array( + self._elem.clone(), dtype=self.dtype, device=self.device + ) + + def conj_update(self): + """Apply the complex conjugated to the tensor in place.""" + self._elem = to.conj(self._elem) + + def _convert_check(self, device): + """Run check that device is implemented and available.""" + if device is not None: + if device not in self.implemented_devices: + raise ValueError( + f"Device {device} is not implemented. Select from" + + f" {self.implemented_devices}" + ) + if self.is_gpu(query=device) and (not GPU_AVAILABLE): + raise ImportError("CUDA GPU is not available") + if self.is_xla(query=device) and (not XLA_AVAILABLE): + raise ImportError("XLA is not available.") + + # pylint: disable-next=unused-argument + def convert(self, dtype=None, device=None, stream=None): + """Convert underlying array to the specified data type inplace.""" + # Both devices available, figure out what we currently have + # and start converting + current = self.device + + if self.is_xla() or self.is_xla(query=device): + # Conversion causes problem if data types are converted + # on xla device, handle conversion in separate function + return self._xla_convert(dtype, device, stream) + + if device is not None: + self._convert_check(device) + + if device == current: + # We already are in the correct device + pass + elif self.is_gpu(query=device): + # We go from the cpu to gpu + cuda_str = device.replace(_GPU_DEVICE, "cuda") + self._elem = self._elem.to(device=cuda_str) + elif self.is_cpu(query=device): + # We go from gpu to cpu + self._elem = self._elem.to(device=_CPU_DEVICE) + else: + # CPU to XLA has to go here once not covered in special case + raise QRedTeaError( + f"Conversion {current} to {device} not possible or not considered yet." + ) + + if (dtype is not None) and (dtype != self.dtype): + self._elem = self._elem.type(dtype) + + return self + + def convert_singvals(self, singvals, dtype, device): + """Convert the singular values via a tensor.""" + # pylint: disable-next=invalid-name,global-variable-not-assigned + global xla_device + + if device is not None: + self._convert_check(device) + + # Both devices available, figure out what we currently have + # and start converting + current = self.device + + if device == current: + # We already are in the correct device + pass + elif self.is_gpu(query=device): + # We go to the cpu to gpu + cuda_str = device.replace(_GPU_DEVICE, "cuda") + singvals = singvals.to(device=cuda_str) + elif self.is_cpu(query=device): + # We go from gpu to cpu + singvals = singvals.to(device=_CPU_DEVICE) + elif self.is_xla(query=device): + # We go from the cpu to xla + singvals = singvals.to(device=xla_device) + + if dtype is not None: + dtype = { + "torch.float16": to.float16, + "torch.float32": to.float32, + "torch.float64": to.float64, + "torch.complex32": to.float16, + "torch.complex64": to.float32, + "torch.complex128": to.float64, + }[str(self.dtype)] + + if dtype != singvals.dtype: + target_is_complex = to.is_complex(to.empty((), dtype=dtype)) + if singvals.is_complex() and not target_is_complex: + singvals = singvals.real + singvals = singvals.to(dtype) + return singvals + + def diag(self, real_part_only=False, do_get=False): + """Return the diagonal as array of rank-2 tensor.""" + if self.ndim != 2: + raise QRedTeaRankError("Can only run on rank-2.") + + diag = self._elem.diag() + + if real_part_only and diag.is_complex(): + diag = to.real(diag) + + if self.device in ACCELERATOR_DEVICES and do_get: + diag = diag.to(device=_CPU_DEVICE) + + return diag + + def eig_api( + self, matvec_func, links, conv_params, args_func=None, kwargs_func=None + ): + """ + Interface to hermitian eigenproblem + + **Arguments** + + matvec_func : callable + Mulitplies "matrix" with "vector" + + links : links according to :class:`QteaTensor` + Contain the dimension of the problem. + + conv_params : instance of :class:`TNConvergenceParameters` + Settings for eigenproblem with Arnoldi method. + + args_func : arguments for matvec_func + + kwargs_func : keyword arguments for matvec_func + + **Returns** + + eigenvalues : scalar + + eigenvectors : instance of :class:`QteaTensor` + """ + eig_api_no_arpack = True + eig_api_qtea_half = self.dtype == to.float16 + + # scipy eigsh switches for complex data types to eigs and + # can only solve k eigenvectors of a nxn matrix with + # k < n - 1. This leads to problems with 2x2 matrices + # where one can get not even one eigenvector. + eig_api_qtea_dim2 = self.is_dtype_complex() and (np.prod(self.shape) == 2) + + if eig_api_qtea_half or eig_api_qtea_dim2 or eig_api_no_arpack: + val, vec = self.eig_api_qtea( + matvec_func, + conv_params, + args_func=args_func, + kwargs_func=kwargs_func, + ) + + # Half precision had problems with normalization (most likely + # as eigh is executed on higher precision + vec /= vec.norm_sqrt() + + return val, vec + + return self.eig_api_arpack( + matvec_func, + links, + conv_params, + args_func=args_func, + kwargs_func=kwargs_func, + ) + + def eig_api_qtea(self, matvec_func, conv_params, args_func=None, kwargs_func=None): + """ + Interface to hermitian eigenproblem via qtealeaves.solvers. Arguments see `eig_api`. + """ + solver_cls = ( + EigenSolverH # DenseTensorEigenSolverH if self.is_gpu() else EigenSolverH + ) + solver = solver_cls( + self, + matvec_func, + conv_params, + args_func=args_func, + kwargs_func=kwargs_func, + ) + + return solver.solve() + + def eig_api_arpack( + self, matvec_func, links, conv_params, args_func=None, kwargs_func=None + ): + """ + Interface to hermitian eigenproblem via Arpack. Arguments see `eig_api`. + Possible implementation is https://github.com/rfeinman/Torch-ARPACK. + """ + raise NotImplementedError("Arpack is non-default interface for pytorch.") + + def einsum(self, einsum_str, *others): + """ + Call to einsum with `self` as first tensor. + + Arguments + --------- + + einsum_str : str + Einsum contraction rule. + + other: List[:class:`QteaTorchTensors`] + 2nd, 3rd, ..., n-th tensor in einsum rule as + positional arguments. + + Results + ------- + + tensor : :class:`QteaTorchTensor` + Contracted tensor according to the einsum rules. + + Details + ------- + + The call ``np.einsum(einsum_str, x.elem, y.elem, z.elem)`` translates + into ``x.einsum(einsum_str, y, z)`` for x, y, and z being + :class:`QteaTorchTensor`. + """ + tensors = [self.elem] + [tensor.elem for tensor in others] + elem = to.einsum(einsum_str, *tensors) + + return self.from_elem_array(elem, dtype=self.dtype, device=self.device) + + # pylint: disable-next=unused-argument + def fuse_links_update(self, fuse_low, fuse_high, is_link_outgoing=True): + """ + Fuses one set of links to a single link (inplace-update). + + Parameters + ---------- + fuse_low : int + First index to fuse + fuse_high : int + Last index to fuse. + + Example: if you want to fuse links 1, 2, and 3, fuse_low=1, fuse_high=3. + Therefore the function requires links to be already sorted before in the + correct order. + """ + self.reshape_update(self._fuse_links_update_shape(fuse_low, fuse_high)) + + def get_of(self, variable): + """Run the get method to transfer to host on variable (same device as self).""" + if not isinstance(variable, to.Tensor): + # It is not a to.Tensor, but no other variables can be + # sent back and forth between CPU and device, so it must + # be already on the host. + return variable + + if self.device_str(variable) in ACCELERATOR_DEVICES: + return variable.detach().to(device=_CPU_DEVICE) + + return variable + + def getsizeof(self): + """Size in memory (approximate, e.g., without considering meta data).""" + # Enable fast switch, previously use sys.getsizeof had trouble + # in resolving size, numpy attribute is only for array without + # metadata, but metadata like dimensions is only small overhead. + # (fast switch if we want to use another approach for estimating + # the size of a numpy array) + return self._elem.numel() * self._elem.element_size() + + def get_entry(self): + """Get entry if scalar on host.""" + if np.prod(self.shape) != 1: + raise QRedTeaError("Cannot use `get_entry`, more than one.") + + if self.device in ACCELERATOR_DEVICES: + return self._elem.to(device=_CPU_DEVICE).reshape(-1).item() + + return self._elem.reshape(-1).item() + + @classmethod + def mpi_bcast(cls, tensor, comm, tensor_backend, root=0): + """ + Broadcast QteaTorchTensor via MPI. + """ + is_root = comm.Get_rank() == root + dtype = tensor_backend.dtype + + # Broadcast the dim of the shape + dim = tensor.ndim if is_root else 0 + dim = comm.bcast(dim, root=root) + + # Broadcast shape via numpy + shape = ( + np.array(list(tensor.shape), dtype=int) + if is_root + else np.zeros(dim, dtype=int) + ) + comm.Bcast([shape, TN_MPI_TYPES[".""" + return self.norm_sqrt() ** 2 + + def norm_sqrt(self): + """ + Calculate the square root of the norm of the tensor, + i.e., sqrt( ). + """ + if self.is_xla(): + return self._xla_norm_sqrt() + return to.linalg.vector_norm(self._elem) + + def normalize(self): + """Normalize tensor with sqrt().""" + self._elem /= self.norm_sqrt() + return self + + def remove_dummy_link(self, position): + """Remove the dummy link at given position (inplace update).""" + # Could use xp.squeeze + new_shape = self._remove_dummy_link_shape(position) + self.reshape_update(new_shape) + return self + + def scale_link(self, link_weights, link_idx, do_inverse=False): + """ + Scale tensor along one link at `link_idx` with weights. + + **Arguments** + + link_weights : np.ndarray + Scalar weights, e.g., singular values. + + link_idx : int + Link which should be scaled. + + do_inverse : bool, optional + If `True`, scale with inverse instead of multiplying with + link weights. + Default to `False` + + **Returns** + + updated_link : instance of :class:`QteaTensor` + """ + if do_inverse: + vec = _scale_link_inverse_vector(link_weights) + return self.scale_link(vec, link_idx) + + key = self._scale_link_einsum(link_idx) + tmp = to.einsum(key, self._elem, link_weights) + return self.from_elem_array(tmp, dtype=self.dtype, device=self.device) + + def scale_link_update(self, link_weights, link_idx, do_inverse=False): + """ + Scale tensor along one link at `link_idx` with weights (inplace update). + + **Arguments** + + link_weights : np.ndarray + Scalar weights, e.g., singular values. + + link_idx : int + Link which should be scaled. + + do_inverse : bool, optional + If `True`, scale with inverse instead of multiplying with + link weights. + Default to `False` + """ + if do_inverse: + vec = _scale_link_inverse_vector(link_weights) + return self.scale_link_update(vec, link_idx) + + if link_idx == 0: + shape = list(self.shape) + dim1 = shape[0] + dim2 = np.prod(shape[1:]) + self.reshape_update((dim1, dim2)) + + for ii, scalar in enumerate(link_weights): + self._elem[ii, :] *= scalar + + self.reshape_update(shape) + return self + + if link_idx + 1 == self.ndim: + # For last link xp.multiply will do the job as the + # last index is one memory block anyway + self._elem[:] = self._elem * link_weights + return self + + # Needs permutation, einsum is probably best despite + # not being tuned to work inplace for now + key = self._scale_link_einsum(link_idx) + self._elem = to.einsum(key, self._elem, link_weights) + + return self + + def set_diagonal_entry(self, position, value): + """Set the diagonal element in a rank-2 tensor (inplace update)""" + if self.ndim != 2: + raise QRedTeaRankError("Can only run on rank-2 tensor.") + self._elem[position, position] = value + + def set_matrix_entry(self, idx_row, idx_col, value): + """Set one element in a rank-2 tensor (inplace update)""" + if self.ndim != 2: + raise QRedTeaRankError("Can only run on rank-2 tensor.") + self._elem[idx_row, idx_col] = value + + @staticmethod + def set_seed(seed, devices=None): # pylint: disable=unused-argument + """ + Set the seed for this tensor backend and the specified devices. + + Arguments + --------- + + seed : list[int] + List of integers used as a seed; list has length 4. + + devices : list[str] | None, optional + Can pass a list of devices via a string, e.g., to + specify GPU by index. torch sets the seed for all + devices, so there is no specific need for it as of + now (we keep it for compatability). + Default to `None` (set for all devices) + """ + # Find single integer as seed + elegant_pairing = lambda nn, mm: nn**2 + nn + mm if nn >= mm else mm * 2 + nn + intermediate_a = elegant_pairing(seed[0], seed[1]) + intermediate_b = elegant_pairing(seed[2], seed[3]) + single_seed = elegant_pairing(intermediate_a, intermediate_b) + to.manual_seed(single_seed) + + def set_subtensor_entry(self, corner_low, corner_high, tensor): + """ + Set a subtensor (potentially expensive as looping explicitly, inplace update). + + **Arguments** + + corner_low : list of ints + The lower index of each dimension of the tensor to set. Length + must match rank of tensor `self`. + + corner_high : list of ints + The higher index of each dimension of the tensor to set. Length + must match rank of tensor `self`. + + tensor : :class:`QteaTorchTensor` + Tensor to be set as subtensor. Rank must match tensor `self`. + Dimensions must match `corner_high - corner_low`. + + **Examples** + + To set the tensor of shape 2x2x2 in a larger tensor `self` of shape + 8x8x8 the corresponing call is in comparison to a numpy syntax: + + * self.set_subtensor_entry([2, 4, 2], [4, 6, 4], tensor) + * self[2:4, 4:6, 2:4] = tensor + + To be able to work with all ranks, we currently avoid the numpy + syntax in our implementation. + """ + lists = [] + for ii, corner_ii in enumerate(corner_low): + corner_jj = corner_high[ii] + lists.append(list(range(corner_ii, corner_jj))) + + shape = self.elem.shape + cdim = np.cumprod(np.array(shape[::-1], dtype=int))[::-1] + cdim = np.array(list(cdim[1:]) + [1], dtype=int) + + # Reshape does not make a copy, but points to memory (unlike flatten) + self_1d = self.elem.reshape(-1) + sub_1d = tensor.elem.reshape(-1) + + kk = -1 + for elem in itertools.product(*lists): + kk += 1 + elem = np.array(elem, dtype=int) + idx = np.sum(elem * cdim) + + self_1d[idx] = sub_1d[kk] + + # self._elem never changed shape, we are done + + def to_dense(self, true_copy=False): + """Return dense tensor (if `true_copy=False`, same object may be returned).""" + if true_copy: + return self.copy() + + return self + + def to_dense_singvals(self, s_vals, true_copy=False): + """Convert singular values to dense vector without symmetries.""" + if true_copy: + return s_vals.detach().clone() + + return s_vals + + def trace(self, return_real_part=False, do_get=False): + """Take the trace of a rank-2 tensor.""" + if self.ndim != 2: + raise QRedTeaRankError("Can only run on rank-2 tensor.") + + value = self._elem.trace() + + if return_real_part and value.is_complex(): + value = value.real + + if self.device in ACCELERATOR_DEVICES and do_get: + value = value.to(device=_CPU_DEVICE) + + return value + + def transpose(self, permutation): + """Permute the links of the tensor and return new tensor.""" + elem = self._elem.permute(tuple(permutation)) + return self.from_elem_array(elem, dtype=self.dtype, device=self.device) + + def transpose_update(self, permutation): + """Permute the links of the tensor inplace.""" + self._elem = self._elem.permute(tuple(permutation)) + + def write(self, filehandle, cmplx=None): + """ + Write tensor in original Fortran compatible way. + + **Details** + + 1) Number of links + 2) Line with link dimensions + 3) Entries of tensors line-by-line in column-major ordering. + """ + # Generate numpy version for write + if self.is_cpu(): + elem_np = self._elem.detach().numpy() + else: + elem_np = self._elem.detach().to(device=_CPU_DEVICE).numpy() + + if cmplx is None: + cmplx = np.sum(np.abs(np.imag(elem_np))) > 1e-15 + + write_tensor(elem_np, filehandle, cmplx=cmplx) + + def expm(self, fuse_point=None, prefactor=1): + """ + Take the matrix exponential with a scalar prefactor, Exp(prefactor * self). + Reshapes the tensor into a matrix by fusing links up to INCLUDING fuse_point + into one, and links after into the second dimension. + + Parameters + ---------- + fuse_point : int, optional + If given, reshapes the tensor into a matrix by fusing links up to INCLUDING fuse_point + into one, and links after into the second dimension. + To compute the exponential of a 4-leg tensor, for example, by fusing (0,1),(2,3), + set fuse_point=1. + Default is None. + + prefactor : float, optional + Prefactor of the tensor to be exponentiated. + Default to 1. + + Return + ------ + mat : instance of :class:`QteaTensor` + Exponential of input tensor. + + Details + ------- + + To compute the exponential of a 4-leg tensor by fusing (0,1),(2,3), + set fuse_point=1. + """ + self.assert_rank_2() + mat = self.copy() + original_shape = mat.shape + + # Fuse the links. + if fuse_point is not None: + mat.fuse_links_update( + fuse_low=0, fuse_high=fuse_point, is_link_outgoing=False + ) + mat.fuse_links_update( + fuse_low=1, fuse_high=mat.ndim - 1, is_link_outgoing=True + ) + elif mat.ndim != 2: + raise QRedTeaRankError( + f"Not a matrix, hence cannot take expm. Expected rank 2, but got shape {mat.shape}." + ) + + # Take the exponent and reshape back into the original shape. + # pylint: disable-next=protected-access + mat._elem = to.matrix_exp(prefactor * mat.elem) + mat.reshape_update(original_shape) + return mat + + # -------------------------------------------------------------------------- + # Two-tensor operations + # -------------------------------------------------------------------------- + + def add_update(self, other, factor_this=None, factor_other=None): + """ + Inplace addition as `self = factor_this * self + factor_other * other`. + + **Arguments** + + other : same instance as `self` + Will be added to `self`. Unmodified on exit. + + factor_this : scalar + Scalar weight for tensor `self`. + + factor_other : scalar + Scalar weight for tensor `other` + """ + if factor_this is not None: + self._elem *= factor_this + + if factor_other is None: + factor_other = 1 + elif to.is_tensor(factor_other): + factor_other = factor_other.item() + + if self.elem.requires_grad or other.elem.requires_grad: + self._elem = to.add(self._elem, other.elem, alpha=factor_other) + else: + to.add(self._elem, other.elem, alpha=factor_other, out=self._elem) + + return self + + def dot(self, other): + """Inner product of two tensors .""" + return to.vdot(self.elem.reshape(-1), other.elem.reshape(-1)) + + def split_qr( + self, + legs_left, + legs_right, + perm_left=None, + perm_right=None, + is_q_link_outgoing=True, # pylint: disable=unused-argument + disable_streams=False, # pylint: disable=unused-argument + ): + """ + Split the tensor via a QR decomposition. + + Parameters + ---------- + + self : instance of :class:`QteaTensor` + Tensor upon which apply the QR + legs_left : list of int + Legs that will compose the rows of the matrix + legs_right : list of int + Legs that will compose the columns of the matrix + perm_left : list of int, optional + permutations of legs after the QR on left tensor + perm_right : list of int, optional + permutation of legs after the QR on right tensor + disable_streams : boolean, optional + No effect here, but in general can disable streams + to avoid nested generation of streams. + + Returns + ------- + + tens_left: instance of :class:`QteaTensor` + unitary tensor after the QR, i.e., Q. + tens_right: instance of :class:`QteaTensor` + upper triangular tensor after the QR, i.e., R + """ + is_good_bipartition, is_sorted_l, is_sorted_r = self._split_checks_links( + legs_left, legs_right + ) + + if is_good_bipartition and is_sorted_l and is_sorted_r: + dim1 = np.prod(np.array(self.shape)[legs_left]) + dim2 = np.prod(np.array(self.shape)[legs_right]) + + tens_left, tens_right = self._split_qr_dim(dim1, dim2) + + k_dim = tens_right.shape[0] + + tens_left.reshape_update(list(np.array(self.shape)[legs_left]) + [k_dim]) + tens_right.reshape_update([k_dim] + list(np.array(self.shape)[legs_right])) + + else: + # Reshaping + matrix = self._elem.permute(legs_left + legs_right) + shape_left = np.array(self.shape)[legs_left] + shape_right = np.array(self.shape)[legs_right] + matrix = matrix.reshape(np.prod(shape_left), np.prod(shape_right)) + k_dim = np.min([matrix.shape[0], matrix.shape[1]]) + + if self.dtype == to.float16: + matrix = matrix.type(to.float32) + + # QR decomposition + mat_left, mat_right = to.linalg.qr(matrix) + + if self.dtype == to.float16: + mat_left = mat_left.to(to.float16) + mat_right = mat_right.to(to.float16) + + # Reshape back to tensors + tens_left = self.from_elem_array( + mat_left.reshape(list(shape_left) + [k_dim]), + dtype=self.dtype, + device=self.device, + ) + tens_right = self.from_elem_array( + mat_right.reshape([k_dim] + list(shape_right)), + dtype=self.dtype, + device=self.device, + ) + + if perm_left is not None: + tens_left.transpose_update(perm_left) + + if perm_right is not None: + tens_right.transpose_update(perm_right) + + return tens_left, tens_right + + # pylint: disable-next=unused-argument + def split_qrte( + self, + tens_right, + singvals_self, + operator=None, + conv_params=None, + is_q_link_outgoing=True, + ): + """ + Perform an Truncated ExpandedQR decomposition, generalizing the idea + of https://arxiv.org/pdf/2212.09782.pdf for a general bond expansion + given the isometry center of the network on `tens_left`. + It should be rather general for three-legs tensors, and thus applicable + with any tensor network ansatz. Notice that, however, you do not have + full control on the approximation, since you know only a subset of the + singular values truncated. + + Parameters + ---------- + tens_left: xp.array + Left tensor + tens_right: xp.array + Right tensor + singvals_left: xp.array + Singular values array insisting on the link to the left of `tens_left` + operator: xp.array or None + Operator to contract with the tensors. If None, no operator is contracted + + Returns + ------- + tens_left: ndarray + left tensor after the EQR + tens_right: ndarray + right tensor after the EQR + singvals: ndarray + singular values kept after the EQR + singvals_cutted: ndarray + subset of thesingular values cutted after the EQR, + normalized with the biggest singval + """ + raise NotImplementedError("QR truncated-expanded not implemented for torch.") + + # torch.linalg has no rq (neither has torch itself) + # def split_rq( + + # pylint: disable-next=too-many-branches + def split_svd( + self, + legs_left, + legs_right, + perm_left=None, + perm_right=None, + contract_singvals="N", + conv_params=None, + no_truncation=False, + is_link_outgoing_left=True, # pylint: disable=unused-argument + disable_streams=False, # pylint: disable=unused-argument + ): + """ + Perform a truncated Singular Value Decomposition by + first reshaping the tensor into a legs_left x legs_right + matrix, and permuting the legs of the ouput tensors if needed. + If the contract_singvals = ('L', 'R') it takes care of + renormalizing the output tensors such that the norm of + the MPS remains 1 even after a truncation. + + Parameters + ---------- + self : instance of :class:`QteaTensor` + Tensor upon which apply the SVD + legs_left : list of int + Legs that will compose the rows of the matrix + legs_right : list of int + Legs that will compose the columns of the matrix + perm_left : list of int, optional + permutations of legs after the SVD on left tensor + perm_right : list of int, optional + permutation of legs after the SVD on right tensor + contract_singvals: string, optional + How to contract the singular values. + 'N' : no contraction + 'L' : to the left tensor + 'R' : to the right tensor + conv_params : :py:class:`TNConvergenceParameters`, optional + Convergence parameters to use in the procedure. If None is given, + then use the default convergence parameters of the TN. + Default to None. + no_truncation : boolean, optional + Allow to run without truncation + Default to `False` (hence truncating by default) + disable_streams : boolean, optional + No effect here, but in general can disable streams + to avoid nested generation of streams. + + Returns + ------- + tens_left: instance of :class:`QteaTensor` + left tensor after the SVD + tens_right: instance of :class:`QteaTensor` + right tensor after the SVD + singvals: xp.ndarray + singular values kept after the SVD + singvals_cut: xp.ndarray + singular values cut after the SVD, normalized with the biggest singval + """ + tensor = self._elem + + # Reshaping + matrix = tensor.permute(legs_left + legs_right) + shape_left = np.array(tensor.shape)[legs_left] + shape_right = np.array(tensor.shape)[legs_right] + matrix = matrix.reshape([np.prod(shape_left), np.prod(shape_right)]) + + # Main tensor module does not know about xla - use GPU logic + device = _GPU_DEVICE if self.device in ACCELERATOR_DEVICES else self.device + + if conv_params is None: + svd_ctrl = "A" + max_bond_dimension = min(matrix.shape) + else: + svd_ctrl = conv_params.svd_ctrl + max_bond_dimension = conv_params.max_bond_dimension + + svd_ctrl = _process_svd_ctrl( + svd_ctrl, + max_bond_dimension, + matrix.shape, + device, + contract_singvals, + ) + if matrix.dtype == to.float16: + matrix = matrix.to(to.float32) + + # SVD decomposition + if svd_ctrl in ("E", "X"): + try: + mat_left, singvals_tot, mat_right = self._split_svd_eigvl( + matrix, + svd_ctrl, + max_bond_dimension, + contract_singvals, + ) + + except to._C._LinAlgError: # pylint: disable=protected-access + # Likely leads to Cuda memory access error, but let's try + logger.warning("SVD in mode E or X failed; falling back to mode D.") + svd_ctrl = "D" + + if svd_ctrl in ("E+QR", "X+QR"): + try: + mat_left, singvals_tot, mat_right = self._split_svd_eigvl_qr( + matrix, + svd_ctrl, + max_bond_dimension, + contract_singvals, + ) + + except to._C._LinAlgError: # pylint: disable=protected-access + logger.warning( + "SVD in mode E+QR or X+QR failed; falling back to mode D." + ) + svd_ctrl = "D" + + if svd_ctrl in ("D", "V"): + try: + mat_left, singvals_tot, mat_right = self._split_svd_normal(matrix) + except QRedTeaLinAlgError: + # Try random SVD instead + logger.warning("To avoid failure in SVD, switching to random SVD.") + svd_ctrl = "R" + + if svd_ctrl == "R": + mat_left, singvals_tot, mat_right = self._split_svd_random( + matrix, max_bond_dimension + ) + + if self.dtype == to.float16: + mat_left = mat_left.to(to.float16) + mat_right = mat_right.to(to.float16) + singvals_tot = singvals_tot.to(to.float16) + + # Truncation + if not no_truncation: + cut, singvals, singvals_cut = self._truncate_singvals( + singvals_tot, conv_params + ) + + if cut < mat_left.shape[1]: + # Cutting bond dimension + mat_left = mat_left[:, :cut] + mat_right = mat_right[:cut, :] + elif cut > mat_left.shape[1]: + # Extending bond dimension to comply with ideal hardware + # settings + dim = mat_left.shape[1] + delta = cut - dim + + mat_left = to.nn.ConstantPad2d((0, delta, 0, 0), 0)(mat_left) + mat_right = to.nn.ConstantPad2d((0, 0, 0, delta), 0)(mat_right) + singvals = to.nn.ConstantPad1d((0, delta), 0)(singvals) + + else: + singvals = singvals_tot + singvals_cut = [] # xp.array([], dtype=self.dtype) + cut = mat_left.shape[1] + + # Contract singular values if requested + if svd_ctrl in ("D", "V", "R"): + if contract_singvals.upper() == "L": + mat_left = to.multiply(mat_left, singvals) + elif contract_singvals.upper() == "R": + mat_right = to.multiply(singvals, mat_right.T).T + elif contract_singvals.upper() != "N": + raise ValueError( + f"Contract_singvals option {contract_singvals} is not " + + "implemented. Choose between right (R), left (L) or None (N)." + ) + + # Reshape back to tensors + tens_left = mat_left.reshape(list(shape_left) + [cut]) + if perm_left is not None: + tens_left = tens_left.permute(perm_left) + + tens_right = mat_right.reshape([cut] + list(shape_right)) + if perm_right is not None: + tens_right = tens_right.permute(perm_right) + + # Convert into QteaTensor + tens_left = self.from_elem_array( + tens_left, dtype=self.dtype, device=self.device + ) + tens_right = self.from_elem_array( + tens_right, dtype=self.dtype, device=self.device + ) + return tens_left, tens_right, singvals, singvals_cut + + def stack_link(self, other, link): + """ + Stack two tensors along a given link. Same as `to.cat([self, other], dim=link)`. + + **Arguments** + + other : instance of :class:`QteaTorchTensor` + Links must match `self` up to the specified link. + + link : integer + Stack along this link. + + **Returns** + + new_this : instance of :class:QteaTorchTensor` + """ + + newelem = to.cat([self.elem, other.elem], dim=link) + new_this = self.from_elem_array(newelem, dtype=self.dtype, device=self.device) + + return new_this + + # pylint: disable=invalid-name + def stack_first_and_last_link(self, other): + """Stack first and last link of tensor targeting MPS addition.""" + newdim_self = list(self.shape) + newdim_self[0] += other.shape[0] + newdim_self[-1] += other.shape[-1] + + d1 = self.shape[0] + # we want to.prod(self.shape[1:-1]), workaround for now + d2 = 1 + for dd in self.shape[1:-1]: + d2 *= dd + d3 = self.shape[-1] + i1 = other.shape[0] + i3 = other.shape[-1] + + new_dims = [d1 + i1, d2, d3 + i3] + + new_this = type(self)(new_dims, ctrl="Z", dtype=self.dtype, device=self.device) + + # pylint: disable-next=protected-access + new_this._elem[:d1, :, :d3] = self.elem.reshape([d1, d2, d3]) + # pylint: disable-next=protected-access + new_this._elem[d1:, :, d3:] = other.elem.reshape([i1, d2, i3]) + new_this.reshape_update(newdim_self) + + return new_this + + # pylint: disable-next=unused-argument + def tensordot(self, other, contr_idx, disable_streams=False): + """Tensor contraction of two tensors along the given indices.""" + + # move tensor 'other' to the device of 'self', if needed + tmp_other_device = other.device + other.convert(device=self.device) + if other.device != tmp_other_device: + logger.warning( + "Switching tensor device on the fly. (%s -> %s)", + tmp_other_device, + other.device, + ) + + elem = to.tensordot(self._elem, other._elem, dims=contr_idx) + tens = self.from_elem_array(elem, dtype=self.dtype, device=self.device) + + # move 'other' back to original device + other.convert(device=tmp_other_device) + + return tens + + def stream(self, disable_streams=False): + """ + Get the instance of a context which can be used to parallelize. + + Parameters + ---------- + + disable_streams : bool, optional + Allows to disable streams to avoid nested creation of + streams. Globally, streams should be disabled via the + `set_streams_qteatorchtensors` function of the corresponding + base tensor module. + Default to False. + + Returns + ------- + + Context manager, e.g., + :class:`QteaTorchStream` if running on GPU and enabled + :class:`nullcontext(AbstractContextManager)` otherwise + + """ + if _USE_STREAMS and (not disable_streams) and self.is_gpu(): + return QteaTorchStream() + + return nullcontext() + + @staticmethod + # pylint: disable-next=unused-argument + def free_memory_device(device=None): + """ + Free the unused device memory that is otherwise occupied by the cache. + Otherwise cupy will keep the memory occupied for caching reasons. + We follow the approach from https://stackoverflow.com/questions/70508960 + + Parameters + ---------- + + device : str | None + No effect with torch as `to.cuda.empty_cache` does not allow + to specify a specific device. If enable later via other calls, + the device is a string, e.g., "gpu:0" + """ + if GPU_AVAILABLE: + gc.collect() + to.cuda.empty_cache() + + # -------------------------------------------------------------------------- + # Gradient descent: backwards propagation + # -------------------------------------------------------------------------- + + @staticmethod + # pylint: disable-next=keyword-arg-before-vararg + def get_optimizer(name="SGD", *args, **kwargs): + """Gets the optimizer with a given name. + For name='SGD', this is the torch.optim.SGD object. + Warning: different torch optimizers have different parameters. + + Parameters + ---------- + *args: list of tensors to be otimised over + + **kwargs: + lr: a real number representing the learning rate + + Returns + ---------- + optimizer: the optimizer object, here torch.optim.SGD + """ + if name == "SGD": + return to.optim.SGD(*args, **kwargs) + if name == "AdamW": + return to.optim.AdamW(*args, **kwargs) + + raise QRedTeaError(f"Unknown optimizer name: {name}.") + + @staticmethod + def get_gradient_clipper(): + """ + Gets the torch's gradient clipper function. + """ + return to.nn.utils.clip_grad_value_ + + def backward(self, **kwargs): + """Implements a step of backward propagation and returns the gradients. + + Parameters + ---------- + + **kwargs: + + retain_graph: boolean, required by PyTorch to retain the graph used to + calculate the forward function + + Returns + ---------- + gradients: list of gradients + """ + if self.ndim != 0: + raise QRedTeaRankError("Not a scalar, cannot compute gradients") + + return self.elem.backward(**kwargs) + + # -------------------------------------------------------------------------- + # Internal methods + # -------------------------------------------------------------------------- + # + # inherit _invert_link_selection + + # -------------------------------------------------------------------------- + # MISC + # -------------------------------------------------------------------------- + + @staticmethod + def get_default_datamover(): + """The default datamover compatible with this class.""" + return DataMoverPytorch() + + # -------------------------------------------------------------------------- + # Methods needed for _AbstractQteaBaseTensor + # -------------------------------------------------------------------------- + + def assert_diagonal(self, tol=1e-7): + """Check that tensor is a diagonal matrix up to tolerance.""" + if self.ndim != 2: + raise QRedTeaRankError("Not a matrix, hence not the identity.") + + tmp = to.diag(to.diag(self._elem)) + tmp -= self._elem + + if to.abs(tmp).max().item() > tol: + raise QRedTeaError("Matrix not diagonal.") + + return + + def assert_int_values(self, tol=1e-7): + """Check that there are only integer values in the tensor.""" + if self.is_dtype_complex(): + tmp = to.imag(self.elem) + if to.abs(tmp).max().item() > tol: + raise QRedTeaError("Matrix not an integer due to imaginary part.") + + tmp = to.round(to.real(self.elem)) + tmp -= to.real(self.elem) + else: + tmp = to.round(self._elem) + tmp -= self._elem + + if to.abs(tmp).max().item() > tol: + raise QRedTeaError("Matrix is not an integer matrix.") + + return + + def assert_real_valued(self, tol=1e-7): + """Check that all tensor entries are real-valued.""" + if not self._elem.is_complex(): + return + + tmp = to.imag(self._elem) + + if to.abs(tmp).max().item() > tol: + raise QRedTeaError("Tensor is not real-valued.") + + def eig(self): + """ + Compute eigenvalues and eigenvectors of a two-leg tensor + + Return + ------ + eigvals, eigvecs : instances of :class:`QteaTorchTensor` + Eigenvalues and corresponding eigenvectors of input tensor. + """ + if self.ndim != 2: + raise QRedTeaRankError("Works only with two-leg tensor") + + eigvals, eigvecs = to.linalg.eig(self.elem) + eigvals = self.from_elem_array(eigvals, dtype=self.dtype, device=self.device) + eigvecs = self.from_elem_array(eigvecs, dtype=self.dtype, device=self.device) + return eigvals, eigvecs + + def eigvalsh(self): + """ + Compute eigenvalues of a two-leg Hermitian tensor + + Return + ------ + eigvals : to.tensor + Eigenvalues of input tensor. + """ + if self.ndim != 2: + raise QRedTeaRankError("Works only with two-leg tensor") + + eigvals = to.linalg.eigvalsh(self.elem) + return eigvals + + def elementwise_abs_smaller_than(self, value): + """Return boolean if each tensor element is smaller than `value`""" + return (to.abs(self._elem) < value).all().item() + + def _expand_tensor(self, link, new_dim, ctrl="R"): + """Expand tensor along given link and to new dimension.""" + newdim = list(self.shape) + newdim[link] = new_dim - newdim[link] + + expansion = type(self)(newdim, ctrl=ctrl, dtype=self.dtype, device=self.device) + + return self.stack_link(expansion, link) + + def expand_tensor(self, link, new_dim, ctrl="R"): + """Expand tensor for one link up to dimension `new_dim`.""" + return self._expand_tensor(link, new_dim, ctrl=ctrl) + + def flatten(self): + """Returns flattened version (rank-1) of dense array in native array type.""" + return self._elem.flatten() + + @classmethod + def from_elem_array(cls, tensor, dtype=None, device=None): + """ + New QteaTorchTensor from array + + **Arguments** + + tensor : to.tensor + Array for new tensor. + + dtype : data type, optional + Can allow to specify data type. + If not `None`, it will convert. + Default to `None` + """ + # pylint: disable-next=invalid-name,global-variable-not-assigned + global xla_device + + if isinstance(tensor, np.ndarray): + if not tensor.flags["WRITEABLE"]: + # Avoids torch warning if numpy side is not writeable + tensor = tensor.copy() + tensor = to.from_numpy(tensor) + + is_float = tensor.is_complex() or tensor.is_floating_point() + if dtype is None and (not is_float): + logger.warning( + ( + "Initializing a tensor with integer dtype can be dangerous " + "for the simulation. Please specify the dtype keyword in the " + "from_elem_array method if it was not intentional." + ) + ) + + if dtype is None: + dtype = tensor.dtype + if device is None: + # We can actually check with torch where we are running + device = cls.device_str(tensor) + + if cls.is_xla_static(device): + device = xla_device + + obj = cls(tensor.shape, ctrl=None, dtype=dtype, device=device) + obj._elem = tensor + + obj.convert(dtype, device) + + return obj + + def get_attr(self, *args): + """High-risk resolve attribute for an operation on an elementary array.""" + attributes = [] + + for elem in args: + if elem == "cumsum": + # Special treatment for cumsum to resolve additional argument which + # is not present in numpy + attributes.append(_cumsum_like_numpy) + elif elem == "log": + # Special treatment to allow integers + attributes.append(_log_like_numpy) + elif elem == "sum": + # Special treatment for cumsum to resolve additional argument which + # is not present in numpy + attributes.append(_sum_like_numpy) + elif elem == "linalg.eigh": + # linalg.eigh cannot be resolved + attributes.append(to.linalg.eigh) + elif not hasattr(to, elem): + raise QRedTeaError( + f"This tensor's elementary array does not support {elem}." + ) + else: + attributes.append(getattr(to, elem)) + + if len(attributes) == 1: + return attributes[0] + + return tuple(attributes) + + def get_argsort_func(self): + """Return callable to argsort function.""" + return to.argsort + + def get_diag_entries_as_int(self): + """Return diagonal entries of rank-2 tensor as integer on host and as numpy.""" + if self.ndim != 2: + raise QRedTeaRankError("Not a matrix, cannot get diagonal.") + + tmp = to.diag(self._elem) + if self.device in ACCELERATOR_DEVICES: + tmp = tmp.detach().to(device=_CPU_DEVICE) + + if tmp.is_complex(): + tmp = to.real(tmp) + + return tmp.type(to.int32).numpy() + + def get_sqrt_func(self): + """Return callable to sqrt function.""" + return to.sqrt + + def get_submatrix(self, row_range, col_range): + """Extract a submatrix of a rank-2 tensor for the given rows / cols.""" + if self.ndim != 2: + raise QRedTeaRankError("Cannot only set submatrix for rank-2 tensors.") + + row1, row2 = row_range + col1, col2 = col_range + + return self.from_elem_array( + self._elem[row1:row2, col1:col2], dtype=self.dtype, device=self.device + ) + + def kron(self, other, idxs=None): + """ + Perform the kronecker product between two tensors. + By default, do it over all the legs, but you can also + specify which legs should be kroned over. + The legs over which the kron is not done should have + the same dimension. + + Parameters + ---------- + other : QteaTensor + Tensor to kron with self + idxs : Tuple[int], optional + Indexes over which to perform the kron. + If None, kron over all indeces. Default to None. + + Returns + ------- + QteaTensor + The kronned tensor + + Details + ------- + + Performing the kronecker product between a tensor of shape (2, 3, 4) + and a tensor of shape (1, 2, 3) will result in a tensor of shape (2, 6, 12). + + To perform the normal kronecker product between matrices just pass rank-2 tensors. + + To perform kronecker product between vectors first transfor them in rank-2 tensors + of shape (1, -1) + + Performing the kronecker product only along **some** legs means that along that + leg it is an elementwise product and not a kronecker. For Example, if idxs=(0, 2) + for the tensors of shapes (2, 3, 4) and (1, 3, 2) the output will be of shape + (2, 3, 8). + """ + + subscipts, final_shape = self._einsum_for_kron(self.shape, other.shape, idxs) + + elem = to.einsum(subscipts, self._elem, other._elem).reshape(tuple(final_shape)) + return self.from_elem_array(elem, dtype=self.dtype, device=self.device) + + def mask_to_device(self, mask): + """ + Send a mask to the device where the tensor is. + (right now only CPU --> GPU, CPU --> CPU). + """ + # pylint: disable-next=invalid-name,global-variable-not-assigned + global xla_device + + if self.is_cpu(): + return mask + + if self.is_xla(): + target_device = xla_device + elif self.is_gpu(): + target_device = self.device + else: + raise QRedTeaError(f"Unknown device {self.device}") + + if not to.is_tensor(mask): + mask = to.from_numpy(np.array(mask, dtype=bool)) + # if target_device == 'gpu': + target_device = "cuda" + mask_on_device = mask.to(device=target_device) + return mask_on_device + + def mask_to_host(self, mask): + """ + Send a mask to the host where we need it for symmetric tensors, e.g., + degeneracies. Return as numpy. + """ + if self.is_cpu(): + if to.is_tensor(mask): + return mask.numpy() + + return mask + + if to.is_tensor(mask): + mask_on_host = mask.to(device=_CPU_DEVICE).numpy() + else: + mask_on_host = mask + + return mask_on_host + + def permute_rows_cols_update(self, inds): + """Permute rows and columns of rank-2 tensor with `inds`. Inplace update.""" + if self.ndim != 2: + raise QRedTeaRankError( + "Cannot only permute rows & cols for rank-2 tensors." + ) + + tmp = self._elem[inds, :][:, inds] + self._elem *= 0.0 + self._elem += tmp + return self + + def prepare_eig_api(self, conv_params): + """ + Return variables for eigsh. + + **Returns** + + kwargs : dict + Keyword arguments for eigs call. + If initial guess can be passed, key "v0" is + set with value `None` + + LinearOperator : `None` + + eigsh : `None` + """ + tolerance = conv_params.sim_params["arnoldi_min_tolerance"] + + kwargs = { + "k": 1, + "which": "LA", + "ncv": None, + "maxiter": None, + "tol": tolerance, + "return_eigenvectors": True, + } + + if self.is_cpu(): + kwargs["v0"] = None + + # For now, we always want to use qtea solver, indicate for + # symmetric tensors + kwargs["use_qtea_solver"] = True + + return kwargs, None, None + + def reshape(self, shape, **kwargs): + """Reshape a tensor.""" + if kwargs.get("order", "C") != "C": + raise QRedTeaError("Cannot consider order in reshape.") + elem = self._elem.reshape(shape) + return self.from_elem_array(elem, dtype=self.dtype, device=self.device) + + def reshape_update(self, shape, **kwargs): + """Reshape tensor dimensions inplace.""" + if kwargs.get("order", "C") != "C": + raise QRedTeaError("Cannot consider order in reshape.") + self._elem = self._elem.reshape(shape) + + def set_submatrix(self, row_range, col_range, tensor): + """Set a submatrix of a rank-2 tensor for the given rows / cols.""" + + if self.ndim != 2: + raise QRedTeaRankError("Cannot only set submatrix for rank-2 tensors.") + + row1, row2 = row_range + col1, col2 = col_range + + # remove the +=! + # pylint: disable-next=protected-access + self._elem[row1:row2, col1:col2] += tensor.elem.reshape( + row2 - row1, col2 - col1 + ) + + def subtensor_along_link(self, link, lower, upper): + """ + Extract and return a subtensor select range (lower, upper) for one line. + """ + dim1, dim2, dim3 = self._shape_as_rank_3(link) + elem = self._elem.reshape([dim1, dim2, dim3]) + + elem = elem[:, lower:upper, :] + + new_shape = list(self.shape) + new_shape[link] = upper - lower + + elem = elem.reshape(new_shape) + + return self.from_elem_array(elem, dtype=self.dtype, device=self.device) + + def subtensor_along_link_inds(self, link, inds): + """ + Extract and return a subtensor via indices for one link. + + Arguments + --------- + + link : int + Select only specific indices along this link (but all indices + along any other link). + + inds : list[int] + Indices to be selected and stored in the subtensor. + + Returns + ------- + + subtensor : :class:`QteaTorchTensor` + Subtensor with selected indices. + + Details + ------- + + The numpy equivalent is ``subtensor = tensor[:, :, inds, :]`` + for a rank-4 tensor and ``link=2``. + """ + # pylint: disable-next=invalid-name + d1, d2, d3 = self._shape_as_rank_3(link) + + elem = self._elem.reshape([d1, d2, d3]) + elem = elem[:, inds, :] + + new_shape = list(self.shape) + new_shape[link] = len(inds) + elem = elem.reshape(new_shape) + + return self.from_elem_array(elem, dtype=self.dtype, device=self.device) + + def _truncate_singvals(self, singvals, conv_params=None): + """ + Truncate the singular values followling the + strategy selected in the convergence parameters class + + Parameters + ---------- + singvals : np.ndarray + Array of singular values + conv_params : :py:class:`TNConvergenceParameters`, optional + Convergence parameters to use in the procedure. If None is given, + then use the default convergence parameters of the TN. + Default to None. + + Returns + ------- + cut : int + Number of singular values kept + singvals_kept : np.ndarray + Normalized singular values kept + singvals_cutted : np.ndarray + Normalized singular values cutted + """ + + if conv_params is None: + conv_params = TNConvergenceParameters() + logger.info("Using default convergence parameters.") + elif not isinstance(conv_params, TNConvergenceParameters): + raise ValueError( + "conv_params must be TNConvergenceParameters or None, " + + f"not {type(conv_params)}." + ) + + if conv_params.trunc_method == "R": + cut = self._truncate_sv_ratio(singvals, conv_params) + elif conv_params.trunc_method == "N": + cut = self._truncate_sv_norm(singvals, conv_params) + else: + raise QRedTeaError(f"Unkown trunc_method {conv_params.trunc_method}") + + # Divide singvals in kept and cut (can handle suggested padding) + singvals_kept = singvals[: min(cut, len(singvals))] + singvals_cutted = singvals[min(cut, len(singvals)) :] + # Renormalizing the singular values vector to its norm + # before the truncation + norm_kept = (singvals_kept**2).sum() + norm_trunc = (singvals_cutted**2).sum() + normalization_factor = to.sqrt(norm_kept / (norm_kept + norm_trunc)) + singvals_kept /= normalization_factor + + # Renormalize cut singular values to track the norm loss + singvals_cutted /= to.sqrt(norm_trunc + norm_kept) + + return cut, singvals_kept, singvals_cutted + + def _truncate_sv_ratio(self, singvals, conv_params): + """ + Truncate the singular values based on the ratio + with the biggest one. + + Parameters + ---------- + singvals : to.ndarray + Array of singular values + conv_params : :py:class:`TNConvergenceParameters`, optional + Convergence parameters to use in the procedure. + + Returns + ------- + cut : int + Number of singular values kept + """ + lambda1 = singvals[0] + cut = to.nonzero(singvals / lambda1 < conv_params.cut_ratio) + if self.device in ACCELERATOR_DEVICES: + cut = cut.to(device=_CPU_DEVICE) + + chi_now = len(singvals) + chi_by_conv = conv_params.max_bond_dimension + chi_by_ratio = cut[0].item() if len(cut) > 0 else chi_now + chi_min = conv_params.min_bond_dimension + + return self._truncate_decide_chi(chi_now, chi_by_conv, chi_by_ratio, chi_min) + + def _truncate_sv_norm(self, singvals, conv_params): + """ + Truncate the singular values based on the + total norm cut. + """ + norm = to.cumsum(to.flip(singvals, dims=(0,)), 0) + norm /= norm[-1].clone() + + # Search for the first index where the constraint is broken, + # so you need to stop an index before + cut = to.nonzero(norm > conv_params.cut_ratio) + if self.is_gpu(): + cut = cut.to(device=_CPU_DEVICE) + + chi_now = len(singvals) + chi_by_conv = conv_params.max_bond_dimension + chi_by_norm = len(singvals) - int(cut[0]) if len(cut) > 0 else chi_now + chi_min = conv_params.min_bond_dimension + + return self._truncate_decide_chi(chi_now, chi_by_conv, chi_by_norm, chi_min) + + def _truncate_decide_chi(self, chi_now, chi_by_conv, chi_by_trunc, chi_min): + """ + Decide on the bond dimension based on the various values chi and + potential hardware preference indicated. + + **Arguments** + + chi_now : int + Current value of the bond dimension + + chi_by_conv : int + Maximum bond dimension as suggested by convergence parameters. + + chi_by_trunc : int + Bond dimension suggested by truncating (either ratio or norm). + + chi_min : int + Minimum bond dimension as suggested by convergence parameters. + """ + return self._truncate_decide_chi_static( + chi_now, + chi_by_conv, + chi_by_trunc, + chi_min, + _BLOCK_SIZE_BOND_DIMENSION, + _BLOCK_SIZE_BYTE, + self.elem.element_size(), + ) + + def vector_with_dim_like(self, dim, dtype=None): + """Generate a vector in the native array of the base tensor.""" + # pylint: disable-next=invalid-name,global-variable-not-assigned + global xla_device + + if dtype is None: + dtype = self.dtype + + vec = to.empty(dim, dtype=dtype) + + target_device = self.device + if self.is_gpu(query=target_device): + vec.to(device="cuda") + elif self.is_xla(query=target_device): + vec.to(device=xla_device) + + return vec + + # -------------------------------------------------------------------------- + # Internal methods (not required by abstract class) + # -------------------------------------------------------------------------- + + @staticmethod + def device_str(obj): + """Resolve once the device for qteatorchtensor purposes as str.""" + device_int = obj.get_device() + device = _CPU_DEVICE if device_int == -1 else f"{_GPU_DEVICE}:{device_int}" + if device.startswith(_GPU_DEVICE) and XLA_AVAILABLE: + # Another assumption: if XLA avaivable, it is used + device = device.replcae(_GPU_DEVICE, _XLA_DEVICE) + return device + + def dtype_mpi(self): + """Resolve the dtype for sending tensors via MPI""" + return { + # pylint: disable=c-extension-no-member + "Z": MPI.DOUBLE_COMPLEX, + "C": MPI.COMPLEX, + "S": MPI.REAL, + "D": MPI.DOUBLE_PRECISION, + "I": MPI.INT, + # pylint: enable=c-extension-no-member + }[self.dtype_to_char()] + + # pylint: disable-next=unused-argument + def _xla_convert(self, dtype, device, stream): + """Conversion for tensors on XLA (always via host even if on device).""" + # pylint: disable-next=invalid-name,global-variable-not-assigned + global xla_device + + current = self.device + do_convert_dtype = (dtype is not None) and (dtype != self.dtype) + + # always convert data types on CPU + if do_convert_dtype and (not self.is_cpu()): + self._elem = self._elem.to(device=_CPU_DEVICE) + current = _CPU_DEVICE + + if do_convert_dtype: + self._elem = self._elem.type(dtype) + + if device is not None: + self._convert_check(device) + + if device == current: + # We already are in the correct device + pass + elif self.is_gpu(query=device): + # We go from the cpu to gpu + cuda_str = device.replace(_GPU_DEVICE, "cuda") + self._elem = self._elem.to(device=cuda_str) + elif self.is_cpu(query=device): + # We go from gpu to cpu + self._elem = self._elem.to(device=_CPU_DEVICE) + elif self.is_xla(query=device): + # We go from the cpu to xla + self._elem = self._elem.to(device=xla_device) + + return self + + def _xla_norm_sqrt(self): + """Avoid problems with complex numbers and vector norm.""" + return to.sqrt(self.dot(self.conj()).real) + + def _split_qr_dim(self, rows, cols): + """Split via QR knowing dimension of rows and columns.""" + if self.dtype == to.float16: + matrix = self._elem.type(to.float32).reshape(rows, cols) + qmat, rmat = to.linalg.qr(matrix) + qmat = qmat.type(to.float16) + rmat = rmat.type(to.float16) + else: + qmat, rmat = to.linalg.qr(self._elem.reshape(rows, cols)) + + qtens = self.from_elem_array(qmat, dtype=self.dtype, device=self.device) + rtens = self.from_elem_array(rmat, dtype=self.dtype, device=self.device) + + return qtens, rtens + + # pylint: disable-next=unused-argument + def _split_svd_eigvl(self, matrix, svd_ctrl, max_bond_dimension, contract_singvals): + """ + SVD of the matrix through an eigvenvalue decomposition. + + Parameters + ---------- + matrix: to.Tensor + Matrix to decompose + svd_crtl : str + If "E" normal eigenvalue decomposition. If "X" use the sparse. + max_bond_dimension : int + Maximum bond dimension + contract_singvals: str + Whhere to contract the singular values + + Returns + ------- + to.Tensor + Matrix U + to.Tensor + Singular values + to.Tensor + Matrix V^dagger + + Details + ------- + + We use ᵀ*=^†, the adjoint. + + - In the contract-to-right case, which means: + H = AAᵀ* = USV Vᵀ*SUᵀ* = U S^2 Uᵀ* + To compute SVᵀ* we have to use: + A = USVᵀ* -> Uᵀ* A = S Vᵀ* + - In the contract-to-left case, which means: + H = Aᵀ*A = VSUᵀ* USVᵀ* = VS^2 Vᵀ* + First, we are given V, but we want Vᵀ*. However, let's avoid double work. + To compute US we have to use: + A = USVᵀ* -> AV = US + Vᵀ* = right.T.conj() (with the conjugation done in place) + """ + if contract_singvals == "R": + # The left tensor is unitary + herm_mat = matrix @ matrix.conj().T + else: + # contract_singvals == "L", the right tensor is unitary + herm_mat = matrix.conj().T @ matrix + + if svd_ctrl == "E": + eigenvalues, eigenvectors = to.linalg.eigh(herm_mat) + elif svd_ctrl == "X": + logger.warning("Falling back from X to E: pytorch has no eigsh.") + # num_eigvl = min(herm_mat.shape[0] - 1, max_bond_dimension - 1) + eigenvalues, eigenvectors = to.linalg.eigh(herm_mat) + else: + raise ValueError( + f"svd_ctrl = {svd_ctrl} not valid with eigenvalue decomposition" + ) + + # Eigenvalues are sorted ascendingly, singular values descendengly + # Only positive eigenvalues makes sense. Due to numerical precision, + # there will be very small negative eigvl. We put them to 0. + eigenvalues[eigenvalues < 0] = 0 + singvals = to.sqrt(to.flip(eigenvalues, dims=(0,))[: min(matrix.shape)]) + eigenvectors = to.flip(eigenvectors, dims=(1,)) + + # Taking only the meaningful part of the eigenvectors + if contract_singvals == "R": + left = eigenvectors[:, : min(matrix.shape)] + right = left.T.conj() @ matrix + else: + right = eigenvectors[:, : min(matrix.shape)] + left = matrix @ right + right = to.conj(right.T) + + return left, singvals, right + + def _split_svd_eigvl_qr( + self, matrix, svd_ctrl, max_bond_dimension, contract_singvals + ): + """ + SVD through eigendecomposition on the better-shaped Gram matrix. + + The +QR control means that the opposite isometry side is cheaper for + the Gram solve. For torch we restore the requested singular-value + placement by rescaling instead of running an additional QR. + """ + svd_ctrl_2 = svd_ctrl.replace("+QR", "") + contract_singvals_2 = {"L": "R", "R": "L"}[contract_singvals] + + left, singvals, right = self._split_svd_eigvl( + matrix, svd_ctrl_2, max_bond_dimension, contract_singvals_2 + ) + safe_singvals = to.where(singvals != 0, singvals, to.ones_like(singvals)) + + if contract_singvals == "L": + left = left * singvals + right = right / safe_singvals.reshape(-1, 1) + else: + left = left / safe_singvals + right = singvals.reshape(-1, 1) * right + + return left, singvals, right + + def _split_svd_normal(self, matrix): + """ + Normal SVD of the matrix. First try the faster gesdd iterative method. + If it fails, resort to gesvd. + + Parameters + ---------- + matrix: to.Tensor + Matrix to decompose + + Returns + ------- + to.Tensor + Matrix U + to.Tensor + Singular values + to.Tensor + Matrix V^dagger + """ + # from torch documentation: By default (driver= None), + # we call ‘gesvdj’ and, if it fails, we fallback to ‘gesvd’. + mat_left, singvals_tot, mat_right = to.linalg.svd(matrix, full_matrices=False) + + if self.is_gpu(): + # There is an ugly failure of SVDs on the GPU with inf values without + # any error message although the matrix is well-behaved (no problem + # on CPU, can be solved with random SVD on GPU). Go for check over performance + # here. + if not to.all(to.isfinite(singvals_tot)): + raise QRedTeaLinAlgError( + "Torch SVD failed with non-finite values for singular values." + ) + + return mat_left, singvals_tot, mat_right + + def _split_svd_random(self, matrix, max_bond_dimension): + """ + SVD of the matrix through a random SVD decomposition + as prescribed in page 227 of Halko, Martinsson, Tropp's 2011 SIAM paper: + "Finding structure with randomness: Probabilistic algorithms for constructing + approximate matrix decompositions" + + Parameters + ---------- + matrix: to.Tensor + Matrix to decompose + max_bond_dimension : int + Maximum bond dimension + + Returns + ------- + to.Tensor + Matrix U + to.Tensor + Singular values + to.Tensor + Matrix V^dagger + """ + # pylint: disable-next=invalid-name,global-variable-not-assigned + global xla_device + + # pylint: disable-next=nested-min-max + rank = min(max_bond_dimension, min(matrix.shape)) + + # This could be parameterized but in the paper they use this + # value + n_samples = 2 * rank + + random = to.randn(matrix.shape[1], n_samples, dtype=self.dtype) + if self.is_gpu(): + random = random.to(device="cuda") + elif self.is_xla(): + random = random.to(device=xla_device) + + reduced_matrix = matrix @ random + # Find orthonormal basis + ortho, _ = to.linalg.qr(reduced_matrix) + + # Second part + to_svd = ortho.T @ matrix + left_tilde, singvals, right = to.linalg.svd(to_svd, full_matrices=False) + left = ortho @ left_tilde + + return left, singvals, right + + +def _scale_link_inverse_vector(link_weights): + """Construct the inverse of singular values setting zeros to one.""" + # Have to handle zeros here ... as we allow padding singular + # values with zeros, we must also automatically avoid division + # by zero due to exact zeros. But we can assume it must be at + # the end of the array + vec = to.clone(link_weights) + if link_weights[-1] == 0.0: + vec[vec == 0.0] = 1.0 + + vec = 1.0 / vec + return vec + + +def _cumsum_like_numpy(array, axis=None, **kwargs): + """Provide cumsum function with same arguments as numpy.""" + # numpy has default of axis=None which acts on the flattened array + # which torch does not support + if axis is None and array.ndim != 1: + raise QRedTeaRankError("Running cumsum without axis on tensor with rank != 1.") + + if axis is None: + axis = 0 + + return to.cumsum(array, axis, **kwargs) + + +def _log_like_numpy(array): + """Provide log function accepting as well scalar integers not being a tensor.""" + if isinstance(array, to.Tensor): + return to.log(array) + + return math.log(array) + + +def _sum_like_numpy(array, axis=None, **kwargs): + """Provide sum function with same arguments as numpy.""" + # numpy has default of axis=None which acts on the flattened array + # which torch does not support + if axis is None: + return to.sum(array, **kwargs) + + return to.sum(array, axis, **kwargs) + + +class DataMoverPytorch(_AbstractDataMover): + """ + Data mover to move QteaTorchTensor between torch CPU and torch GPU format. + """ + + tensor_cls = (QteaTorchTensor,) + + def __init__(self): + pass + + @property + def device_memory(self): + """Current memory occupied in the device""" + # return self.mempool.used_bytes() + raise NotImplementedError("pytorch data mover") + + def sync_move(self, tensor, device): + """ + Move the tensor `tensor` to the device `device` + synchronously with the main computational stream + + Parameters + ---------- + tensor : _AbstractTensor + The tensor to be moved + device: str + The device where to move the tensor + """ + if GPU_AVAILABLE or XLA_AVAILABLE: + tensor.convert(dtype=None, device=device) + + # pylint: disable-next=unused-argument + def async_move(self, tensor, device, stream=None): + """ + Move the tensor `tensor` to the device `device` + asynchronously with respect to the main computational + stream + + Parameters + ---------- + tensor : _AbstractTensor + The tensor to be moved + device: str + The device where to move the tensor + stream : stream-object + Stream to be used to move the data if a stream + different from the data mover's stream should be + used. + Default to None (use DataMover's stream) + """ + logger.debug("Moving still sync for pytorch.") + self.sync_move(tensor, device) + + def wait(self): + """ + Put a barrier for the streams and wait them + """ + # pylint: disable-next=unnecessary-pass + pass + + +def default_pytorch_backend(device="cpu", dtype=to.complex128): + """ + Generate a default tensor backend for dense tensors, i.e., with + a :class:`QteaTorchTensor`. + + **Arguments** + + dtype : data type, optional + Data type for pytorch. + Default to to.complex128 + + device : device specification, optional + Default to `"cpu"`. + Available: `"cpu", "gpu", "xla"` + + **Returns** + + tensor_backend : :class:`TensorBackend` + """ + tensor_backend = TensorBackend( + tensor_cls=QteaTorchTensor, + base_tensor_cls=QteaTorchTensor, + device=device, + dtype=dtype, + symmetry_injector=None, + datamover=DataMoverPytorch(), + ) + + return tensor_backend + + +def default_abelian_pytorch_backend(device="cpu", dtype=to.complex128): + """ + Generate a default tensor backend for symmetric tensors, i.e., with + a :class:`QteaTorchTensor`. The tensors support Abelian symmetries. + + **Arguments** + + dtype : data type, optional + Data type for pytorch. + Default to to.complex128 + + device : device specification, optional + Default to `"cpu"`. + Available: `"cpu", "gpu", "xla"` + + **Returns** + + tensor_backend : :class:`TensorBackend` + """ + tensor_backend = TensorBackend( + tensor_cls=QteaAbelianTensor, + base_tensor_cls=QteaTorchTensor, + device=device, + dtype=dtype, + symmetry_injector=AbelianSymmetryInjector(), + datamover=DataMoverPytorch(), + ) + + return tensor_backend diff --git a/.venv/lib/python3.12/site-packages/qtealeaves/emulator/mpi_mps_simulator.py b/.venv/lib/python3.12/site-packages/qtealeaves/emulator/mpi_mps_simulator.py new file mode 100644 index 0000000..c09bd4c --- /dev/null +++ b/.venv/lib/python3.12/site-packages/qtealeaves/emulator/mpi_mps_simulator.py @@ -0,0 +1,691 @@ +# This code is part of qtealeaves. +# +# This code is licensed under the Apache License, Version 2.0. You may +# obtain a copy of this license in the LICENSE.txt file in the root directory +# of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. +# +# Any modifications or derivative works of this code must retain this +# copyright notice, and modified files need to carry a notice indicating +# that they have been altered from the originals. + +""" +The module contains a the MPI version of the MPS simulator. + +Code for the MPI simulations should be run as: + +.. code-block:: + mpiexec -n 4 python my_mpi_script.py + +where we used 4 processes as an example. +""" +import os + +import numpy as np + +from qtealeaves.convergence_parameters import TNConvergenceParameters +from qtealeaves.tensors import TensorBackend +from qtealeaves.tooling.mpisupport import MPI, TN_MPI_TYPES + +from .mps_simulator import MPS + +__all__ = ["MPIMPS"] + + +def _mpi_array_dtype(array): + """Return the MPI dtype for numpy arrays and CPU tensor buffers.""" + dtype = array.dtype + if hasattr(dtype, "str"): + return TN_MPI_TYPES[dtype.str] + + # qredtea torch singular values are raw torch.Tensor objects, not + # QteaTorchTensor instances, so they do not expose dtype_mpi(). + import torch + + return { + torch.complex128: MPI.DOUBLE_COMPLEX, + torch.complex64: MPI.COMPLEX, + torch.float64: MPI.DOUBLE_PRECISION, + torch.float32: MPI.REAL, + torch.int64: MPI.INT, + }[dtype] + + +def _mpi_send_array(comm, array, to_): + if hasattr(array, "resolve_conj"): + array = array.resolve_conj().contiguous() + comm.Send([array, _mpi_array_dtype(array)], to_) + + +def _mpi_empty_like(array, shape): + if hasattr(array, "resolve_conj"): + import torch + + return torch.empty(shape, dtype=array.dtype, device="cpu") + return np.empty(shape, array.dtype) + + +def _mpi_recv_array(comm, template, shape, from_): + array = _mpi_empty_like(template, shape) + comm.Recv([array, _mpi_array_dtype(array)], from_) + if hasattr(template, "device") and hasattr(array, "to"): + array = array.to(device=template.device) + return array + + +# pylint: disable-next=too-many-instance-attributes +class MPIMPS(MPS): + """ + MPI version of the MPS emulator that divides the MPS between the different nodes + + Parameters + ---------- + num_sites: int + Number of sites + convergence_parameters: :py:class:`TNConvergenceParameters` + Class for handling convergence parameters. In particular, in the MPS simulator we are + interested in: + - the *maximum bond dimension* :math:`\\chi`; + - the *cut ratio* :math:`\\epsilon` after which the singular + values are neglected, i.e. if :math:`\\lamda_1` is the + bigger singular values then after an SVD we neglect all the + singular values such that :math:`\\frac{\\lambda_i}{\\lambda_1}\\leq\\epsilon` + local_dim: int or list of ints, optional + Local dimension of the degrees of freedom. Default to 2. + If a list is given, then it must have length num_sites. + initialize: str, optional + The method for the initialization. Default to "vacuum" + Available: + - "vacuum", for the |000...0> state + - "random", for a random state at given bond dimension + tensor_backend : `None` or instance of :class:`TensorBackend` + Default for `None` is :class:`QteaTensor` with np.complex128 on CPU. + + """ + + # pylint: disable-next=too-many-arguments + def __init__( + self, + num_sites, + convergence_parameters, + local_dim=2, + initialize="vacuum", + tensor_backend=None, + ): + if MPI is None: + raise ImportError("No module mpi4py found in python environment") + # MPI variables + # pylint: disable-next=c-extension-no-member + self.comm = MPI.COMM_WORLD + self.size = self.comm.Get_size() + self.rank = self.comm.Get_rank() + self.tot_sites = num_sites + + # Number of sites in the local MPS + modulus = num_sites % self.size + local_num_size = int(np.floor(num_sites // self.size)) + self.indexes = [0] + [ + local_num_size + 1 if ii < modulus else local_num_size + for ii in range(self.size) + ] + local_num_size = self.indexes[self.rank + 1] + + # indexes takes into account which indexes are in each core + self.indexes = np.cumsum(self.indexes) + + # The par_map is a dicrionary where the index is the position of the + # sites in the full chain, while the value the position on the + # subchain in this process + self.par_map = dict( + zip( + np.arange( + self.indexes[self.rank], self.indexes[self.rank + 1], dtype=int + ), + np.arange(local_num_size, dtype=int), + ) + ) + + # Auxiliary site for the boundaries + if self.rank < self.size - 1: + local_num_size += 1 + + if not np.isscalar(local_dim): + local_dim = local_dim[ + self.indexes[self.rank] : self.indexes[self.rank + 1] + + int(self.rank != (self.size - 1)) + ] + + super().__init__( + local_num_size, + convergence_parameters, + local_dim=local_dim, + initialize=initialize, + tensor_backend=tensor_backend, + ) + + # MPS initializetion not aware of device + self.convert(self.tensor_backend.dtype, self.tensor_backend.memory_device) + + @property + def mpi_dtype(self): + """Return the MPI version of the MPS dtype (going via first tensor)""" + return TN_MPI_TYPES[np.dtype(self[0].dtype).str] + + def get_tensor_of_site(self, idx): + """Retrieve tensor of specifc site.""" + return self[self.par_map[idx]] + + def apply_one_site_operator(self, op, pos): + """ + Applies a one operator `op` to the site `pos` of the MPIMPS. + Instead of communicating the changes on the boundaries we + perform an additional contraction. + + Parameters + ---------- + op: numpy array shape (local_dim, local_dim) + Matrix representation of the quantum gate + pos: int + Position of the qubit where to apply `op`. + """ + # Apply the gate on the right MPS + if pos in self.par_map: + super().apply_one_site_operator(op, self.par_map[pos]) + + # For one-qubit gates it is more convenient to apply them both to + # the real and auxiliary qubits if they are on the boundaries + elif pos - 1 in self.par_map: + super().apply_one_site_operator(op, self.num_sites - 1) + + return None + + # pylint: disable-next=too-many-arguments + def apply_two_site_operator(self, op, pos, swap=False, svd=None, parallel=None): + """ + Applies a two-site operator `op` to the site `pos`, `pos+1` of the MPS. + Then, perform the necessary communications between the interested + process and the process + + Parameters + ---------- + op: numpy array shape (local_dim, local_dim, local_dim, local_dim) + Matrix representation of the quantum gate + pos: int or list of ints + Position of the qubit where to apply `op`. If a list is passed, + the two sites should be adjacent. The first index is assumed to + be the control, and the second the target. The swap argument is + overwritten if a list is passed. + swap: bool + If True swaps the operator. This means that instead of the + first contraction in the following we get the second. + It is written is a list of pos is passed. + svd : None + Required for compatibility. Can be only True. + parallel: None + Required for compatibility. Can be only True + + Returns + ------- + singular_values_cutted: ndarray + Array of singular values cutted, normalized to the biggest singular value + + """ + if not np.isscalar(pos) and len(pos) == 2: + pos = min(pos[0], pos[1]) + elif not np.isscalar(pos): + raise ValueError( + f"pos should be only scalar or len 2 array-like, not len {len(pos)}" + ) + + # Hardcoded but necessary for compatibility + svd = True + if parallel is None: + parallel_env = os.environ.get("QTEALEAVES_MPIMPS_PARALLEL", "1").lower() + parallel = parallel_env not in ("0", "false", "no", "off") + + if pos in self.par_map: + res = super().apply_two_site_operator( + op, self.par_map[pos], swap, svd=svd, parallel=parallel + ) + + # Send the information back to the auxiliary if it was the first site + if self.par_map[pos] == 0 and self.rank > 0: + self.mpi_send_tensor(self[0], to_=self.rank - 1) + _mpi_send_array(self.comm, self.singvals[1], self.rank - 1) + + # Send the information towards the next if it was the last site + elif self.par_map[pos] == self.num_sites - 2 and self.rank < self.size - 1: + self.mpi_send_tensor(self[self.num_sites - 1], to_=self.rank + 1) + _mpi_send_array( + self.comm, self.singvals[self.num_sites - 1], self.rank + 1 + ) + + else: + res = [] + # Receive the information from the MPS on the right + if pos == self.indexes[self.rank + 1] and self.rank < self.size - 1: + tens = self.mpi_receive_tensor(from_=self.rank + 1) + + self[self.num_sites - 1] = tens + + singvals = _mpi_recv_array( + self.comm, + self.singvals[self.num_sites], + tens.shape[2], + self.rank + 1, + ) + self._singvals[self.num_sites] = singvals + + # Receive the information from the MPS from the left + if pos == self.indexes[self.rank] - 1 and self.rank > 0: + tens = self.mpi_receive_tensor(from_=self.rank - 1) + self[0] = tens + + singvals = _mpi_recv_array( + self.comm, + self.singvals[0], + tens.shape[0], + self.rank - 1, + ) + self._singvals[0] = singvals + + return res + + def apply_projective_operator(self, site, selected_output=None, remove=False): + """ + Apply a projective operator to the site **site**, and give the measurement as output. + You can also decide to select a given output for the measurement, if the probability is + non-zero. Finally, you have the possibility of removing the site after the measurement. + + Parameters + ---------- + site: int + Index of the site you want to measure + selected_output: int, optional + If provided, the selected state is measured. Throw an error if the probability of the + state is 0 + remove: bool, optional + If True, the measured index is traced away after the measurement. Default to False. + + Returns + ------- + meas_state: int | None + Measured state or None if site not in this part of the MPI-MPS. + state_prob : float | None + Probability of measuring the output state or None if site not + in this part of the MPI-MPS. + """ + self.reinstall_isometry_serial() + if site in self.par_map: + res = super().apply_projective_operator( + self.par_map[site], selected_output, remove + ) + else: + res = (None, None) + + # Move informations to further right + self.reinstall_isometry_serial(left=False, from_site=site) + # Move information to the left + self.reinstall_isometry_serial() + + return res + + # pylint: disable-next=arguments-differ + def reinstall_isometry_serial(self, left=False, from_site=None): + """ + Reinstall the isometry center on position 0 of the full MPS. + + This step is serial because we have to serially pass the information + along the MPS. It cannot be parallelized. + + Parameters + ---------- + left: bool, optional + If True, reinstall the isometry to the left. + If False, to the right. Defaulto to False + from_site: int, optional + The site from which the isometrization should start. + By default None, i.e. the other end of the MPS chain. + + Returns + ------- + None + """ + if from_site is None: + from_site = self.num_sites - 1 if left else 0 + extrem = np.nonzero(from_site <= self.indexes)[0][0] + + if left: + boundaries = (extrem, -1, -1) + tidx = 0 + to_ = self.rank - 1 + from_ = self.rank + 1 + else: + boundaries = (extrem, self.size, 1) + tidx = self.num_sites - 1 + to_ = self.rank + 1 + from_ = self.rank - 1 + + for ii in range(*boundaries): + if self.rank == ii: + self._first_non_orthogonal_left = self.num_sites - 1 + self._first_non_orthogonal_right = self.num_sites - 1 + requires_singvals = self._requires_singvals + self._requires_singvals = True + if left: + self.right_canonize(0, False, True) + else: + self.left_canonize(self.num_sites - 1, False, True) + self._requires_singvals = requires_singvals + + # Send tensor + if (self.rank > 0 and left) or (self.rank + 1 < self.size and not left): + self.mpi_send_tensor(self[tidx], to_=to_) + + elif (self.rank == ii - 1 and left) or (self.rank == ii + 1 and not left): + # Receive tensor + tens = self.mpi_receive_tensor(from_=from_) + self[self.num_sites - 1 - tidx] = tens + + # pylint: disable-next=arguments-differ + def reinstall_isometry_parallel(self, num_cycles): + """ + Reinstall the isometry by applying identities to all even sites and + to all odd sites, and repeating for `num_cycles` cycles. + The reinstallation is exact for `num_cycles=num_sites/2`. + Method from https://arxiv.org/abs/2312.02667 + + This step is serial because we have to serially pass the information + along the MPS. It cannot be parallelized. + + Parameters + ---------- + num_cycles: int + Number of cycles for reinstalling the isometry + + Returns + ------- + None + """ + for _ in range(num_cycles): + # Apply on all even sites + for ii in range(0, self.tot_sites - 1, 2): + self.apply_two_site_operator( + self[0].eye_like(4), ii, svd=True, parallel=True + ) + # Apply on all odd sites + for ii in range(1, self.tot_sites - 1, 2): + self.apply_two_site_operator( + self[0].eye_like(4), ii, svd=True, parallel=True + ) + + def mpi_gather_tn(self): + """ + Gather the tensors on process 0. + We do not use MPI.comm.Gather because we would gather lists of np.arrays + without using the np.array advantages, making it slower than the single + communications. + + Returns + ------- + list on np.ndarray or None + List of tensors on the rank 0 process, None on the others + """ + self.comm.Barrier() + if self.rank != 0: + num_tensors = ( + self.num_sites if self.rank == self.size - 1 else self.num_sites - 1 + ) + for jj in range(num_tensors): + self.mpi_send_tensor(self[jj], to_=0) + tensor_list = None + else: + tensor_list = [None for _ in range(self.tot_sites)] + tensor_list[: self.num_sites - 1] = self.tensors[:-1] + + tidx = self.num_sites - 1 + for ii in range(1, self.size): + num_tensors = self.indexes[ii + 1] - self.indexes[ii] + for jj in range(num_tensors): + tens = self.mpi_receive_tensor(from_=ii) + tensor_list[tidx + jj] = tens + tidx += num_tensors + + self.comm.Barrier() + + return tensor_list + + def mpi_scatter_tn(self, tensor_list): + """ + Scatter the tensors on process 0. + We do not use MPI.comm.Scatter because we would gather lists of np.arrays + without using the np.array advantages, making it slower than the single + communications. + + Parameters + ---------- + tensor_list : list of lists of np.ndarrays + The index i of the list is sent to the rank i + + Returns + ------- + list on np.ndarray or None + List of tensors on the rank 0 process, None on the others + """ + self.comm.Barrier() + if self.rank == 0: + for ridx, sub_tensorlist in enumerate(tensor_list[1:]): + for idx, tens in enumerate(sub_tensorlist): + self.mpi_send_tensor(tens, to_=ridx + 1) + + tensor_list = tensor_list[0] + else: + num_tensors = len(tensor_list[self.rank]) + tensor_list = [None for _ in range(num_tensors)] + for idx in range(num_tensors): + tens = self.mpi_receive_tensor(from_=0) + tensor_list[idx] = tens + + self.comm.Barrier() + + return tensor_list + + def to_tensor_list(self): + """ + Return the tensor list of the full MPS. Thus, here there are + communications between the different processes and all the tensorlist + is returned on process 0 + + Returns + ------- + list of np.ndarray or None + List of tensors on the rank 0 process, None on the others + """ + return self.mpi_gather_tn() + + def to_statevector(self, qiskit_order=False, max_qubit_equivalent=20): + """ + Serially compute the statevector + + Parameters + ---------- + qiskit_order: bool, optional + weather to use qiskit ordering or the theoretical one. For + example the state |011> has 0 in the first position for the + theoretical ordering, while for qiskit ordering it is on the + last position. + max_qubit_equivalent: int, optional + Maximum number of qubit sites the MPS can have and still be + transformed into a statevector. + If the number of sites is greater, it will throw an exception. + Default to 20. + + Returns + ------- + np.ndarray or None + Statevector on process 0, None on the others + """ + + tensorlist = self.to_tensor_list() + if self.rank == 0: + mps = MPS.from_tensor_list(tensorlist) + statevect = mps.to_statevector(qiskit_order, max_qubit_equivalent) + else: + statevect = None + + return statevect + + @classmethod + def from_tensor_list( + cls, + tensor_list, + conv_params=None, + tensor_backend=None, + target_device=None, + ): + """ + Initialize the MPS tensors using a list of correctly shaped tensors + + Parameters + ---------- + tensor_list : list of ndarrays or cupy arrays + List of tensor for initializing the MPS + conv_params : :py:class:`TNConvergenceParameters`, optional + Convergence parameters for the new MPS. If None, the maximum bond + bond dimension possible is assumed, and a cut_ratio=1e-9. + Default to None. + tensor_backend : `None` or instance of :class:`TensorBackend` + Default for `None` is :class:`QteaTensor` with np.complex128 on CPU. + target_device: None | str, optional + If `None`, take memory device of tensor backend. + If string is `any`, do not convert. Otherwise, + use string as device string. + + Returns + ------- + obj : :py:class:`MPIMPS` + The MPIMPS class + """ + mismatches = [ + tensor_list[ii].shape[2] != tensor_list[ii + 1].shape[0] + for ii in range(len(tensor_list) - 1) + ] + if any(mismatches): + msg = f"Mismatches for tensors equals to True: {mismatches}." + raise ValueError(f"Dimension mismatch when constructing MPS:{msg}") + + if conv_params is None: + max_bond_dim = max(elem.shape[2] for elem in tensor_list) + conv_params = TNConvergenceParameters(max_bond_dimension=int(max_bond_dim)) + if tensor_backend is None: + # Have to resolve it here in case target device is not given + tensor_backend = TensorBackend() + if target_device is None: + target_device = tensor_backend.memory_device + elif target_device == "any": + target_device = None + + local_dim = [elem.shape[1] for elem in tensor_list] + obj = cls( + len(tensor_list), conv_params, local_dim, tensor_backend=tensor_backend + ) + + # Convert data type (lateron device if GPU enabled?) + for elem in tensor_list: + elem.convert(obj.tensor_backend.dtype, target_device) + + if obj.rank == 0: + tensorlist = [ + tensor_list[ + obj.indexes[rank] : obj.indexes[rank + 1] + + int(rank != obj.size - 1) + ] + for rank in range(obj.size) + ] + else: + list_sizes = obj.indexes[1:] - obj.indexes[:-1] + 1 + list_sizes[-1] -= 1 + tensorlist = [ + [None for _ in range(list_sizes[rank])] for rank in range(obj.size) + ] + + tensor_list = obj.mpi_scatter_tn(tensorlist) + obj._tensors = tensor_list + + return obj + + @classmethod + def from_statevector( + cls, + statevector, + local_dim=2, + conv_params=None, + tensor_backend=None, + ): + """Serially decompose the statevector and then initialize the MPS""" + mps = MPS.from_statevector( + statevector, local_dim, conv_params, tensor_backend=tensor_backend + ) + + return cls.from_tensor_list( + mps.to_tensor_list(), conv_params, tensor_backend=tensor_backend + ) + + # --------------------------- + # ----- MEASURE METHODS ----- + # --------------------------- + + def meas_local(self, op_list): + """ + Measure a local observable along all sites of the MPS + + Parameters + ---------- + op_list : list of :class:`_AbstractQteaTensor` + local operator to measure on each site + + Return + ------ + measures : ndarray, shape (num_sites) + Measures of the local operator along each site on rank-0 + """ + res = super().meas_local(op_list) + + # Call back on the site 0 the results + if self.rank != 0: + self.comm.Send([res, self.mpi_dtype[res.dtype.str]], 0) + tot_res = None + else: + tot_res = np.empty(self.tot_sites, dtype=res.dtype) + tot_res[: self.num_sites - 1] = res[:-1] + + tidx = self.num_sites - 1 + for ii in range(1, self.size): + num_tensors = self.indexes[ii] - self.indexes[ii - 1] + self.comm.Recv( + [tot_res[tidx : tidx + num_tensors], self.mpi_dtype[res.dtype.str]], + ii, + ) + tidx += num_tensors + + return tot_res + + def _get_eff_op_on_pos(self, pos): + """ + Obtain the list of effective operators adjacent + to the position pos and the index where they should + be contracted + + Parameters + ---------- + pos : int + Index of the tensor w.r.t. which we have to retrieve + the effective operators + + Returns + ------- + list of IndexedOperators + List of effective operators + list of ints + Indexes where the operators should be contracted + """ + raise NotImplementedError("This function has to be overwritten") diff --git a/.venv/lib/python3.12/site-packages/quimb/tensor/tn1d/core.py b/.venv/lib/python3.12/site-packages/quimb/tensor/tn1d/core.py new file mode 100644 index 0000000..c4b0233 --- /dev/null +++ b/.venv/lib/python3.12/site-packages/quimb/tensor/tn1d/core.py @@ -0,0 +1,5061 @@ +"""Classes and algorithms related to 1D tensor networks.""" + +import functools +import itertools +import operator +from math import log, log2 +from numbers import Integral + +import scipy.sparse.linalg as spla +from autoray import ( + conj, + dag, + do, + get_dtype_name, + get_namespace, + reshape, + size, + transpose, +) + +import quimb as qu + +from ...linalg.base_linalg import norm_trace_dense +from ...utils import ( + deprecated, + ensure_dict, + pairwise, + partition_all, + print_multi_line, +) +from .. import array_ops as ops +from ..tensor_core import ( + Tensor, + bonds, + new_bond, + oset, + rand_uuid, + tags_to_oset, + tensor_canonize_bond, + tensor_compress_bond, +) +from ..tnag.core import ( + TensorNetworkGen, + TensorNetworkGenOperator, + TensorNetworkGenVector, + tensor_network_ag_sum, + tensor_network_align, + tensor_network_apply_op_op, + tensor_network_apply_op_vec, +) + +align_TN_1D = deprecated( + tensor_network_align, "align_TN_1D", "tensor_network_align" +) + + +def expec_TN_1D(*tns, compress=None, eps=1e-15): + """Compute the expectation of several 1D TNs, using transfer matrix + compression if any are periodic. + + Parameters + ---------- + tns : sequence of TensorNetwork1D + The MPS and MPO to find expectation of. Should start and begin with + an MPS e.g. ``(MPS, MPO, ..., MPS)``. + compress : {None, False, True}, optional + Whether to perform transfer matrix compression on cyclic systems. If + set to ``None`` (the default), decide heuristically. + eps : float, optional + The accuracy of the transfer matrix compression. + + Returns + ------- + x : float + The expectation value. + """ + expec_tn = functools.reduce(operator.or_, tensor_network_align(*tns)) + + # if OBC or <= 0.0 specified use exact contraction + cyclic = any(tn.cyclic for tn in tns) + if not cyclic: + compress = False + + n = expec_tn.L + isflat = all(isinstance(tn, TensorNetwork1DFlat) for tn in tns) + + # work out whether to compress, could definitely be improved ... + if compress is None and isflat: + # compression only worth it for long, high bond dimension TNs. + total_bd = qu.prod(tn.bond_size(0, 1) for tn in tns) + compress = (n >= 100) and (total_bd >= 1000) + + if compress: + expec_tn.replace_section_with_svd(1, n, eps=eps, inplace=True) + return expec_tn ^ all + + return expec_tn ^ ... + + +def gate_TN_1D( + tn, + G, + where, + contract=False, + tags=None, + propagate_tags="sites", + info=None, + inplace=False, + cur_orthog=None, + **compress_opts, +): + r"""Act with the gate ``G`` on sites ``where``, maintaining the outer + indices of the 1D tensor network:: + + + contract=False contract=True + . . . . <- where + o-o-o-o-o-o-o o-o-o-GGG-o-o-o + | | | | | | | | | | / \ | | | + GGG + | | + + + contract='split-gate' contract='swap-split-gate' + . . . . <- where + o-o-o-o-o-o-o o-o-o-o-o-o-o + | | | | | | | | | | | | | | + G~G G~G + | | \ / + X + / \ + + contract='swap+split' + . . <- where + o-o-o-G=G-o-o-o + | | | | | | | | + + Note that the sites in ``where`` do not have to be contiguous. By default, + site tags will be propagated to the gate tensors, identifying a + 'light cone'. + + Parameters + ---------- + tn : TensorNetwork1DVector + The 1D vector-like tensor network, for example, and MPS. + G : array + A square array to act with on sites ``where``. It should have twice the + number of dimensions as the number of sites. The second half of these + will be contracted with the MPS, and the first half indexed with the + correct ``site_ind_id``. Sites are read left to right from the shape. + A two-dimensional array is permissible if each dimension factorizes + correctly. + where : int or sequence of int + Where the gate should act. + contract : {False, 'split-gate', 'swap-split-gate', + 'auto-split-gate', True, 'swap+split'}, optional + Whether to contract the gate into the 1D tensor network. If, + + - False: leave the gate uncontracted, the default + - 'split-gate': like False, but split the gate if it is two-site. + - 'swap-split-gate': like 'split-gate', but decompose the gate as + if a swap had first been applied + - 'auto-split-gate': automatically select between the above three + options, based on the rank of the gate. + - True: contract the gate into the tensor network, if the gate acts + on more than one site, this will produce an ever larger tensor. + - 'swap+split': Swap sites until they are adjacent, then contract + the gate and split the resulting tensor, then swap the sites back + to their original position. In this way an MPS structure can be + explicitly maintained at the cost of rising bond-dimension. + + tags : str or sequence of str, optional + Tag the new gate tensor with these tags. + propagate_tags : {'sites', 'register', False, True}, optional + Add any tags from the sites to the new gate tensor (only matters if + ``contract=False`` else tags are merged anyway): + + - If ``'sites'``, then only propagate tags matching e.g. 'I{}' and + ignore all others. I.e. just propagate the lightcone. + - If ``'register'``, then only propagate tags matching the sites of + where this gate was actually applied. I.e. ignore the lightcone, + just keep track of which 'registers' the gate was applied to. + - If ``False``, propagate nothing. + - If ``True``, propagate all tags. + + inplace, bool, optional + Perform the gate in place. + compress_opts + Supplied to :meth:`~quimb.tensor.tensor_core.Tensor.split` + if ``contract='swap+split'`` or + :meth:`~quimb.tensor.tn1d.core.MatrixProductState.gate_with_auto_swap` + if ``contract='swap+split'``. + + Returns + ------- + TensorNetwork1DVector + + See Also + -------- + MatrixProductState.gate_split + + Examples + -------- + >>> p = MPS_rand_state(3, 7) + >>> p.gate_(spin_operator('X'), where=1, tags=['GX']) + >>> p + + + >>> p.outer_inds() + ('k0', 'k1', 'k2') + """ + if isinstance(where, Integral): + where = (where,) + ng = len(where) # number of sites the gate acts on + + if contract == "auto-mps": + # automatically choose a contract mode based on maintaining MPS form + if ng == 1: + contract = True + elif ng == 2: + contract = "swap+split" + else: + contract = "nonlocal" + + # check special MPS methods + if contract == "swap+split": + if ng == 1: + # no swapping or splitting needed + contract = True + else: + return tn.gate_with_auto_swap( + G, + where, + cur_orthog=cur_orthog, + info=info, + inplace=inplace, + **compress_opts, + ) + + elif contract == "nonlocal": + if ng == 1: + # no MPO needed + contract = True + else: + return tn.gate_nonlocal( + G, + where, + cur_orthog=cur_orthog, + info=info, + inplace=inplace, + **compress_opts, + ) + + # can use generic gate method + return TensorNetworkGenVector.gate( + tn, + G, + where, + contract=contract, + tags=tags, + propagate_tags=propagate_tags, + info=info, + inplace=inplace, + **compress_opts, + ) + + +def superop_TN_1D( + tn_super, + tn_op, + upper_ind_id="k{}", + lower_ind_id="b{}", + so_outer_upper_ind_id=None, + so_inner_upper_ind_id=None, + so_inner_lower_ind_id=None, + so_outer_lower_ind_id=None, +): + r"""Take a tensor network superoperator and act with it on a + tensor network operator, maintaining the original upper and lower + indices of the operator:: + + + outer_upper_ind_id upper_ind_id + | | | ... | | | | ... | + +----------+ +----------+ + | tn_super +---+ | tn_super +---+ + +----------+ | upper_ind_id +----------+ | + | | | ... | | | | | ... | | | | ... | | + inner_upper_ind_id| +-----------+ +-----------+ | + | + | tn_op | = | tn_op | | + inner_lower_ind_id| +-----------+ +-----------+ | + | | | ... | | | | | ... | | | | ... | | + +----------+ | lower_ind_id +----------+ | + | tn_super +---+ | tn_super +---+ + +----------+ +----------+ + | | | ... | <-- | | | ... | + outer_lower_ind_id lower_ind_id + + + Parameters + ---------- + tn_super : TensorNetwork + The superoperator in the form of a 1D-like tensor network. + tn_op : TensorNetwork + The operator to be acted on in the form of a 1D-like tensor network. + upper_ind_id : str, optional + Current id of the upper operator indices, e.g. usually ``'k{}'``. + lower_ind_id : str, optional + Current id of the lower operator indices, e.g. usually ``'b{}'``. + so_outer_upper_ind_id : str, optional + Current id of the superoperator's upper outer indices, these will be + reindexed to form the new effective operators upper indices. + so_inner_upper_ind_id : str, optional + Current id of the superoperator's upper inner indices, these will be + joined with those described by ``upper_ind_id``. + so_inner_lower_ind_id : str, optional + Current id of the superoperator's lower inner indices, these will be + joined with those described by ``lower_ind_id``. + so_outer_lower_ind_id : str, optional + Current id of the superoperator's lower outer indices, these will be + reindexed to form the new effective operators lower indices. + + Returns + ------- + KAK : TensorNetwork + The tensornetwork of the superoperator acting on the operator. + """ + n = tn_op.L + + if so_outer_upper_ind_id is None: + so_outer_upper_ind_id = getattr(tn_super, "outer_upper_ind_id", "kn{}") + if so_inner_upper_ind_id is None: + so_inner_upper_ind_id = getattr(tn_super, "inner_upper_ind_id", "k{}") + if so_inner_lower_ind_id is None: + so_inner_lower_ind_id = getattr(tn_super, "inner_lower_ind_id", "b{}") + if so_outer_lower_ind_id is None: + so_outer_lower_ind_id = getattr(tn_super, "outer_lower_ind_id", "bn{}") + + reindex_map = {} + for i in range(n): + upper_bnd = rand_uuid() + lower_bnd = rand_uuid() + reindex_map[upper_ind_id.format(i)] = upper_bnd + reindex_map[lower_ind_id.format(i)] = lower_bnd + reindex_map[so_inner_upper_ind_id.format(i)] = upper_bnd + reindex_map[so_inner_lower_ind_id.format(i)] = lower_bnd + reindex_map[so_outer_upper_ind_id.format(i)] = upper_ind_id.format(i) + reindex_map[so_outer_lower_ind_id.format(i)] = lower_ind_id.format(i) + + return tn_super.reindex(reindex_map) & tn_op.reindex(reindex_map) + + +def parse_cur_orthog(cur_orthog="calc", info=None): + if info is None: + info = {} + + if isinstance(cur_orthog, Integral): + info.setdefault("cur_orthog", (cur_orthog, cur_orthog)) + else: + info.setdefault("cur_orthog", cur_orthog) + + return info + + +def convert_cur_orthog(fn): + @functools.wraps(fn) + def wrapped(self, *args, cur_orthog=None, info=None, **kwargs): + info = parse_cur_orthog(cur_orthog, info) + return fn(self, *args, info=info, **kwargs) + + return wrapped + + +class TensorNetwork1D(TensorNetworkGen): + """Base class for tensor networks with a one-dimensional structure.""" + + _NDIMS = 1 + _EXTRA_PROPS = ("_site_tag_id", "_L") + _CONTRACT_STRUCTURED = True + + def _compatible_1d(self, other): + """Check whether ``self`` and ``other`` are compatible 2D tensor + networks such that they can remain a 2D tensor network when combined. + """ + return isinstance(other, TensorNetwork1D) and all( + getattr(self, e) == getattr(other, e) + for e in TensorNetwork1D._EXTRA_PROPS + ) + + def combine(self, other, *, virtual=False, check_collisions=True): + """Combine this tensor network with another, returning a new tensor + network. If the two are compatible, cast the resulting tensor network + to a :class:`TensorNetwork1D` instance. + + Parameters + ---------- + other : TensorNetwork1D or TensorNetwork + The other tensor network to combine with. + virtual : bool, optional + Whether the new tensor network should copy all the incoming tensors + (``False``, the default), or view them as virtual (``True``). + check_collisions : bool, optional + Whether to check for index collisions between the two tensor + networks before combining them. If ``True`` (the default), any + inner indices that clash will be mangled. + + Returns + ------- + TensorNetwork1D or TensorNetwork + """ + new = super().combine( + other, virtual=virtual, check_collisions=check_collisions + ) + if self._compatible_1d(other): + new.view_as_(TensorNetwork1D, like=self) + return new + + @property + def L(self): + """The number of sites, i.e. length.""" + return self._L + + @property + def nsites(self): + """The number of sites.""" + return self._L + + def gen_site_coos(self): + """Generate the coordinates of all possible sites.""" + return range(self._L) + + def site_tag(self, i): + """The name of the tag specifiying the tensor at site ``i``.""" + if not isinstance(i, str): + i = i % self.L + return self._site_tag_id.format(i) + + def slice2sites(self, tag_slice): + """Take a slice object, and work out its implied start, stop and step, + taking into account cyclic boundary conditions. + + Examples + -------- + Normal slicing: + + >>> p = MPS_rand_state(10, bond_dim=7) + >>> p.slice2sites(slice(5)) + (0, 1, 2, 3, 4) + + >>> p.slice2sites(slice(4, 8)) + (4, 5, 6, 7) + + Slicing from end backwards: + + >>> p.slice2sites(slice(..., -3, -1)) + (9, 8) + + Slicing round the end: + + >>> p.slice2sites(slice(7, 12)) + (7, 8, 9, 0, 1) + + >>> p.slice2sites(slice(-3, 2)) + (7, 8, 9, 0, 1) + + If the start point is > end point (*before* modulo n), then step needs + to be negative to return anything. + """ + if tag_slice.start is None: + start = 0 + elif tag_slice.start is ...: + if tag_slice.step == -1: + start = self.L - 1 + else: + start = -1 + else: + start = tag_slice.start + + if tag_slice.stop in (..., None): + stop = self.L + else: + stop = tag_slice.stop + + step = 1 if tag_slice.step is None else tag_slice.step + + return tuple(s % self.L for s in range(start, stop, step)) + + def maybe_convert_coo(self, x): + """Check if ``x`` is an integer and convert to the + corresponding site tag if so. + """ + if isinstance(x, Integral): + return self.site_tag(x) + + if isinstance(x, slice): + return tuple(map(self.site_tag, self.slice2sites(x))) + + return x + + def contract_structured( + self, + tag_slice, + structure_bsz=5, + optimize="auto", + inplace=False, + **contract_opts, + ): + """Perform a structured contraction, translating ``tag_slice`` from a + ``slice`` or `...` to a cumulative sequence of tags. + + Parameters + ---------- + tag_slice : slice or ... (Ellipsis) + The range of sites, or `...` for all. + structure_bsz : int, optional + The number of sites to group together for each sub-contraction. + inplace : bool, optional + Whether to perform the contraction inplace. + + Returns + ------- + TensorNetwork, Tensor or scalar + The result of the contraction, still a ``TensorNetwork`` if the + contraction was only partial. + + See Also + -------- + contract, contract_tags, contract_cumulative + """ + # check for all sites + if tag_slice is ...: + # else slice over all sites + tag_slice = slice(0, self.L) + + if optimize is None: + # this helps a lot vs greedy for large bond triple overlap e.g. + optimize = "auto" + + # filter sites by the slice, but also which sites are present at all + tags_seq = filter( + self.tag_map.__contains__, + map(self.site_tag, self.slice2sites(tag_slice)), + ) + + # partition sites into `structure_bsz` groups + if structure_bsz > 1: + tags_seq = partition_all(structure_bsz, tags_seq) + + # contract each block of sites cumulatively + return self.contract_cumulative( + tags_seq, + optimize=optimize, + inplace=inplace, + **contract_opts, + ) + + def compute_left_environments(self, **contract_opts): + """Compute the left environments of this 1D tensor network. + + Parameters + ---------- + contract_opts + Supplied to + :meth:`~quimb.tensor.tensor_core.TensorNetwork.contract`. + + Returns + ------- + dict[int, Tensor] + Environments indexed by the site they are to the left of, so keys + run from (1, ... L - 1). + """ + left_envs = {1: self.select(0).contract(all, **contract_opts)} + for i in range(2, self.L): + tll = left_envs[i - 1] + tll.drop_tags() + tnl = self.select(i - 1) | tll + left_envs[i] = tnl.contract(all, **contract_opts) + + return left_envs + + def compute_right_environments(self, **contract_opts): + """Compute the right environments of this 1D tensor network. + + Parameters + ---------- + contract_opts + Supplied to + :meth:`~quimb.tensor.tensor_core.TensorNetwork.contract`. + + Returns + ------- + dict[int, Tensor] + Environments indexed by the site they are to the right of, so keys + run from (0, ... L - 2). + """ + right_envs = { + self.L - 2: self.select(-1).contract(all, **contract_opts) + } + for i in range(self.L - 3, -1, -1): + trr = right_envs[i + 1] + trr.drop_tags() + tnr = self.select(i + 1) | trr + right_envs[i] = tnr.contract() + + return right_envs + + def flatten( + self, + fuse_multibonds=True, + inplace=False, + ) -> "TensorNetwork1DFlat": + """Contract all tensors at each site together, yielding a single tensor + per site. By default, any multibonds between flattened sites will also + be fused together. If not already, the resulting tensor network will be + promoted to a :class:`TensorNetwork1DFlat`. + + Parameters + ---------- + fuse_multibonds : bool, optional + Whether to fuse any multibonds that are created by this process. + Defaults to ``True``. + inplace : bool, optional + Whether to modify this tensor network inplace, or return a new + one. Defaults to ``False``. + + Returns + ------- + TensorNetwork1DFlat + """ + tn = super().flatten(fuse_multibonds=fuse_multibonds, inplace=inplace) + if not isinstance(tn, TensorNetwork1DFlat): + tn.view_as_(TensorNetwork1DFlat, like=self) + return tn + + flatten_ = functools.partialmethod(flatten, inplace=True) + + def _repr_info(self): + info = super()._repr_info() + info["L"] = self.L + info["max_bond"] = self.max_bond() + return info + + +class TensorNetwork1DVector(TensorNetwork1D, TensorNetworkGenVector): + """1D Tensor network which overall is like a vector with a single type of + site ind. + """ + + _EXTRA_PROPS = ( + "_site_tag_id", + "_site_ind_id", + "_L", + ) + + def reindex_sites(self, new_id, where=None, inplace=False): + """Update the physical site index labels to a new string specifier. + Note that this doesn't change the stored id string with the TN. + + Parameters + ---------- + new_id : str + A string with a format placeholder to accept an int, e.g. "ket{}". + where : None or slice + Which sites to update the index labels on. If ``None`` (default) + all sites. + inplace : bool + Whether to reindex in place. + """ + if where is None: + where = self.gen_sites_present() + elif isinstance(where, slice): + where = self.slice2sites(where) + else: + where = where + + return super().reindex_sites(new_id, where, inplace=inplace) + + reindex_sites_ = functools.partialmethod(reindex_sites, inplace=True) + + def site_ind(self, i): + """Get the physical index name of site ``i``.""" + if not isinstance(i, str): + i = i % self.L + return self.site_ind_id.format(i) + + @functools.wraps(gate_TN_1D) + def gate(self, *args, inplace=False, **kwargs): + return gate_TN_1D(self, *args, inplace=inplace, **kwargs) + + gate_ = functools.partialmethod(gate, inplace=True) + + @functools.wraps(expec_TN_1D) + def expec(self, *args, **kwargs): + return expec_TN_1D(self, *args, **kwargs) + + def correlation(self, A, i, j, B=None, **expec_opts): + """Correlation of operator ``A`` between ``i`` and ``j``. + + Parameters + ---------- + A : array + The operator to act with, can be multi site. + i : int or sequence of int + The first site(s). + j : int or sequence of int + The second site(s). + expec_opts + Supplied to :func:`~quimb.tensor.tn1d.core.expec_TN_1D`. + + Returns + ------- + C : float + The correlation `` + - ``. + + Examples + -------- + >>> ghz = (MPS_computational_state('0000') + + ... MPS_computational_state('1111')) / 2**0.5 + >>> ghz.correlation(pauli('Z'), 0, 1) + 1.0 + >>> ghz.correlation(pauli('Z'), 0, 1, B=pauli('X')) + 0.0 + """ + if B is None: + B = A + + bra = self.H + + pA = self.gate(A, i, contract=True) + cA = expec_TN_1D(bra, pA, **expec_opts) + + pB = self.gate(B, j, contract=True) + cB = expec_TN_1D(bra, pB, **expec_opts) + + pAB = pA.gate_(B, j, contract=True) + cAB = expec_TN_1D(bra, pAB, **expec_opts) + + return cAB - cA * cB + + +class TensorNetwork1DOperator(TensorNetwork1D, TensorNetworkGenOperator): + _EXTRA_PROPS = ( + "_site_tag_id", + "_upper_ind_id", + "_lower_ind_id", + "_L", + ) + + def reindex_lower_sites(self, new_id, where=None, inplace=False): + """Update the lower site index labels to a new string specifier. + + Parameters + ---------- + new_id : str + A string with a format placeholder to accept an int, e.g. + ``"ket{}"``. + where : None or slice + Which sites to update the index labels on. If ``None`` (default) + all sites. + inplace : bool + Whether to reindex in place. + """ + if where is None: + start = 0 + stop = self.L + else: + start = 0 if where.start is None else where.start + stop = self.L if where.stop is ... else where.stop + + return self.reindex( + {self.lower_ind(i): new_id.format(i) for i in range(start, stop)}, + inplace=inplace, + ) + + reindex_lower_sites_ = functools.partialmethod( + reindex_lower_sites, inplace=True + ) + + def reindex_upper_sites(self, new_id, where=None, inplace=False): + """Update the upper site index labels to a new string specifier. + + Parameters + ---------- + new_id : str + A string with a format placeholder to accept an int, e.g. "ket{}". + where : None or slice + Which sites to update the index labels on. If ``None`` (default) + all sites. + inplace : bool + Whether to reindex in place. + """ + if where is None: + start = 0 + stop = self.L + else: + start = 0 if where.start is None else where.start + stop = self.L if where.stop is ... else where.stop + + return self.reindex( + {self.upper_ind(i): new_id.format(i) for i in range(start, stop)}, + inplace=inplace, + ) + + reindex_upper_sites_ = functools.partialmethod( + reindex_upper_sites, inplace=True + ) + + +def set_default_compress_mode(opts, cyclic=False): + opts.setdefault("cutoff_mode", "rel" if cyclic else "rsum2") + + +class TensorNetwork1DFlat(TensorNetwork1D): + """1D Tensor network which has a flat structure.""" + + _EXTRA_PROPS = ("_site_tag_id", "_L") + + def left_canonize_site(self, i, bra=None, create_bond=False): + r"""Left canonize this TN's ith site, inplace:: + + i i + -o-o- ->-s- + ... | | ... ==> ... | | ... + + Parameters + ---------- + i : int + Which site to canonize. The site at i + 1 also absorbs the + non-isometric part of the decomposition of site i. + bra : None or matching TensorNetwork to self, optional + If set, also update this TN's data with the conjugate canonization. + create_bond : bool, optional + Whether to create a new bond between the two tensors if none + exists. If False, an error will be raised in such a case. + """ + tl, tr = self[i], self[i + 1] + tensor_canonize_bond(tl, tr, create_bond=create_bond) + if bra is not None: + # TODO: handle left inds + bra[i].modify(data=conj(tl.data)) + bra[i + 1].modify(data=conj(tr.data)) + + def right_canonize_site(self, i, bra=None, create_bond=False): + r"""Right canonize this TN's ith site, inplace:: + + i i + -o-o- -s-<- + ... | | ... ==> ... | | ... + + Parameters + ---------- + i : int + Which site to canonize. The site at i - 1 also absorbs the + non-isometric part of the decomposition of site i. + bra : None or matching TensorNetwork to self, optional + If set, also update this TN's data with the conjugate canonization. + create_bond : bool, optional + Whether to create a new bond between the two tensors if none + exists. If False, an error will be raised in such a case. + """ + tl, tr = self[i - 1], self[i] + tensor_canonize_bond(tr, tl, create_bond=create_bond) + if bra is not None: + # TODO: handle left inds + bra[i].modify(data=conj(tr.data)) + bra[i - 1].modify(data=conj(tl.data)) + + def left_canonicalize( + self, + stop=None, + start=None, + normalize=False, + bra=None, + create_bond=False, + inplace=False, + ): + r"""Left canonicalize all or a portion of this TN (i.e. sweep the + orthogonality center to the right). If this is a MPS, + this implies that:: + + i i + >->->->->->->-o-o- +-o-o- + | | | | | | | | | ... => | | | ... + >->->->->->->-o-o- +-o-o- + + Parameters + ---------- + start : int, optional + If given, the site to start left canonicalizing at. + stop : int, optional + If given, the site to stop left canonicalizing at. + normalize : bool, optional + Whether to normalize the state, only works for OBC. + bra : MatrixProductState, optional + If supplied, simultaneously left canonicalize this MPS too, + assuming it to be the conjugate state. + create_bond : bool, optional + Whether to create new bonds between the two tensors if none + exists. If False, an error will be raised in such a case. + inplace : bool, optional + Whether to perform the operation inplace. If ``bra`` is supplied + then it is always modifed inplace. + + Returns + ------- + TensorNetwork1DFlat + """ + mps = self if inplace else self.copy() + + if start is None: + start = -1 if mps.cyclic else 0 + if stop is None: + stop = mps.L - 1 + + for i in range(start, stop): + mps.left_canonize_site(i, bra=bra, create_bond=create_bond) + + if normalize: + factor = mps[-1].norm() + mps[-1] /= factor + if bra is not None: + bra[-1] /= factor + + return mps + + left_canonicalize_ = functools.partialmethod( + left_canonicalize, inplace=True + ) + left_canonize = left_canonicalize_ + + def right_canonicalize( + self, + stop=None, + start=None, + normalize=False, + bra=None, + create_bond=False, + inplace=False, + ): + r"""Right canonicalize all or a portion of this TN (i.e. sweep the + orthogonality center to the left). If this is a MPS, + + i i + -o-o-<-<-<-<-<-<-< -o-o-+ + ... | | | | | | | | | -> ... | | | + -o-o-<-<-<-<-<-<-< -o-o-+ + + + Parameters + ---------- + start : int, optional + If given, the site to start right canonizing at. + stop : int, optional + If given, the site to stop right canonizing at. + normalize : bool, optional + Whether to normalize the state. + bra : MatrixProductState, optional + If supplied, simultaneously right canonicalize this MPS too, + assuming it to be the conjugate state. + create_bond : bool, optional + Whether to create new bonds between the two tensors if none + exists. If False, an error will be raised in such a case. + inplace : bool, optional + Whether to perform the operation inplace. If ``bra`` is supplied + then it is always modifed inplace. + + Returns + ------- + TensorNetwork1DFlat + """ + mps = self if inplace else self.copy() + + if start is None: + start = mps.L - (0 if mps.cyclic else 1) + if stop is None: + stop = 0 + + for i in range(start, stop, -1): + mps.right_canonize_site(i, bra=bra, create_bond=create_bond) + + if normalize: + factor = mps[0].norm() + mps[0] /= factor + if bra is not None: + bra[0] /= factor + + return mps + + right_canonicalize_ = functools.partialmethod( + right_canonicalize, inplace=True + ) + right_canonize = right_canonicalize_ + + def canonize_cyclic(self, i, bra=None, method="isvd", inv_tol=1e-10): + """Bring this MatrixProductState into (possibly only approximate) + canonical form at site(s) ``i``. + + Parameters + ---------- + i : int or slice + The site or range of sites to make canonical. + bra : MatrixProductState, optional + Simultaneously canonize this state as well, assuming it to be the + co-vector. + method : {'isvd', 'svds', ...}, optional + How to perform the lateral compression. + inv_tol : float, optional + Tolerance with which to invert the gauge. + """ + if isinstance(i, Integral): + start, stop = i, i + 1 + elif isinstance(i, slice): + start, stop = i.start, i.stop + else: + start, stop = min(i), max(i) + 1 + if tuple(i) != tuple(range(start, stop)): + raise ValueError( + "Parameter ``i`` should be an integer or " + f"contiguous block of integers, got {i}." + ) + + k = self.copy() + b = k.H + k.add_tag("_KET") + b.add_tag("_BRA") + kb = k & b + + # approximate the rest of the chain with a separable transfer operator + kbc = kb.replace_section_with_svd( + start, + stop, + eps=0.0, + which="!any", + method=method, + max_bond=1, + ltags="_LEFT", + rtags="_RIGHT", + ) + + EL = kbc["_LEFT"].squeeze() + # explicitly symmetrize to hermitian + EL.modify(data=(EL.data + dag(EL.data)) / 2) + # split into upper 'ket' part and lower 'bra' part, symmetric + (EL_lix,) = EL.bonds(kbc[k.site_tag(start), "_BRA"]) + _, x = EL.split(EL_lix, method="eigh", cutoff=-1, get="arrays") + + ER = kbc["_RIGHT"].squeeze() + # explicitly symmetrize to hermitian + ER.modify(data=(ER.data + dag(ER.data)) / 2) + # split into upper 'ket' part and lower 'bra' part, symmetric + (ER_lix,) = ER.bonds(kbc[k.site_tag(stop - 1), "_BRA"]) + _, y = ER.split(ER_lix, method="eigh", cutoff=-1, get="arrays") + + self.insert_gauge(x, start - 1, start, tol=inv_tol) + self.insert_gauge(y, stop, stop - 1, tol=inv_tol) + + if bra is not None: + for i in (start - 1, start, stop, stop - 1): + bra[i].modify(data=self[i].data.conj()) + + def shift_orthogonality_center( + self, + current, + new, + bra=None, + create_bond=False, + ): + """Move the orthogonality center of this MPS. + + Parameters + ---------- + current : int + The current orthogonality center. + new : int + The target orthogonality center. + bra : MatrixProductState, optional + If supplied, simultaneously move the orthogonality center of this + MPS too, assuming it to be the conjugate state. + create_bond : bool, optional + Whether to create new bonds between two tensors if none + exists. If False, an error will be raised in such a case. + """ + if new > current: + for i in range(current, new): + self.left_canonize_site(i, bra=bra, create_bond=create_bond) + else: + for i in range(current, new, -1): + self.right_canonize_site(i, bra=bra, create_bond=create_bond) + + def canonicalize( + self, + where, + cur_orthog="calc", + info=None, + bra=None, + create_bond=False, + inplace=False, + ): + r"""Gauge this MPS into mixed canonical form, implying:: + + i i + >->->->->- ->-o-<- -<-<-<-<-< +-o-+ + | | | | |...| | |...| | | | | -> | | | + >->->->->- ->-o-<- -<-<-<-<-< +-o-+ + + You can also supply a min/max of sites to orthogonalize around, and a + current location of the orthogonality center for efficiency:: + + current where + ....... ..... + >->->-c-c-c-c-<-<-<-<-<-< >->->->->->-w-<-<-<-<-<-< + | | | | | | | | | | | | | -> | | | | | | | | | | | | | + >->->-c-c-c-c-<-<-<-<-<-< >->->->->->-w-<-<-<-<-<-< + cmin cmax i j + + This would only move ``cmin`` to ``i`` and ``cmax`` to ``j`` if + necessary. + + Parameters + ---------- + where : int or sequence of int + Which site(s) to orthogonalize around. If a sequence of int then + make sure that section from min(where) to max(where) is orthog. + info : dict, optional + If supplied, will be used to infer and store various extra + information. Currently, the key "cur_orthog" is used to store the + current orthogonality center. Its input value can be ``"calc"``, a + single site, or a pair of sites representing the min/max range, + inclusive. It will be updated to the actual range after. + bra : MatrixProductState, optional + If supplied, simultaneously mixed canonicalize this MPS too, + assuming it to be the conjugate state. + create_bond : bool, optional + Whether to create new bonds between two tensors if none + exists. If False, an error will be raised in such a case. + inplace : bool, optional + Whether to perform the operation inplace. If ``bra`` is supplied + then it is always modifed inplace. + + Returns + ------- + MatrixProductState + The mixed canonical form MPS. + """ + mps = self if inplace else self.copy() + + if isinstance(where, Integral): + i = j = where + else: + i, j = min(where), max(where) + + info = parse_cur_orthog(cur_orthog, info) + cur_orthog = info["cur_orthog"] + if cur_orthog == "calc": + cur_orthog = mps.calc_current_orthog_center() + + if cur_orthog is not None: + if isinstance(cur_orthog, int): + cmin = cmax = cur_orthog + else: + cmin, cmax = min(cur_orthog), max(cur_orthog) + + if i > cmin: + mps.shift_orthogonality_center( + current=cmin, new=i, bra=bra, create_bond=create_bond + ) + else: + i = min(j, cmin) + + if j < cmax: + mps.shift_orthogonality_center( + current=cmax, new=j, bra=bra, create_bond=create_bond + ) + else: + j = max(i, cmax) + + else: + mps.left_canonicalize_(i, bra=bra) + mps.right_canonicalize_(j, bra=bra) + + info["cur_orthog"] = (i, j) + + return mps + + canonicalize_ = functools.partialmethod(canonicalize, inplace=True) + canonize = canonicalize_ + + def left_compress_site( + self, i, bra=None, create_bond=False, **compress_opts + ): + """Left compress this 1D TN's ith site, such that the site is then + left unitary with its right bond (possibly) reduced in dimension. + + Parameters + ---------- + i : int + Which site to compress. + bra : None or matching TensorNetwork to self, optional + If set, also update this TN's data with the conjugate compression. + create_bond : bool, optional + Whether to create new bonds between the two tensors if none + exists. If False, an error will be raised in such a case. + compress_opts + Supplied to :meth:`Tensor.split`. By default absorb is set to + ``'right'`` and reduced to ``'left'``. Other notable options are + ``max_bond`` and ``cutoff``. + """ + set_default_compress_mode(compress_opts, self.cyclic) + compress_opts.setdefault("absorb", "right") + compress_opts.setdefault("reduced", "left") + + tl, tr = self[i], self[i + 1] + tensor_compress_bond(tl, tr, create_bond=create_bond, **compress_opts) + + if bra is not None: + # TODO: handle left inds + bra[i].modify(data=conj(tl.data)) + bra[i + 1].modify(data=conj(tr.data)) + + def right_compress_site( + self, i, bra=None, create_bond=False, **compress_opts + ): + """Right compress this 1D TN's ith site, such that the site is then + right unitary with its left bond (possibly) reduced in dimension. + + Parameters + ---------- + i : int + Which site to compress. + bra : None or matching TensorNetwork to self, optional + If set, update this TN's data with the conjugate compression. + create_bond : bool, optional + Whether to create new bonds between the two tensors if none + exists. If False, an error will be raised in such a case. + compress_opts + Supplied to :meth:`Tensor.split`. By default absorb is set to + ``'left'`` and reduced to ``'right'``. Other notable options are + ``max_bond`` and ``cutoff``. + """ + set_default_compress_mode(compress_opts, self.cyclic) + compress_opts.setdefault("absorb", "left") + compress_opts.setdefault("reduced", "right") + + tl, tr = self[i - 1], self[i] + tensor_compress_bond(tl, tr, create_bond=create_bond, **compress_opts) + + if bra is not None: + # TODO: handle left inds + bra[i].modify(data=conj(tr.data)) + bra[i - 1].modify(data=conj(tl.data)) + + def left_compress( + self, + start=None, + stop=None, + bra=None, + create_bond=False, + **compress_opts, + ): + """Compress this 1D TN, from left to right, such that it becomes + left-canonical (unless ``absorb != 'right'``). + + Parameters + ---------- + start : int, optional + Site to begin compressing on. + stop : int, optional + Site to stop compressing at (won't itself be an isometry). + bra : None or TensorNetwork like this one, optional + If given, update this TN as well, assuming it to be the conjugate. + create_bond : bool, optional + Whether to create new bonds between adjacent tensors if none + exists. If False, an error will be raised in such a case. + compress_opts + Supplied to :meth:`Tensor.split`. Notably, ``max_bond``, + ``cutoff``. + """ + if start is None: + start = -1 if self.cyclic else 0 + if stop is None: + stop = self.L - 1 + + for i in range(start, stop): + self.left_compress_site( + i, bra=bra, create_bond=create_bond, **compress_opts + ) + + def right_compress( + self, + start=None, + stop=None, + bra=None, + create_bond=False, + **compress_opts, + ): + """Compress this 1D TN, from right to left, such that it becomes + right-canonical (unless ``absorb != 'left'``). + + Parameters + ---------- + start : int, optional + Site to begin compressing on. + stop : int, optional + Site to stop compressing at (won't itself be an isometry). + bra : None or TensorNetwork like this one, optional + If given, update this TN as well, assuming it to be the conjugate. + create_bond : bool, optional + Whether to create new bonds between adjacent tensors if none + exists. If False, an error will be raised in such a case. + compress_opts + Supplied to :meth:`Tensor.split`. Notably, ``max_bond``, + ``cutoff``. + """ + if start is None: + start = self.L - (0 if self.cyclic else 1) + if stop is None: + stop = 0 + + for i in range(start, stop, -1): + self.right_compress_site( + i, bra=bra, create_bond=create_bond, **compress_opts + ) + + def compress(self, form=None, create_bond=False, **compress_opts): + """Compress this 1D Tensor Network, possibly into canonical form. + + Parameters + ---------- + form : None, int, 'right', 'left' or 'flat', optional + Output form of the TN. The default `None` currently maps to + 'right'. `'right'` results in a right canonical TN, + with orthogonality center at site 0. `'left'` results in a + left canonical TN, with orthogonality center at site L - 1. + An integer value specifies the desired orthogonality center. + `'flat'` is a non-canonical method that performs a sweep of + compressions only (no canonicalization) from both sides. + create_bond : bool, optional + Whether to create new bonds between adjacent tensors if none + exists. If False, an error will be raised in such a case. + compress_opts + Supplied to :meth:`Tensor.split`. Notably, ``max_bond``, + ``cutoff``. + """ + if form is None: + form = "right" + + if isinstance(form, Integral): + if form < self.L // 2: + self.left_canonize(create_bond=create_bond) + self.right_compress(**compress_opts) + self.left_canonize(stop=form) + else: + self.right_canonize(create_bond=create_bond) + self.left_compress(**compress_opts) + self.right_canonize(stop=form) + + elif form == "left": + self.right_canonize( + bra=compress_opts.get("bra", None), create_bond=create_bond + ) + self.left_compress(**compress_opts) + + elif form == "right": + self.left_canonize( + bra=compress_opts.get("bra", None), create_bond=create_bond + ) + self.right_compress(**compress_opts) + + elif form == "flat": + compress_opts["absorb"] = "both" + self.right_compress( + stop=self.L // 2, create_bond=create_bond, **compress_opts + ) + self.left_compress( + stop=self.L // 2, create_bond=create_bond, **compress_opts + ) + + else: + raise ValueError( + f"Form specifier {form} not understood, should be either " + "'left', 'right', 'flat' or an int specifiying a new orthog " + "center." + ) + + @convert_cur_orthog + def compress_site( + self, + i, + canonize=True, + info=None, + bra=None, + **compress_opts, + ): + r"""Compress the bonds adjacent to site ``i``, by default first setting + the orthogonality center to that site:: + + i i + -o-o-o-o-o- --> ->->~o~<-<- + | | | | | | | | | | + + Parameters + ---------- + i : int + Which site to compress around + canonize : bool, optional + Whether to first set the orthogonality center to site ``i``. + info : dict, optional + If supplied, will be used to infer and store various extra + information. Currently, the key "cur_orthog" is used to store the + current orthogonality center. Its input value can be ``"calc"``, a + single site, or a pair of sites representing the min/max range, + inclusive. It will be updated to the actual range after. + bra : MatrixProductState, optional + The conjugate state to also apply the compression to. + compress_opts + Supplied to :func:`~quimb.tensor.tensor_core.tensor_split`. + """ + if canonize: + self.canonicalize_(i, info=info, bra=bra) + + if self.cyclic or i > 0: + self.left_compress_site(i - 1, bra=bra, **compress_opts) + + if self.cyclic or i < self.L - 1: + self.right_compress_site(i + 1, bra=bra, **compress_opts) + + def bond(self, i, j): + """Get the name of the index defining the bond between sites i and j.""" + (bond,) = self[i].bonds(self[j]) + return bond + + def bond_size(self, i, j): + """Return the size of the bond between site ``i`` and ``j``.""" + b_ix = self.bond(i, j) + return self[i].ind_size(b_ix) + + def bond_sizes(self): + bnd_szs = [self.bond_size(i, i + 1) for i in range(self.L - 1)] + if self.cyclic: + bnd_szs.append(self.bond_size(-1, 0)) + return bnd_szs + + def amplitude(self, b): + """Compute the amplitude of configuration ``b``. + + Parameters + ---------- + b : sequence of int + The configuration to compute the amplitude of. + + Returns + ------- + c_b : scalar + """ + if len(b) != self.nsites: + raise ValueError( + f"Bit-string {b} length does not " + f"match MPS length {self.nsites}." + ) + + selector = {self.site_ind(i): int(xi) for i, xi in enumerate(b)} + mps_b = self.isel(selector) + return mps_b ^ ... + + @convert_cur_orthog + def singular_values(self, i, info=None, method="svd"): + r"""Find the singular values associated with the ith bond:: + + ....L.... i + o-o-o-o-o-l-o-o-o-o-o-o-o-o-o-o-o + | | | | | | | | | | | | | | | | + i-1 ..........R.......... + + Leaves the 1D TN in mixed canoncial form at bond ``i``. + + Parameters + ---------- + i : int + Which bond, or equivalently, the number of sites in the + left partition. + info : dict, optional + If supplied, will be used to infer and store various extra + information. Currently, the key "cur_orthog" is used to store the + current orthogonality center. Its input value can be ``"calc"``, a + single site, or a pair of sites representing the min/max range, + inclusive. It will be updated to the actual range after. + + Returns + ------- + svals : 1d-array + The singular values. + """ + if not (0 < i < self.L): + raise ValueError(f"Need 0 < i < {self.L}, got i={i}.") + + self.canonicalize_(i, info=info) + + Tm1 = self[i] + left_inds = Tm1.bonds(self[i - 1]) + return Tm1.singular_values(left_inds, method=method) + + def ensure_bonds_exist(self): + """Ensure that all bonds between adjacent sites are present in the + tensor network, creating new bonds of size 1 if necessary. + """ + for i in range(self.L - 1): + ti = self[i] + tj = self[i + 1] + if not ti.bonds(tj): + ti.new_bond(tj) + if self.cyclic: + ti = self[-1] + tj = self[0] + if not ti.bonds(tj): + ti.new_bond(tj) + + def expand_bond_dimension( + self, + new_bond_dim, + rand_strength=0.0, + bra=None, + create_bond=False, + inplace=True, + ): + """Expand the bond dimensions of this 1D tensor network to at least + ``new_bond_dim``. + + Parameters + ---------- + new_bond_dim : int + Minimum bond dimension to expand to. + rand_strength : float, optional + If ``rand_strength > 0``, fill the new tensor entries with gaussian + noise of strength ``rand_strength``. + bra : MatrixProductState, optional + Mirror the changes to ``bra`` inplace, treating it as the conjugate + state. + create_bond : bool, optional + Whether to create new bonds between adjacent tensors if none + exists. If ``False``, unconnected sites will remain unconnected. + inplace : bool, optional + Whether to perform the expansion in place. + + Returns + ------- + MatrixProductState + """ + tn = self if inplace else self.copy() + + if create_bond: + tn.ensure_bonds_exist() + + tn = super().expand_bond_dimension( + new_bond_dim=new_bond_dim, + rand_strength=rand_strength, + inplace=True, + ) + + if bra is not None: + for coo in tn.gen_sites_present(): + bra[coo].modify(data=tn[coo].data.conj()) + + return tn + + def count_canonized(self): + """Count the number of canonical sites to the left and right of the + tensor network. For cyclic TNs, this is always 0. + + Returns + ------- + (int, int) + The number of canonical sites to the left and right of the + orthogonality center. + """ + if self.cyclic: + return 0, 0 + + ov = self.H & self + num_can_l = 0 + num_can_r = 0 + + def isidentity(x): + d = x.shape[0] + if get_dtype_name(x) in ("float32", "complex64"): + rtol, atol = 1e-5, 1e-6 + else: + rtol, atol = 1e-9, 1e-11 + + idtty = do("eye", d, like=x) + return do("allclose", x, idtty, rtol=rtol, atol=atol) + + for i in range(self.L - 1): + ov ^= slice(max(0, i - 1), i + 1) + x = ov[i].data + if isidentity(x): + num_can_l += 1 + else: + break + + for j in reversed(range(num_can_l + 1, self.L)): + ov ^= slice(j, min(self.L, j + 2)) + x = ov[j].data + if isidentity(x): + num_can_r += 1 + else: + break + + return num_can_l, num_can_r + + def calc_current_orthog_center(self): + """Calculate the site(s) of the current orthogonality center. + + Returns + ------- + (int, int) + The min/max of sites around which the TN is currently orthogonal. + """ + lo, ro = self.count_canonized() + return lo, self.L - ro - 1 + + def as_cyclic(self, inplace=False): + """Convert this flat, 1D, TN into cyclic form by adding a dummy bond + between the first and last sites. + """ + tn = self if inplace else self.copy() + + # nothing to do + if tn.cyclic: + return tn + + tn.new_bond(0, -1) + tn.cyclic = True + return tn + + def show(self, max_width=None): + l1 = "" + l2 = "" + l3 = "" + num_can_l, num_can_r = self.count_canonized() + for i in range(self.L - 1): + bdim = self.bond_size(i, i + 1) + strl = len(str(bdim)) + l1 += f" {bdim}" + l2 += ( + ">" + if i < num_can_l + else "<" + if i >= self.L - num_can_r + else "●" + ) + ("─" if bdim < 100 else "━") * strl + l3 += "│" + " " * strl + strl = len(str(bdim)) + + l1 += " " + l2 += "<" if num_can_r > 0 else "●" + l3 += "│" + + if self.cyclic: + bdim = self.bond_size(0, self.L - 1) + bnd_str = ("─" if bdim < 100 else "━") * strl + l1 = f" {bdim}{l1}{bdim} " + l2 = f"+{bnd_str}{l2}{bnd_str}+" + l3 = f" {' ' * strl}{l3}{' ' * strl} " + + print_multi_line(l1, l2, l3, max_width=max_width) + + +class MatrixProductState(TensorNetwork1DVector, TensorNetwork1DFlat): + """Initialise a matrix product state, with auto labelling and tagging. + + Parameters + ---------- + arrays : sequence of arrays + The tensor arrays to form into a MPS. + sites : sequence of int, optional + Construct the MPO on these sites only. If not given, enumerate from + zero. Should be monotonically increasing and match ``arrays``. + L : int, optional + The number of sites the MPO should be defined on. If not given, this is + taken as the max ``sites`` value plus one (i.e.g the number of arrays + if ``sites`` is not given). + shape : str, optional + String specifying layout of *input* arrays. E.g. 'lrp' (the default) + indicates the shape corresponds left-bond, right-bond, physical index. + End tensors have either 'l' or 'r' dropped from the string. The + arrays will be permuted to 'lrp' order. + tags : str or sequence of str, optional + Global tags to attach to all tensors. + site_ind_id : str + A string specifiying how to label the physical site indices. Should + contain a ``'{}'`` placeholder. It is used to generate the actual + indices like: ``map(site_ind_id.format, range(len(arrays)))``. + site_tag_id : str + A string specifiying how to tag the tensors at each site. Should + contain a ``'{}'`` placeholder. It is used to generate the actual tags + like: ``map(site_tag_id.format, range(len(arrays)))``. + """ + + _EXTRA_PROPS = ( + "_site_tag_id", + "_site_ind_id", + "cyclic", + "_L", + ) + + def __init__( + self, + arrays, + *, + sites=None, + L=None, + shape="lrp", + tags=None, + site_ind_id="k{}", + site_tag_id="I{}", + **tn_opts, + ): + # short-circuit for copying MPSs + if isinstance(arrays, MatrixProductState): + super().__init__(arrays) + return + + arrays = tuple(arrays) + + if sites is None: + # assume dense + sites = range(len(arrays)) + if L is None: + L = len(arrays) + num_sites = L + else: + sites = tuple(sites) + if L is None: + L = max(sites) + 1 + num_sites = len(sites) + + self._L = len(arrays) + self._site_ind_id = site_ind_id + self._site_tag_id = site_tag_id + self.cyclic = ops.ndim(arrays[0]) == 3 + + tensors = [] + tags = tags_to_oset(tags) + bonds = [rand_uuid() for _ in range(num_sites)] + # account for cyclic case + bonds.append(bonds[0]) + + for i, (site, array) in enumerate(zip(sites, arrays)): + inds = [] + + if L == 1: + # only one site + if self.cyclic: + # bond is a self loop on the single tensor + shape_desired = "lrp" + inds.append(bonds[i]) + inds.append(bonds[i]) + # XXX: should we just trace it out instead? + else: + # no bonds, just physical index + shape_desired = "p" + + elif (i == 0) and not self.cyclic: + # only right bond + shape_desired = "rp" + inds.append(bonds[i + 1]) + elif (i == num_sites - 1) and not self.cyclic: + # only left bond + shape_desired = "lp" + inds.append(bonds[i]) + else: + shape_desired = "lrp" + # both bonds + inds.append(bonds[i]) + inds.append(bonds[i + 1]) + + # this is the perm needed to bring the arrays from + # their current `shape`, to the desired 'lrud' order + shape_given = [x for x in shape if x in shape_desired] + order = [shape_given.index(x) for x in shape_desired] + + # physical index + inds.append(site_ind_id.format(site)) + + tensors.append( + Tensor( + data=transpose(array, order), + inds=inds, + tags=tags | oset([site_tag_id.format(site)]), + ) + ) + + super().__init__(tensors, virtual=True, **tn_opts) + + @classmethod + def from_fill_fn( + cls, + fill_fn, + L, + bond_dim, + phys_dim=2, + sites=None, + cyclic=False, + shape="lrp", + site_ind_id="k{}", + site_tag_id="I{}", + tags=None, + ): + """Create an MPS by supplying a 'filling' function to generate the data + for each site. + + Parameters + ---------- + fill_fn : callable + A function with signature + ``fill_fn(shape : tuple[int]) -> array_like``. + L : int + The number of sites. + bond_dim : int + The bond dimension. + phys_dim : int or Sequence[int], optional + The physical dimension(s) of each site, if a sequence it will be + cycled over. + sites : None or sequence of int, optional + Construct the MPS on these sites only. If not given, enumerate from + zero. + cyclic : bool, optional + Whether the MPS should be cyclic (periodic). + shape : str, optional + String specifying layout of *input* arrays. E.g. 'lrp' (the + default) indicates the shape corresponds left-bond, right-bond, + physical index. End tensors have either 'l' or 'r' dropped from the + string. The arrays will be permuted to 'lrp' order. + site_ind_id : str, optional + How to label the physical site indices. + site_tag_id : str, optional + How to tag the physical sites. + tags : str or sequence of str, optional + Global tags to attach to all tensors. + + Returns + ------- + MatrixProductState + """ + if set(shape) - set("lrp"): + raise ValueError("Invalid shape string: {}".format(shape)) + + # check for site varying physical dimensions + if isinstance(phys_dim, Integral): + phys_dims = itertools.repeat(phys_dim) + else: + phys_dims = itertools.cycle(phys_dim) + + mps = cls.new( + L=L, + cyclic=cyclic, + site_ind_id=site_ind_id, + site_tag_id=site_tag_id, + ) + + # which sites are actually present + if sites is None: + sites = range(L) + else: + sites = tuple(sites) + num_sites = len(sites) + + global_tags = tags_to_oset(tags) + bonds = [rand_uuid() for _ in range(num_sites)] + bonds.append(bonds[0]) + + for i, site in enumerate(sites): + inds = [] + data_shape = [] + for c in shape: + if c == "l": + if (i - 1) >= 0 or cyclic: + inds.append(bonds[i]) + data_shape.append(bond_dim) + elif c == "r": + if (i + 1) < num_sites or cyclic: + inds.append(bonds[i + 1]) + data_shape.append(bond_dim) + else: # c == 'p': + inds.append(site_ind_id.format(site)) + data_shape.append(next(phys_dims)) + data = fill_fn(data_shape) + tags = global_tags | oset((site_tag_id.format(site),)) + mps |= Tensor(data, inds=inds, tags=tags) + + return mps + + @classmethod + def from_dense( + cls, + psi, + dims=2, + tags=None, + site_ind_id="k{}", + site_tag_id="I{}", + **split_opts, + ): + """Create a ``MatrixProductState`` directly from a dense vector + + Parameters + ---------- + psi : array_like + The dense state to convert to MPS from. + dims : int or sequence of int + Physical subsystem dimensions of each site. If a single int, all + sites have this same dimension, by default, 2. + tags : str or sequence of str, optional + Global tags to attach to all tensors. + site_ind_id : str, optional + How to index the physical sites, see + :class:`~quimb.tensor.tn1d.core.MatrixProductState`. + site_tag_id : str, optional + How to tag the physical sites, see + :class:`~quimb.tensor.tn1d.core.MatrixProductState`. + split_opts + Supplied to :func:`~quimb.tensor.tensor_core.tensor_split` to + in order to partition the dense vector into tensors. + ``absorb='left'`` is set by default, to ensure the compression + is canonical / optimal. + + Returns + ------- + MatrixProductState + + Examples + -------- + + >>> dims = [2, 2, 2, 2, 2, 2] + >>> psi = rand_ket(prod(dims)) + >>> mps = MatrixProductState.from_dense(psi, dims) + >>> mps.show() + 2 4 8 4 2 + o-o-o-o-o-o + | | | | | | + """ + set_default_compress_mode(split_opts) + # ensure compression is canonical / optimal + split_opts.setdefault("absorb", "right") + + # make sure array_like + psi = ops.asarray(psi) + + if isinstance(dims, Integral): + # assume all sites have the same dimension + L = round(log(size(psi), dims)) + dims = (dims,) * L + else: + dims = tuple(dims) + L = len(dims) + + # create a bare MPS TN object + mps = cls.new( + L=L, + cyclic=False, + site_ind_id=site_ind_id, + site_tag_id=site_tag_id, + ) + + inds = [mps.site_ind(i) for i in range(L)] + + tm = Tensor(data=reshape(psi, dims), inds=inds) + for i in range(L - 1): + # progressively split off one more physical index + tl, tm = tm.split( + left_inds=None, + right_inds=inds[i + 1 :], + ltags=mps.site_tag(i), + get="tensors", + **split_opts, + ) + # add left tensor + mps |= tl + + # add final right tensor + tm.add_tag(mps.site_tag(L - 1)) + mps |= tm + + # add global tags + if tags is not None: + mps.add_tag(tags) + + return mps + + def add_MPS(self, other, inplace=False, **kwargs): + """Add another MatrixProductState to this one.""" + return tensor_network_ag_sum(self, other, inplace=inplace, **kwargs) + + add_MPS_ = functools.partialmethod(add_MPS, inplace=True) + + def permute_arrays(self, shape="lrp"): + """Ensure the arrays are stored internally in the specified order. + This doesn't change how the overall object interacts with other tensor + networks but may be useful for extracting the underlying arrays + consistently. This is an inplace operation. + + Parameters + ---------- + shape : str, optional + A permutation of ``'lrp'`` specifying the *desired* order of the + [l]eft, [r]ight, and [p]hysical indices respectively. + """ + self.ensure_bonds_exist() + + for i in self.gen_sites_present(): + inds = {"p": self.site_ind(i)} + if self.cyclic or i > 0: + inds["l"] = self.bond(i, (i - 1) % self.L) + if self.cyclic or i < self.L - 1: + inds["r"] = self.bond(i, (i + 1) % self.L) + inds = [inds[s] for s in shape if s in inds] + self[i].transpose_(*inds) + + def normalize(self, bra=None, eps=1e-15, insert=None): + """Normalize this MPS, optional with co-vector ``bra``. For periodic + MPS this uses transfer matrix SVD approximation with precision ``eps`` + in order to be efficient. Inplace. + + Parameters + ---------- + bra : MatrixProductState, optional + If given, normalize this MPS with the same factor. + eps : float, optional + If cyclic, precision to approximation transfer matrix with. + Default: 1e-14. + insert : int, optional + Insert the corrective normalization on this site, random if + not given. + + Returns + ------- + old_norm : float + The old norm ``self.H @ self``. + """ + norm = expec_TN_1D(self.H, self, eps=eps) + + if insert is None: + insert = -1 + + self[insert].modify(data=self[insert].data / norm**0.5) + if bra is not None: + bra[insert].modify(data=bra[insert].data / norm**0.5) + + return norm + + def gate_split(self, G, where, inplace=False, **compress_opts): + r"""Apply a two-site gate and then split resulting tensor to retrieve a + MPS form:: + + -o-o-A-B-o-o- + | | | | | | -o-o-GGG-o-o- -o-o-X~Y-o-o- + | | GGG | | ==> | | | | | | ==> | | | | | | + | | | | | | i j i j + i j + + As might be found in TEBD. + + Parameters + ---------- + G : array + The gate, with shape ``(d**2, d**2)`` for physical dimension ``d``. + where : (int, int) + Indices of the sites to apply the gate to. + compress_opts + Supplied to :func:`~quimb.tensor.tensor_split`. + + See Also + -------- + gate, gate_with_auto_swap + """ + set_default_compress_mode(compress_opts, self.cyclic) + ix_i, ix_j = map(self.site_ind, where) + # note that 'reduce-split' is unecessary: tensors have ndim<=3 + return self.gate_inds( + G, (ix_i, ix_j), contract="split", inplace=inplace, **compress_opts + ) + + gate_split_ = functools.partialmethod(gate_split, inplace=True) + + @convert_cur_orthog + def swap_sites_with_compress( + self, i, j, info=None, inplace=False, **compress_opts + ): + """Swap sites ``i`` and ``j`` by contracting, then splitting with the + physical indices swapped. If the sites are not adjacent, this will + happen multiple times. + + Parameters + ---------- + i : int + The first site to swap. + j : int + The second site to swap. + cur_orthog : int, sequence of int, or 'calc' + If known, the current orthogonality center. + info : dict, optional + If supplied, will be used to infer and store various extra + information. Currently, the key "cur_orthog" is used to store the + current orthogonality center. Its input value can be ``"calc"``, a + single site, or a pair of sites representing the min/max range, + inclusive. It will be updated to the actual range after. + inplace : bond, optional + Perform the swaps inplace. + compress_opts + Supplied to :func:`~quimb.tensor.tensor_core.tensor_split`. + """ + mps = self if inplace else self.copy() + + i, j = sorted((i, j)) + if i + 1 != j: + mps.swap_site_to_(j, i, info=info) + # first site is now at j + 1, move back up + mps.swap_site_to_(i + 1, j, info=info) + return mps + + mps.canonicalize_((i, j), info=info) + + # get site tensors and indices + ix_i, ix_j = map(mps.site_ind, (i, j)) + Ti, Tj = mps[i], mps[j] + _, unshared = Ti.filter_bonds(Tj) + + # split the contracted tensor, swapping the site indices + Tij = Ti @ Tj + lix = [i for i in unshared if i != ix_i] + [ix_j] + set_default_compress_mode(compress_opts, self.cyclic) + sTi, sTj = Tij.split(lix, get="tensors", **compress_opts) + + # reindex and transpose the tensors to directly update original tensors + sTi.reindex_({ix_j: ix_i}).transpose_like_(Ti) + Ti.modify(data=sTi.data) + sTj.reindex_({ix_i: ix_j}).transpose_like_(Tj) + Tj.modify(data=sTj.data) + + absorb = compress_opts.get("absorb", None) + if absorb == "left": + info["cur_orthog"] = (i, i) + elif absorb == "right": + info["cur_orthog"] = (j, j) + + return mps + + swap_sites_with_compress_ = functools.partialmethod( + swap_sites_with_compress, + inplace=True, + ) + + @convert_cur_orthog + def swap_site_to( + self, + i, + f, + info=None, + inplace=False, + **compress_opts, + ): + r"""Swap site ``i`` to site ``f``, compressing the bond after each + swap:: + + i f + 0 1 2 3 4 5 6 7 8 9 0 1 2 4 5 6 7 3 8 9 + o-o-o-x-o-o-o-o-o-o >->->->->->->-x-<-< + | | | | | | | | | | -> | | | | | | | | | | + + + Parameters + ---------- + i : int + The site to move. + f : int + The new location for site ``i``. + info : dict, optional + If supplied, will be used to infer and store various extra + information. Currently, the key "cur_orthog" is used to store the + current orthogonality center. Its input value can be ``"calc"``, a + single site, or a pair of sites representing the min/max range, + inclusive. It will be updated to the actual range after. + inplace : bond, optional + Perform the swaps inplace. + compress_opts + Supplied to :func:`~quimb.tensor.tensor_core.tensor_split`. + """ + mps = self if inplace else self.copy() + + if i == f: + return mps + if i < f: + compress_opts.setdefault("absorb", "right") + js = range(i, f) + if f < i: + compress_opts.setdefault("absorb", "left") + js = range(i - 1, f - 1, -1) + + for j in js: + mps.swap_sites_with_compress( + j, j + 1, info=info, inplace=True, **compress_opts + ) + + return mps + + swap_site_to_ = functools.partialmethod(swap_site_to, inplace=True) + + @convert_cur_orthog + def gate_with_auto_swap( + self, + G, + where, + info=None, + swap_back=True, + inplace=False, + **compress_opts, + ): + """Perform a two site gate on this MPS by, if necessary, swapping and + compressing the sites until they are adjacent, using ``gate_split``, + then unswapping the sites back to their original position. + + Parameters + ---------- + G : array + The gate, with shape ``(d**2, d**2)`` for physical dimension ``d``. + where : (int, int) + Indices of the sites to apply the gate to. + info : dict, optional + If supplied, will be used to infer and store various extra + information. Currently, the key "cur_orthog" is used to store the + current orthogonality center. Its input value can be ``"calc"``, a + single site, or a pair of sites representing the min/max range, + inclusive. It will be updated to the actual range after. + swap_back : bool, optional + Whether to swap the sites back to their original position after + applying the gate. If not, for sites ``i < j``, the site ``j`` will + remain swapped to ``i + 1``, and sites between ``i + 1`` and ``j`` + will be shifted one place up. + inplace : bond, optional + Perform the swaps inplace. + compress_opts + Supplied to :func:`~quimb.tensor.tensor_core.tensor_split`. + + See Also + -------- + gate, gate_split + """ + mps = self if inplace else self.copy() + + i, j = where + + if i > j: + # work with i < j but flip application of gate when necessary + i, j = j, i + final_gate_where = (i + 1, i) + absorb = "left" + else: + final_gate_where = (i, i + 1) + absorb = "right" + + need_to_swap = i + 1 != j + + # move j site adjacent to i site + if need_to_swap: + mps.swap_site_to( + j, i + 1, info=info, inplace=True, **compress_opts + ) + + # make sure sites are orthog center + mps.canonicalize_((i, i + 1), info=info) + + # apply gate and split back into MPS form + mps.gate_split_( + G, where=final_gate_where, absorb=absorb, **compress_opts + ) + + # absorb setting guarantees this + info["cur_orthog"] = (i + 1, i + 1) + + if need_to_swap and swap_back: + # move j site back to original position + mps.swap_site_to( + i + 1, j, info=info, inplace=True, **compress_opts + ) + + return mps + + gate_with_auto_swap_ = functools.partialmethod( + gate_with_auto_swap, + inplace=True, + ) + + @convert_cur_orthog + def gate_with_submpo( + self, + submpo, + where=None, + method="direct", + transpose=False, + info=None, + inplace=False, + inplace_mpo=False, + **compress_opts, + ): + """Apply an MPO, which only acts on a subset of sites, to this MPS, + compressing the MPS with the MPO only on the minimal set of sites + covering `where`, keeping the MPS form:: + + │ │ │ + A───A─A + │ │ │ -> │ │ │ │ │ │ │ │ + >─>─O━O━O━O─<─< + │ │ │ │ │ │ │ │ + o─o─o─o─o─o─o─o + + Parameters + ---------- + submpo : MatrixProductOperator + The MPO to apply. + where : sequence of int, optional + The range of sites the MPO acts on, will be inferred from the + support of the MPO if not given. + method : {"direct", "lazy", "fit", ...}, optional + The compression method to use. If "lazy", the MPO is simply added + to the MPS without any contraction or compression. Else `method` + is passed to + :func:`~quimb.tensor.tn1d.compress.tensor_network_1d_compress` + and controls how the compression of the subsection is performed. + transpose : bool, optional + Whether to transpose the MPO before applying it. By default the + lower inds of the MPO are contracted with the MPS, if transposed + the upper inds are contracted. + info : dict, optional + If supplied, will be used to infer and store various extra + information. Currently, the key "cur_orthog" is used to store the + current orthogonality center. Its input value can be ``"calc"``, a + single site, or a pair of sites representing the min/max range, + inclusive. It will be updated to the actual range after. + inplace : bool, optional + Whether to perform the application and compression inplace. + compress_opts + Supplied to + :func:`~quimb.tensor.tn1d.compress.tensor_network_1d_compress`. + + Returns + ------- + MatrixProductState + """ + from .compress import tensor_network_1d_compress + + psi = self if inplace else self.copy() + + # get the span of sites the sub-MPO acts on + if where is None: + where = tuple(submpo.gen_sites_present()) + + if method != "lazy": + si, sf = min(where), max(where) + + # make the psi canonical around the sub-MPO region + psi.canonicalize_((si, sf), info=info) + + # lazily combine the sub-MPO with the MPS + psi.gate_with_op_lazy_( + submpo, + transpose=transpose, + inplace_op=inplace_mpo, + ) + + if method == "lazy": + # we just add the sub-MPO, no contraction or compression + return psi + + # split off the sub MPS-MPO TN section + sub_site_tags = [psi.site_tag(s) for s in range(si, sf + 1)] + _, subpsi = psi.partition(sub_site_tags, which="any", inplace=True) + + # compress it! + tensor_network_1d_compress( + subpsi, + site_tags=sub_site_tags, + method=method, + # the sub TN can't be automatically permuted when missing sites + permute_arrays=False, + inplace=True, + **compress_opts, + ) + + if compress_opts.get("sweep_reverse", False): + info["cur_orthog"] = (sf, sf) + else: + info["cur_orthog"] = (si, si) + + # recombine the compressed sub region TN + psi |= subpsi + + return psi + + gate_with_submpo_ = functools.partialmethod(gate_with_submpo, inplace=True) + + def gate_with_mpo( + self, + mpo, + method="direct", + transpose=False, + inplace=False, + inplace_mpo=False, + **compress_opts, + ): + """Gate this MPS with an MPO and compress the result with one of + various methods back to MPS form:: + + │ │ │ │ │ │ │ │ + A─A─A─A─A─A─A─A + │ │ │ │ │ │ │ │ -> │ │ │ │ │ │ │ │ + O━O━O━O━O━O━O━O + │ │ │ │ │ │ │ │ + o─o─o─o─o─o─o─o + + Parameters + ---------- + mpo : MatrixProductOperator + The MPO to apply. + max_bond : int, optional + A maximum bond dimension to keep when compressing. + cutoff : float, optional + A singular value cutoff to use when compressing. + method : {'direct", 'dm', 'zipup', 'zipup-first', 'fit', ...}, optional + The compression method to use. + transpose : bool, optional + Whether to transpose the MPO before applying it. By default the + lower inds of the MPO are contracted with the MPS, if transposed + the upper inds are contracted. + inplace : bool, optional + Whether to perform the compression inplace. + inplace_mpo : bool, optional + Whether to reindex the operator tensor network ``mpo`` inplace, a + minor performance gain if you don't need to use it afterwards. + compress_opts + Other options supplied to + :func:`~quimb.tensor.tn1d.compress.tensor_network_1d_compress`. + + Returns + ------- + MatrixProductState + """ + from .compress import tensor_network_1d_compress + + psi = self if inplace else self.copy() + + # lazily combine the MPO with the MPS + psi.gate_with_op_lazy_( + mpo, + transpose=transpose, + inplace_op=inplace_mpo, + ) + + # compress it! + return tensor_network_1d_compress( + psi, + method=method, + inplace=True, + **compress_opts, + ) + + gate_with_mpo_ = functools.partialmethod(gate_with_mpo, inplace=True) + + @convert_cur_orthog + def gate_nonlocal( + self, + G, + where, + dims=None, + method="direct", + info=None, + inplace=False, + **compress_opts, + ): + """Apply a potentially non-local gate to this MPS by first decomposing + it into an MPO, then compressing the MPS with MPO only on the minimal + set of sites covering `where`. + + Parameters + ---------- + G : array_like + The gate to apply. + where : sequence of int + The sites to apply the gate to. + max_bond : int, optional + A maximum bond dimension to keep when compressing. + cutoff : float, optional + A singular value cutoff to use when compressing. + dims : sequence of int, optional + The factorized dimensions of the gate ``G``, which should match the + physical dimensions of the sites it acts on. Calculated if not + supplied. If a single int, all sites are assumed to have this same + dimension. + method : {"direct", "lazy", "fit", ...}, optional + The compression method to use. If "lazy", the MPO is simply added + to the MPS without any contraction or compression. Else `method` + is passed to + :func:`~quimb.tensor.tn1d.compress.tensor_network_1d_compress` + and controls how the compression of the subsection is performed. + info : dict, optional + If supplied, will be used to infer and store various extra + information. Currently, the key "cur_orthog" is used to store the + current orthogonality center. Its input value can be ``"calc"``, a + single site, or a pair of sites representing the min/max range, + inclusive. It will be updated to the actual range after. + inplace : bool, optional + Whether to perform the compression inplace. + compress_opts + Supplied to + :func:`~quimb.tensor.tn1d.compress.tensor_network_1d_compress`. + + Returns + ------- + MatrixProductState + """ + if dims is None: + dims = tuple(self.phys_dim(i) for i in where) + + # create a sub-MPO and lazily combine it with the MPS + mpo = MatrixProductOperator.from_dense( + G, dims=dims, sites=where, L=self.L + ) + + return self.gate_with_submpo_( + mpo, + where=where, + method=method, + info=info, + inplace=inplace, + inplace_mpo=True, + **compress_opts, + ) + + gate_nonlocal_ = functools.partialmethod(gate_nonlocal, inplace=True) + + def flip(self, inplace=False): + """Reverse the order of the sites in the MPS, such that site ``i`` is + now at site ``L - i - 1``. + """ + flipped = self if inplace else self.copy() + + retag_map = { + self.site_tag(i): self.site_tag(self.L - i - 1) for i in self.sites + } + reindex_map = { + self.site_ind(i): self.site_ind(self.L - i - 1) for i in self.sites + } + + return flipped.retag_(retag_map).reindex_(reindex_map) + + @convert_cur_orthog + def magnetization( + self, + i, + direction="Z", + info=None, + ): + """Compute the magnetization at site ``i``.""" + if self.cyclic: + msg = ( + "``magnetization`` currently makes use of orthogonality for" + " efficiencies sake, for cyclic systems is it still " + "possible to compute as a normal expectation." + ) + raise NotImplementedError(msg) + + self.canonicalize_(i, info=info) + + # +-k-+ + # | O | + # +-b-+ + + Tk = self[i] + ind1, ind2 = self.site_ind(i), "__tmp__" + Tb = Tk.H.reindex({ind1: ind2}) + + O_data = qu.spin_operator(direction, S=(self.phys_dim(i) - 1) / 2) + TO = Tensor(O_data, inds=(ind1, ind2)) + + return Tk.contract(TO, Tb) + + @convert_cur_orthog + def schmidt_values(self, i, info=None, method="svd"): + r"""Find the schmidt values associated with the bipartition of this + MPS between sites on either site of ``i``. In other words, ``i`` is the + number of sites in the left hand partition:: + + ....L.... i + o-o-o-o-o-S-o-o-o-o-o-o-o-o-o-o-o + | | | | | | | | | | | | | | | | + i-1 ..........R.......... + + The schmidt values, ``S``, are the singular values associated with the + ``(i - 1, i)`` bond, squared, provided the MPS is mixed canonized at + one of those sites. + + Parameters + ---------- + i : int + The number of sites in the left partition. + info : dict, optional + If given, will be used to infer and store various extra + information. Currently the key "cur_orthog" is used to store the + current orthogonality center. + + Returns + ------- + S : 1d-array + The schmidt values. + """ + if self.cyclic: + raise NotImplementedError + + return self.singular_values(i, info=info, method=method) ** 2 + + @convert_cur_orthog + def entropy(self, i, info=None, method="svd"): + """The entropy of bipartition between the left block of ``i`` sites and + the rest. + + Parameters + ---------- + i : int + The number of sites in the left partition. + info : dict, optional + If given, will be used to infer and store various extra + information. Currently the key "cur_orthog" is used to store the + current orthogonality center. + + Returns + ------- + float + """ + if self.cyclic: + msg = ( + "For cyclic systems, try explicitly computing the entropy " + "of the (compressed) reduced density matrix." + ) + raise NotImplementedError(msg) + + S = self.schmidt_values(i, info=info, method=method) + S = S[S > 0.0] + return do("sum", -S * do("log2", S)) + + @convert_cur_orthog + def schmidt_gap(self, i, info=None, method="svd"): + """The schmidt gap of bipartition between the left block of ``i`` sites + and the rest. + + Parameters + ---------- + i : int + The number of sites in the left partition. + info : dict, optional + If given, will be used to infer and store various extra + information. Currently the key "cur_orthog" is used to store the + current orthogonality center. + + Returns + ------- + float + """ + if self.cyclic: + raise NotImplementedError + + S = self.schmidt_values(i, info=info, method=method) + + if len(S) == 1: + return S[0] + + return S[0] - S[1] + + def partial_trace_to_mpo( + self, keep, upper_ind_id="b{}", rescale_sites=True + ): + r"""Partially trace this matrix product state, producing a matrix + product operator. + + Parameters + ---------- + keep : sequence of int or slice + Indicies of the sites to keep. + upper_ind_id : str, optional + The ind id of the (new) 'upper' inds, i.e. the 'bra' inds. + rescale_sites : bool, optional + If ``True`` (the default), then the kept sites will be rescaled to + ``(0, 1, 2, ...)`` etc. rather than keeping their original site + numbers. + + Returns + ------- + rho : MatrixProductOperator + The density operator in MPO form. + """ + p_bra = self.copy() + p_bra.reindex_sites_(upper_ind_id, where=keep) + rho = self.H & p_bra + # now have e.g: + # | | | | + # o-o-o-o-o-o-o-o-o + # | | | | | + # o-o-o-o-o-o-o-o-o + # | | | | + + if isinstance(keep, slice): + keep = self.slice2sites(keep) + + keep = sorted(keep) + + for i in self.gen_sites_present(): + if i in keep: + # | + # -o- | + # ... -o- ... -> ... -O- ... + # i| i| + rho ^= self.site_tag(i) + else: + # | + # -o-o- | + # ... | ... -> ... -OO- ... + # -o-o- |i+1 + # i |i+1 + if i < self.L - 1: + rho >>= [self.site_tag(i), self.site_tag(i + 1)] + else: + rho >>= [self.site_tag(i), self.site_tag(max(keep))] + + rho.drop_tags(self.site_tag(i)) + + # if single site a single tensor is produced + if isinstance(rho, Tensor): + rho = rho.as_network() + + if rescale_sites: + # e.g. [3, 4, 5, 7, 9] -> [0, 1, 2, 3, 4] + retag, reind = {}, {} + for new, old in enumerate(keep): + retag[self.site_tag(old)] = self.site_tag(new) + reind[self.site_ind(old)] = self.site_ind(new) + reind[upper_ind_id.format(old)] = upper_ind_id.format(new) + + rho.retag_(retag) + rho.reindex_(reind) + L = len(keep) + else: + L = self.L + + # transpose upper and lower tags to match other MPOs + rho.view_as_( + MatrixProductOperator, + cyclic=self.cyclic, + L=L, + site_tag_id=self.site_tag_id, + lower_ind_id=upper_ind_id, + upper_ind_id=self.site_ind_id, + ) + rho.fuse_multibonds_() + return rho + + def partial_trace(self, *_, **__): + raise AttributeError( + "`mps.partial_trace` has been renamed to " + "`mps.partial_trace_to_mpo`. Soon `mps.partial_trace` " + "will produce (dense) local reduced density matrices to match " + "methods elsewhere in quimb." + ) + + def ptr(self, *_, **__): + raise AttributeError( + "`mps.ptr` has been renamed to `mps.partial_trace_to_mpo`." + ) + + def partial_trace_to_dense_canonical( + self, where, normalized=True, info=None, **contract_opts + ): + """Compute the dense local reduced density matrix by canonicalizing + around the target sites and then contracting the local tensors. Note + this moves the orthogonality around inplace, and records it in `info`. + + Parameters + ---------- + where : int or tuple[int] + The site or sites to compute the reduced density matrix for. + normalized : bool, optional + Explicitly normalize the local reduced density matrix. + info : dict, optional + If supplied, will be used to infer and store various extra + information. Currently the key "cur_orthog" is used to store the + current orthogonality center. Its input value can be ``"calc"``, a + single site, or a pair of sites representing the min/max range, + inclusive. It will be updated to the actual range after. + contract_opts + Passed to `tensor_contract` when computing the reduced local + density matrix. + + Returns + ------- + array_like + """ + if self.cyclic: + raise NotImplementedError("Only supports OBC.") + + if isinstance(where, Integral): + where = (where,) + + # canonicalize around our sites + self.canonicalize_(where, info=info) + + # form the local reduced density matrix tn + kix = [self.site_ind(i) for i in where] + bix = [f"__b{i}__" for i in where] + k = self[min(where) : max(where) + 1] + b = k.reindex(dict(zip(kix, bix))).conj_() + rho_tn = k | b + + # contract down to a matrix + rho = rho_tn.to_dense(kix, bix, **contract_opts) + + if normalized: + # locally normalize, usually unnecessary for an MPS but cheap + rho = rho / do("trace", rho) + + return rho + + def local_expectation_canonical( + self, G, where, normalized=True, info=None, **contract_opts + ): + """Compute a local expectation value (via forming the reduced density + matrix). Note this moves the orthogonality around inplace, and records + it in `info`. + + Parameters + ---------- + G : array_like + The local operator to compute the expectation of. + where : int or tuple[int] + The site or sites to compute the expectation at. + normalized : bool, optional + Explicitly normalize the local reduced density matrix. + info : dict, optional + If supplied, will be used to infer and store various extra + information. Currently the key "cur_orthog" is used to store the + current orthogonality center. Its input value can be ``"calc"``, a + single site, or a pair of sites representing the min/max range, + inclusive. It will be updated to the actual range after. + contract_opts + Passed to `tensor_contract` when computing the reduced local + density matrix. + + Returns + ------- + float + """ + rho = self.partial_trace_to_dense_canonical( + where, normalized=normalized, info=info, **contract_opts + ) + return do("trace", G @ rho) + + def compute_local_expectation_canonical( + self, + terms, + normalized=True, + return_all=False, + info=None, + inplace=False, + **contract_opts, + ): + """Compute many local expectations at once, via forming the relevant + reduced density matrices via canonicalization. This moves the + orthogonality around inplace, and records it in `info`. + + Parameters + ---------- + terms : dict[int or tuple[int], array_like] + The local terms to compute values for. + normalized : bool, optional + Explicitly normalize each local reduced density matrix. + return_all : bool, optional + Whether to return each expectation in `terms` separately + or sum them all together (the default). + info : dict, optional + If supplied, will be used to infer and store various extra + information. Currently, the key "cur_orthog" is used to store the + current orthogonality center. Its input value can be ``"calc"``, a + single site, or a pair of sites representing the min/max range, + inclusive. It will be updated to the actual range after. + inplace : bool, optional + Whether to perform the required canonicalizations inplace. + contract_opts + Supplied to + :meth:`~quimb.tensor.tensor_core.TensorNetwork.contract` + when contracting the local density matrices. + + Returns + ------- + float or dict[in or tuple[int], float] + The expecetation value(s), either summed or for each term if + `return_all=True`. + + See Also + -------- + compute_local_expectation_via_envs, local_expectation_canonical + partial_trace_to_dense_canonical + """ + if self.cyclic: + raise NotImplementedError("Only supports OBC.") + + if info is None: + # this is used to keep track of canonical center + info = {} + + if inplace: + mps = self + else: + mps = self.copy() + info = info.copy() + + cur_orthog = info.get("cur_orthog", "calc") + if isinstance(cur_orthog, tuple): + # have a canonical center already -> start close to it + terms = sorted( + terms.items(), key=lambda kv: abs(min(kv[0]) - cur_orthog[0]) + ) + else: + # sort by the smallest site so we sweep in one direction + terms = sorted(terms.items(), key=lambda kv: min(kv[0])) + + expecs = { + where: mps.local_expectation_canonical( + G, + where, + normalized=normalized, + info=info, + **contract_opts, + ) + for where, G in terms + } + + if return_all: + return expecs + + return functools.reduce(operator.add, expecs.values()) + + def compute_local_expectation_via_envs( + self, + terms, + normalized=True, + return_all=False, + **contract_opts, + ): + """Compute many local expectations at once, via forming the relevant + local overlaps using left and right environments formed via + contraction. This does not require any canonicalization and can be + quicker if the canonical center is not already aligned. + + Parameters + ---------- + terms : dict[int or tuple[int], array_like] + The local terms to compute values for. + normalized : bool, optional + Explicitly normalize each local reduced density matrix. + return_all : bool, optional + Whether to return each expectation in `terms` separately + or sum them all together (the default). + contract_opts + Supplied to + :meth:`~quimb.tensor.tensor_core.TensorNetwork.contract` + when contracting the local overlaps. + + Returns + ------- + float or dict[int or tuple[int], float] + The expecetation value(s), either summed or for each term if + `return_all=True`. + + See Also + -------- + compute_local_expectation_canonical, compute_left_environments, + compute_right_environments + """ + norm, ket, bra = self.make_norm(return_all=True) + + left_envs = norm.compute_left_environments(**contract_opts) + right_envs = norm.compute_right_environments(**contract_opts) + + expecs = {} + + if normalized: + nfactor = (norm.select(0) | right_envs[0]).contract( + all, **contract_opts + ) + else: + nfactor = None + + for where, G in terms.items(): + sitemin = min(where) + sitemax = max(where) + tags = [ket.site_tag(i) for i in range(sitemin, sitemax + 1)] + # form: + # sitemin sitemax + # : : + # ┌─┐ ┌─┐ + # ┌──┤k├─┤k├──┐ + # │ └┬┘ └┬┘ │ + # │ │ │ │ + # ┌┴┐ ┌┴───┴┐ ┌┴┐ + # │l│ │ G │ │r│ + # └┬┘ └┬───┬┘ └┬┘ + # │ │ │ │ + # │ ┌┴┐ ┌┴┐ │ + # └──┤b├─┤b├──┘ + # └─┘ └─┘ + # (n.b. might be non-gated sites in between as well) + k = ket.select_any(tags, virtual=False) + b = bra.select_any(tags, virtual=False) + k.gate_(G, where, contract=False) + + tn_local_overlap = k | b + if sitemin in left_envs: + tn_local_overlap |= left_envs[sitemin] + if sitemax in right_envs: + tn_local_overlap |= right_envs[sitemax] + + x = tn_local_overlap.contract(all, **contract_opts) + if normalized: + x = x / nfactor + + expecs[where] = x + + if return_all: + return expecs + + return functools.reduce(operator.add, expecs.values()) + + def compute_local_expectation( + self, + terms, + normalized=True, + return_all=False, + method="canonical", + info=None, + inplace=False, + **contract_opts, + ): + """Compute many local expectations at once. + + Parameters + ---------- + terms : dict[int or tuple[int], array_like] + The local terms to compute values for. + normalized : bool, optional + Explicitly normalize each local term. + return_all : bool, optional + Whether to return each expectation in `terms` separately + or sum them all together (the default). + method : {'canonical', 'envs'}, optional + The method to use to compute the local expectations. + + - 'canonical': canonicalize around the sites of interest and + contract the local reduced density matrices, moving the canonical + center around as needed. + - 'envs': form the local overlaps using left and right environments + and contract these directly. This can be quicker if the canonical + center is not already aligned. + + info : dict, optional + If supplied, and `method=="canonical"`, will be used to infer and + store various extra information. Currently the key "cur_orthog" is + used to store the current orthogonality center. Its input value can + be ``"calc"``, a single site, or a pair of sites representing the + min/max range, inclusive. It will be updated to the actual range + after. + inplace : bool, optional + If `method=="canonical"`, whether to perform the required + canonicalizations inplace or on a copy of the state. + contract_opts + Supplied to + :meth:`~quimb.tensor.tensor_core.TensorNetwork.contract` + when contracting the local overlaps or density matrices. + + Returns + ------- + float or dict[int or tuple[int], float] + The expecetation value(s), either summed or for each term if + `return_all=True`. + + See Also + -------- + compute_local_expectation_canonical, compute_local_expectation_via_envs + """ + if method == "canonical": + return self.compute_local_expectation_canonical( + terms, + normalized=normalized, + return_all=return_all, + info=info, + inplace=inplace, + **contract_opts, + ) + elif method == "envs": + return self.compute_local_expectation_via_envs( + terms, + normalized=normalized, + return_all=return_all, + **contract_opts, + ) + else: + raise ValueError( + f"Unrecognized method: {method}, should be one of: " + "'canonical', 'envs'." + ) + + @convert_cur_orthog + def bipartite_schmidt_state(self, sz_a, get="ket", info=None): + r"""Compute the reduced state for a bipartition of an OBC MPS, in terms + of the minimal left/right schmidt basis:: + + A B + ......... ........... + >->->->->--s--<-<-<-<-<-< -> +-s-+ + | | | | | | | | | | | | | + k0 k1... kA kB + + Parameters + ---------- + sz_a : int + The number of sites in subsystem A, must be ``0 < sz_a < N``. + get : {'ket', 'rho', 'ket-dense', 'rho-dense'}, optional + Get the: + + - 'ket': vector form as tensor. + - 'rho': density operator form, i.e. vector outer product + - 'ket-dense': like 'ket' but return ``qarray``. + - 'rho-dense': like 'rho' but return ``qarray``. + + info : dict, optional + If given, will be used to infer and store various extra + information. Currently the key "cur_orthog" is used to store the + current orthogonality center. + """ + if self.cyclic: + raise NotImplementedError("MPS must have OBC.") + + s = do("diag", self.singular_values(sz_a, info=info)) + + if "dense" in get: + kd = qu.qarray(s.reshape(-1, 1)) + if "ket" in get: + return kd + elif "rho" in get: + return kd @ kd.H + + else: + k = Tensor(s, (self.site_ind("A"), self.site_ind("B"))) + if "ket" in get: + return k + elif "rho" in get: + return k & k.reindex({"kA": "bA", "kB": "bB"}) + + @staticmethod + def _do_lateral_compress( + mps, + kb, + section, + leave_short, + ul, + ll, + heps, + hmethod, + hmax_bond, + verbosity, + compressed, + **compress_opts, + ): + # section + # ul -o-o-o-o-o-o-o-o-o- ul -\ /- + # | | | | | | | | | ==> 0~~~~~0 + # ll -o-o-o-o-o-o-o-o-o- ll -/ : \- + # hmax_bond + + if leave_short: + # if section is short doesn't make sense to lateral compress + # work out roughly when this occurs by comparing bond size + left_sz = mps.bond_size(section[0] - 1, section[0]) + right_sz = mps.bond_size(section[-1], section[-1] + 1) + + if mps.phys_dim() ** len(section) <= left_sz * right_sz: + if verbosity >= 1: + print( + f"Leaving lateral compress of section '{section}' as" + f" it is too short: length={len(section)}, eff " + f"size={left_sz * right_sz}." + ) + return + + if verbosity >= 1: + print( + f"Laterally compressing section {section}. Using options: " + f"eps={heps}, method={hmethod}, max_bond={hmax_bond}" + ) + + section_tags = map(mps.site_tag, section) + kb.replace_with_svd( + section_tags, + (ul, ll), + heps, + inplace=True, + ltags="_LEFT", + rtags="_RIGHT", + method=hmethod, + max_bond=hmax_bond, + **compress_opts, + ) + + compressed.append(section) + + @staticmethod + def _do_vertical_decomp( + mps, + kb, + section, + sysa, + sysb, + compressed, + ul, + ur, + ll, + lr, + vmethod, + vmax_bond, + veps, + verbosity, + **compress_opts, + ): + if section == sysa: + label = "A" + elif section == sysb: + label = "B" + else: + return + + section_tags = [mps.site_tag(i) for i in section] + + if section in compressed: + # ----U---- | <- vmax_bond + # -\ /- / ----U---- + # L~~~~R ==> \ ==> + # -/ \- / ----D---- + # ----D---- | <- vmax_bond + + # try and choose a sensible method + if vmethod is None: + left_sz = mps.bond_size(section[0] - 1, section[0]) + right_sz = mps.bond_size(section[-1], section[-1] + 1) + if left_sz * right_sz <= 2**13: + # cholesky is not rank revealing + vmethod = "eigh" if vmax_bond else "cholesky" + else: + vmethod = "isvd" + + if verbosity >= 1: + print( + f"Performing vertical decomposition of section {label}, " + f"using options: eps={veps}, method={vmethod}, " + f"max_bond={vmax_bond}." + ) + + # do vertical SVD + kb.replace_with_svd( + section_tags, + (ul, ur), + right_inds=(ll, lr), + eps=veps, + ltags="_UP", + rtags="_DOWN", + method=vmethod, + inplace=True, + max_bond=vmax_bond, + **compress_opts, + ) + + # cut joined bond by reindexing to upper- and lower- ind_id. + kb.cut_between( + (mps.site_tag(section[0]), "_UP"), + (mps.site_tag(section[0]), "_DOWN"), + f"_tmp_ind_u{label}", + f"_tmp_ind_l{label}", + ) + + else: + # just unfold and fuse physical indices: + # | + # -A-A-A-A-A-A-A- -AAAAAAA- + # | | | | | | | ===> + # -A-A-A-A-A-A-A- -AAAAAAA- + # | + + if verbosity >= 1: + print(f"Just vertical unfolding section {label}.") + + kb, sec = kb.partition(section_tags, inplace=True) + sec_l, sec_u = sec.partition("_KET", inplace=True) + T_UP = sec_u ^ all + T_UP.add_tag("_UP") + T_UP.fuse_( + {f"_tmp_ind_u{label}": [mps.site_ind(i) for i in section]} + ) + T_DN = sec_l ^ all + T_DN.add_tag("_DOWN") + T_DN.fuse_( + {f"_tmp_ind_l{label}": [mps.site_ind(i) for i in section]} + ) + kb |= T_UP + kb |= T_DN + + def partial_trace_compress( + self, + sysa, + sysb, + eps=1e-8, + method=("isvd", None), + max_bond=(None, 1024), + leave_short=True, + renorm=True, + lower_ind_id="b{}", + verbosity=0, + **compress_opts, + ): + r"""Perform a compressed partial trace using singular value + lateral then vertical decompositions of transfer matrix products:: + + + .....sysa...... ...sysb.... + o-o-o-o-A-A-A-A-A-A-A-A-o-o-B-B-B-B-B-B-o-o-o-o-o-o-o-o-o + | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + ==> form inner product + + ............... ........... + o-o-o-o-A-A-A-A-A-A-A-A-o-o-B-B-B-B-B-B-o-o-o-o-o-o-o-o-o + | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + o-o-o-o-A-A-A-A-A-A-A-A-o-o-B-B-B-B-B-B-o-o-o-o-o-o-o-o-o + + ==> lateral SVD on each section + + .....sysa...... ...sysb.... + /\ /\ /\ /\ + ... ~~~E A~~~~~~~~~~~A E~E B~~~~~~~B E~~~ ... + \/ \/ \/ \/ + + ==> vertical SVD and unfold on A & B + + | | + /-------A-------\ /-----B-----\ + ... ~~~E E~E E~~~ ... + \-------A-------/ \-----B-----/ + | | + + With various special cases including OBC or end spins included in + subsytems. + + + Parameters + ---------- + sysa : sequence of int + The sites, which should be contiguous, defining subsystem A. + sysb : sequence of int + The sites, which should be contiguous, defining subsystem B. + eps : float or (float, float), optional + Tolerance(s) to use when compressing the subsystem transfer + matrices and vertically decomposing. + method : str or (str, str), optional + Method(s) to use for laterally compressing the state then + vertially compressing subsytems. + max_bond : int or (int, int), optional + The maximum bond to keep for laterally compressing the state then + vertially compressing subsytems. + leave_short : bool, optional + If True (the default), don't try to compress short sections. + renorm : bool, optional + If True (the default), renomalize the state so that ``tr(rho)==1``. + lower_ind_id : str, optional + The index id to create for the new density matrix, the upper_ind_id + is automatically taken as the current site_ind_id. + compress_opts : dict, optional + If given, supplied to ``partial_trace_compress`` to govern how + singular values are treated. See ``tensor_split``. + verbosity : {0, 1}, optional + How much information to print while performing the compressed + partial trace. + + Returns + ------- + rho_ab : TensorNetwork + Density matrix tensor network with + ``outer_inds = ('k0', 'k1', 'b0', 'b1')`` for example. + """ + N = self.L + + if (len(sysa) + len(sysb) == N) and not self.cyclic: + return self.bipartite_schmidt_state(len(sysa), get="rho") + + # parse horizontal and vertical svd tolerances and methods + try: + heps, veps = eps + except (ValueError, TypeError): + heps = veps = eps + try: + hmethod, vmethod = method + except (ValueError, TypeError): + hmethod = vmethod = method + try: + hmax_bond, vmax_bond = max_bond + except (ValueError, TypeError): + hmax_bond = vmax_bond = max_bond + + # the sequence of sites in each of the 'environment' sections + envm = range(max(sysa) + 1, min(sysb)) + envl = range(0, min(sysa)) + envr = range(max(sysb) + 1, N) + + # spread norm, and if not cyclic put in mixed canonical form, taking + # care that the orthogonality centre is in right place to use identity + k = self.copy() + k.left_canonize() + k.right_canonize(max(sysa) + (bool(envm) or bool(envr))) + + # form the inner product + b = k.conj() + k.add_tag("_KET") + b.add_tag("_BRA") + kb = k | b + + # label the various partitions + names = ("_ENVL", "_SYSA", "_ENVM", "_SYSB", "_ENVR") + for name, where in zip(names, (envl, sysa, envm, sysb, envr)): + if where: + kb.add_tag(name, where=map(self.site_tag, where), which="any") + + if self.cyclic: + # can combine right and left envs + sections = [envm, sysa, sysb, (*envr, *envl)] + else: + sections = [envm] + # if either system includes end, can ignore and use identity + if 0 not in sysa: + sections.append(sysa) + if N - 1 not in sysb: + sections.append(sysb) + + # ignore empty sections + sections = list(filter(len, sections)) + + # figure out the various indices + ul_ur_ll_lrs = [] + for section in sections: + # ...section[i].... + # ul[i] -o-o-o-o-o-o-o-o-o- ur[i] + # | | | | | | | | | + # ll[i] -o-o-o-o-o-o-o-o-o- lr[i] + + st_left = self.site_tag(section[0] - 1) + st_right = self.site_tag(section[0]) + (ul,) = bonds(kb["_KET", st_left], kb["_KET", st_right]) + (ll,) = bonds(kb["_BRA", st_left], kb["_BRA", st_right]) + + st_left = self.site_tag(section[-1]) + st_right = self.site_tag(section[-1] + 1) + (ur,) = bonds(kb["_KET", st_left], kb["_KET", st_right]) + (lr,) = bonds(kb["_BRA", st_left], kb["_BRA", st_right]) + + ul_ur_ll_lrs.append((ul, ur, ll, lr)) + + # lateral compress sections if long + compressed = [] + for section, (ul, _, ll, _) in zip(sections, ul_ur_ll_lrs): + self._do_lateral_compress( + self, + kb, + section, + leave_short, + ul, + ll, + heps, + hmethod, + hmax_bond, + verbosity, + compressed, + **compress_opts, + ) + + # vertical compress and unfold system sections only + for section, (ul, ur, ll, lr) in zip(sections, ul_ur_ll_lrs): + self._do_vertical_decomp( + self, + kb, + section, + sysa, + sysb, + compressed, + ul, + ur, + ll, + lr, + vmethod, + vmax_bond, + veps, + verbosity, + **compress_opts, + ) + + if not self.cyclic: + # check if either system is at end, and thus reduces to identities + # + # A-A-A-A-A-A-A-m-m-m- \-m-m-m- + # | | | | | | | | | | ... ==> | | | ... + # A-A-A-A-A-A-A-m-m-m- /-m-m-m- + # + if 0 in sysa: + # get neighbouring tensor + if envm: + try: + TU = TD = kb["_ENVM", "_LEFT"] + except KeyError: + # didn't lateral compress + TU = kb["_ENVM", "_KET", self.site_tag(envm[0])] + TD = kb["_ENVM", "_BRA", self.site_tag(envm[0])] + else: + TU = kb["_SYSB", "_UP"] + TD = kb["_SYSB", "_DOWN"] + (ubnd,) = kb["_KET", self.site_tag(sysa[-1])].bonds(TU) + (lbnd,) = kb["_BRA", self.site_tag(sysa[-1])].bonds(TD) + + # delete the A system + kb.delete("_SYSA") + kb.reindex_({ubnd: "_tmp_ind_uA", lbnd: "_tmp_ind_lA"}) + else: + # or else replace the left or right envs with identites since + # + # >->->->-A-A-A-A- +-A-A-A-A- + # | | | | | | | | ... ==> | | | | | + # >->->->-A-A-A-A- +-A-A-A-A- + # + kb.replace_with_identity("_ENVL", inplace=True) + + if N - 1 in sysb: + # get neighbouring tensor + if envm: + try: + TU = TD = kb["_ENVM", "_RIGHT"] + except KeyError: + # didn't lateral compress + TU = kb["_ENVM", "_KET", self.site_tag(envm[-1])] + TD = kb["_ENVM", "_BRA", self.site_tag(envm[-1])] + else: + TU = kb["_SYSA", "_UP"] + TD = kb["_SYSA", "_DOWN"] + (ubnd,) = kb["_KET", self.site_tag(sysb[0])].bonds(TU) + (lbnd,) = kb["_BRA", self.site_tag(sysb[0])].bonds(TD) + + # delete the B system + kb.delete("_SYSB") + kb.reindex_({ubnd: "_tmp_ind_uB", lbnd: "_tmp_ind_lB"}) + else: + kb.replace_with_identity("_ENVR", inplace=True) + + kb.reindex_( + { + "_tmp_ind_uA": self.site_ind("A"), + "_tmp_ind_lA": lower_ind_id.format("A"), + "_tmp_ind_uB": self.site_ind("B"), + "_tmp_ind_lB": lower_ind_id.format("B"), + } + ) + + if renorm: + # normalize + norm = kb.trace(["kA", "kB"], ["bA", "bB"]) + + ts = [] + tags = kb.tags + + # check if we have system A + if "_SYSA" in tags: + ts.extend(kb[sysa[0]]) + + # check if we have system B + if "_SYSB" in tags: + ts.extend(kb[sysb[0]]) + + # If we dont' have either (OBC with both at ends) use middle envm + if len(ts) == 0: + ts.extend(kb[envm[0]]) + + nt = len(ts) + + if verbosity > 0: + print(f"Renormalizing for norm {norm} among {nt} tensors.") + + # now spread the norm out among tensors + for t in ts: + t.modify(data=t.data / norm ** (1 / nt)) + + return kb + + def logneg_subsys( + self, + sysa, + sysb, + compress_opts=None, + approx_spectral_opts=None, + verbosity=0, + approx_thresh=2**12, + ): + r"""Compute the logarithmic negativity between subsytem blocks, e.g.:: + + sysa sysb + ......... ..... + ... -o-o-o-o-o-o-A-A-A-A-A-o-o-o-B-B-B-o-o-o-o-o-o-o- ... + | | | | | | | | | | | | | | | | | | | | | | | | + + Parameters + ---------- + sysa : sequence of int + The sites, which should be contiguous, defining subsystem A. + sysb : sequence of int + The sites, which should be contiguous, defining subsystem B. + eps : float, optional + Tolerance to use when compressing the subsystem transfer matrices. + method : str or (str, str), optional + Method(s) to use for laterally compressing the state then + vertially compressing subsytems. + compress_opts : dict, optional + If given, supplied to ``partial_trace_compress`` to govern how + singular values are treated. See ``tensor_split``. + approx_spectral_opts + Supplied to :func:`~quimb.approx_spectral_function`. + + Returns + ------- + ln : float + The logarithmic negativity. + + See Also + -------- + MatrixProductState.partial_trace_compress, approx_spectral_function + """ + if not self.cyclic and (len(sysa) + len(sysb) == self.L): + # pure bipartition with OBC + psi = self.bipartite_schmidt_state(len(sysa), get="ket-dense") + d = round(psi.shape[0] ** 0.5) + return qu.logneg(psi, [d, d]) + + compress_opts = ensure_dict(compress_opts) + approx_spectral_opts = ensure_dict(approx_spectral_opts) + + # set the default verbosity for each method + compress_opts.setdefault("verbosity", verbosity) + approx_spectral_opts.setdefault("verbosity", verbosity) + + # form the compressed density matrix representation + rho_ab = self.partial_trace_compress(sysa, sysb, **compress_opts) + + # view it as an operator + rho_ab_pt_lo = rho_ab.aslinearoperator(["kA", "bB"], ["bA", "kB"]) + + if rho_ab_pt_lo.shape[0] <= approx_thresh: + tr_norm = norm_trace_dense(rho_ab_pt_lo.to_dense(), isherm=True) + else: + # estimate its spectrum and sum the abs(eigenvalues) + tr_norm = qu.approx_spectral_function( + rho_ab_pt_lo, abs, **approx_spectral_opts + ) + + # clip below 0 + return max(0, log2(tr_norm)) + + @convert_cur_orthog + def measure( + self, + site, + remove=False, + outcome=None, + renorm=True, + info=None, + get=None, + seed=None, + backend_random="numpy", + inplace=False, + ): + r"""Measure this MPS at ``site``, including projecting the state. + Optionally remove the site afterwards, yielding an MPS with one less + site. In either case the orthogonality center of the returned MPS is + ``min(site, new_L - 1)``. + + Parameters + ---------- + site : int + The site to measure. + remove : bool, optional + Whether to remove the site completely after projecting the + measurement. If ``True``, sites greater than ``site`` will be + retagged and reindex one down, and the MPS will have one less site. + E.g:: + + 0-1-2-3-4-5-6 + / / / - measure and remove site 3 + 0-1-2-4-5-6 + - reindex sites (4, 5, 6) to (3, 4, 5) + 0-1-2-3-4-5 + + outcome : None or int, optional + Specify the desired outcome of the measurement. If ``None``, it + will be randomly sampled according to the local density matrix. + renorm : bool, optional + Whether to renormalize the state post measurement. + info : dict, optional + If given, will be used to infer and store various extra + information. Currently the key "cur_orthog" is used to store the + current orthogonality center. + get : {None, 'outcome'}, optional + If ``'outcome'``, simply return the outcome, and don't perform any + projection. + seed : None, int, or np.random.Generator, optional + A random seed or generator to use. + backend_random : {'numpy', None, ...}, optional + The backend to use for random sampling. If ``None``, will be + inferred from the tensor data. By default numpy is used meaning the + probability distributions are always converted to numpy arrays, for + consistency. + inplace : bool, optional + Whether to perform the measurement in place or not. + + Returns + ------- + outcome : int + The measurement outcome, drawn from ``range(phys_dim)``. + psi : MatrixProductState + The measured state, if ``get != 'outcome'``. + """ + if self.cyclic: + raise ValueError("Not supported on cyclic MPS yet.") + + tn = self if inplace else self.copy() + L = tn.L + d = self.phys_dim(site) + + # make sure MPS is canonicalized + tn.canonicalize_(site, info=info) + + # local tensor and physical dim + t = tn[site] + ind = tn.site_ind(site) + + # array namespace + xp = t.get_namespace() + # random namespace + if backend_random is None: + rxp = xp + convert = False + elif backend_random == "numpy": + rxp = get_namespace("numpy") + convert = xp is not rxp + else: + rxp = get_namespace(backend_random) + convert = False + + # diagonal of reduced density matrix = probs + tii = t.contract(t.H, output_inds=(ind,)) + pi = xp.real(tii.data) + + if convert: + pi = xp.to_numpy(pi) + + pi = pi / rxp.sum(pi) + if outcome is None: + # sample an outcome + rng = rxp.random.default_rng(seed) + outcome = rng.choice(rxp.size(pi), p=pi) + + if backend_random == "numpy": + # XXX: unnecessary? numpy always returns int for scalar size? + outcome = int(outcome) + + if get == "outcome": + return outcome + + # project the outcome and renormalize + t.isel_({ind: outcome}) + + if renorm: + t.modify(data=t.data / pi[outcome] ** 0.5) + + if remove: + # contract the projected tensor into neighbor + if site == L - 1: + tn ^= slice(site - 1, site + 1) + else: + tn ^= slice(site, site + 2) + + # adjust structure for one less spin + for i in range(site + 1, L): + tn[i].reindex_({tn.site_ind(i): tn.site_ind(i - 1)}) + tn[i].retag_({tn.site_tag(i): tn.site_tag(i - 1)}) + tn._L = L - 1 + else: + # re-expand index, populating non-measured outcomes with zeros + zeros = xp.zeros_like(t.data) + arrays = [zeros] * d + arrays[outcome] = t.data + t.modify(data=xp.stack(arrays, axis=-1), inds=(*t.inds, ind)) + + return outcome, tn + + measure_ = functools.partialmethod(measure, inplace=True) + + def sample_configuration( + self, + seed=None, + backend_random="numpy", + info=None, + ): + """Sample a configuration from this MPS. + + Parameters + ---------- + seed : None, int, or np.random.Generator, optional + A random seed or generator to use. + backend_random : {'numpy', None, ...}, optional + The backend to use for random sampling. If ``None``, will be + inferred from the tensor data. By default numpy is used meaning the + probability distributions are always converted to numpy arrays, for + consistency. + info : dict, optional + If given, will be used to infer and store various extra + information. Currently the key "cur_orthog" is used to store the + current orthogonality center. + """ + # array namespace + xp = self.get_namespace() + # random namespace + if backend_random is None: + # use array backend + rxp = xp + convert = False + elif backend_random == "numpy": + # use numpy to sample regardless of array backend + rxp = get_namespace("numpy") + convert = xp is not rxp + else: + # manual backend + rxp = get_namespace(backend_random) + convert = False + + # if seed is already a generator this simply returns it + rng = rxp.random.default_rng(seed) + + # right canonicalize + psi = self.canonicalize(0, info=info) + + config = [] + omega = 1.0 + for i in range(psi.L): + # form local density matrix + ki = psi[i] + bi = ki.H + ix = psi.site_ind(i) + # contract diagonal to get probabilities + pi = (ki & bi).contract(output_inds=[ix]).data + pi = xp.real(pi) + + if convert: + pi = xp.to_numpy(pi) + + pi = pi / rxp.sum(pi) + xi = rng.choice(rxp.size(pi), p=pi) + config.append(xi) + # track local probability + omega = omega * pi[xi] + + # project outcome + psi.isel_({ix: xi}) + if i < psi.L - 1: + # and absorb projected site into next site + psi.contract_tags_([psi.site_tag(i), psi.site_tag(i + 1)]) + + return config, omega + + def sample(self, C, seed=None, backend_random="numpy", info=None): + """Generate ``C`` samples rom this MPS, along with their probabilities. + + Parameters + ---------- + C : int + The number of samples to generate. + seed : None, int, or np.random.Generator, optional + A random seed or generator to use. + backend_random : {'numpy', None, ...}, optional + The backend to use for random sampling. If ``None``, will be + inferred from the tensor data. By default numpy is used meaning the + probability distributions are always converted to numpy arrays, for + consistency. + info : dict, optional + If given, will be used to infer and store various extra + information. Currently the key "cur_orthog" is used to store the + current orthogonality center. + + Yields + ------ + config : sequence of int + The sample configuration. + omega : float + The probability of this configuration. + """ + + if info is None: + info = {} + + # do right canonicalization once (supplying info avoids re-performing) + psi0 = self.canonicalize(0, info=info) + + if backend_random is None: + # use array backend + rxp = psi0.get_namespace() + elif backend_random == "numpy": + # use numpy to sample regardless of array backend + rxp = get_namespace("numpy") + else: + # manual backend + rxp = get_namespace(backend_random) + + rng = rxp.random.default_rng(seed) + for _ in range(C): + yield psi0.sample_configuration( + seed=rng, + info=info, + backend_random=backend_random, + ) + + +class MatrixProductOperator(TensorNetwork1DOperator, TensorNetwork1DFlat): + """Initialise a matrix product operator, with auto labelling and tagging. + + Parameters + ---------- + arrays : sequence of arrays + The tensor arrays to form into a MPO. + sites : sequence of int, optional + Construct the MPO on these sites only. If not given, enumerate from + zero. Should be monotonically increasing and match ``arrays``. + L : int, optional + The number of sites the MPO should be defined on. If not given, this is + taken as the max ``sites`` value plus one (i.e.g the number of arrays + if ``sites`` is not given). + shape : str, optional + String specifying layout of *input* arrays. E.g. 'lrp' (the + default) indicates the shape corresponds left-bond, right-bond, + 'up' physical index, 'down' physical index. End tensors have either + 'l' or 'r' dropped from the string. The arrays will be permuted to + 'lrud' order. + tags : str or sequence of str, optional + Global tags to attach to all tensors. + upper_ind_id : str + A string specifiying how to label the upper physical site indices. + Should contain a ``'{}'`` placeholder. It is used to generate the + actual indices like: ``map(upper_ind_id.format, range(len(arrays)))``. + lower_ind_id : str + A string specifiying how to label the lower physical site indices. + Should contain a ``'{}'`` placeholder. It is used to generate the + actual indices like: ``map(lower_ind_id.format, range(len(arrays)))``. + site_tag_id : str + A string specifiying how to tag the tensors at each site. Should + contain a ``'{}'`` placeholder. It is used to generate the actual tags + like: ``map(site_tag_id.format, range(len(arrays)))``. + """ + + _EXTRA_PROPS = ( + "_site_tag_id", + "_upper_ind_id", + "_lower_ind_id", + "cyclic", + "_L", + ) + + def __init__( + self, + arrays, + *, + sites=None, + L=None, + shape="lrud", + tags=None, + upper_ind_id="k{}", + lower_ind_id="b{}", + site_tag_id="I{}", + **tn_opts, + ): + # short-circuit for copying + if isinstance(arrays, MatrixProductOperator): + super().__init__(arrays) + return + + arrays = tuple(arrays) + + if sites is None: + # assume dense + sites = range(len(arrays)) + if L is None: + L = len(arrays) + num_sites = L + else: + sites = tuple(sites) + if L is None: + L = max(sites) + 1 + num_sites = len(sites) + + self._L = L + self._upper_ind_id = upper_ind_id + self._lower_ind_id = lower_ind_id + self._site_tag_id = site_tag_id + self.cyclic = ops.ndim(arrays[0]) == 4 + + tensors = [] + tags = tags_to_oset(tags) + bonds = [rand_uuid() for _ in range(num_sites)] + # account for cyclic case + bonds.append(bonds[0]) + + for i, (site, array) in enumerate(zip(sites, arrays)): + inds = [] + + if L == 1: + # only one site + if self.cyclic: + # bond is a self loop on the single tensor + shape_desired = "lrud" + inds.append(bonds[i]) + inds.append(bonds[i]) + # XXX: should we just trace it out instead? + else: + # no bonds, just physical indices + shape_desired = "ud" + elif (i == 0) and not self.cyclic: + # only right bond + shape_desired = "rud" + inds.append(bonds[i + 1]) + elif (i == num_sites - 1) and not self.cyclic: + # only left bond + shape_desired = "lud" + inds.append(bonds[i]) + else: + shape_desired = "lrud" + # both bonds + inds.append(bonds[i]) + inds.append(bonds[i + 1]) + + # this is the perm needed to bring the arrays from + # their current `shape`, to the desired 'lrud' order + shape_given = [x for x in shape if x in shape_desired] + order = [shape_given.index(x) for x in shape_desired] + + # physical indices + inds.append(upper_ind_id.format(site)) + inds.append(lower_ind_id.format(site)) + + tensors.append( + Tensor( + data=transpose(array, order), + inds=inds, + tags=tags | oset([site_tag_id.format(site)]), + ) + ) + + super().__init__(tensors, virtual=True, **tn_opts) + + @classmethod + def from_fill_fn( + cls, + fill_fn, + L, + bond_dim, + phys_dim=2, + sites=None, + cyclic=False, + shape="lrud", + tags=None, + upper_ind_id="k{}", + lower_ind_id="b{}", + site_tag_id="I{}", + ): + """Create an MPO by supplying a 'filling' function to generate the data + for each site. + + Parameters + ---------- + fill_fn : callable + A function with signature + ``fill_fn(shape : tuple[int]) -> array_like``. + L : int + The number of sites. + bond_dim : int + The bond dimension. + phys_dim : int or Sequence[int], optional + The physical dimension(s) of each site, if a sequence it will be + cycled over. + sites : None or sequence of int, optional + Construct the MPO on these sites only. If not given, enumerate from + zero. + cyclic : bool, optional + Whether the MPO should be cyclic (periodic). + shape : str, optional + String specifying layout of *input* arrays. E.g. 'lrp' (the + default) indicates the shape corresponds left-bond, right-bond, + 'up' physical index, 'down' physical index. End tensors have either + 'l' or 'r' dropped from the string. The arrays will be permuted to + 'lrud' order. + tags : str or sequence of str, optional + Global tags to attach to all tensors. + upper_ind_id : str + A string specifiying how to label the upper physical site indices. + Should contain a ``'{}'`` placeholder. + lower_ind_id : str + A string specifiying how to label the lower physical site indices. + Should contain a ``'{}'`` placeholder. + site_tag_id : str, optional + How to tag the physical sites. Should contain a ``'{}'`` + placeholder. + + Returns + ------- + MatrixProductState + """ + if set(shape) - {"l", "r", "u", "d"}: + raise ValueError(f"Invalid shape string: {shape}.") + + # check for site varying physical dimensions + if isinstance(phys_dim, Integral): + phys_dims = itertools.repeat(phys_dim) + else: + phys_dims = itertools.cycle(phys_dim) + + mpo = cls.new( + L=L, + cyclic=cyclic, + site_tag_id=site_tag_id, + upper_ind_id=upper_ind_id, + lower_ind_id=lower_ind_id, + ) + + # which sites are actually present + if sites is None: + sites = range(L) + else: + sites = tuple(sites) + num_sites = len(sites) + + global_tags = tags_to_oset(tags) + bonds = [rand_uuid() for _ in range(num_sites)] + bonds.append(bonds[0]) + + for i, site in enumerate(sites): + p = next(phys_dims) + inds = [] + data_shape = [] + for c in shape: + if c == "l": + if (i - 1) >= 0 or cyclic: + inds.append(bonds[i]) + data_shape.append(bond_dim) + elif c == "r": + if (i + 1) < L or cyclic: + inds.append(bonds[i + 1]) + data_shape.append(bond_dim) + elif c == "u": + inds.append(upper_ind_id.format(site)) + data_shape.append(p) + else: # c == "d" + inds.append(lower_ind_id.format(site)) + data_shape.append(p) + data = fill_fn(data_shape) + tags = global_tags | oset((site_tag_id.format(site),)) + mpo |= Tensor(data, inds=inds, tags=tags) + + return mpo + + @classmethod + def from_dense( + cls, + A, + dims=2, + sites=None, + L=None, + tags=None, + site_tag_id="I{}", + upper_ind_id="k{}", + lower_ind_id="b{}", + **split_opts, + ): + """Build an MPO from a raw dense matrix. + + Parameters + ---------- + A : array + The dense operator, it should be reshapeable to ``(*dims, *dims)``. + dims : int, sequence of int, optional + The physical subdimensions of the operator. If any integer, assume + all sites have the same dimension. If a sequence, the dimension of + each site. Default is 2. + sites : sequence of int, optional + The sites to place the operator on. If None, will place it on + first `len(dims)` sites. + L : int, optional + The total number of sites in the MPO, if the operator represents + only a subset. + tags : str or sequence of str, optional + Global tags to attach to all tensors. + site_tag_id : str, optional + The string to use to label the site tags. + upper_ind_id : str, optional + The string to use to label the upper physical indices. + lower_ind_id : str, optional + The string to use to label the lower physical indices. + split_opts + Supplied to :func:`~quimb.tensor.tensor_core.tensor_split`. + + Returns + ------- + MatrixProductOperator + """ + set_default_compress_mode(split_opts) + # ensure compression is canonical / optimal + split_opts.setdefault("absorb", "right") + + # make sure array_like + A = ops.asarray(A) + + if isinstance(dims, Integral): + # assume all sites have the same dimension + ng = round(log(size(A), dims) / 2) + dims = (dims,) * ng + else: + dims = tuple(dims) + ng = len(dims) + + if sites is None: + sorted_sites = sites = range(ng) + else: + sorted_sites = sorted(sites) + + if L is None: + L = max(sites) + 1 + + # create a bare MPO TN object + mpo = cls.new( + L=L, + cyclic=False, + upper_ind_id=upper_ind_id, + lower_ind_id=lower_ind_id, + site_tag_id=site_tag_id, + ) + + # initial inds and tensor contains desired site order ... + uix = [mpo.upper_ind(i) for i in sites] + lix = [mpo.lower_ind(i) for i in sites] + tm = Tensor(data=reshape(A, (*dims, *dims)), inds=uix + lix) + + # ... but want to create MPO in sorted site order + uix = [mpo.upper_ind(i) for i in sorted_sites] + lix = [mpo.lower_ind(i) for i in sorted_sites] + + for i, site in enumerate(sorted_sites[:-1]): + # progressively split off one more pair of physical indices + tl, tm = tm.split( + left_inds=None, + right_inds=uix[i + 1 :] + lix[i + 1 :], + ltags=mpo.site_tag(site), + get="tensors", + **split_opts, + ) + # add left tensor + mpo |= tl + + # add final right tensor + tm.add_tag(mpo.site_tag(sorted_sites[-1])) + mpo |= tm + + # add global tags + if tags is not None: + mpo.add_tag(tags) + + return mpo + + def fill_empty_sites( + self, mode="full", phys_dim=None, fill_array=None, inplace=False + ): + """Fill any empty sites of this MPO with identity tensors, adding + size 1 bonds or draping existing bonds where necessary such that the + resulting tensor has nearest neighbor bonds only. + + Parameters + ---------- + mode : {'full', 'minimal'}, optional + Whether to fill in all sites, including at either end, or simply + the minimal range covering the min to max current sites present. + phys_dim : int, optional + The physical dimension of the identity tensors to add. If not + specified, will use the upper physical dimension of the first + present site. + fill_array : array, optional + The array to use for the identity tensors. If not specified, will + use the identity array of the same dtype as the first present site. + inplace : bool, optional + Whether to perform the operation inplace. + + Returns + ------- + MatrixProductOperator + The modified MPO. + """ + mpo = self if inplace else self.copy() + + sites_present = tuple(mpo.gen_sites_present()) + sites_present_set = set(sites_present) + sitei = sites_present[0] + sitef = sites_present[-1] + + if fill_array is None: + t0 = mpo[sitei] + if phys_dim is None: + d = mpo.phys_dim(sitei) + fill_array = do("eye", d, dtype=t0.dtype, like=t0.data) + + if mode == "full": + sites_to_add = [ + site for site in range(mpo.L) if site not in sites_present_set + ] + elif mode == "minimal": + sites_to_add = [ + site + for site in range(sitei, sitef + 1) + if site not in sites_present_set + ] + else: + sites_to_add = list(mode) + sites_to_add_set = set(sites_to_add) + + new_sites = list(sites_present) + new_sites.extend(sites_to_add) + new_sites.sort() + + # add desired identites + for site in sites_to_add: + mpo |= Tensor( + data=fill_array, + inds=(mpo.upper_ind(site), mpo.lower_ind(site)), + tags=mpo.site_tag(site), + ) + + # connect up between existing tensors + for si, sj in pairwise(sites_present): + if bonds(mpo[si], mpo[sj]): + # existing bond -> drape it thru + sl = si + for k in range(si + 1, sj): + if k in sites_to_add_set: + mpo.drape_bond_between_(sl, sj, k) + sl = k + + else: + # no bond -> just add bond dim 1 + sl = si + for k in range(si, sj - 1): + if k in sites_to_add_set: + new_bond(mpo[sl], mpo[k]) + sl = k + new_bond(mpo[sl], mpo[sj]) + + # connect up on either side of existing patch + for si, sj in pairwise(new_sites): + if (sj <= sitei) or (si >= sitef): + new_bond(mpo[si], mpo[sj]) + + return mpo + + fill_empty_sites_ = functools.partialmethod(fill_empty_sites, inplace=True) + + def add_MPO(self, other, inplace=False, **kwargs): + return tensor_network_ag_sum(self, other, inplace=inplace, **kwargs) + + add_MPO_ = functools.partialmethod(add_MPO, inplace=True) + + def _apply_mps( + self, other, compress=False, contract=True, **compress_opts + ): + return tensor_network_apply_op_vec( + A=self, + x=other, + compress=compress, + contract=contract, + **compress_opts, + ) + + def _apply_mpo( + self, other, compress=False, contract=True, **compress_opts + ): + return tensor_network_apply_op_op( + A=self, + B=other, + contract=contract, + compress=compress, + **compress_opts, + ) + + def apply(self, other, compress=False, **compress_opts): + r"""Act with this MPO on another MPO or MPS, such that the resulting + object has the same tensor network structure/indices as ``other``. + + For an MPS:: + + | | | | | | | | | | | | | | | | | | + self: A-A-A-A-A-A-A-A-A-A-A-A-A-A-A-A-A-A + | | | | | | | | | | | | | | | | | | + other: x-x-x-x-x-x-x-x-x-x-x-x-x-x-x-x-x-x + + --> + + | | | | | | | | | | | | | | | | | | <- other.site_ind_id + out: y=y=y=y=y=y=y=y=y=y=y=y=y=y=y=y=y=y + + For an MPO:: + + | | | | | | | | | | | | | | | | | | + self: A-A-A-A-A-A-A-A-A-A-A-A-A-A-A-A-A-A + | | | | | | | | | | | | | | | | | | + other: B-B-B-B-B-B-B-B-B-B-B-B-B-B-B-B-B-B + | | | | | | | | | | | | | | | | | | + + --> + + | | | | | | | | | | | | | | | | | | <- other.upper_ind_id + out: C=C=C=C=C=C=C=C=C=C=C=C=C=C=C=C=C=C + | | | | | | | | | | | | | | | | | | <- other.lower_ind_id + + The resulting TN will have the same structure/indices as ``other``, but + probably with larger bonds (depending on compression). + + + Parameters + ---------- + other : MatrixProductOperator or MatrixProductState + The object to act on. + compress : bool, optional + Whether to compress the resulting object. + compress_opts + Supplied to :meth:`TensorNetwork1DFlat.compress`. + + Returns + ------- + MatrixProductOperator or MatrixProductState + """ + if isinstance(other, MatrixProductState): + return self._apply_mps(other, compress=compress, **compress_opts) + elif isinstance(other, MatrixProductOperator): + return self._apply_mpo(other, compress=compress, **compress_opts) + else: + raise TypeError( + "Can only Dot with a MatrixProductOperator or a " + f"MatrixProductState, got {type(other)}" + ) + + dot = apply + + def permute_arrays(self, shape="lrud"): + """Permute the indices of each tensor in this MPO to match ``shape``. + This doesn't change how the overall object interacts with other tensor + networks but may be useful for extracting the underlying arrays + consistently. This is an inplace operation. + + Parameters + ---------- + shape : str, optional + A permutation of ``'lrud'`` specifying the *desired* order of the + left, right, upper and lower (down) indices respectively. + """ + self.ensure_bonds_exist() + + for i in self.gen_sites_present(): + inds = {"u": self.upper_ind(i), "d": self.lower_ind(i)} + if self.cyclic or i > 0: + inds["l"] = self.bond(i, (i - 1) % self.L) + if self.cyclic or i < self.L - 1: + inds["r"] = self.bond(i, (i + 1) % self.L) + inds = [inds[s] for s in shape if s in inds] + self[i].transpose_(*inds) + + def trace(self, left_inds=None, right_inds=None): + """Take the trace of this MPO.""" + if left_inds is None: + left_inds = map(self.upper_ind, self.gen_sites_present()) + if right_inds is None: + right_inds = map(self.lower_ind, self.gen_sites_present()) + + return super().trace(left_inds, right_inds) + + def partial_transpose(self, sysa, inplace=False): + """Perform the partial transpose on this MPO by swapping the bra and + ket indices on sites in ``sysa``. + + Parameters + ---------- + sysa : sequence of int or int + The sites to transpose indices on. + inplace : bool, optional + Whether to perform the partial transposition inplace. + + Returns + ------- + MatrixProductOperator + """ + tn = self if inplace else self.copy() + + if isinstance(sysa, Integral): + sysa = (sysa,) + + tmp_ind_id = "__tmp_{}__" + + tn.reindex_({tn.upper_ind(i): tmp_ind_id.format(i) for i in sysa}) + tn.reindex_({tn.lower_ind(i): tn.upper_ind(i) for i in sysa}) + tn.reindex_({tmp_ind_id.format(i): tn.lower_ind(i) for i in sysa}) + return tn + + def rand_state(self, bond_dim, **mps_opts): + """Get a random vector matching this MPO.""" + return qu.tensor.MPS_rand_state( + self.L, + bond_dim=bond_dim, + phys_dim=[self.phys_dim(i) for i in self.gen_sites_present()], + dtype=self.dtype, + cyclic=self.cyclic, + **mps_opts, + ) + + def identity(self, **mpo_opts): + """Get a identity matching this MPO.""" + return qu.tensor.MPO_identity_like(self, **mpo_opts) + + def show(self, max_width=None): + l1 = "" + l2 = "" + l3 = "" + num_can_l, num_can_r = self.count_canonized() + for i in range(self.L - 1): + bdim = self.bond_size(i, i + 1) + strl = len(str(bdim)) + l1 += f"│{bdim}" + l2 += ( + ">" + if i < num_can_l + else "<" + if i >= self.L - num_can_r + else "●" + ) + ("─" if bdim < 100 else "━") * strl + l3 += "│" + " " * strl + + l1 += "│" + l2 += "<" if num_can_r > 0 else "●" + l3 += "│" + + if self.cyclic: + bdim = self.bond_size(0, self.L - 1) + bnd_str = ("─" if bdim < 100 else "━") * strl + l1 = f" {bdim}{l1}{bdim} " + l2 = f"+{bnd_str}{l2}{bnd_str}+" + l3 = f" {' ' * strl}{l3}{' ' * strl} " + + print_multi_line(l1, l2, l3, max_width=max_width) + + +class Dense1D(TensorNetwork1DVector): + """Mimics other 1D tensor network structures, but really just keeps the + full state in a single tensor. This allows e.g. applying gates in the same + way for quantum circuit simulation as lazily represented hilbert spaces. + + Parameters + ---------- + array : array_like + The full hilbert space vector - assumed to be made of equal hilbert + spaces each of size ``phys_dim`` and will be reshaped as such. + phys_dim : int, optional + The hilbert space size of each site, default: 2. + tags : sequence of str, optional + Extra tags to add to the tensor network. + site_ind_id : str, optional + String formatter describing how to label the site indices. + site_tag_id : str, optional + String formatter describing how to label the site tags. + tn_opts + Supplied to :class:`~quimb.tensor.tensor_core.TensorNetwork`. + """ + + _EXTRA_PROPS = ( + "_site_ind_id", + "_site_tag_id", + "_L", + ) + + def __init__( + self, + array, + phys_dim=2, + tags=None, + site_ind_id="k{}", + site_tag_id="I{}", + **tn_opts, + ): + # copy short-circuit + if isinstance(array, Dense1D): + super().__init__(array) + return + + # work out number of sites and sub-dimensions etc. + self._L = qu.infer_size(array, base=phys_dim) + dims = [phys_dim] * self.L + data = ops.asarray(array).reshape(*dims) + + # process site indices + self._site_ind_id = site_ind_id + site_inds = [self.site_ind(i) for i in range(self.L)] + + # process site tags + self._site_tag_id = site_tag_id + site_tags = oset(self.site_tag(i) for i in range(self.L)) + + if tags is not None: + # mix in global tags + site_tags = tags_to_oset(tags) | site_tags + + T = Tensor(data=data, inds=site_inds, tags=site_tags) + + super().__init__([T], virtual=True, **tn_opts) + + @classmethod + def rand(cls, n, phys_dim=2, dtype=float, **dense1d_opts): + """Create a random dense vector 'tensor network'.""" + array = qu.randn(phys_dim**n, dtype=dtype) + array /= qu.norm(array, "fro") + return cls(array, **dense1d_opts) + + +class SuperOperator1D(TensorNetwork1D): + r"""A 1D tensor network super-operator class:: + + 0 1 2 n-1 + | | | | <-- outer_upper_ind_id + O===O===O== =O + |\ |\ |\ |\ <-- inner_upper_ind_id + ) ) ) ... ) <-- K (size of local Kraus sum) + |/ |/ |/ |/ <-- inner_lower_ind_id + O===O===O== =O + | | : | | <-- outer_lower_ind_id + : + chi (size of entangling bond dim) + + Parameters + ---------- + arrays : sequence of arrays + The data arrays defining the superoperator, this should be a sequence + of 2n arrays, such that the first two correspond to the upper and lower + operators acting on site 0 etc. The arrays should be 5 dimensional + unless OBC conditions are desired, in which case the first two and last + two should be 4-dimensional. The dimensions of array can be should + match the ``shape`` option. + + """ + + _EXTRA_PROPS = ( + "_site_tag_id", + "_outer_upper_ind_id", + "_inner_upper_ind_id", + "_inner_lower_ind_id", + "_outer_lower_ind_id", + "cyclic", + "_L", + ) + + def __init__( + self, + arrays, + shape="lrkud", + site_tag_id="I{}", + outer_upper_ind_id="kn{}", + inner_upper_ind_id="k{}", + inner_lower_ind_id="b{}", + outer_lower_ind_id="bn{}", + tags=None, + tags_upper=None, + tags_lower=None, + **tn_opts, + ): + # short-circuit for copying + if isinstance(arrays, SuperOperator1D): + super().__init__(arrays) + return + + arrays = tuple(arrays) + self._L = len(arrays) // 2 + + # process indices + self._outer_upper_ind_id = outer_upper_ind_id + self._inner_upper_ind_id = inner_upper_ind_id + self._inner_lower_ind_id = inner_lower_ind_id + self._outer_lower_ind_id = outer_lower_ind_id + + sites_present = tuple(self.gen_sites_present()) + outer_upper_inds = map(outer_upper_ind_id.format, sites_present) + inner_upper_inds = map(inner_upper_ind_id.format, sites_present) + inner_lower_inds = map(inner_lower_ind_id.format, sites_present) + outer_lower_inds = map(outer_lower_ind_id.format, sites_present) + + # process tags + self._site_tag_id = site_tag_id + tags = tags_to_oset(tags) + tags_upper = tags_to_oset(tags_upper) + tags_lower = tags_to_oset(tags_lower) + + def gen_tags(): + for site_tag in self.site_tags: + yield (site_tag,) + tags + tags_upper + yield (site_tag,) + tags + tags_lower + + self.cyclic = ops.ndim(arrays[0]) == 5 + + # transpose arrays to 'lrkud' order + # u + # | + # l--O--r + # |\ + # d k + def gen_orders(): + lkud_ord = tuple(shape.replace("r", "").find(x) for x in "lkud") + rkud_ord = tuple(shape.replace("l", "").find(x) for x in "rkud") + lrkud_ord = tuple(map(shape.find, "lrkud")) + yield rkud_ord if not self.cyclic else lrkud_ord + yield rkud_ord if not self.cyclic else lrkud_ord + for _ in range(self.L - 2): + yield lrkud_ord + yield lrkud_ord + yield lkud_ord if not self.cyclic else lrkud_ord + yield lkud_ord if not self.cyclic else lrkud_ord + + def gen_inds(): + # |<- outer_upper_ind + # cycU_ix or pU_ix --O-- nU_ix + # /|<- inner_upper_ind + # k_ix ->( + # \|<- inner_lower_ind + # cycL_ix or pL_ix --O-- nL_ix + # |<- outer_lower_ind + if self.cyclic: + cycU_ix, cycL_ix = (rand_uuid(),), (rand_uuid(),) + else: + cycU_ix, cycL_ix = (), () + nU_ix, nL_ix, k_ix = rand_uuid(), rand_uuid(), rand_uuid() + yield ( + *cycU_ix, + nU_ix, + k_ix, + next(outer_upper_inds), + next(inner_upper_inds), + ) + yield ( + *cycL_ix, + nL_ix, + k_ix, + next(outer_lower_inds), + next(inner_lower_inds), + ) + pU_ix, pL_ix = nU_ix, nL_ix + for _ in range(self.L - 2): + nU_ix, nL_ix, k_ix = rand_uuid(), rand_uuid(), rand_uuid() + yield ( + pU_ix, + nU_ix, + k_ix, + next(outer_upper_inds), + next(inner_upper_inds), + ) + yield ( + pL_ix, + nL_ix, + k_ix, + next(outer_lower_inds), + next(inner_lower_inds), + ) + pU_ix, pL_ix = nU_ix, nL_ix + k_ix = rand_uuid() + yield ( + pU_ix, + *cycU_ix, + k_ix, + next(outer_upper_inds), + next(inner_upper_inds), + ) + yield ( + pL_ix, + *cycL_ix, + k_ix, + next(outer_lower_inds), + next(inner_lower_inds), + ) + + def gen_tensors(): + for array, tags, inds, order in zip( + arrays, gen_tags(), gen_inds(), gen_orders() + ): + yield Tensor(transpose(array, order), inds=inds, tags=tags) + + super().__init__(gen_tensors(), virtual=True, **tn_opts) + + @classmethod + def rand( + cls, + n, + K, + chi, + phys_dim=2, + herm=True, + cyclic=False, + dtype=complex, + **superop_opts, + ): + def gen_arrays(): + for i in range(n): + shape = [] + if cyclic or (i != 0): + shape += [chi] + if cyclic or (i != n - 1): + shape += [chi] + shape += [K, phys_dim, phys_dim] + data = qu.randn(shape=shape, dtype=dtype) + yield data + if herm: + yield data.conj() + else: + yield qu.randn(shape=shape, dtype=dtype) + + arrays = map(ops.sensibly_scale, gen_arrays()) + + return cls(arrays, **superop_opts) + + @property + def outer_upper_ind_id(self): + return self._outer_upper_ind_id + + @property + def inner_upper_ind_id(self): + return self._inner_upper_ind_id + + @property + def inner_lower_ind_id(self): + return self._inner_lower_ind_id + + @property + def outer_lower_ind_id(self): + return self._outer_lower_ind_id + + +class TNLinearOperator1D(spla.LinearOperator): + r"""A 1D tensor network linear operator like:: + + start stop - 1 + . . + :-O-O-O-O-O-O-O-O-O-O-O-O-: --+ + : | | | | | | | | | | | | : | + :-H-H-H-H-H-H-H-H-H-H-H-H-: acting on --V + : | | | | | | | | | | | | : | + :-O-O-O-O-O-O-O-O-O-O-O-O-: --+ + left_inds^ ^right_inds + + Like :class:`~quimb.tensor.tensor_core.TNLinearOperator`, but performs a + structured contract from one end to the other than can handle very long + chains possibly more efficiently by contracting in blocks from one end. + + + Parameters + ---------- + tn : TensorNetwork + The tensor network to turn into a ``LinearOperator``. + left_inds : sequence of str + The left indicies. + right_inds : sequence of str + The right indicies. + start : int + Index of starting site. + stop : int + Index of stopping site (does not include this site). + ldims : tuple of int, optional + If known, the dimensions corresponding to ``left_inds``. + rdims : tuple of int, optional + If known, the dimensions corresponding to ``right_inds``. + + See Also + -------- + TNLinearOperator + """ + + def __init__( + self, + tn, + left_inds, + right_inds, + start, + stop, + ldims=None, + rdims=None, + is_conj=False, + is_trans=False, + ): + self.tn = tn + self.start, self.stop = start, stop + + if ldims is None or rdims is None: + ind_sizes = tn.ind_sizes() + ldims = tuple(ind_sizes[i] for i in left_inds) + rdims = tuple(ind_sizes[i] for i in right_inds) + + self.left_inds, self.right_inds = left_inds, right_inds + self.ldims, ld = ldims, qu.prod(ldims) + self.rdims, rd = rdims, qu.prod(rdims) + self.tags = self.tn.tags + + # conjugate inputs/ouputs rather all tensors if necessary + self.is_conj = is_conj + self.is_trans = is_trans + self._conj_linop = None + self._adjoint_linop = None + self._transpose_linop = None + + super().__init__(dtype=self.tn.dtype, shape=(ld, rd)) + + def _matvec(self, vec): + in_data = reshape(vec, self.rdims) + + if self.is_conj: + in_data = conj(in_data) + + if self.is_trans: + i, f, s = self.start, self.stop, 1 + else: + i, f, s = self.stop - 1, self.start - 1, -1 + + # add the vector to the right of the chain + tnc = self.tn | Tensor(in_data, self.right_inds, tags=["_VEC"]) + tnc.view_like_(self.tn) + # tnc = self.tn.copy() + # tnc |= Tensor(in_data, self.right_inds, tags=['_VEC']) + + # absorb it into the rightmost site + tnc ^= ["_VEC", self.tn.site_tag(i)] + + # then do a structured contract along the whole chain + out_T = tnc ^ slice(i, f, s) + + out_data = out_T.transpose_(*self.left_inds).data.ravel() + if self.is_conj: + out_data = conj(out_data) + + return out_data + + def _matmat(self, mat): + d = mat.shape[-1] + in_data = reshape(mat, (*self.rdims, d)) + + if self.is_conj: + in_data = conj(in_data) + + if self.is_trans: + i, f, s = self.start, self.stop, 1 + else: + i, f, s = self.stop - 1, self.start - 1, -1 + + # add the vector to the right of the chain + in_ix = (*self.right_inds, "_mat_ix") + + tnc = self.tn | Tensor(in_data, inds=in_ix, tags=["_VEC"]) + tnc.view_like_(self.tn) + # tnc = self.tn.copy() + # tnc |= Tensor(in_data, inds=in_ix, tags=['_VEC']) + + # absorb it into the rightmost site + tnc ^= ["_VEC", self.tn.site_tag(i)] + + # then do a structured contract along the whole chain + out_T = tnc ^ slice(i, f, s) + + out_ix = (*self.left_inds, "_mat_ix") + out_data = reshape(out_T.transpose_(*out_ix).data, (-1, d)) + if self.is_conj: + out_data = conj(out_data) + + return out_data + + def copy(self, conj=False, transpose=False): + if transpose: + inds = (self.right_inds, self.left_inds) + dims = (self.rdims, self.ldims) + is_trans = not self.is_trans + else: + inds = (self.left_inds, self.right_inds) + dims = (self.ldims, self.rdims) + is_trans = self.is_trans + + if conj: + is_conj = not self.is_conj + else: + is_conj = self.is_conj + + return TNLinearOperator1D( + self.tn, + *inds, + self.start, + self.stop, + *dims, + is_conj=is_conj, + is_trans=is_trans, + ) + + def conj(self): + if self._conj_linop is None: + self._conj_linop = self.copy(conj=True) + return self._conj_linop + + def _transpose(self): + if self._transpose_linop is None: + self._transpose_linop = self.copy(transpose=True) + return self._transpose_linop + + def _adjoint(self): + """Hermitian conjugate of this TNLO.""" + # cache the adjoint + if self._adjoint_linop is None: + self._adjoint_linop = self.copy(conj=True, transpose=True) + return self._adjoint_linop + + def to_dense(self): + T = self.tn ^ slice(self.start, self.stop) + + if self.is_conj: + T = T.conj() + + return T.to_dense(self.left_inds, self.right_inds) + + def toarray(self): + return self.to_dense() + + @property + def A(self): + return self.to_dense() diff --git a/src/qibotn/__init__.py b/src/qibotn/__init__.py index 4942a45..fb2c1f7 100644 --- a/src/qibotn/__init__.py +++ b/src/qibotn/__init__.py @@ -1,5 +1,29 @@ import importlib.metadata as im -from qibotn.backends import MetaBackend - __version__ = im.version(__package__) + +_LAZY_EXPORTS = { + "MetaBackend": ("qibotn.backends", "MetaBackend"), + "cpu_backend": ("qibotn.expectation_runner", "cpu_backend"), + "cpu_expectation": ("qibotn.expectation_runner", "cpu_expectation"), + "mps_expectation": ("qibotn.expectation_runner", "mps_expectation"), + "cpu_runcard": ("qibotn.expectation_runner", "cpu_runcard"), + "pauli_pattern": ("qibotn.observables", "pauli_pattern"), + "pauli_sum": ("qibotn.observables", "pauli_sum"), +} + + +def __getattr__(name): + try: + module_name, object_name = _LAZY_EXPORTS[name] + except KeyError: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") from None + + from importlib import import_module + + value = getattr(import_module(module_name), object_name) + globals()[name] = value + return value + + +__all__ = sorted([*_LAZY_EXPORTS, "__version__"]) diff --git a/src/qibotn/backends/__init__.py b/src/qibotn/backends/__init__.py index 954f793..416fce8 100644 --- a/src/qibotn/backends/__init__.py +++ b/src/qibotn/backends/__init__.py @@ -1,10 +1,6 @@ -from typing import Union - from qibo.config import raise_error from qibotn.backends.abstract import QibotnBackend -from qibotn.backends.cpu import CpuTensorNet -from qibotn.backends.cutensornet import CuTensorNet # pylint: disable=E0401 PLATFORMS = ("cutensornet", "cpu", "quimb", "qmatchatea", "vidal") @@ -24,8 +20,12 @@ class MetaBackend: """ if platform == "cutensornet": # pragma: no cover + from qibotn.backends.cutensornet import CuTensorNet + return CuTensorNet(runcard) elif platform == "cpu": + from qibotn.backends.cpu import CpuTensorNet + return CpuTensorNet(runcard) elif platform == "quimb": # pragma: no cover import qibotn.backends.quimb as qmb @@ -55,8 +55,8 @@ class MetaBackend: for platform in PLATFORMS: try: MetaBackend.load(platform=platform) - available = True - except: - available = False - available_backends[platform] = available + except (ImportError, NotImplementedError, TypeError, ValueError): + available_backends[platform] = False + else: + available_backends[platform] = True return available_backends diff --git a/src/qibotn/backends/cpu.py b/src/qibotn/backends/cpu.py index 83770a6..91b0528 100644 --- a/src/qibotn/backends/cpu.py +++ b/src/qibotn/backends/cpu.py @@ -15,14 +15,9 @@ from qibo.config import raise_error from qibotn.backends.abstract import QibotnBackend from qibotn.backends.vidal import ( _observable_mpo_tensors, - _operator_terms_to_mpo, - _symbolic_hamiltonian_to_operator_terms, _unsupported_reason, ) -from qibotn.backends.vidal_mpi_segment import SegmentVidalMPIExecutor -from qibotn.backends.vidal_tebd import VidalTEBDExecutor from qibotn.observables import check_observable -from qibotn.result import TensorNetworkResult def _as_bool_or_dict(value, name): @@ -282,79 +277,35 @@ class CpuTensorNet(QibotnBackend, NumpyBackend): ): if compile_circuit is None: compile_circuit = self.compile_circuit - if preprocess: - if self.MPI_enabled: - from mpi4py import MPI - - self.rank = MPI.COMM_WORLD.Get_rank() - - from qibotn.backends.vidal import VidalBackend - - backend = VidalBackend() - backend.configure_tn_simulation( - max_bond_dimension=self.max_bond_dimension, - cut_ratio=self.cut_ratio, - tensor_module=self.tensor_module, - compile_circuit=compile_circuit, - mpi_approach="CT" if self.MPI_enabled else "SR", - mpi_term_batch_size=self.mpi_term_batch_size, - fallback=False, - ) - value = backend.expectation( - circuit, - observable, - preprocess=True, - compile_circuit=compile_circuit, - ) - self.rank = getattr(backend, "rank", self.rank) - self.last_truncation_error = getattr( - backend, "last_truncation_error", np.nan - ) - self.last_max_truncation_error = getattr( - backend, "last_max_truncation_error", np.nan - ) - return value - - mpo_tensors = _observable_mpo_tensors(observable, circuit.nqubits) if self.MPI_enabled: from mpi4py import MPI - comm = MPI.COMM_WORLD - self.rank = comm.Get_rank() - executor = SegmentVidalMPIExecutor( - nqubits=circuit.nqubits, - max_bond=self.max_bond_dimension, - cut_ratio=self.cut_ratio, - tensor_module=self.tensor_module, - comm=comm, - ) - executor.run_circuit(circuit) - self.last_truncation_error = float(executor.global_truncation_error()) - self.last_max_truncation_error = float( - executor.global_max_truncation_error() - ) - if mpo_tensors is not None: - value = executor.expectation_mpo_root(mpo_tensors) - else: - terms = _symbolic_hamiltonian_to_operator_terms(observable) - value = executor.expectation_mpo_root( - _operator_terms_to_mpo(terms, circuit.nqubits) - ) - return np.nan if self.rank != 0 else value + self.rank = MPI.COMM_WORLD.Get_rank() - executor = VidalTEBDExecutor( - nqubits=circuit.nqubits, - max_bond=self.max_bond_dimension, + from qibotn.backends.vidal import VidalBackend + + backend = VidalBackend() + backend.configure_tn_simulation( + max_bond_dimension=self.max_bond_dimension, cut_ratio=self.cut_ratio, tensor_module=self.tensor_module, + compile_circuit=compile_circuit, + mpi_approach="CT" if self.MPI_enabled else "SR", + mpi_term_batch_size=self.mpi_term_batch_size, + fallback=False, ) - executor.run_circuit(circuit) - self.last_truncation_error = float(executor.truncation_error) - self.last_max_truncation_error = float(executor.max_truncation_error) - if mpo_tensors is not None: - return executor.expectation_mpo(mpo_tensors) - terms = _symbolic_hamiltonian_to_operator_terms(observable) - return executor.expectation_mpo(_operator_terms_to_mpo(terms, circuit.nqubits)) + value = backend.expectation( + circuit, + observable, + preprocess=preprocess, + compile_circuit=compile_circuit, + ) + self.rank = getattr(backend, "rank", self.rank) + self.last_truncation_error = getattr(backend, "last_truncation_error", np.nan) + self.last_max_truncation_error = getattr( + backend, "last_max_truncation_error", np.nan + ) + return value def _quimb_backend(self): import qibotn.backends.quimb as qmb diff --git a/src/qibotn/expectation_runner.py b/src/qibotn/expectation_runner.py index 6af6de8..9592974 100644 --- a/src/qibotn/expectation_runner.py +++ b/src/qibotn/expectation_runner.py @@ -12,6 +12,50 @@ from qibotn.benchmark_cases import exact_pauli_sum from qibotn.observables import check_observable +def cpu_runcard( + observable=None, + *, + ansatz: str = "tn", + mpi: bool = False, + bond: int | None = 1024, + cut_ratio: float | None = 1e-12, + tensor_module: str = "torch", + quimb_backend: str = "torch", + dtype: str = "complex128", + torch_threads: int | None = 8, + parallel_opts: dict | None = None, + compile_circuit: bool = False, + preprocess: bool = False, +): + """Build the small CPU backend runcard used throughout qibotn.""" + return { + "MPI_enabled": mpi, + "MPS_enabled": ansatz.lower() == "mps", + "NCCL_enabled": False, + "expectation_enabled": observable if observable is not None else False, + "max_bond_dimension": bond, + "cut_ratio": cut_ratio, + "tensor_module": tensor_module, + "quimb_backend": quimb_backend, + "dtype": dtype, + "torch_threads": torch_threads, + "parallel_opts": parallel_opts or {}, + "compile_circuit": compile_circuit, + "preprocess": preprocess, + } + + +def cpu_backend(**kwargs): + """Return a configured qibotn CPU backend. + + Example: + ``backend = cpu_backend(ansatz="mps", bond=512, torch_threads=8)`` + """ + from qibotn.backends.cpu import CpuTensorNet + + return CpuTensorNet(cpu_runcard(**kwargs)) + + @dataclass class ExpectationConfig: ansatz: str = "tn" @@ -33,6 +77,15 @@ class ExpectationResult: parallel_stats: list | None = None +def _config_from_kwargs(**kwargs): + fields = ExpectationConfig.__dataclass_fields__ + config_kwargs = {name: kwargs.pop(name) for name in list(kwargs) if name in fields} + if kwargs: + unknown = ", ".join(sorted(kwargs)) + raise TypeError(f"Unknown expectation option(s): {unknown}") + return ExpectationConfig(**config_kwargs) + + def exact_for_observable(circuit, observable, nqubits): if isinstance(observable, dict) and "terms" in observable: terms = [ @@ -49,19 +102,18 @@ def exact_for_observable(circuit, observable, nqubits): def run_cpu_expectation(circuit, observable, config): - runcard = { - "MPI_enabled": config.mpi, - "MPS_enabled": config.ansatz.lower() == "mps", - "NCCL_enabled": False, - "expectation_enabled": observable, - "max_bond_dimension": config.bond, - "cut_ratio": config.cut_ratio, - "tensor_module": config.tensor_module, - "quimb_backend": config.quimb_backend, - "dtype": config.dtype, - "torch_threads": config.torch_threads, - "parallel_opts": config.parallel_opts or {}, - } + runcard = cpu_runcard( + observable, + ansatz=config.ansatz, + mpi=config.mpi, + bond=config.bond, + cut_ratio=config.cut_ratio, + tensor_module=config.tensor_module, + quimb_backend=config.quimb_backend, + dtype=config.dtype, + torch_threads=config.torch_threads, + parallel_opts=config.parallel_opts, + ) backend = construct_backend( backend="qibotn", platform="cpu", @@ -80,3 +132,26 @@ def run_cpu_expectation(circuit, observable, config): rank=rank, parallel_stats=list(stats) if stats is not None else None, ) + + +def cpu_expectation(circuit, observable=None, *, return_result=False, **kwargs): + """Compute a CPU TN/MPS expectation with concise keyword options. + + This is the preferred API for small scripts. Common options are + ``ansatz="tn" | "mps"``, ``bond``, ``cut_ratio``, ``mpi``, + ``torch_threads``, ``quimb_backend`` and ``parallel_opts``. + """ + config = _config_from_kwargs(**kwargs) + result = run_cpu_expectation(circuit, observable, config) + return result if return_result else result.value + + +def mps_expectation(circuit, observable=None, *, return_result=False, **kwargs): + """Compute expectation using the CPU Vidal/MPS path when possible.""" + kwargs.setdefault("ansatz", "mps") + return cpu_expectation( + circuit, + observable, + return_result=return_result, + **kwargs, + ) diff --git a/src/qibotn/observables.py b/src/qibotn/observables.py index 5fe3c3b..7f3c242 100644 --- a/src/qibotn/observables.py +++ b/src/qibotn/observables.py @@ -4,6 +4,30 @@ from qibo import hamiltonians from qibo.symbols import I, X, Y, Z +def pauli_pattern(pattern): + """Return the compact qibotn representation of a repeated Pauli string.""" + return {"pauli_string_pattern": pattern} + + +def pauli_sum(*terms): + """Return the compact qibotn representation of a Pauli sum. + + Each term is ``(coefficient, operators)`` where operators are pairs like + ``("X", 0)``. Example: + + ``pauli_sum((0.5, [("X", 0), ("Z", 1)]), (-1.0, [("Z", 3)]))`` + """ + return { + "terms": [ + { + "coefficient": coeff, + "operators": [(name, int(site)) for name, site in operators], + } + for coeff, operators in terms + ] + } + + def check_observable(observable, circuit_nqubit): """Checks the type of observable and returns the appropriate Hamiltonian.""" if observable is None: @@ -20,11 +44,10 @@ def check_observable(observable, circuit_nqubit): def build_observable(circuit_nqubit): """Construct the default benchmark observable used by qibotn.""" - hamiltonian_form = 0 - for i in range(circuit_nqubit): - hamiltonian_form += 0.5 * X(i % circuit_nqubit) * Z((i + 1) % circuit_nqubit) - - return hamiltonians.SymbolicHamiltonian(form=hamiltonian_form) + form = sum( + 0.5 * X(i) * Z((i + 1) % circuit_nqubit) for i in range(circuit_nqubit) + ) + return hamiltonians.SymbolicHamiltonian(form=form) def create_hamiltonian_from_dict(data, circuit_nqubit): @@ -50,7 +73,6 @@ def create_hamiltonian_from_dict(data, circuit_nqubit): term_expr = full_term_expr[0] for op in full_term_expr[1:]: term_expr *= op - terms.append(coeff * term_expr) if not terms: @@ -84,23 +106,20 @@ def create_hamiltonian_from_pauli_pattern(pattern, circuit_nqubit): continue factor = pauli_gates[name](qubit) expr = factor if expr is None else expr * factor - - if expr is None: - expr = I(0) - - return hamiltonians.SymbolicHamiltonian(form=expr) + return hamiltonians.SymbolicHamiltonian(form=expr or I(0)) def build_random_circuit(nqubits, nlayers, seed=42): """Build a random circuit with RY+RZ+CNOT layers for benchmarks.""" import numpy as np from qibo import Circuit, gates - np.random.seed(seed) + + rng = np.random.default_rng(seed) c = Circuit(nqubits) for _ in range(nlayers): for q in range(nqubits): - c.add(gates.RY(q, theta=np.random.uniform(0, 2*np.pi))) - c.add(gates.RZ(q, theta=np.random.uniform(0, 2*np.pi))) + c.add(gates.RY(q, theta=rng.uniform(0, 2 * np.pi))) + c.add(gates.RZ(q, theta=rng.uniform(0, 2 * np.pi))) for q in range(nqubits): c.add(gates.CNOT(q % nqubits, (q + 1) % nqubits)) return c diff --git a/src/qibotn/result.py b/src/qibotn/result.py index 2748cb0..194dae2 100644 --- a/src/qibotn/result.py +++ b/src/qibotn/result.py @@ -32,20 +32,19 @@ class TensorNetworkResult: statevector: ndarray def __post_init__(self): - # TODO: define the general convention when using backends different from qmatchatea if self.measured_probabilities is None: - self.measured_probabilities = {"default": self.measured_probabilities} + self.measured_probabilities = {} def probabilities(self): """Return calculated probabilities according to the given method.""" - if self.prob_type == "U": - measured_probabilities = deepcopy(self.measured_probabilities) - for bitstring, prob in self.measured_probabilities[self.prob_type].items(): - measured_probabilities[self.prob_type][bitstring] = prob[1] - prob[0] - probabilities = measured_probabilities[self.prob_type] - else: - probabilities = self.measured_probabilities - return probabilities + if self.prob_type != "U": + return self.measured_probabilities + + measured_probabilities = deepcopy(self.measured_probabilities) + values = measured_probabilities.get(self.prob_type, {}) + for bitstring, prob in values.items(): + values[bitstring] = prob[1] - prob[0] + return values def frequencies(self): """Return frequencies if a certain number of shots has been set.""" diff --git a/tests/test_cpu_backend.py b/tests/test_cpu_backend.py index 877ddd8..e5ea781 100644 --- a/tests/test_cpu_backend.py +++ b/tests/test_cpu_backend.py @@ -9,6 +9,7 @@ from qibotn.benchmark_cases import ( build_circuit as build_benchmark_circuit, exact_pauli_sum, ) +from qibotn import cpu_expectation, mps_expectation, pauli_pattern, pauli_sum def build_circuit(nqubits=6): @@ -46,6 +47,37 @@ def test_cpu_generic_tn_expectation_matches_statevector(): assert math.isclose(value, exact, abs_tol=1e-12) +def test_public_cpu_expectation_api_matches_statevector(): + circuit = build_circuit() + observable = pauli_sum((0.5, [("X", 0), ("Z", 1)]), (-0.25, [("Z", 5)])) + exact = exact_pauli_sum( + circuit, + [(0.5, (("X", 0), ("Z", 1))), (-0.25, (("Z", 5),))], + circuit.nqubits, + ) + + value = cpu_expectation(circuit, observable, torch_threads=1) + + assert math.isclose(value, exact, abs_tol=1e-12) + + +def test_public_mps_expectation_api_accepts_pauli_pattern(): + circuit = build_circuit() + exact_hamiltonian = hamiltonians.SymbolicHamiltonian( + form=X(1) * Z(2) * X(4) * Z(5) + ) + exact = exact_hamiltonian.expectation_from_state(circuit().state(numpy=True)) + + value = mps_expectation( + circuit, + pauli_pattern("IXZ"), + bond=64, + torch_threads=1, + ) + + assert math.isclose(value, exact, abs_tol=1e-12) + + def test_cpu_mps_expectation_matches_statevector(): circuit = build_circuit() observable = build_observable(circuit.nqubits)