Files
qibotn/.venv/lib/python3.12/site-packages/cotengra/core.py
jaunatisblue 28080dff1d
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
简化代码;加入.venv下内容
2026-05-18 02:47:40 +08:00

4131 lines
137 KiB
Python

"""Core contraction tree data structure and methods."""
import collections
import functools
import itertools
import math
import warnings
from dataclasses import dataclass
from typing import Optional
from autoray import do
from .contract import make_contractor
from .hypergraph import get_hypergraph
from .parallel import (
can_scatter,
maybe_leave_pool,
maybe_rejoin_pool,
parse_parallel_arg,
scatter,
submit,
)
from .pathfinders.path_simulated_annealing import (
parallel_temper_tree,
simulated_anneal_tree,
)
from .plot import (
plot_contractions,
plot_contractions_alt,
plot_hypergraph,
plot_tree_circuit,
plot_tree_flat,
plot_tree_ring,
plot_tree_rubberband,
plot_tree_span,
plot_tree_tent,
)
from .scoring import (
DEFAULT_COMBO_FACTOR,
CompressedStatsTracker,
get_score_fn,
)
from .utils import (
MaxCounter,
compute_size_by_dict,
deprecated,
get_rng,
get_symbol,
groupby,
inputs_output_to_eq,
interleave,
is_valid_node,
node_from_seq,
node_from_single,
node_get_single_el,
node_supremum,
oset,
prod,
unique,
)
def cached_node_property(name):
"""Decorator for caching information about nodes."""
def wrapper(meth):
@functools.wraps(meth)
def getter(self, node):
try:
return self.info[node][name]
except KeyError:
self.info[node][name] = value = meth(self, node)
return value
return getter
return wrapper
def union_it(bs):
"""Non-variadic version of various set type unions."""
b0, *bs = bs
return b0.union(*bs)
def legs_union(legs_seq):
"""Combine a sequence of legs into a single set of legs, summing their
appearances.
"""
new_legs, *rem_legs = legs_seq
new_legs = new_legs.copy()
for legs in rem_legs:
for ix, ix_count in legs.items():
new_legs[ix] = new_legs.get(ix, 0) + ix_count
return new_legs
def legs_without(legs, ind):
"""Discard ``ind`` from legs to create a new set of legs."""
new_legs = legs.copy()
new_legs.pop(ind, None)
return new_legs
def get_with_default(k, obj, default):
return obj.get(k, default)
@dataclass(order=True, frozen=True)
class SliceInfo:
inner: bool
ind: str
size: int
project: Optional[int]
@property
def sliced_range(self):
if self.project is None:
return range(self.size)
else:
return [self.project]
def get_slice_strides(sliced_inds):
"""Compute the 'strides' given the (ordered) dictionary of sliced indices."""
slice_infos = list(sliced_inds.values())
nsliced = len(slice_infos)
strides = [1] * nsliced
# backwards cumulative product
for i in range(nsliced - 2, -1, -1):
strides[i] = strides[i + 1] * slice_infos[i + 1].size
return strides
def add_maybe_exponent_stripped(x, y):
"""Add two arrays, or tuples of (array, exponent) together in a stable
and branchless way.
"""
xistup = isinstance(x, tuple)
yistup = isinstance(y, tuple)
if not (xistup or yistup):
# simple sum without exponent
return x + y
if xistup:
xm, xe = x
else:
xm = x
xe = 0.0
if yistup:
ym, ye = y
else:
ym = y
ye = 0.0
# perform branchless for jit etc.
e = max(xe, ye)
m = xm * 10 ** (xe - e) + ym * 10 ** (ye - e)
return (m, e)
class ContractionTree:
"""Binary tree representing a tensor network contraction.
Parameters
----------
inputs : sequence of str
The list of input tensor's indices.
output : str
The output indices.
size_dict : dict[str, int]
The size of each index.
track_childless : bool, optional
Whether to dynamically keep track of which nodes are childless. Useful
if you are 'divisively' building the tree.
track_flops : bool, optional
Whether to dynamically keep track of the total number of flops. If
``False`` You can still compute this once the tree is complete.
track_write : bool, optional
Whether to dynamically keep track of the total number of elements
written. If ``False`` You can still compute this once the tree is
complete.
track_size : bool, optional
Whether to dynamically keep track of the largest tensor so far. If
``False`` You can still compute this once the tree is complete.
objective : str or Objective, optional
An default objective function to use for further optimization and
scoring, for example reconfiguring or computing the combo cost. If not
supplied the default is to create a flops objective when needed.
Attributes
----------
children : dict[node, tuple[node]]
Mapping of each node to two children.
info : dict[node, dict]
Information about the tree nodes. The key is the set of inputs (a
set of inputs indices) the node contains. Or in other words, the
subgraph of the node. The value is a dictionary to cache information
about effective 'leg' indices, size, flops of formation etc.
"""
def __init__(
self,
inputs,
output,
size_dict,
track_childless=False,
track_flops=False,
track_write=False,
track_size=False,
objective=None,
):
self.inputs = inputs
self.output = output
if isinstance(self.inputs[0], set) or isinstance(self.output, set):
warnings.warn(
"The inputs or output of this tree are not ordered."
"Costs will be accurate but actually contracting requires "
"ordered indices corresponding to array axes."
)
if not isinstance(next(iter(size_dict.values()), 1), int):
# make sure we are working with python integers to avoid overflow
# comparison errors with inf etc.
self.size_dict = {k: int(v) for k, v in size_dict.items()}
else:
self.size_dict = size_dict
self.N = len(self.inputs)
# the index representation for each input is an ordered mapping of
# each index to the number of times it has appeared on children. By
# also tracking the total number of appearances one can efficiently
# and locally compute which indices should be kept or contracted
self.appearances = {}
for term in self.inputs:
for ix in term:
self.appearances[ix] = self.appearances.get(ix, 0) + 1
# adding output appearances ensures these are never contracted away,
# N.B. if after this step every appearance count is exactly 2,
# then there are no 'hyper' indices in the contraction
for ix in self.output:
self.appearances[ix] = self.appearances.get(ix, 0) + 1
# this stores potentialy preprocessing steps that are not part of the
# main contraction tree, but assumed to have been applied, for example
# tracing or summing over indices that appear only once
self.preprocessing = {}
# mapping of parents to children - the core binary tree object
self.children = {}
# information about all the nodes
self.info = {}
# add constant nodes: the leaves
for leaf in self.gen_leaves():
self._add_node(leaf)
# and the root or top node
self.root = node_supremum(self.N)
self._add_node(self.root)
# whether to keep track of dangling nodes/subgraphs
self.track_childless = track_childless
if self.track_childless:
# the set of dangling nodes
self.childless = oset([self.root])
# running largest_intermediate and total flops
self._track_flops = track_flops
if track_flops:
self._flops = 0
self._track_write = track_write
if track_write:
self._write = 0
self._track_size = track_size
if track_size:
self._sizes = MaxCounter()
# container for caching subtree reconfiguration condidates
self.already_optimized = dict()
# info relating to slicing (base constructor is always unsliced)
self.multiplicity = 1
self.sliced_inds = {}
self.sliced_inputs = frozenset()
# cache for compiled contraction cores
self.contraction_cores = {}
# a default objective function useful for
# further optimization and scoring
self._default_objective = objective
def set_state_from(self, other):
"""Set the internal state of this tree to that of ``other``."""
# immutable or never mutated properties
for attr in (
"appearances",
"inputs",
"multiplicity",
"N",
"output",
"root",
"size_dict",
"sliced_inputs",
"_default_objective",
):
setattr(self, attr, getattr(other, attr))
# mutable properties
for attr in (
"children",
"contraction_cores",
"sliced_inds",
"preprocessing",
):
setattr(self, attr, getattr(other, attr).copy())
# dicts of mutable
for attr in ("info", "already_optimized"):
setattr(
self,
attr,
{k: v.copy() for k, v in getattr(other, attr).items()},
)
self.track_childless = other.track_childless
if other.track_childless:
self.childless = other.childless.copy()
self._track_flops = other._track_flops
if other._track_flops:
self._flops = other._flops
self._track_write = other._track_write
if other._track_write:
self._write = other._write
self._track_size = other._track_size
if other._track_size:
self._sizes = other._sizes.copy()
def copy(self):
"""Create a copy of this ``ContractionTree``."""
tree = object.__new__(self.__class__)
tree.set_state_from(self)
return tree
def set_default_objective(self, objective):
"""Set the objective function for this tree."""
self._default_objective = get_score_fn(objective)
def get_default_objective(self):
"""Get the objective function for this tree."""
if self._default_objective is None:
self._default_objective = get_score_fn("flops")
return self._default_objective
def get_default_combo_factor(self):
"""Get the default combo factor for this tree."""
objective = self.get_default_objective()
try:
return objective.factor
except AttributeError:
return DEFAULT_COMBO_FACTOR
def get_score(self, objective=None):
"""Score this tree using the default objective function."""
from .scoring import get_score_fn
if objective is None:
objective = self.get_default_objective()
objective = get_score_fn(objective)
return objective({"tree": self})
@property
def nslices(self):
"""Simple alias for how many independent contractions this tree
represents overall.
"""
return self.multiplicity
@property
def nchunks(self):
"""The number of 'chunks' - determined by the number of sliced output
indices.
"""
return prod(
si.size for si in self.sliced_inds.values() if not si.inner
)
def node_to_terms(self, node):
"""Turn a node -- a frozen set of ints -- into the corresponding terms
-- a sequence of sets of str corresponding to input indices.
"""
return (self.get_legs(node_from_single(i)) for i in node)
def gen_leaves(self):
"""Generate the nodes representing leaves of the contraction tree, i.e.
of size 1 each corresponding to a single input tensor.
"""
return map(node_from_single, range(self.N))
def get_incomplete_nodes(self):
"""Get the set of current nodes that have no children and the set of
nodes that have no parents. These are the 'childless' and 'parentless'
nodes respectively, that need to be contracted to complete the tree.
The parentless nodes are grouped into the childless nodes that contain
them as subgraphs.
Returns
-------
groups : dict[frozenet[int], list[frozenset[int]]]
A mapping of childless nodes to the list of parentless nodes are
beneath them.
See Also
--------
autocomplete
"""
childless = dict.fromkeys(
node
for node in self.info
# start wth all but leaves
if len(node) != 1
)
parentless = dict.fromkeys(
node
for node in self.info
# start with all but root
if len(node) != self.N
)
for p, (l, r) in self.children.items():
parentless.pop(l)
parentless.pop(r)
childless.pop(p)
groups = {node: [] for node in childless}
for node in parentless:
# get the smallest node that contains this node
ancestor = min(
filter(node.issubset, childless),
key=len,
)
groups[ancestor].append(node)
return groups
def autocomplete(self, **contract_opts):
"""Contract all remaining node groups (as computed by
``tree.get_incomplete_nodes``) in the tree to complete it.
Parameters
----------
contract_opts
Options to pass to ``tree.contract_nodes``.
See Also
--------
get_incomplete_nodes, contract_nodes
"""
groups = self.get_incomplete_nodes()
for _, parentless_subnodes in groups.items():
self.contract_nodes(parentless_subnodes, **contract_opts)
@classmethod
def from_path(
cls,
inputs,
output,
size_dict,
*,
path=None,
ssa_path=None,
edge_path=None,
optimize="auto-hq",
autocomplete="auto",
check=False,
**kwargs,
):
"""Create a (completed) ``ContractionTree`` from the usual inputs plus
a standard contraction path or 'ssa_path' - you need to supply one.
Parameters
----------
inputs : Sequence[Sequence[str]]
The input indices of each tensor, as single unicode characters.
output : Sequence[str]
The output indices.
size_dict : dict[str, int]
The size of each index.
path : Sequence[Sequence[int]], optional
The contraction path, a sequence of pairs of tensor ids to
contract. The ids are linear indices into the list of temporary
tensors, which are recycled as each contraction pops a pair and
appends the result. One of ``path``, ``ssa_path`` or ``edge_path``
must be supplied.
ssa_path : Sequence[Sequence[int]], optional
The contraction path, a sequence of pairs of indices to contract.
The indices are single use, as if the result of each contraction is
appended to the end of the list of temporary tensors without
popping. One of ``path``, ``ssa_path`` or ``edge_path`` must be
supplied.
edge_path : Sequence[str], optional
The contraction path, a sequence of indices to contract in order.
One of ``path``, ``ssa_path`` or ``edge_path`` must be supplied.
optimize : str, optional
If a contraction within the path contains 3 or more tensors, how to
optimize this subcontraction into a binary tree.
autocomplete : "auto" or bool, optional
Whether to automatically complete the path, i.e. contract all
remaining nodes. If "auto" then a warning is issued if the path is
not complete.
check : bool, optional
Whether to perform some basic checks while creating the contraction
nodes.
Returns
-------
ContractionTree
"""
if (path is None) + (ssa_path is None) + (edge_path is None) != 2:
raise ValueError(
"Exactly one of ``path`` or ``ssa_path`` must be supplied."
)
contract_opts = {"optimize": optimize, "check": check}
if edge_path is not None:
from .pathfinders.path_basic import edge_path_to_ssa
ssa_path = edge_path_to_ssa(edge_path, inputs)
if ssa_path is not None:
path = ssa_path
tree = cls(inputs, output, size_dict, **kwargs)
if ssa_path is not None:
# ssa path (single use ids)
nodes = dict(enumerate(tree.gen_leaves()))
ssa = len(nodes)
for p in path:
merge = [nodes.pop(i) for i in p]
nodes[ssa] = tree.contract_nodes(merge, **contract_opts)
ssa += 1
nodes = nodes.values()
else:
# regular path ('recycled' ids)
nodes = list(tree.gen_leaves())
for p in path:
merge = [nodes.pop(i) for i in sorted(p, reverse=True)]
nodes.append(tree.contract_nodes(merge, **contract_opts))
if len(nodes) > 1 and autocomplete:
if autocomplete == "auto":
# warn that we are completing
warnings.warn(
"Path was not complete - contracting all remaining. "
"You can silence this warning with `autocomplete=True`."
"Or produce an incomplete tree with `autocomplete=False`."
)
tree.contract_nodes(nodes, **contract_opts)
return tree
@classmethod
def from_info(cls, info, **kwargs):
"""Create a ``ContractionTree`` from an ``opt_einsum.PathInfo`` object."""
return cls.from_path(
inputs=info.input_subscripts.split(","),
output=info.output_subscript,
size_dict=info.size_dict,
path=info.path,
**kwargs,
)
@classmethod
def from_eq(cls, eq, size_dict, **kwargs):
"""Create a empty ``ContractionTree`` directly from an equation and set
of shapes.
Parameters
----------
eq : str
The einsum string equation.
size_dict : dict[str, int]
The size of each index.
"""
lhs, output = eq.split("->")
inputs = lhs.split(",")
return cls(inputs, output, size_dict, **kwargs)
def get_eq(self):
"""Get the einsum equation corresponding to this tree. Note that this
is the total (or original) equation, so includes indices which have
been sliced.
Returns
-------
eq : str
"""
return inputs_output_to_eq(self.inputs, self.output)
def get_shapes(self):
"""Get the shapes of the input tensors corresponding to this tree.
Returns
-------
shapes : tuple[tuple[int]]
"""
return tuple(
tuple(self.size_dict[ix] for ix in term) for term in self.inputs
)
def get_inputs_sliced(self):
"""Get the input indices corresponding to a single slice of this tree,
i.e. with sliced indices removed.
Returns
-------
inputs : tuple[tuple[str]]
"""
return tuple(
tuple(ix for ix in term if ix not in self.sliced_inds)
for term in self.inputs
)
def get_output_sliced(self):
"""Get the output indices corresponding to a single slice of this tree,
i.e. with sliced indices removed.
Returns
-------
output : tuple[str]
"""
return tuple(ix for ix in self.output if ix not in self.sliced_inds)
def get_eq_sliced(self):
"""Get the einsum equation corresponding to a single slice of this
tree, i.e. with sliced indices removed.
Returns
-------
eq : str
"""
return inputs_output_to_eq(
self.get_inputs_sliced(), self.get_output_sliced()
)
def get_shapes_sliced(self):
"""Get the shapes of the input tensors corresponding to a single slice
of this tree, i.e. with sliced indices removed.
Returns
-------
shapes : tuple[tuple[int]]
"""
return tuple(
tuple(
self.size_dict[ix] for ix in term if ix not in self.sliced_inds
)
for term in self.inputs
)
@classmethod
def from_edge_path(
cls,
edge_path,
inputs,
output,
size_dict,
optimize="auto-hq",
autocomplete="auto",
check=False,
**kwargs,
):
"""Create a ``ContractionTree`` from an edge elimination ordering."""
warnings.warn(
"ContractionTree.from_edge_path(edge_path, ...) is deprecated. Use"
" ContractionTree.from_path(edge_path=edge_path, ...) instead."
)
return cls.from_path(
inputs,
output,
size_dict,
edge_path=edge_path,
optimize=optimize,
autocomplete=autocomplete,
check=check,
**kwargs,
)
def _add_node(self, node, check=False):
if check:
if len(self.info) > 2 * self.N - 1:
raise ValueError("There are too many children already.")
if len(self.children) > self.N - 1:
raise ValueError("There are too many branches already.")
if not is_valid_node(node):
raise ValueError("{} is not a valid node.".format(node))
self.info.setdefault(node, dict())
def _remove_node(self, node):
"""Remove ``node`` from this tree and update the flops and maximum size
if tracking them respectively, as well as input pre-processing.
"""
node_extent = len(node)
if node_extent == 1:
# leaf nodes should always exist
self.info[node].clear()
# input: remove any associated preprocessing
self.preprocessing.pop(node_get_single_el(node), None)
else:
# only non-leaf nodes contribute to size, flops and write
if self._track_size:
self._sizes.discard(self.get_size(node))
if self._track_flops:
self._flops -= self.get_flops(node)
if self._track_write:
self._write -= self.get_size(node)
del self.children[node]
if node_extent == self.N:
# root node should always exist
self.info[node].clear()
else:
del self.info[node]
def compute_leaf_legs(self, i):
"""Compute the effective 'outer' indices for the ith input tensor. This
is not always simply the ith input indices, due to A) potential slicing
and B) potential preprocessing.
"""
# indices of input tensor (after slicing which is done immediately)
if self.sliced_inds:
term = tuple(
ix for ix in self.inputs[i] if ix not in self.sliced_inds
)
else:
term = self.inputs[i]
legs = {}
for ix in term:
legs[ix] = legs.get(ix, 0) + 1
# check for single term simplifications, these are treated as a simple
# preprocessing step that only is taken into account during actual
# contraction, and are not represented in the binary tree
# N.B. need to compute simplifiability *after* slicing
is_simplifiable = (
# repeated indices (diag or traces)
(len(term) != len(legs))
or
# reduced indices (are summed immediately)
any(
ix_count == self.appearances[ix]
for ix, ix_count in legs.items()
)
)
if is_simplifiable:
# compute the simplified legs -> the new effective input legs
legs = {
ix: ix_count
for ix, ix_count in legs.items()
if ix_count != self.appearances[ix]
}
# add a preprocessing step to the list of contractions
eq = inputs_output_to_eq((term,), legs, canonicalize=True)
self.preprocessing[i] = eq
return legs
def has_preprocessing(self):
# touch all inputs legs, since preprocessing is lazily computed
for node in self.gen_leaves():
self.get_legs(node)
return bool(self.preprocessing)
def has_hyper_indices(self):
"""Check if there are any 'hyper' indices in the contraction, i.e.
indices that don't appear exactly twice, when considering the inputs
and output.
"""
return any(ix_count != 2 for ix_count in self.appearances.values())
@cached_node_property("legs")
def get_legs(self, node):
"""Get the effective 'outer' indices for the collection of tensors
in ``node``.
"""
node_extent = len(node)
if node_extent == 1:
# leaf legs are inputs
return self.compute_leaf_legs(node_get_single_el(node))
elif node_extent == self.N:
# root legs are output, after slicing
# n.b. the index counts are irrelevant for the output
return {ix: 0 for ix in self.output if ix not in self.sliced_inds}
try:
involved = self.get_involved(node)
except KeyError:
involved = legs_union(self.node_to_terms(node))
return {
ix: ix_count
for ix, ix_count in involved.items()
if ix_count < self.appearances[ix]
}
@cached_node_property("involved")
def get_involved(self, node):
"""Get all the indices involved in the formation of subgraph ``node``."""
if len(node) == 1:
return {}
sub_legs = map(self.get_legs, self.children[node])
return legs_union(sub_legs)
@cached_node_property("size")
def get_size(self, node):
"""Get the tensor size of ``node``."""
return compute_size_by_dict(self.get_legs(node), self.size_dict)
@cached_node_property("flops")
def get_flops(self, node):
"""Get the FLOPs for the pairwise contraction that will create
``node``.
"""
if len(node) == 1:
return 0
involved = self.get_involved(node)
return compute_size_by_dict(involved, self.size_dict)
@cached_node_property("can_dot")
def get_can_dot(self, node):
"""Get whether this contraction can be performed as a dot product (i.e.
with ``tensordot``), or else requires ``einsum``, as it has indices
that don't appear exactly twice in either the inputs or the output.
"""
l, r = self.children[node]
sp, sl, sr = map(self.get_legs, (node, l, r))
return set(sp) == set(sl).symmetric_difference(sr)
@cached_node_property("inds")
def get_inds(self, node):
"""Get the indices of this node - an ordered string version of
``get_legs`` that starts with ``tree.inputs`` and maintains the order
they appear in each contraction 'ABC,abc->ABCabc', to match tensordot.
"""
# NB: self.inputs and self.output contain the full (unsliced) indices
# thus we filter even the input legs and output legs
if len(node) in (1, self.N):
return "".join(self.get_legs(node))
legs = self.get_legs(node)
l_inds, r_inds = map(self.get_inds, self.children[node])
default_inds = "".join(
unique(filter(legs.__contains__, itertools.chain(l_inds, r_inds)))
)
# The default ordering can put right-only output indices between
# batches of left-only indices. On torch CPU this tends to create
# high-rank permuted views that have to be cloned before GEMM/BMM.
# Group output indices by how this contraction naturally produces
# them: shared kept indices first, then left-only, then right-only.
r_inds_set = set(r_inds)
l_inds_set = set(l_inds)
batch = [
ix for ix in l_inds
if (ix in legs) and (ix in r_inds_set)
]
left_keep = [
ix for ix in l_inds
if (ix in legs) and (ix not in r_inds_set)
]
right_keep = [
ix for ix in r_inds
if (ix in legs) and (ix not in l_inds_set)
]
ordered = list(unique(itertools.chain(batch, left_keep, right_keep)))
seen = set(ordered)
ordered.extend(ix for ix in default_inds if ix not in seen)
return "".join(ordered)
@cached_node_property("tensordot_axes")
def get_tensordot_axes(self, node):
"""Get the ``axes`` arg for a tensordot ocontraction that produces
``node``. The pairs are sorted in order of appearance on the left
input.
"""
l_inds, r_inds = map(self.get_inds, self.children[node])
l_axes, r_axes = [], []
for i, ind in enumerate(l_inds):
j = r_inds.find(ind)
if j != -1:
l_axes.append(i)
r_axes.append(j)
return tuple(l_axes), tuple(r_axes)
@cached_node_property("tensordot_perm")
def get_tensordot_perm(self, node):
"""Get the permutation required, if any, to bring the tensordot output
of this nodes contraction into line with ``self.get_inds(node)``.
"""
l_inds, r_inds = map(self.get_inds, self.children[node])
# the target output inds
p_inds = self.get_inds(node)
# the tensordot output inds
td_inds = "".join(sorted(p_inds, key=f"{l_inds}{r_inds}".find))
if td_inds == p_inds:
return None
return tuple(map(td_inds.find, p_inds))
@cached_node_property("einsum_eq")
def get_einsum_eq(self, node):
"""Get the einsum string describing the contraction that produces
``node``, unlike ``get_inds`` the characters are mapped into [a-zA-Z],
for compatibility with ``numpy.einsum`` for example.
"""
l, r = self.children[node]
l_inds, r_inds, p_inds = map(self.get_inds, (l, r, node))
# we need to map any extended unicode characters into ascii
char_mapping = {
ord(ix): get_symbol(i)
for i, ix in enumerate(unique(itertools.chain(l_inds, r_inds)))
}
return f"{l_inds},{r_inds}->{p_inds}".translate(char_mapping)
def get_centrality(self, node):
try:
return self.info[node]["centrality"]
except KeyError:
self.compute_centralities()
return self.info[node]["centrality"]
def total_flops(self, dtype=None, log=None):
"""Sum the flops contribution from every node in the tree.
Parameters
----------
dtype : {'float', 'complex', None}, optional
Scale the answer depending on the assumed data type.
"""
if self._track_flops:
C = self.multiplicity * self._flops
else:
self._flops = 0
for node, _, _ in self.traverse():
self._flops += self.get_flops(node)
self._track_flops = True
C = self.multiplicity * self._flops
if dtype is None:
pass
elif "float" in dtype:
C *= 2
elif "complex" in dtype:
C *= 4
else:
raise ValueError(f"Unknown dtype {dtype}")
if log is not None:
C = math.log(C, log)
return C
def total_write(self):
"""Sum the total amount of memory that will be created and operated on."""
if not self._track_write:
self._write = 0
for node, _, _ in self.traverse():
self._write += self.get_size(node)
self._track_write = True
return self.multiplicity * self._write
def combo_cost(self, factor=DEFAULT_COMBO_FACTOR, combine=sum, log=None):
t = 0
for p in self.children:
f = self.get_flops(p)
w = self.get_size(p)
t += combine((f, factor * w))
t *= self.multiplicity
if log is not None:
t = math.log(t, log)
return t
total_cost = combo_cost
def max_size(self, log=None):
"""The size of the largest intermediate tensor."""
if self.N == 1:
return self.get_size(self.root)
if not self._track_size:
self._sizes = MaxCounter()
for node, _, _ in self.traverse():
self._sizes.add(self.get_size(node))
self._track_size = True
size = self._sizes.max()
if log is not None:
size = math.log(size, log)
return size
def peak_size(self, order=None, log=None):
"""Get the peak concurrent size of tensors needed - this depends on the
traversal order, i.e. the exact contraction path, not just the
contraction tree.
"""
tot_size = sum(self.get_size(node) for node in self.gen_leaves())
peak = tot_size
for p, l, r in self.traverse(order=order):
tot_size += self.get_size(p)
# measure peak assuming we need both inputs and output
peak = max(peak, tot_size)
tot_size -= self.get_size(l)
tot_size -= self.get_size(r)
if log is not None:
peak = math.log(peak, log)
return peak
def contract_stats(self, force=False):
"""Simulteneously compute the total flops, write and size of the
contraction tree. This is more efficient than calling each of the
individual methods separately. Once computed, each quantity is then
automatically tracked.
Returns
-------
stats : dict[str, int]
The total flops, write and size.
"""
if force or not (
self._track_flops and self._track_write and self._track_size
):
self._flops = self._write = 0
self._sizes = MaxCounter()
for node, _, _ in self.traverse():
self._flops += self.get_flops(node)
node_size = self.get_size(node)
self._write += node_size
self._sizes.add(node_size)
self._track_flops = self._track_write = self._track_size = True
return {
"flops": self.multiplicity * self._flops,
"write": self.multiplicity * self._write,
"size": self._sizes.max(),
}
def arithmetic_intensity(self):
"""The ratio of total flops to total write - the higher the better for
extracting good computational performance.
"""
return self.total_flops(dtype=None) / self.total_write()
def contraction_scaling(self):
"""This is computed simply as the maximum number of indices involved
in any single contraction, which will match the scaling assuming that
all dimensions are equal.
"""
return max(len(self.get_involved(node)) for node in self.info)
def contraction_cost(self, log=None):
"""Get the total number of scalar operations ~ time complexity."""
return self.total_flops(dtype=None, log=log)
def contraction_width(self, log=2):
"""Get log2 of the size of the largest tensor."""
return self.max_size(log=log)
def compressed_contract_stats(
self,
chi=None,
order="surface_order",
compress_late=None,
):
if chi is None:
chi = self.get_default_chi()
if compress_late is None:
compress_late = self.get_default_compress_late()
hg = self.get_hypergraph(accel="auto")
# conversion between tree nodes <-> hypergraph nodes during contraction
tree_map = dict(zip(self.gen_leaves(), range(hg.get_num_nodes())))
tracker = CompressedStatsTracker(hg, chi)
for p, l, r in self.traverse(order):
li = tree_map[l]
ri = tree_map[r]
tracker.update_pre_step()
if compress_late:
tracker.update_pre_compress(hg, li, ri)
# compress just before we contract tensors
hg.compress(chi=chi, edges=hg.get_node(li))
hg.compress(chi=chi, edges=hg.get_node(ri))
tracker.update_post_compress(hg, li, ri)
tracker.update_pre_contract(hg, li, ri)
pi = tree_map[p] = hg.contract(li, ri)
tracker.update_post_contract(hg, pi)
if not compress_late:
# compress as soon as we can after contracting tensors
tracker.update_pre_compress(hg, pi)
hg.compress(chi=chi, edges=hg.get_node(pi))
tracker.update_post_compress(hg, pi)
tracker.update_post_step()
return tracker
def total_flops_compressed(
self,
chi=None,
order="surface_order",
compress_late=None,
dtype=None,
log=None,
):
"""Estimate the total flops for a compressed contraction of this tree
with maximum bond size ``chi``. This includes basic estimates of the
ops to perform contractions, QRs and SVDs.
"""
if dtype is not None:
raise ValueError(
"Can only estimate cost in terms of "
"number of abstract scalar ops."
)
F = self.compressed_contract_stats(
chi=chi,
order=order,
compress_late=compress_late,
).flops
if log is not None:
F = math.log(F, log)
return F
contraction_cost_compressed = total_flops_compressed
def total_write_compressed(
self,
chi=None,
order="surface_order",
compress_late=None,
accel="auto",
log=None,
):
"""Compute the total size of all intermediate tensors when a
compressed contraction is performed with maximum bond size ``chi``,
ordered by ``order``. This is relevant maybe for time complexity and
e.g. autodiff space complexity (since every intermediate is kept).
"""
W = self.compressed_contract_stats(
chi=chi,
order=order,
compress_late=compress_late,
).write
if log is not None:
W = math.log(W, log)
return W
def combo_cost_compressed(
self,
chi=None,
order="surface_order",
compress_late=None,
factor=None,
log=None,
):
if factor is None:
factor = self.get_default_combo_factor()
C = self.total_flops_compressed(
chi=chi, order=order, compress_late=compress_late
) + factor * self.total_write_compressed(
chi=chi, order=order, compress_late=compress_late
)
if log is not None:
C = math.log(C, log)
return C
total_cost_compressed = combo_cost_compressed
def max_size_compressed(
self, chi=None, order="surface_order", compress_late=None, log=None
):
"""Compute the maximum sized tensor produced when a compressed
contraction is performed with maximum bond size ``chi``, ordered by
``order``. This is close to the ideal space complexity if only
tensors that are being directly operated on are kept in memory.
"""
S = self.compressed_contract_stats(
chi=chi,
order=order,
compress_late=compress_late,
).max_size
if log is not None:
S = math.log(S, log)
return S
def peak_size_compressed(
self,
chi=None,
order="surface_order",
compress_late=None,
accel="auto",
log=None,
):
"""Compute the peak size of combined intermediate tensors when a
compressed contraction is performed with maximum bond size ``chi``,
ordered by ``order``. This is the practical space complexity if one is
not swapping intermediates in and out of memory.
"""
P = self.compressed_contract_stats(
chi=chi,
order=order,
compress_late=compress_late,
).peak_size
if log is not None:
P = math.log(P, log)
return P
def contraction_width_compressed(
self, chi=None, order="surface_order", compress_late=None, log=2
):
"""Compute log2 of the maximum sized tensor produced when a compressed
contraction is performed with maximum bond size ``chi``, ordered by
``order``.
"""
return self.max_size_compressed(chi, order, compress_late, log=log)
def _update_tracked(self, node):
if self._track_flops:
self._flops += self.get_flops(node)
if self._track_write:
self._write += self.get_size(node)
if self._track_size:
self._sizes.add(self.get_size(node))
def contract_nodes_pair(
self,
x,
y,
legs=None,
cost=None,
size=None,
check=False,
):
"""Contract node ``x`` with node ``y`` in the tree to create a new
parent node, which is returned.
Parameters
----------
x : frozenset[int]
The first node to contract.
y : frozenset[int]
The second node to contract.
legs : dict[str, int], optional
The effective 'legs' of the new node if already known. If not
given, this is computed from the inputs of ``x`` and ``y``.
cost : int, optional
The cost of the contraction if already known. If not given, this is
computed from the inputs of ``x`` and ``y``.
size : int, optional
The size of the new node if already known. If not given, this is
computed from the inputs of ``x`` and ``y``.
check : bool, optional
Whether to check the inputs are valid.
Returns
-------
parent : frozenset[int]
The new parent node of ``x`` and ``y``.
"""
parent = x.union(y)
# make sure info entries exist for all (default dict)
for node in (x, y, parent):
self._add_node(node, check=check)
# enforce left ordering of 'heaviest' subtrees
nx, ny = len(x), len(y)
if nx == ny:
# deterministically break ties
sortx = -min(x)
sorty = -min(y)
else:
sortx = nx
sorty = ny
if sortx > sorty:
lr = (x, y)
else:
lr = (y, x)
self.children[parent] = lr
if self.track_childless:
self.childless.discard(parent)
if x not in self.children and nx > 1:
self.childless.add(x)
if y not in self.children and ny > 1:
self.childless.add(y)
# pre-computed information
if legs is not None:
self.info[parent]["legs"] = legs
if cost is not None:
self.info[parent]["flops"] = cost
if size is not None:
self.info[parent]["size"] = size
self._update_tracked(parent)
return parent
def contract_nodes(
self,
nodes,
optimize="auto-hq",
check=False,
extra_opts=None,
):
"""Contract an arbitrary number of ``nodes`` in the tree to build up a
subtree. The root of this subtree (a new intermediate) is returned.
"""
if len(nodes) == 1:
return next(iter(nodes))
if len(nodes) == 2:
return self.contract_nodes_pair(*nodes, check=check)
from .interface import find_path
# create the bottom and top nodes
grandparent = union_it(nodes)
self._add_node(grandparent, check=check)
for node in nodes:
self._add_node(node, check=check)
# if more than two nodes need to find the path to fill in between
# \
# GN <- 'grandparent'
# / \
# ?????????
# ????????????? <- to be filled with 'temp nodes'
# / \ / / \
# N0 N1 N2 N3 N4 <- ``nodes``, or, subgraphs
# / \ / / \
path_inputs = [tuple(self.get_legs(x)) for x in nodes]
path_output = tuple(self.get_legs(grandparent))
path = find_path(
path_inputs,
path_output,
self.size_dict,
optimize=optimize,
**(extra_opts or {}),
)
# now we have path create the nodes in between
temp_nodes = list(nodes)
for p in path:
to_contract = [temp_nodes.pop(i) for i in sorted(p, reverse=True)]
temp_nodes.append(self.contract_nodes(to_contract, check=check))
(parent,) = temp_nodes
if check:
# final remaining temp input should be the 'grandparent'
assert parent == grandparent
return parent
def is_complete(self):
"""Check every node has two children, unless it is a leaf."""
too_many_nodes = len(self.info) > 2 * self.N - 1
too_many_branches = len(self.children) > self.N - 1
if too_many_nodes or too_many_branches:
raise ValueError("Contraction tree seems to be over complete!")
queue = [self.root]
while queue:
x = queue.pop()
if len(x) == 1:
continue
try:
queue.extend(self.children[x])
except KeyError:
return False
return True
def get_default_order(self):
return "dfs"
def _traverse_dfs(self):
"""Traverse the tree in a depth first, non-recursive, order."""
ready = set(self.gen_leaves())
queue = [self.root]
while queue:
node = queue[-1]
l, r = self.children[node]
# both node's children are ready -> we can yield this contraction
if (l in ready) and (r in ready):
ready.add(queue.pop())
yield node, l, r
continue
if r not in ready:
queue.append(r)
if l not in ready:
queue.append(l)
def _traverse_ordered(self, order):
"""Traverse the tree in the order that minimizes ``order(node)``, but
still constrained to produce children before parents.
"""
from bisect import bisect
if order == "surface_order":
order = self.surface_order
seen = set()
queue = [self.root]
scores = [order(self.root)]
while len(seen) != len(self.children):
i = 0
while i < len(queue):
node = queue[i]
if node not in seen:
for child in self.children[node]:
if len(child) > 1:
# insert child into queue by score + before parent
score = order(child)
ci = bisect(scores[:i], score)
scores.insert(ci, score)
queue.insert(ci, child)
# parent moves extra place to right
i += 1
seen.add(node)
i += 1
for node in queue:
yield (node, *self.children[node])
def traverse(self, order=None):
"""Generate, in order, all the node merges in this tree. Non-recursive!
This ensures children are always visited before their parent.
Parameters
----------
order : None, "dfs", or callable, optional
How to order the contractions within the tree. If a callable is
given (which should take a node as its argument), try to contract
nodes that minimize this function first.
Returns
-------
generator[tuple[node]]
The bottom up ordered sequence of tree merges, each a
tuple of ``(parent, left_child, right_child)``.
See Also
--------
descend
"""
if self.N == 1:
return
if order is None:
order = self.get_default_order()
if order == "dfs":
yield from self._traverse_dfs()
else:
yield from self._traverse_ordered(order=order)
def descend(self, mode="dfs"):
"""Generate, from root to leaves, all the node merges in this tree.
Non-recursive! This ensures parents are visited before their children.
Parameters
----------
mode : {'dfs', bfs}, optional
How expand from a parent.
Returns
-------
generator[tuple[node]
The top down ordered sequence of tree merges, each a
tuple of ``(parent, left_child, right_child)``.
See Also
--------
traverse
"""
queue = [self.root]
while queue:
if mode == "dfs":
parent = queue.pop(-1)
elif mode == "bfs":
parent = queue.pop(0)
l, r = self.children[parent]
yield parent, l, r
if len(l) > 1:
queue.append(l)
if len(r) > 1:
queue.append(r)
def get_subtree(self, node, size, search="bfs", seed=None):
"""Get a subtree spanning down from ``node`` which will have ``size``
leaves (themselves not necessarily leaves of the actual tree).
Parameters
----------
node : node
The node of the tree to start with.
size : int
How many subtree leaves to aim for.
search : {'bfs', 'dfs', 'random'}, optional
How to build the tree:
- 'bfs': breadth first expansion
- 'dfs': depth first expansion (largest nodes first)
- 'random': random expansion
seed : None, int or random.Random, optional
Random number generator seed, if ``search`` is 'random'.
Returns
-------
sub_leaves : tuple[node]
Nodes which are subtree leaves.
branches : tuple[node]
Nodes which are between the subtree leaves and root.
"""
# nodes which are subtree leaves
branches = []
# actual tree leaves - can't expand
real_leaves = []
# nodes to expand
queue = [node]
if search == "random":
rng = get_rng(seed)
else:
rng = None
if search == "bfs":
i = 0
elif search == "dfs":
i = -1
while (len(queue) + len(real_leaves) < size) and queue:
if rng is not None:
i = rng.randint(0, len(queue) - 1)
p = queue.pop(i)
if len(p) == 1:
real_leaves.append(p)
continue
# the left child is always >= in weight that right child
# if we append it last then ``.pop(-1)`` above perform the
# depth first search sorting by node subgraph size
l, r = self.children[p]
queue.append(r)
queue.append(l)
branches.append(p)
# nodes at the bottom of the subtree
sub_leaves = queue + real_leaves
return tuple(sub_leaves), tuple(branches)
def remove_ind(self, ind, project=None, inplace=False):
"""Remove (i.e. by default slice) index ``ind`` from this contraction
tree, taking care to update all relevant information about each node.
"""
tree = self if inplace else self.copy()
if ind in tree.sliced_inds:
raise ValueError(f"Index {ind} already sliced.")
# make sure all flops and size information has been populated
tree.contract_stats()
d = tree.size_dict[ind]
if project is None:
# we are slicing the index
si = SliceInfo(ind not in tree.output, ind, d, None)
tree.multiplicity = tree.multiplicity * d
else:
si = SliceInfo(ind not in tree.output, ind, 1, project)
# update the ordered slice information dictionary, but maintain the
# order such that output sliced indices always appear first ->
# enforced by the dataclass SliceInfo ordering
tree.sliced_inds = {
si.ind: si for si in sorted((*tree.sliced_inds.values(), si))
}
for node, node_info in tree.info.items():
if len(node) == 1:
# handle leaves separately
i = node_get_single_el(node)
term = tree.inputs[i]
if ind in term:
# n.b. leaves don't contribute to size, flops or write
# simply recalculate all information, incl. preprocessing
tree._remove_node(node)
tree.sliced_inputs = tree.sliced_inputs | frozenset([i])
else:
involved = tree.get_involved(node)
if ind not in involved:
# if ind doesn't feature in this node (contraction)
# -> nothing to do
continue
# else update all the relevant information about this node
# -> flops changes for all involved indices
node_info["involved"] = legs_without(involved, ind)
old_flops = tree.get_flops(node)
new_flops = old_flops // d
node_info["flops"] = new_flops
tree._flops += new_flops - old_flops
# -> size and write only changes for node legs (output) indices
legs = tree.get_legs(node)
if ind in legs:
node_info["legs"] = legs_without(legs, ind)
old_size = tree.get_size(node)
tree._sizes.discard(old_size)
new_size = old_size // d
tree._sizes.add(new_size)
node_info["size"] = new_size
tree._write += new_size - old_size
# delete info we can't change
for k in (
"inds",
"einsum_eq",
"can_dot",
"tensordot_axes",
"tensordot_perm",
):
tree.info[node].pop(k, None)
tree.already_optimized.clear()
tree.contraction_cores.clear()
return tree
remove_ind_ = functools.partialmethod(remove_ind, inplace=True)
def restore_ind(self, ind, inplace=False):
"""Restore (unslice or un-project) index ``ind`` to this contraction
tree, taking care to update all relevant information about each node.
Parameters
----------
ind : str
The index to restore.
inplace : bool, optional
Whether to perform the restoration inplace or not.
Returns
-------
ContractionTree
"""
tree = self if inplace else self.copy()
# pop sliced index info
si = tree.sliced_inds.pop(ind)
# make sure all flops and size information has been populated
tree.contract_stats()
tree.multiplicity //= si.size
# handle inputs
for i, term in enumerate(tree.inputs):
# this is the original term with all indices
if ind in term:
tree._remove_node(node_from_single(i))
if all(ix not in tree.sliced_inds for ix in term):
# mark this input as not sliced
tree.sliced_inputs = tree.sliced_inputs - frozenset([i])
# delete and re-add dependent intermediates
for p, l, r in tree.traverse():
if ind in tree.get_legs(l) or ind in tree.get_legs(r):
tree._remove_node(p)
tree.contract_nodes_pair(l, r)
# reset caches
tree.already_optimized.clear()
tree.contraction_cores.clear()
return tree
restore_ind_ = functools.partialmethod(restore_ind, inplace=True)
def unslice_rand(self, seed=None, inplace=False):
"""Unslice (restore) a random index from this contraction tree.
Parameters
----------
seed : None, int or random.Random, optional
Random number generator seed.
inplace : bool, optional
Whether to perform the unslicing inplace or not.
Returns
-------
ContractionTree
"""
rng = get_rng(seed)
ix = rng.choice(tuple(self.sliced_inds))
return self.restore_ind(ix, inplace=inplace)
unslice_rand_ = functools.partialmethod(unslice_rand, inplace=True)
def unslice_all(self, inplace=False):
"""Unslice (restore) all sliced indices from this contraction tree.
Parameters
----------
inplace : bool, optional
Whether to perform the unslicing inplace or not.
Returns
-------
ContractionTree
"""
tree = self if inplace else self.copy()
for ind in tuple(tree.sliced_inds):
tree.restore_ind_(ind)
return tree
unslice_all_ = functools.partialmethod(unslice_all, inplace=True)
def calc_subtree_candidates(self, pwr=2, what="flops"):
candidates = list(self.children)
if what == "size":
weights = [self.get_size(x) for x in candidates]
elif what == "flops":
weights = [self.get_flops(x) for x in candidates]
max_weight = max(weights)
# can be bigger than numpy int/float allows
weights = [float(w / max_weight) ** (1 / pwr) for w in weights]
# sort by descending score
candidates, weights = zip(
*sorted(zip(candidates, weights), key=lambda x: -x[1])
)
return list(candidates), list(weights)
def subtree_reconfigure(
self,
subtree_size=8,
subtree_search="bfs",
weight_what="flops",
weight_pwr=2,
select="max",
maxiter=500,
seed=None,
minimize=None,
optimize=None,
inplace=False,
progbar=False,
):
"""Reconfigure subtrees of this tree with locally optimal paths.
Parameters
----------
subtree_size : int, optional
The size of subtree to consider. Cost is exponential in this.
subtree_search : {'bfs', 'dfs', 'random'}, optional
How to build the subtrees:
- 'bfs': breadth-first-search creating balanced subtrees
- 'dfs': depth-first-search creating imbalanced subtrees
- 'random': random subtree building
weight_what : {'flops', 'size'}, optional
When assessing nodes to build and optimize subtrees from whether to
score them by the (local) contraction cost, or tensor size.
weight_pwr : int, optional
When assessing nodes to build and optimize subtrees from, how to
scale their score into a probability: ``score**(1 / weight_pwr)``.
The larger this is the more explorative the algorithm is when
``select='random'``.
select : {'max', 'min', 'random'}, optional
What order to select node subtrees to optimize:
- 'max': choose the highest score first
- 'min': choose the lowest score first
- 'random': choose randomly weighted on score -- see
``weight_pwr``.
maxiter : int, optional
How many subtree optimizations to perform, the algorithm can
terminate before this if all subtrees have been optimized.
seed : int, optional
A random seed (seeds python system random module).
minimize : {'flops', 'size'}, optional
Whether to minimize with respect to contraction flops or size.
inplace : bool, optional
Whether to perform the reconfiguration inplace or not.
progbar : bool, optional
Whether to show live progress of the reconfiguration.
Returns
-------
ContractionTree
"""
tree = self if inplace else self.copy()
# ensure these have been computed and thus are being tracked
tree.contract_stats()
if minimize is None:
minimize = self.get_default_objective()
scorer = get_score_fn(minimize)
if optimize is None:
from .pathfinders.path_basic import OptimalOptimizer
opt = OptimalOptimizer(
minimize=scorer.get_dynamic_programming_minimize()
)
else:
opt = optimize
node_cost = getattr(scorer, "cost_local_tree_node", lambda _: 2)
# different caches as we might want to reconfigure one before other
tree.already_optimized.setdefault(minimize, set())
already_optimized = tree.already_optimized[minimize]
if select == "random":
rng = get_rng(seed)
else:
if select == "max":
i = 0
elif select == "min":
i = -1
rng = None
candidates, weights = tree.calc_subtree_candidates(
pwr=weight_pwr, what=weight_what
)
if progbar:
import tqdm
pbar = tqdm.tqdm()
pbar.set_description(_describe_tree(tree), refresh=False)
r = 0
try:
while candidates and r < maxiter:
if rng is not None:
(i,) = rng.choices(range(len(candidates)), weights=weights)
weights.pop(i)
sub_root = candidates.pop(i)
# get a subtree to possibly reconfigure
sub_leaves, sub_branches = tree.get_subtree(
sub_root, size=subtree_size, search=subtree_search
)
sub_leaves = frozenset(sub_leaves)
# check if its already been optimized
if sub_leaves in already_optimized:
continue
# else remove the branches, keeping track of current cost
current_cost = node_cost(tree, sub_root)
for node in sub_branches:
if minimize == "size":
current_cost = max(current_cost, node_cost(tree, node))
else:
current_cost += node_cost(tree, node)
tree._remove_node(node)
# make the optimizer more efficient by supplying accurate cap
opt.cost_cap = max(2, current_cost)
# and reoptimize the leaves
tree.contract_nodes(sub_leaves, optimize=opt)
already_optimized.add(sub_leaves)
r += 1
if progbar:
pbar.update()
pbar.set_description(_describe_tree(tree), refresh=False)
# if we have reconfigured simply re-add all candidates
candidates, weights = tree.calc_subtree_candidates(
pwr=weight_pwr, what=weight_what
)
finally:
if progbar:
pbar.close()
# invalidate any compiled contractions
tree.contraction_cores.clear()
return tree
subtree_reconfigure_ = functools.partialmethod(
subtree_reconfigure, inplace=True
)
def subtree_reconfigure_forest(
self,
num_trees=8,
num_restarts=10,
restart_fraction=0.5,
subtree_maxiter=100,
subtree_size=10,
subtree_search=("random", "bfs"),
subtree_select=("random",),
subtree_weight_what=("flops", "size"),
subtree_weight_pwr=(2,),
parallel="auto",
parallel_maxiter_steps=4,
minimize=None,
seed=None,
progbar=False,
inplace=False,
):
"""'Forested' version of ``subtree_reconfigure`` which is more
explorative and can be parallelized. It stochastically generates
a 'forest' reconfigured trees, then only keeps some fraction of these
to generate the next forest.
Parameters
----------
num_trees : int, optional
The number of trees to reconfigure at each stage.
num_restarts : int, optional
The number of times to halt, prune and then restart the
tree reconfigurations.
restart_fraction : float, optional
The fraction of trees to keep at each stage and generate the next
forest from.
subtree_maxiter : int, optional
Number of subtree reconfigurations per step.
``num_restarts * subtree_maxiter`` is the max number of total
subtree reconfigurations for the final tree produced.
subtree_size : int, optional
The size of subtrees to search for and reconfigure.
subtree_search : tuple[{'random', 'bfs', 'dfs'}], optional
Tuple of options for the ``search`` kwarg of
:meth:`ContractionTree.subtree_reconfigure` to randomly sample.
subtree_select : tuple[{'random', 'max', 'min'}], optional
Tuple of options for the ``select`` kwarg of
:meth:`ContractionTree.subtree_reconfigure` to randomly sample.
subtree_weight_what : tuple[{'flops', 'size'}], optional
Tuple of options for the ``weight_what`` kwarg of
:meth:`ContractionTree.subtree_reconfigure` to randomly sample.
subtree_weight_pwr : tuple[int], optional
Tuple of options for the ``weight_pwr`` kwarg of
:meth:`ContractionTree.subtree_reconfigure` to randomly sample.
parallel : 'auto', False, True, int, or distributed.Client
Whether to parallelize the search.
parallel_maxiter_steps : int, optional
If parallelizing, how many steps to break each reconfiguration into
in order to evenly saturate many processes.
minimize : {'flops', 'size', ..., Objective}, optional
Whether to minimize the total flops or maximum size of the
contraction tree.
seed : None, int or random.Random, optional
A random seed to use.
progbar : bool, optional
Whether to show live progress.
inplace : bool, optional
Whether to perform the subtree reconfiguration inplace.
Returns
-------
ContractionTree
"""
tree = self if inplace else self.copy()
# some of these might be unpicklable
tree.contraction_cores.clear()
# candidate trees
num_keep = max(1, int(num_trees * restart_fraction))
# how to rank the trees
if minimize is None:
minimize = self.get_default_objective()
score = get_score_fn(minimize)
rng = get_rng(seed)
# set up the initial 'forest' and parallel machinery
pool = parse_parallel_arg(parallel)
is_scatter_pool = can_scatter(pool)
if is_scatter_pool:
is_worker = maybe_leave_pool(pool)
# store the trees as futures for the entire process
forest = [scatter(pool, tree)]
maxiter = subtree_maxiter // parallel_maxiter_steps
else:
forest = [tree]
maxiter = subtree_maxiter
if progbar:
import tqdm
pbar = tqdm.tqdm(total=num_restarts)
pbar.set_description(_describe_tree(tree), refresh=False)
try:
for _ in range(num_restarts):
# on the next round take only the best trees
forest = itertools.cycle(forest[:num_keep])
# select some random configurations
saplings = [
{
"tree": next(forest),
"maxiter": maxiter,
"minimize": minimize,
"subtree_size": subtree_size,
"subtree_search": rng.choice(subtree_search),
"select": rng.choice(subtree_select),
"weight_pwr": rng.choice(subtree_weight_pwr),
"weight_what": rng.choice(subtree_weight_what),
}
for _ in range(num_trees)
]
if pool is None:
forest = [_reconfigure_tree(**s) for s in saplings]
res = [{"tree": t, **_get_tree_info(t)} for t in forest]
elif not is_scatter_pool:
forest_futures = [
submit(pool, _reconfigure_tree, **s) for s in saplings
]
forest = [f.result() for f in forest_futures]
res = [{"tree": t, **_get_tree_info(t)} for t in forest]
else:
# submit in smaller steps to saturate processes
for _ in range(parallel_maxiter_steps):
for s in saplings:
s["tree"] = submit(pool, _reconfigure_tree, **s)
# compute scores remotely then gather
forest_futures = [s["tree"] for s in saplings]
res_futures = [
submit(pool, _get_tree_info, t) for t in forest_futures
]
res = [
{"tree": tree_future, **res_future.result()}
for tree_future, res_future in zip(
forest_futures, res_futures
)
]
# update the order of the new forest
res.sort(key=score)
forest = [r["tree"] for r in res]
if progbar:
pbar.update()
if pool is None:
d = _describe_tree(forest[0])
else:
d = submit(pool, _describe_tree, forest[0]).result()
pbar.set_description(d, refresh=False)
finally:
if progbar:
pbar.close()
if is_scatter_pool:
tree.set_state_from(forest[0].result())
maybe_rejoin_pool(is_worker, pool)
else:
tree.set_state_from(forest[0])
return tree
subtree_reconfigure_forest_ = functools.partialmethod(
subtree_reconfigure_forest, inplace=True
)
simulated_anneal = simulated_anneal_tree
simulated_anneal_ = functools.partialmethod(simulated_anneal, inplace=True)
parallel_temper = parallel_temper_tree
parallel_temper_ = functools.partialmethod(parallel_temper, inplace=True)
def slice(
self,
target_size=None,
target_overhead=None,
target_slices=None,
temperature=0.01,
minimize=None,
allow_outer=True,
max_repeats=16,
reslice=False,
seed=None,
inplace=False,
):
"""Slice this tree (turn some indices into indices which are explicitly
summed over rather than being part of contractions). The indices are
stored in ``tree.sliced_inds``, and the contraction width updated to
take account of the slicing. Calling ``tree.contract(arrays)`` moreover
which automatically perform the slicing and summation.
Parameters
----------
target_size : int, optional
The target number of entries in the largest tensor of the sliced
contraction. The search algorithm will terminate after this is
reached.
target_slices : int, optional
The target or minimum number of 'slices' to consider - individual
contractions after slicing indices. The search algorithm will
terminate after this is breached. This is on top of the current
number of slices.
target_overhead : float, optional
The target increase in total number of floating point operations.
For example, a value of ``2.0`` will terminate the search just
before the cost of computing all the slices individually breaches
twice that of computing the original contraction all at once.
temperature : float, optional
How much to randomize the repeated search.
minimize : {'flops', 'size', ..., Objective}, optional
Which metric to score the overhead increase against.
allow_outer : bool, optional
Whether to allow slicing of outer indices.
max_repeats : int, optional
How many times to repeat the search with a slight randomization.
reslice : bool, optional
Whether to reslice the tree, i.e. first remove all currently
sliced indices and start the search again. Generally any 'good'
sliced indices will be easily found again.
seed : None, int or random.Random, optional
A random seed or generator to use for the search.
inplace : bool, optional
Whether the remove the indices from this tree inplace or not.
Returns
-------
ContractionTree
See Also
--------
SliceFinder, ContractionTree.slice_and_reconfigure
"""
from .slicer import SliceFinder
if minimize is None:
minimize = self.get_default_objective()
tree = self if inplace else self.copy()
if reslice:
if target_slices is not None:
target_slices *= tree.nslices
tree.unslice_all_()
sf = SliceFinder(
tree,
target_size=target_size,
target_overhead=target_overhead,
target_slices=target_slices,
temperature=temperature,
minimize=minimize,
allow_outer=allow_outer,
seed=seed,
)
ix_sl, _ = sf.search(max_repeats)
for ix in ix_sl:
tree.remove_ind_(ix)
return tree
slice_ = functools.partialmethod(slice, inplace=True)
def slice_and_reconfigure(
self,
target_size,
step_size=2,
temperature=0.01,
minimize=None,
allow_outer=True,
max_repeats=16,
reslice=False,
reconf_opts=None,
progbar=False,
inplace=False,
):
"""Interleave slicing (removing indices into an exterior sum) with
subtree reconfiguration to minimize the overhead induced by this
slicing.
Parameters
----------
target_size : int
Slice the tree until the maximum intermediate size is this or
smaller.
step_size : int, optional
The minimum size reduction to try and achieve before switching to a
round of subtree reconfiguration.
temperature : float, optional
The temperature to supply to ``SliceFinder`` for searching for
indices.
minimize : {'flops', 'size', ..., Objective}, optional
The metric to minimize when slicing and reconfiguring subtrees.
max_repeats : int, optional
The number of slicing attempts to perform per search.
progbar : bool, optional
Whether to show live progress.
inplace : bool, optional
Whether to perform the slicing and reconfiguration inplace.
reconf_opts : None or dict, optional
Supplied to
:meth:`ContractionTree.subtree_reconfigure` or
:meth:`ContractionTree.subtree_reconfigure_forest`, depending on
`'forested'` key value.
"""
tree = self if inplace else self.copy()
reconf_opts = {} if reconf_opts is None else dict(reconf_opts)
if minimize is None:
minimize = self.get_default_objective()
minimize = get_score_fn(minimize)
reconf_opts.setdefault("minimize", minimize)
forested_reconf = reconf_opts.pop("forested", False)
if progbar:
import tqdm
pbar = tqdm.tqdm()
pbar.set_description(_describe_tree(tree), refresh=False)
try:
while tree.max_size() > target_size:
tree.slice_(
temperature=temperature,
target_slices=step_size,
minimize=minimize,
allow_outer=allow_outer,
max_repeats=max_repeats,
reslice=reslice,
)
if forested_reconf:
tree.subtree_reconfigure_forest_(**reconf_opts)
else:
tree.subtree_reconfigure_(**reconf_opts)
if progbar:
pbar.update()
pbar.set_description(_describe_tree(tree), refresh=False)
finally:
if progbar:
pbar.close()
return tree
slice_and_reconfigure_ = functools.partialmethod(
slice_and_reconfigure, inplace=True
)
def slice_and_reconfigure_forest(
self,
target_size,
step_size=2,
num_trees=8,
restart_fraction=0.5,
temperature=0.02,
max_repeats=32,
reslice=False,
minimize=None,
allow_outer=True,
parallel="auto",
progbar=False,
inplace=False,
reconf_opts=None,
):
"""'Forested' version of :meth:`ContractionTree.slice_and_reconfigure`.
This maintains a 'forest' of trees with different slicing and subtree
reconfiguration attempts, pruning the worst at each step and generating
a new forest from the best.
Parameters
----------
target_size : int
Slice the tree until the maximum intermediate size is this or
smaller.
step_size : int, optional
The minimum size reduction to try and achieve before switching to a
round of subtree reconfiguration.
num_restarts : int, optional
The number of times to halt, prune and then restart the
tree reconfigurations.
restart_fraction : float, optional
The fraction of trees to keep at each stage and generate the next
forest from.
temperature : float, optional
The temperature at which to randomize the sliced index search.
max_repeats : int, optional
The number of slicing attempts to perform per search.
parallel : 'auto', False, True, int, or distributed.Client
Whether to parallelize the search.
progbar : bool, optional
Whether to show live progress.
inplace : bool, optional
Whether to perform the slicing and reconfiguration inplace.
reconf_opts : None or dict, optional
Supplied to
:meth:`ContractionTree.slice_and_reconfigure`.
Returns
-------
ContractionTree
"""
tree = self if inplace else self.copy()
# some of these might be unpicklable
tree.contraction_cores.clear()
# candidate trees
num_keep = max(1, int(num_trees * restart_fraction))
# how to rank the trees
if minimize is None:
minimize = self.get_default_objective()
score = get_score_fn(minimize)
# set up the initial 'forest' and parallel machinery
pool = parse_parallel_arg(parallel)
is_scatter_pool = can_scatter(pool)
if is_scatter_pool:
is_worker = maybe_leave_pool(pool)
# store the trees as futures for the entire process
forest = [scatter(pool, tree)]
else:
forest = [tree]
if progbar:
import tqdm
pbar = tqdm.tqdm()
pbar.set_description(_describe_tree(tree), refresh=False)
next_size = tree.max_size()
try:
while True:
next_size //= step_size
# on the next round take only the best trees
forest = itertools.cycle(forest[:num_keep])
saplings = [
{
"tree": next(forest),
"target_size": next_size,
"step_size": step_size,
"temperature": temperature,
"max_repeats": max_repeats,
"reconf_opts": reconf_opts,
"allow_outer": allow_outer,
"reslice": reslice,
}
for _ in range(num_trees)
]
if pool is None:
forest = [
_slice_and_reconfigure_tree(**s) for s in saplings
]
res = [{"tree": t, **_get_tree_info(t)} for t in forest]
elif not is_scatter_pool:
# simple pool with no pass by reference
forest_futures = [
submit(pool, _slice_and_reconfigure_tree, **s)
for s in saplings
]
forest = [f.result() for f in forest_futures]
res = [{"tree": t, **_get_tree_info(t)} for t in forest]
else:
forest_futures = [
submit(pool, _slice_and_reconfigure_tree, **s)
for s in saplings
]
# compute scores remotely then gather
res_futures = [
submit(pool, _get_tree_info, t) for t in forest_futures
]
res = [
{"tree": tree_future, **res_future.result()}
for tree_future, res_future in zip(
forest_futures, res_futures
)
]
# we want to sort by flops, but also favour sampling as
# many different sliced index combos as possible
# ~ [1, 1, 1, 2, 2, 3] -> [1, 2, 3, 1, 2, 1]
res.sort(key=score)
res = list(
interleave(
groupby(lambda r: r["sliced_ind_set"], res).values()
)
)
# update the order of the new forest
forest = [r["tree"] for r in res]
if progbar:
pbar.update()
if pool is None:
d = _describe_tree(forest[0])
else:
d = submit(pool, _describe_tree, forest[0]).result()
pbar.set_description(d, refresh=False)
if res[0]["size"] <= target_size:
break
finally:
if progbar:
pbar.close()
if is_scatter_pool:
tree.set_state_from(forest[0].result())
maybe_rejoin_pool(is_worker, pool)
else:
tree.set_state_from(forest[0])
return tree
slice_and_reconfigure_forest_ = functools.partialmethod(
slice_and_reconfigure_forest, inplace=True
)
def compressed_reconfigure(
self,
minimize=None,
order_only=False,
max_nodes="auto",
max_time=None,
local_score=None,
exploration_power=0,
best_score=None,
progbar=False,
inplace=False,
):
"""Reconfigure this tree according to ``peak_size_compressed``.
Parameters
----------
chi : int
The maximum bond dimension to consider.
order_only : bool, optional
Whether to only consider the ordering of the current tree
contractions, or all possible contractions, starting with the
current.
max_nodes : int, optional
Set the maximum number of contraction steps to consider.
max_time : float, optional
Set the maximum time to spend on the search.
local_score : callable, optional
A function that assigns a score to a potential contraction, with a
lower score giving more priority to explore that contraction
earlier. It should have signature::
local_score(step, new_score, dsize, new_size)
where ``step`` is the number of steps so far, ``new_score`` is the
score of the contraction so far, ``dsize`` is the change in memory
by the current step, and ``new_size`` is the new memory size after
contraction.
exploration_power : float, optional
If not ``0.0``, the inverse power to which the step is raised in
the default local score function. Higher values favor exploring
more promising branches early on - at the cost of increased memory.
Ignored if ``local_score`` is supplied.
best_score : float, optional
Manually specify an upper bound for best score found so far.
progbar : bool, optional
If ``True``, display a progress bar.
inplace : bool, optional
Whether to perform the reconfiguration inplace on this tree.
Returns
-------
ContractionTree
"""
from .experimental.path_compressed_branchbound import (
CompressedExhaustive,
)
if minimize is None:
minimize = self.get_default_objective()
if max_nodes == "auto":
if max_time is None:
max_nodes = max(10_000, self.N**2)
else:
max_nodes = float("inf")
opt = CompressedExhaustive(
minimize=minimize,
local_score=local_score,
max_nodes=max_nodes,
max_time=max_time,
exploration_power=exploration_power,
best_score=best_score,
progbar=progbar,
)
opt.setup(self.inputs, self.output, self.size_dict)
opt.explore_path(self.get_path_surface(), restrict=order_only)
# rtree = opt.search(self.inputs, self.output, self.size_dict)
opt.run(self.inputs, self.output, self.size_dict)
ssa_path = opt.ssa_path
# ssa_path = opt(self.inputs, self.output, self.size_dict)
rtree = self.__class__.from_path(
self.inputs,
self.output,
self.size_dict,
ssa_path=ssa_path,
objective=minimize,
)
if inplace:
self.set_state_from(rtree)
rtree = self
rtree.contraction_cores.clear()
return rtree
compressed_reconfigure_ = functools.partialmethod(
compressed_reconfigure, inplace=True
)
def windowed_reconfigure(
self,
minimize=None,
order_only=False,
window_size=20,
max_iterations=100,
max_window_tries=1000,
score_temperature=0.0,
queue_temperature=1.0,
scorer=None,
queue_scorer=None,
seed=None,
inplace=False,
progbar=False,
**kwargs,
):
from .pathfinders.path_compressed import WindowedOptimizer
if minimize is None:
minimize = self.get_default_objective()
wo = WindowedOptimizer(
self.inputs,
self.output,
self.size_dict,
minimize=minimize,
ssa_path=self.get_ssa_path(),
seed=seed,
)
wo.refine(
window_size=window_size,
max_iterations=max_iterations,
order_only=order_only,
max_window_tries=max_window_tries,
score_temperature=score_temperature,
queue_temperature=queue_temperature,
scorer=scorer,
queue_scorer=queue_scorer,
progbar=progbar,
**kwargs,
)
ssa_path = wo.get_ssa_path()
rtree = self.__class__.from_path(
self.inputs,
self.output,
self.size_dict,
ssa_path=ssa_path,
objective=minimize,
)
if inplace:
self.set_state_from(rtree)
rtree = self
rtree.contraction_cores.clear()
return rtree
windowed_reconfigure_ = functools.partialmethod(
windowed_reconfigure, inplace=True
)
def flat_tree(self, order=None):
"""Create a nested tuple representation of the contraction tree like::
((0, (1, 2)), ((3, 4), ((5, (6, 7)), (8, 9))))
Such that the contraction will progress like::
((0, (1, 2)), ((3, 4), ((5, (6, 7)), (8, 9))))
((0, 12), (34, ((5, 67), 89)))
(012, (34, (567, 89)))
(012, (34, 56789))
(012, 3456789)
0123456789
Where each integer represents a leaf (i.e. single element node).
"""
tups = dict(zip(self.gen_leaves(), range(self.N)))
for parent, l, r in self.traverse(order=order):
tups[parent] = tups[l], tups[r]
return tups[self.root]
def get_leaves_ordered(self):
"""Return the list of leaves as ordered by the contraction tree.
Returns
-------
tuple[frozenset[str]]
"""
if not self.is_complete():
raise ValueError("Can't order the leaves until tree is complete.")
return tuple(
nd
for nd in itertools.chain.from_iterable(self.traverse())
if len(nd) == 1
)
def get_path(self, order=None):
"""Generate a standard path (with linear recycled ids) from the
contraction tree.
Parameters
----------
order : None, "dfs", or callable, optional
How to order the contractions within the tree. If a callable is
given (which should take a node as its argument), try to contract
nodes that minimize this function first.
Returns
-------
path: tuple[tuple[int, int]]
"""
from bisect import bisect_left
ssa = self.N
ssas = list(range(ssa))
node_to_ssa = dict(zip(self.gen_leaves(), ssas))
path = []
for parent, left, right in self.traverse(order=order):
# map nodes to ssas
lssa = node_to_ssa[left]
rssa = node_to_ssa[right]
# map ssas to linear indices, using bisection
i, j = sorted((bisect_left(ssas, lssa), bisect_left(ssas, rssa)))
# 'contract' nodes
ssas.pop(j)
ssas.pop(i)
path.append((i, j))
ssas.append(ssa)
# update mapping
node_to_ssa[parent] = ssa
ssa += 1
return tuple(path)
path = deprecated(get_path, "path", "get_path")
def get_numpy_path(self, order=None):
"""Generate a path compatible with the `optimize` kwarg of
`numpy.einsum`.
"""
return ["einsum_path", *self.get_path(order=order)]
def get_ssa_path(self, order=None):
"""Generate a single static assignment path from the contraction tree.
Parameters
----------
order : None, "dfs", or callable, optional
How to order the contractions within the tree. If a callable is
given (which should take a node as its argument), try to contract
nodes that minimize this function first.
Returns
-------
ssa_path: tuple[tuple[int, int]]
"""
ssa_path = []
pos = dict(zip(self.gen_leaves(), range(self.N)))
for parent, l, r in self.traverse(order=order):
i, j = sorted((pos[l], pos[r]))
ssa_path.append((i, j))
pos[parent] = len(ssa_path) + self.N - 1
return tuple(ssa_path)
ssa_path = deprecated(get_ssa_path, "ssa_path", "get_ssa_path")
def surface_order(self, node):
return (len(node), self.get_centrality(node))
def set_surface_order_from_path(self, ssa_path):
o = {}
nodes = list(self.gen_leaves())
for j, p in enumerate(ssa_path):
l, r = (nodes[i] for i in p)
p = l.union(r)
nodes.append(p)
o[p] = j
self.surface_order = functools.partial(
get_with_default, obj=o, default=float("inf")
)
def get_path_surface(self):
return self.get_path(order=self.surface_order)
path_surface = deprecated(
get_path_surface, "path_surface", "get_path_surface"
)
def get_ssa_path_surface(self):
return self.get_ssa_path(order=self.surface_order)
ssa_path_surface = deprecated(
get_ssa_path_surface, "ssa_path_surface", "get_ssa_path_surface"
)
def get_spans(self):
"""Get all (which could mean none) potential embeddings of this
contraction tree into a spanning tree of the original graph.
Returns
-------
tuple[dict[frozenset[int], frozenset[int]]]
"""
ind_to_term = collections.defaultdict(set)
for i, term in enumerate(self.inputs):
for ix in term:
ind_to_term[ix].add(i)
def boundary_pairs(node):
"""Get nodes along the boundary of the bipartition represented by
``node``.
"""
pairs = set()
involved = self.get_involved(node)
legs = self.get_legs(node)
removed = [ix for ix in involved if ix not in legs]
for ix in removed:
# for every index across the contraction
l1, l2 = ind_to_term[ix]
# can either span from left to right or right to left
pairs.add((l1, l2))
pairs.add((l2, l1))
return pairs
# first span choice is any nodes across the top level bipart
candidates = [
{
# which intermedate nodes map to which leaf nodes
"map": {self.root: node_from_single(l2)},
# the leaf nodes in the spanning tree
"spine": {l1, l2},
}
for l1, l2 in boundary_pairs(self.root)
]
for _, l, r in self.descend():
for child in (r, l):
# for each current candidate check all the possible extensions
for _ in range(len(candidates)):
cand = candidates.pop(0)
# don't need to do anything for
if len(child) == 1:
candidates.append(
{
"map": {child: child, **cand["map"]},
"spine": cand["spine"].copy(),
}
)
for l1, l2 in boundary_pairs(child):
if (l1 in cand["spine"]) or (l2 not in cand["spine"]):
# pair does not merge inwards into spine
continue
# valid extension of spanning tree
candidates.append(
{
"map": {
child: node_from_single(l2),
**cand["map"],
},
"spine": cand["spine"] | {l1, l2},
}
)
return tuple(c["map"] for c in candidates)
def compute_centralities(self, combine="mean"):
"""Compute a centrality for every node in this contraction tree."""
hg = self.get_hypergraph(accel="auto")
cents = hg.simple_centrality()
for i, leaf in enumerate(self.gen_leaves()):
self.info[leaf]["centrality"] = cents[i]
combine = {
"mean": lambda x, y: (x + y) / 2,
"sum": lambda x, y: (x + y),
"max": max,
"min": min,
}.get(combine, combine)
for p, l, r in self.traverse("dfs"):
self.info[p]["centrality"] = combine(
self.info[l]["centrality"], self.info[r]["centrality"]
)
def get_hypergraph(self, accel=False):
"""Get a hypergraph representing the uncontracted network (i.e. the
leaves).
"""
return get_hypergraph(self.inputs, self.output, self.size_dict, accel)
def reset_contraction_indices(self):
"""Reset all information regarding the explicit contraction indices
ordering.
"""
# delete all derived information
for node in self.children:
for k in (
"inds",
"einsum_eq",
"can_dot",
"tensordot_axes",
"tensordot_perm",
):
self.info[node].pop(k, None)
# invalidate any compiled contractions
self.contraction_cores.clear()
def sort_contraction_indices(
self,
priority="flops",
make_output_contig=True,
make_contracted_contig=True,
reset=True,
):
"""Set explicit orders for the contraction indices of this self to
optimize for one of two things: contiguity in contracted ('k') indices,
or contiguity of left and right output ('m' and 'n') indices.
Parameters
----------
priority : {'flops', 'size', 'root', 'leaves'}, optional
Which order to process the intermediate nodes in. Later nodes
re-sort previous nodes so are more likely to keep their ordering.
E.g. for 'flops' the mostly costly contracton will be process last
and thus will be guaranteed to have its indices exactly sorted.
make_output_contig : bool, optional
When processing a pairwise contraction, sort the parent contraction
indices so that the order of indices is the order they appear
from left to right in the two child (input) tensors.
make_contracted_contig : bool, optional
When processing a pairwise contraction, sort the child (input)
tensor indices so that all contracted indices appear contiguously.
reset : bool, optional
Reset all indices to the default order before sorting.
"""
if reset:
self.reset_contraction_indices()
if priority == "flops":
nodes = sorted(
self.children.items(), key=lambda x: self.get_flops(x[0])
)
elif priority == "size":
nodes = sorted(
self.children.items(), key=lambda x: self.get_size(x[0])
)
elif priority == "root":
nodes = ((p, (l, r)) for p, l, r in self.traverse())
elif priority == "leaves":
nodes = ((p, (l, r)) for p, l, r in self.descend())
else:
raise ValueError(priority)
for p, (l, r) in nodes:
p_inds, l_inds, r_inds = map(self.get_inds, (p, l, r))
if make_output_contig and len(p) != self.N:
# sort indices by whether they appear in the left or right
# whether this happens before or after the sort below depends
# on the order we are processing the nodes
# (avoid root as don't want to modify output)
def psort(ix):
# group by whether in left or right input
return (r_inds.find(ix), l_inds.find(ix))
p_inds = "".join(sorted(p_inds, key=psort))
self.info[p]["inds"] = p_inds
if make_contracted_contig:
# sort indices by:
# 1. if they are going to be contracted
# 2. what order they appear in the parent indices
# (but ignore leaf indices)
if len(l) != 1:
def lsort(ix):
return (r_inds.find(ix), p_inds.find(ix))
l_inds = "".join(sorted(self.get_legs(l), key=lsort))
self.info[l]["inds"] = l_inds
if len(r) != 1:
def rsort(ix):
return (p_inds.find(ix), l_inds.find(ix))
r_inds = "".join(sorted(self.get_legs(r), key=rsort))
self.info[r]["inds"] = r_inds
# invalidate any compiled contractions
self.contraction_cores.clear()
def print_contractions(self, sort=None, show_brackets=True):
"""Print each pairwise contraction, with colorized indices (if
`colorama` is installed), and other information. The color codes are:
- blue: index appears on left and is kept
- green: index appears on right and is kept
- red: contracted index: appears on both sides and is removed
- pink: batch index: appears on both sides and is kept
Any trivial indices that appear only on one term and not in the output
are removed and shown by the preprocessing steps.
Parameters
----------
sort : {'flops', 'size'}, optional
Sort the contractions by either the number of floating point
operations or the size of the intermediate tensor. By default the
contraction are show in the order they are performed.
show_brackets : bool, optional
Whether to show the brackets around contiguous sections of the same
type of indices.
"""
try:
from colorama import Fore
RESET = Fore.RESET
GREY = Fore.WHITE
PINK = Fore.MAGENTA
RED = Fore.RED
BLUE = Fore.BLUE
GREEN = Fore.GREEN
except ImportError:
RESET = GREY = PINK = RED = BLUE = GREEN = ""
entries = []
if self.has_preprocessing():
for pi, eq in self.preprocessing.items():
# eq is with canonical indices, reinsert original inputs
replacer = dict(zip(eq.split("->")[0], self.inputs[pi]))
eq = "".join(replacer.get(c, c) for c in eq)
print(f"{GREY}preprocess input {pi}: {RESET}{eq}")
print()
for i, (p, l, r) in enumerate(self.traverse()):
p_legs, l_legs, r_legs = map(self.get_legs, [p, l, r])
p_inds, l_inds, r_inds = map(self.get_inds, [p, l, r])
# print sizes and flops
p_flops = self.get_flops(p)
p_sz, l_sz, r_sz = (
math.log2(self.get_size(node)) for node in [p, l, r]
)
# print whether tensordottable
if self.get_can_dot(p):
type_msg = "tensordot"
perm = self.get_tensordot_perm(p)
if perm is not None:
# and whether indices match tensordot
type_msg += "+perm"
else:
type_msg = "einsum"
kpt_brck_l = "(" if show_brackets else ""
kpt_brck_r = ")" if show_brackets else ""
con_brck_l = "[" if show_brackets else ""
con_brck_r = "]" if show_brackets else ""
hyp_brck_l = "{" if show_brackets else ""
hyp_brck_r = "}" if show_brackets else ""
pa = (
"".join(
PINK + f"{hyp_brck_l}{ix}{hyp_brck_r}"
if (ix in l_legs) and (ix in r_legs)
else GREEN + f"{kpt_brck_l}{ix}{kpt_brck_r}"
if ix in r_legs
else BLUE + ix
for ix in p_inds
)
.replace(f"){GREEN}(", "")
.replace(f"}}{PINK}{{", "")
)
la = (
"".join(
PINK + f"{hyp_brck_l}{ix}{hyp_brck_r}"
if (ix in p_legs) and (ix in r_legs)
else RED + f"{con_brck_l}{ix}{con_brck_r}"
if ix in r_legs
else BLUE + ix
for ix in l_inds
)
.replace(f"]{RED}[", "")
.replace(f"}}{PINK}{{", "")
)
ra = (
"".join(
PINK + f"{hyp_brck_l}{ix}{hyp_brck_r}"
if (ix in p_legs) and (ix in l_legs)
else RED + f"{con_brck_l}{ix}{con_brck_r}"
if ix in l_legs
else GREEN + ix
for ix in r_inds
)
.replace(f"]{RED}[", "")
.replace(f"}}{PINK}{{", "")
)
entries.append(
(
p,
f"{GREY}({i}) cost: {RESET}{p_flops:.1e} "
f"{GREY}widths: {RESET}{l_sz:.1f},{r_sz:.1f}->{p_sz:.1f} "
f"{GREY}type: {RESET}{type_msg}\n"
f"{GREY}inputs: {la},{ra}{RESET}->\n"
f"{GREY}output: {pa}\n",
)
)
if sort == "flops":
entries.sort(key=lambda x: self.get_flops(x[0]), reverse=True)
if sort == "size":
entries.sort(key=lambda x: self.get_size(x[0]), reverse=True)
entries.append((None, f"{RESET}"))
o = "\n".join(entry for _, entry in entries)
print(o)
# --------------------- Performing the Contraction ---------------------- #
def get_contractor(
self,
order=None,
prefer_einsum=False,
strip_exponent=False,
check_zero=False,
implementation=None,
autojit=False,
progbar=False,
):
"""Get a reusable function which performs the contraction corresponding
to this tree, cached.
Parameters
----------
tree : ContractionTree
The contraction tree.
order : str or callable, optional
Supplied to :meth:`ContractionTree.traverse`, the order in which
to perform the pairwise contractions given by the tree.
prefer_einsum : bool, optional
Prefer to use ``einsum`` for pairwise contractions, even if
``tensordot`` can perform the contraction.
strip_exponent : bool, optional
If ``True``, the function will eagerly strip the exponent (in
log10) from intermediate tensors to control numerical problems from
leaving the range of the datatype. This method then returns the
scaled 'mantissa' output array and the exponent separately.
check_zero : bool, optional
If ``True``, when ``strip_exponent=True``, explicitly check for
zero-valued intermediates that would otherwise produce ``nan``,
instead terminating early if encountered and returning
``(0.0, 0.0)``.
implementation : str or tuple[callable, callable], optional
What library to use to actually perform the contractions. Options
are:
- None: let cotengra choose.
- "autoray": dispatch with autoray, using the ``tensordot`` and
``einsum`` implementation of the backend.
- "cotengra": use the ``tensordot`` and ``einsum`` implementation
of cotengra, which is based on batch matrix multiplication. This
is faster for some backends like numpy, and also enables
libraries which don't yet provide ``tensordot`` and ``einsum`` to
be used.
- "cuquantum": use the cuquantum library to perform the whole
contraction (not just individual contractions).
- tuple[callable, callable]: manually supply the ``tensordot`` and
``einsum`` implementations to use.
autojit : bool, optional
If ``True``, use :func:`autoray.autojit` to compile the contraction
function.
progbar : bool, optional
Whether to show progress through the contraction by default.
Returns
-------
fn : callable
The contraction function, with signature ``fn(*arrays)``.
"""
key = (
autojit,
order,
prefer_einsum,
strip_exponent,
check_zero,
implementation,
progbar,
)
try:
fn = self.contraction_cores[key]
except KeyError:
fn = self.contraction_cores[key] = make_contractor(
tree=self,
order=order,
prefer_einsum=prefer_einsum,
strip_exponent=strip_exponent,
check_zero=check_zero,
implementation=implementation,
autojit=autojit,
progbar=progbar,
)
return fn
def contract_core(
self,
arrays,
order=None,
prefer_einsum=False,
strip_exponent=False,
check_zero=False,
backend=None,
implementation=None,
autojit="auto",
progbar=False,
):
"""Contract ``arrays`` with this tree. The order of the axes and
output is assumed to be that of ``tree.inputs`` and ``tree.output``,
but with sliced indices removed. This functon contracts the core tree
and thus if indices have been sliced the arrays supplied need to be
sliced as well.
Parameters
----------
arrays : sequence of array
The arrays to contract.
order : str or callable, optional
Supplied to :meth:`ContractionTree.traverse`.
prefer_einsum : bool, optional
Prefer to use ``einsum`` for pairwise contractions, even if
``tensordot`` can perform the contraction.
backend : str, optional
What library to use for ``einsum`` and ``transpose``, will be
automatically inferred from the arrays if not given.
autojit : "auto" or bool, optional
Whether to use ``autoray.autojit`` to jit compile the expression.
If "auto", then let ``cotengra`` choose.
progbar : bool, optional
Show progress through the contraction.
"""
if autojit == "auto":
# choose for the user
autojit = backend == "jax"
fn = self.get_contractor(
order=order,
prefer_einsum=prefer_einsum,
strip_exponent=strip_exponent is not False,
implementation=implementation,
autojit=autojit,
check_zero=check_zero,
progbar=progbar,
)
return fn(*arrays, backend=backend)
def slice_key(self, i, strides=None):
"""Get the combination of sliced index values for overall slice ``i``.
Parameters
----------
i : int
The overall slice index.
Returns
-------
key : dict[str, int]
The value each sliced index takes for slice ``i``.
"""
if strides is None:
strides = get_slice_strides(self.sliced_inds)
key = {}
for (ind, info), stride in zip(self.sliced_inds.items(), strides):
if info.project is None:
key[ind] = i // stride
i %= stride
else:
# size is 1 and i doesn't change
key[ind] = info.project
return key
def slice_arrays(self, arrays, i):
"""Take ``arrays`` and slice the relevant inputs according to
``tree.sliced_inds`` and the dynary representation of ``i``.
"""
temp_arrays = list(arrays)
# e.g. {'a': 2, 'd': 7, 'z': 0}
locations = self.slice_key(i)
for c in self.sliced_inputs:
# the indexing object, e.g. [:, :, 7, :, 2, :, :, 0]
selector = tuple(
locations.get(ix, slice(None)) for ix in self.inputs[c]
)
# re-insert the sliced array
temp_arrays[c] = temp_arrays[c][selector]
return temp_arrays
def contract_slice(self, arrays, i, **kwargs):
"""Get slices ``i`` of ``arrays`` and then contract them."""
return self.contract_core(self.slice_arrays(arrays, i), **kwargs)
def gather_slices(self, slices, backend=None, progbar=False):
"""Gather all the output contracted slices into a single full result.
If none of the sliced indices appear in the output, then this is a
simple sum - otherwise the slices need to be partially summed and
partially stacked.
"""
if progbar:
import tqdm
slices = tqdm.tqdm(slices, total=self.multiplicity)
output_pos = {
ix: i for i, ix in enumerate(self.output) if ix in self.sliced_inds
}
if not output_pos:
# we can just sum everything
return functools.reduce(add_maybe_exponent_stripped, slices)
# first we sum over non-output sliced indices
chunks = {}
for i, s in enumerate(slices):
key_slice = self.slice_key(i)
key = tuple(key_slice[ix] for ix in output_pos)
try:
chunks[key] = add_maybe_exponent_stripped(chunks[key], s)
except KeyError:
chunks[key] = s
if isinstance(next(iter(chunks.values())), tuple):
# have stripped exponents, need to scale to largest
emax = max(v[1] for v in chunks.values())
chunks = {
k: mi * 10 ** (ei - emax) for k, (mi, ei) in chunks.items()
}
else:
emax = None
# then we stack these summed chunks over output sliced indices
def recursively_stack_chunks(loc, remaining):
if not remaining:
return chunks[loc]
arrays = [
recursively_stack_chunks(loc + (d,), remaining[1:])
for d in self.sliced_inds[remaining[0]].sliced_range
]
axes = output_pos[remaining[0]] - len(loc)
return do("stack", arrays, axes, like=backend)
result = recursively_stack_chunks((), tuple(output_pos))
if emax is not None:
# strip_exponent was True, return the exponent separately
return result, emax
return result
def gen_output_chunks(
self, arrays, with_key=False, progbar=False, **contract_opts
):
"""Generate each output chunk of the contraction - i.e. take care of
summing internally sliced indices only first. This assumes that the
``sliced_inds`` are sorted by whether they appear in the output or not
(the default order). Useful for performing some kind of reduction over
the final tensor object like ``fn(x).sum()`` without constructing the
entire thing.
Parameters
----------
arrays : sequence of array
The arrays to contract.
with_key : bool, optional
Whether to yield the output index configuration key along with the
chunk.
progbar : bool, optional
Show progress through the contraction chunks.
Yields
------
chunk : array
A chunk of the contracted result.
key : dict[str, int]
The value each sliced output index takes for this chunk.
"""
# consecutive slices of size ``stepsize`` all belong to the same output
# block because the sliced indices are sorted output first
stepsize = prod(
si.size for si in self.sliced_inds.values() if si.inner
)
if progbar:
import tqdm
it = tqdm.trange(self.nslices // stepsize)
else:
it = range(self.nslices // stepsize)
for o in it:
chunk = self.contract_slice(arrays, o * stepsize, **contract_opts)
if with_key:
output_key = {
ix: x
for ix, x in self.slice_key(o * stepsize).items()
if ix in self.output
}
for j in range(1, stepsize):
i = o * stepsize + j
chunk = chunk + self.contract_slice(arrays, i, **contract_opts)
if with_key:
yield chunk, output_key
else:
yield chunk
def contract(
self,
arrays,
order=None,
prefer_einsum=False,
strip_exponent=False,
check_zero=False,
backend=None,
implementation=None,
autojit="auto",
progbar=False,
):
"""Contract ``arrays`` with this tree. This function takes *unsliced*
arrays and handles the slicing, contractions and gathering. The order
of the axes and output is assumed to match that of ``tree.inputs`` and
``tree.output``.
Parameters
----------
arrays : sequence of array
The arrays to contract.
order : str or callable, optional
Supplied to :meth:`ContractionTree.traverse`.
prefer_einsum : bool, optional
Prefer to use ``einsum`` for pairwise contractions, even if
``tensordot`` can perform the contraction.
strip_exponent : bool, optional
If ``True``, eagerly strip the exponent (in log10) from
intermediate tensors to control numerical problems from leaving the
range of the datatype. This method then returns the scaled
'mantissa' output array and the exponent separately.
check_zero : bool, optional
If ``True``, when ``strip_exponent=True``, explicitly check for
zero-valued intermediates that would otherwise produce ``nan``,
instead terminating early if encountered and returning
``(0.0, 0.0)``.
backend : str, optional
What library to use for ``tensordot``, ``einsum`` and
``transpose``, it will be automatically inferred from the input
arrays if not given.
autojit : bool, optional
Whether to use the 'autojit' feature of `autoray` to compile the
contraction expression.
progbar : bool, optional
Whether to show a progress bar.
Returns
-------
output : array
The contracted output, it will be scaled if
``strip_exponent==True``.
exponent : float
The exponent of the output in base 10, returned only if
``strip_exponent==True``.
See Also
--------
contract_core, contract_slice, slice_arrays, gather_slices
"""
if not self.sliced_inds:
return self.contract_core(
arrays,
order=order,
prefer_einsum=prefer_einsum,
strip_exponent=strip_exponent,
check_zero=check_zero,
backend=backend,
implementation=implementation,
autojit=autojit,
progbar=progbar,
)
slices = (
self.contract_slice(
arrays,
i,
order=order,
prefer_einsum=prefer_einsum,
strip_exponent=strip_exponent,
check_zero=check_zero,
backend=backend,
implementation=implementation,
autojit=autojit,
)
for i in range(self.multiplicity)
)
return self.gather_slices(slices, backend=backend, progbar=progbar)
def contract_mpi(self, arrays, comm=None, root=None, **kwargs):
"""Contract the slices of this tree and sum them in parallel -
*assuming* we are already running under MPI.
Parameters
----------
arrays : sequence of array
The input (unsliced arrays)
comm : None or mpi4py communicator
Defaults to ``mpi4py.MPI.COMM_WORLD`` if not given.
root : None or int, optional
If ``root=None``, an ``Allreduce`` will be performed such that
every process has the resulting tensor, else if an integer e.g.
``root=0``, the result will be exclusively gathered to that
process using ``Reduce``, with every other process returning
``None``.
kwargs
Supplied to :meth:`~cotengra.ContractionTree.contract_slice`.
"""
if not set(self.sliced_inds).isdisjoint(set(self.output)):
raise NotImplementedError(
"Sliced and output indices overlap - currently only a simple "
"sum of result slices is supported currently."
)
if comm is None:
from mpi4py import MPI
comm = MPI.COMM_WORLD
if self.multiplicity < comm.size:
raise ValueError(
f"Need to have more slices than MPI processes, but have "
f"{self.multiplicity} and {comm.size} respectively."
)
# round robin compute each slice, eagerly summing
result_i = None
for i in range(comm.rank, self.multiplicity, comm.size):
# note: fortran ordering is needed for the MPI reduce
x = do("asfortranarray", self.contract_slice(arrays, i, **kwargs))
if result_i is None:
result_i = x
else:
result_i += x
if root is None:
# everyone gets the summed result
result = do("empty_like", result_i)
comm.Allreduce(result_i, result)
return result
# else we only sum reduce the result to process ``root``
if comm.rank == root:
result = do("empty_like", result_i)
else:
result = None
comm.Reduce(result_i, result, root=root)
return result
def benchmark(
self,
dtype,
max_time=60,
min_reps=3,
max_reps=100,
warmup=True,
**contract_opts,
):
"""Benchmark the contraction of this tree.
Parameters
----------
dtype : {"float32", "float64", "complex64", "complex128"}
The datatype to use.
max_time : float, optional
The maximum time to spend benchmarking in seconds.
min_reps : int, optional
The minimum number of repetitions to perform, regardless of time.
max_reps : int, optional
The maximum number of repetitions to perform, regardless of time.
warmup : bool or int, optional
Whether to perform a warmup run before the benchmark. If an int,
the number of warmup runs to perform.
contract_opts
Supplied to :meth:`~cotengra.ContractionTree.contract_slice`.
Returns
-------
dict
A dictionary of benchmarking results. The keys are:
- "time_per_slice" : float
The average time to contract a single slice.
- "est_time_total" : float
The estimated total time to contract all slices.
- "est_gigaflops" : float
The estimated gigaflops of the contraction.
See Also
--------
contract_slice
"""
import time
from .utils import make_arrays_from_inputs
arrays = make_arrays_from_inputs(
self.inputs, self.size_dict, dtype=dtype
)
for i in range(int(warmup)):
self.contract_slice(arrays, i % self.nslices, **contract_opts)
t0 = time.time()
ti = t0
i = 0
while (ti - t0 < max_time) or (i < min_reps):
self.contract_slice(arrays, i % self.nslices, **contract_opts)
ti = time.time()
i += 1
if i >= max_reps:
break
time_per_slice = (ti - t0) / i
est_time_total = time_per_slice * self.nslices
est_gigaflops = self.total_flops(dtype=dtype) / (1e9 * est_time_total)
return {
"time_per_slice": time_per_slice,
"est_time_total": est_time_total,
"est_gigaflops": est_gigaflops,
}
plot_ring = plot_tree_ring
plot_tent = plot_tree_tent
plot_span = plot_tree_span
plot_flat = plot_tree_flat
plot_circuit = plot_tree_circuit
plot_rubberband = plot_tree_rubberband
plot_contractions = plot_contractions
plot_contractions_alt = plot_contractions_alt
@functools.wraps(plot_hypergraph)
def plot_hypergraph(self, **kwargs):
hg = self.get_hypergraph(accel=False)
hg.plot(**kwargs)
def describe(self, info="normal", join=" "):
"""Return a string describing the contraction tree."""
self.contract_stats()
if info == "normal":
return join.join(
(
f"log10[FLOPs]={self.total_flops(log=10):.2f}",
f"log2[SIZE]={self.max_size(log=2):.2f}",
)
)
elif info == "full":
s = [
f"log10[FLOPS]={self.total_flops(log=10):.2f}",
f"log10[COMBO]={self.combo_cost(log=10):.2f}",
f"log2[SIZE]={self.max_size(log=2):.2f}",
f"log2[PEAK]={self.peak_size(log=2):.2f}",
]
if self.sliced_inds:
s.append(f"NSLICES={self.multiplicity:.2f}")
return join.join(s)
elif info == "concise":
s = [
f"F={self.total_flops(log=10):.2f}",
f"C={self.combo_cost(log=10):.2f}",
f"S={self.max_size(log=2):.2f}",
f"P={self.peak_size(log=2):.2f}",
]
if self.sliced_inds:
s.append(f"$={self.multiplicity:.2f}")
return join.join(s)
def __repr__(self):
if self.is_complete():
return f"<{self.__class__.__name__}(N={self.N})>"
else:
s = "<{}(N={}, branches={}, complete={})>"
return s.format(
self.__class__.__name__,
self.N,
len(self.children),
self.is_complete(),
)
def __str__(self):
if not self.is_complete():
return self.__repr__()
else:
d = self.describe("concise", join=", ")
return f"<{self.__class__.__name__}(N={self.N}, {d})>"
def _reconfigure_tree(tree, *args, **kwargs):
return tree.subtree_reconfigure(*args, **kwargs)
def _slice_and_reconfigure_tree(tree, *args, **kwargs):
return tree.slice_and_reconfigure(*args, **kwargs)
def _get_tree_info(tree):
stats = tree.contract_stats()
stats["sliced_ind_set"] = frozenset(tree.sliced_inds)
return stats
def _describe_tree(tree, info="normal"):
return tree.describe(info=info)
class ContractionTreeCompressed(ContractionTree):
"""A contraction tree for compressed contractions. Currently the only
difference is that this defaults to the 'surface' traversal ordering.
"""
def set_state_from(self, other):
super().set_state_from(other)
self.set_surface_order_from_path(other.get_ssa_path())
@classmethod
def from_path(
cls,
inputs,
output,
size_dict,
*,
path=None,
ssa_path=None,
autocomplete="auto",
check=False,
**kwargs,
):
"""Create a (completed) ``ContractionTreeCompressed`` from the usual
inputs plus a standard contraction path or 'ssa_path' - you need to
supply one. This also set the default 'surface' traversal ordering to
be the initial path.
"""
if int(path is None) + int(ssa_path is None) != 1:
raise ValueError(
"Exactly one of ``path`` or ``ssa_path`` must be supplied."
)
if path is not None:
from .pathfinders.path_basic import linear_to_ssa
ssa_path = linear_to_ssa(path)
tree = cls(inputs, output, size_dict, **kwargs)
terms = list(tree.gen_leaves())
for p in ssa_path:
merge = [terms[i] for i in p]
terms.append(tree.contract_nodes(merge, check=check))
tree.set_surface_order_from_path(ssa_path)
if (len(tree.children) < tree.N - 1) and autocomplete:
if autocomplete == "auto":
# warn that we are completing
warnings.warn(
"Path was not complete - contracting all remaining. "
"You can silence this warning with `autocomplete=True`."
"Or produce an incomplete tree with `autocomplete=False`."
)
tree.autocomplete(optimize="greedy-compressed")
return tree
def get_default_order(self):
return "surface_order"
def get_default_objective(self):
if self._default_objective is None:
self._default_objective = get_score_fn("peak-compressed")
return self._default_objective
def get_default_chi(self):
objective = self.get_default_objective()
try:
chi = objective.chi
except AttributeError:
chi = "auto"
if chi == "auto":
chi = max(self.size_dict.values()) ** 2
return chi
def get_default_compress_late(self):
objective = self.get_default_objective()
try:
return objective.compress_late
except AttributeError:
return False
total_flops = ContractionTree.total_flops_compressed
total_write = ContractionTree.total_write_compressed
combo_cost = ContractionTree.combo_cost_compressed
total_cost = ContractionTree.total_cost_compressed
max_size = ContractionTree.max_size_compressed
peak_size = ContractionTree.peak_size_compressed
contraction_cost = ContractionTree.contraction_cost_compressed
contraction_width = ContractionTree.contraction_width_compressed
total_flops_exact = ContractionTree.total_flops
total_write_exact = ContractionTree.total_write
combo_cost_exact = ContractionTree.combo_cost
total_cost_exact = ContractionTree.total_cost
max_size_exact = ContractionTree.max_size
peak_size_exact = ContractionTree.peak_size
def get_contractor(self, *_, **__):
raise NotImplementedError(
"`cotengra` doesn't implement compressed contraction itself. "
"If you want to use compressed contractions, you need to use "
"`quimb` and the `TensorNetwork.contract_compressed` method, "
"with e.g. `optimize=tree.get_path()`."
)
def simulated_anneal(
self,
minimize=None,
tfinal=0.0001,
tstart=0.01,
tsteps=50,
numiter=50,
seed=None,
inplace=False,
progbar=False,
**kwargs,
):
"""Perform simulated annealing refinement of this *compressed*
contraction tree.
"""
from .pathfinders.path_compressed import WindowedOptimizer
if minimize is None:
minimize = self.get_default_objective()
wo = WindowedOptimizer(
self.inputs,
self.output,
self.size_dict,
minimize=minimize,
ssa_path=self.get_ssa_path(),
seed=seed,
)
wo.simulated_anneal(
tfinal=tfinal,
tstart=tstart,
tsteps=tsteps,
numiter=numiter,
progbar=progbar,
**kwargs,
)
ssa_path = wo.get_ssa_path()
rtree = self.__class__.from_path(
self.inputs,
self.output,
self.size_dict,
ssa_path=ssa_path,
objective=minimize,
)
if inplace:
self.set_state_from(rtree)
rtree = self
rtree.contraction_cores.clear()
return rtree
simulated_anneal_ = functools.partialmethod(simulated_anneal, inplace=True)
class PartitionTreeBuilder:
"""Function wrapper that takes a function that partitions graphs and
uses it to build a contraction tree. ``partition_fn`` should have
signature:
def partition_fn(inputs, output, size_dict,
weight_nodes, weight_edges, **kwargs):
...
return membership
Where ``weight_nodes`` and ``weight_edges`` decsribe how to weight the
nodes and edges of the graph respectively and ``membership`` should be a
list of integers of length ``len(inputs)`` labelling which partition
each input node should be put it.
"""
def __init__(self, partition_fn):
self.partition_fn = partition_fn
def build_divide(
self,
inputs,
output,
size_dict,
random_strength=0.01,
cutoff=10,
parts=2,
parts_decay=0.5,
sub_optimize="greedy",
super_optimize="auto-hq",
check=False,
seed=None,
**partition_opts,
):
tree = ContractionTree(inputs, output, size_dict, track_childless=True)
rng = get_rng(seed)
rand_size_dict = jitter_dict(size_dict, random_strength, rng)
dynamic_imbalance = ("imbalance" in partition_opts) and (
"imbalance_decay" in partition_opts
)
if dynamic_imbalance:
imbalance = partition_opts.pop("imbalance")
imbalance_decay = partition_opts.pop("imbalance_decay")
else:
imbalance = imbalance_decay = None
dynamic_fix = partition_opts.get("fix_output_nodes", None) == "auto"
while tree.childless:
tree_node = next(iter(tree.childless))
subgraph = tuple(tree_node)
subsize = len(subgraph)
# skip straight to better method
if subsize <= cutoff:
tree.contract_nodes(
[node_from_single(x) for x in subgraph],
optimize=sub_optimize,
check=check,
)
continue
# relative subgraph size
s = subsize / tree.N
# let the target number of communities depend on subgraph size
parts_s = max(int(s**parts_decay * parts), 2)
# let the imbalance either rise or fall
if dynamic_imbalance:
if imbalance_decay >= 0:
imbalance_s = s**imbalance_decay * imbalance
else:
imbalance_s = 1 - s**-imbalance_decay * (1 - imbalance)
partition_opts["imbalance"] = imbalance_s
if dynamic_fix:
# for the top level subtree (s==1.0) we partition the outputs
# nodes first into their own bi-partition
parts_s = 2
partition_opts["fix_output_nodes"] = s == 1.0
# partition! get community membership list e.g.
# [0, 0, 1, 0, 1, 0, 0, 2, 2, ...]
inputs = tuple(map(tuple, tree.node_to_terms(subgraph)))
output = tuple(tree.get_legs(tree_node))
membership = self.partition_fn(
inputs,
output,
rand_size_dict,
parts=parts_s,
seed=rng,
**partition_opts,
)
# divide subgraph up e.g. if we enumerate the subgraph index sets
# (0, 1, 2, 3, 4, 5, 6, 7, 8, ...) ->
# ({0, 1, 3, 5, 6}, {2, 4}, {7, 8})
new_subgs = tuple(
map(node_from_seq, separate(subgraph, membership))
)
if len(new_subgs) == 1:
# no communities found - contract all remaining
tree.contract_nodes(
tuple(map(node_from_single, subgraph)),
optimize=sub_optimize,
check=check,
)
continue
# update tree structure with newly contracted subgraphs
tree.contract_nodes(
new_subgs, optimize=super_optimize, check=check
)
if check:
assert tree.is_complete()
return tree
def build_agglom(
self,
inputs,
output,
size_dict,
random_strength=0.01,
groupsize=4,
check=False,
sub_optimize="greedy",
seed=None,
**partition_opts,
):
tree = ContractionTree(inputs, output, size_dict, track_childless=True)
rand_size_dict = jitter_dict(size_dict, random_strength, seed)
leaves = tuple(tree.gen_leaves())
for node in leaves:
tree._add_node(node, check=check)
output = tuple(tree.output)
while len(leaves) > groupsize:
parts = max(2, len(leaves) // groupsize)
inputs = [tuple(tree.get_legs(node)) for node in leaves]
membership = self.partition_fn(
inputs,
output,
rand_size_dict,
parts=parts,
**partition_opts,
)
leaves = [
tree.contract_nodes(group, check=check, optimize=sub_optimize)
for group in separate(leaves, membership)
]
if len(leaves) > 1:
tree.contract_nodes(leaves, check=check, optimize=sub_optimize)
if check:
assert tree.is_complete()
return tree
def trial_fn(self, inputs, output, size_dict, **partition_opts):
return self.build_divide(inputs, output, size_dict, **partition_opts)
def trial_fn_agglom(self, inputs, output, size_dict, **partition_opts):
return self.build_agglom(inputs, output, size_dict, **partition_opts)
def jitter(x, strength, rng):
return x * (1 + strength * rng.expovariate(1.0))
def jitter_dict(d, strength, seed=None):
rng = get_rng(seed)
return {k: jitter(v, strength, rng) for k, v in d.items()}
def separate(xs, blocks):
"""Partition ``xs`` into ``n`` different list based on the corresponding
labels in ``blocks``.
"""
sorter = collections.defaultdict(list)
for x, b in zip(xs, blocks):
sorter[b].append(x)
x_b = list(sorter.items())
x_b.sort()
return [x[1] for x in x_b]