Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
4131 lines
137 KiB
Python
4131 lines
137 KiB
Python
"""Core contraction tree data structure and methods."""
|
|
|
|
import collections
|
|
import functools
|
|
import itertools
|
|
import math
|
|
import warnings
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
|
|
from autoray import do
|
|
|
|
from .contract import make_contractor
|
|
from .hypergraph import get_hypergraph
|
|
from .parallel import (
|
|
can_scatter,
|
|
maybe_leave_pool,
|
|
maybe_rejoin_pool,
|
|
parse_parallel_arg,
|
|
scatter,
|
|
submit,
|
|
)
|
|
from .pathfinders.path_simulated_annealing import (
|
|
parallel_temper_tree,
|
|
simulated_anneal_tree,
|
|
)
|
|
from .plot import (
|
|
plot_contractions,
|
|
plot_contractions_alt,
|
|
plot_hypergraph,
|
|
plot_tree_circuit,
|
|
plot_tree_flat,
|
|
plot_tree_ring,
|
|
plot_tree_rubberband,
|
|
plot_tree_span,
|
|
plot_tree_tent,
|
|
)
|
|
from .scoring import (
|
|
DEFAULT_COMBO_FACTOR,
|
|
CompressedStatsTracker,
|
|
get_score_fn,
|
|
)
|
|
from .utils import (
|
|
MaxCounter,
|
|
compute_size_by_dict,
|
|
deprecated,
|
|
get_rng,
|
|
get_symbol,
|
|
groupby,
|
|
inputs_output_to_eq,
|
|
interleave,
|
|
is_valid_node,
|
|
node_from_seq,
|
|
node_from_single,
|
|
node_get_single_el,
|
|
node_supremum,
|
|
oset,
|
|
prod,
|
|
unique,
|
|
)
|
|
|
|
|
|
def cached_node_property(name):
|
|
"""Decorator for caching information about nodes."""
|
|
|
|
def wrapper(meth):
|
|
@functools.wraps(meth)
|
|
def getter(self, node):
|
|
try:
|
|
return self.info[node][name]
|
|
except KeyError:
|
|
self.info[node][name] = value = meth(self, node)
|
|
return value
|
|
|
|
return getter
|
|
|
|
return wrapper
|
|
|
|
|
|
def union_it(bs):
|
|
"""Non-variadic version of various set type unions."""
|
|
b0, *bs = bs
|
|
return b0.union(*bs)
|
|
|
|
|
|
def legs_union(legs_seq):
|
|
"""Combine a sequence of legs into a single set of legs, summing their
|
|
appearances.
|
|
"""
|
|
new_legs, *rem_legs = legs_seq
|
|
new_legs = new_legs.copy()
|
|
for legs in rem_legs:
|
|
for ix, ix_count in legs.items():
|
|
new_legs[ix] = new_legs.get(ix, 0) + ix_count
|
|
return new_legs
|
|
|
|
|
|
def legs_without(legs, ind):
|
|
"""Discard ``ind`` from legs to create a new set of legs."""
|
|
new_legs = legs.copy()
|
|
new_legs.pop(ind, None)
|
|
return new_legs
|
|
|
|
|
|
def get_with_default(k, obj, default):
|
|
return obj.get(k, default)
|
|
|
|
|
|
@dataclass(order=True, frozen=True)
|
|
class SliceInfo:
|
|
inner: bool
|
|
ind: str
|
|
size: int
|
|
project: Optional[int]
|
|
|
|
@property
|
|
def sliced_range(self):
|
|
if self.project is None:
|
|
return range(self.size)
|
|
else:
|
|
return [self.project]
|
|
|
|
|
|
def get_slice_strides(sliced_inds):
|
|
"""Compute the 'strides' given the (ordered) dictionary of sliced indices."""
|
|
slice_infos = list(sliced_inds.values())
|
|
nsliced = len(slice_infos)
|
|
strides = [1] * nsliced
|
|
# backwards cumulative product
|
|
for i in range(nsliced - 2, -1, -1):
|
|
strides[i] = strides[i + 1] * slice_infos[i + 1].size
|
|
return strides
|
|
|
|
|
|
def add_maybe_exponent_stripped(x, y):
|
|
"""Add two arrays, or tuples of (array, exponent) together in a stable
|
|
and branchless way.
|
|
"""
|
|
xistup = isinstance(x, tuple)
|
|
yistup = isinstance(y, tuple)
|
|
if not (xistup or yistup):
|
|
# simple sum without exponent
|
|
return x + y
|
|
|
|
if xistup:
|
|
xm, xe = x
|
|
else:
|
|
xm = x
|
|
xe = 0.0
|
|
|
|
if yistup:
|
|
ym, ye = y
|
|
else:
|
|
ym = y
|
|
ye = 0.0
|
|
|
|
# perform branchless for jit etc.
|
|
e = max(xe, ye)
|
|
m = xm * 10 ** (xe - e) + ym * 10 ** (ye - e)
|
|
|
|
return (m, e)
|
|
|
|
|
|
class ContractionTree:
|
|
"""Binary tree representing a tensor network contraction.
|
|
|
|
Parameters
|
|
----------
|
|
inputs : sequence of str
|
|
The list of input tensor's indices.
|
|
output : str
|
|
The output indices.
|
|
size_dict : dict[str, int]
|
|
The size of each index.
|
|
track_childless : bool, optional
|
|
Whether to dynamically keep track of which nodes are childless. Useful
|
|
if you are 'divisively' building the tree.
|
|
track_flops : bool, optional
|
|
Whether to dynamically keep track of the total number of flops. If
|
|
``False`` You can still compute this once the tree is complete.
|
|
track_write : bool, optional
|
|
Whether to dynamically keep track of the total number of elements
|
|
written. If ``False`` You can still compute this once the tree is
|
|
complete.
|
|
track_size : bool, optional
|
|
Whether to dynamically keep track of the largest tensor so far. If
|
|
``False`` You can still compute this once the tree is complete.
|
|
objective : str or Objective, optional
|
|
An default objective function to use for further optimization and
|
|
scoring, for example reconfiguring or computing the combo cost. If not
|
|
supplied the default is to create a flops objective when needed.
|
|
|
|
Attributes
|
|
----------
|
|
children : dict[node, tuple[node]]
|
|
Mapping of each node to two children.
|
|
info : dict[node, dict]
|
|
Information about the tree nodes. The key is the set of inputs (a
|
|
set of inputs indices) the node contains. Or in other words, the
|
|
subgraph of the node. The value is a dictionary to cache information
|
|
about effective 'leg' indices, size, flops of formation etc.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
inputs,
|
|
output,
|
|
size_dict,
|
|
track_childless=False,
|
|
track_flops=False,
|
|
track_write=False,
|
|
track_size=False,
|
|
objective=None,
|
|
):
|
|
self.inputs = inputs
|
|
self.output = output
|
|
|
|
if isinstance(self.inputs[0], set) or isinstance(self.output, set):
|
|
warnings.warn(
|
|
"The inputs or output of this tree are not ordered."
|
|
"Costs will be accurate but actually contracting requires "
|
|
"ordered indices corresponding to array axes."
|
|
)
|
|
|
|
if not isinstance(next(iter(size_dict.values()), 1), int):
|
|
# make sure we are working with python integers to avoid overflow
|
|
# comparison errors with inf etc.
|
|
self.size_dict = {k: int(v) for k, v in size_dict.items()}
|
|
else:
|
|
self.size_dict = size_dict
|
|
|
|
self.N = len(self.inputs)
|
|
|
|
# the index representation for each input is an ordered mapping of
|
|
# each index to the number of times it has appeared on children. By
|
|
# also tracking the total number of appearances one can efficiently
|
|
# and locally compute which indices should be kept or contracted
|
|
self.appearances = {}
|
|
for term in self.inputs:
|
|
for ix in term:
|
|
self.appearances[ix] = self.appearances.get(ix, 0) + 1
|
|
# adding output appearances ensures these are never contracted away,
|
|
# N.B. if after this step every appearance count is exactly 2,
|
|
# then there are no 'hyper' indices in the contraction
|
|
for ix in self.output:
|
|
self.appearances[ix] = self.appearances.get(ix, 0) + 1
|
|
|
|
# this stores potentialy preprocessing steps that are not part of the
|
|
# main contraction tree, but assumed to have been applied, for example
|
|
# tracing or summing over indices that appear only once
|
|
self.preprocessing = {}
|
|
|
|
# mapping of parents to children - the core binary tree object
|
|
self.children = {}
|
|
|
|
# information about all the nodes
|
|
self.info = {}
|
|
|
|
# add constant nodes: the leaves
|
|
for leaf in self.gen_leaves():
|
|
self._add_node(leaf)
|
|
# and the root or top node
|
|
self.root = node_supremum(self.N)
|
|
self._add_node(self.root)
|
|
|
|
# whether to keep track of dangling nodes/subgraphs
|
|
self.track_childless = track_childless
|
|
if self.track_childless:
|
|
# the set of dangling nodes
|
|
self.childless = oset([self.root])
|
|
|
|
# running largest_intermediate and total flops
|
|
self._track_flops = track_flops
|
|
if track_flops:
|
|
self._flops = 0
|
|
|
|
self._track_write = track_write
|
|
if track_write:
|
|
self._write = 0
|
|
|
|
self._track_size = track_size
|
|
if track_size:
|
|
self._sizes = MaxCounter()
|
|
|
|
# container for caching subtree reconfiguration condidates
|
|
self.already_optimized = dict()
|
|
|
|
# info relating to slicing (base constructor is always unsliced)
|
|
self.multiplicity = 1
|
|
self.sliced_inds = {}
|
|
self.sliced_inputs = frozenset()
|
|
|
|
# cache for compiled contraction cores
|
|
self.contraction_cores = {}
|
|
|
|
# a default objective function useful for
|
|
# further optimization and scoring
|
|
self._default_objective = objective
|
|
|
|
def set_state_from(self, other):
|
|
"""Set the internal state of this tree to that of ``other``."""
|
|
# immutable or never mutated properties
|
|
for attr in (
|
|
"appearances",
|
|
"inputs",
|
|
"multiplicity",
|
|
"N",
|
|
"output",
|
|
"root",
|
|
"size_dict",
|
|
"sliced_inputs",
|
|
"_default_objective",
|
|
):
|
|
setattr(self, attr, getattr(other, attr))
|
|
|
|
# mutable properties
|
|
for attr in (
|
|
"children",
|
|
"contraction_cores",
|
|
"sliced_inds",
|
|
"preprocessing",
|
|
):
|
|
setattr(self, attr, getattr(other, attr).copy())
|
|
|
|
# dicts of mutable
|
|
for attr in ("info", "already_optimized"):
|
|
setattr(
|
|
self,
|
|
attr,
|
|
{k: v.copy() for k, v in getattr(other, attr).items()},
|
|
)
|
|
|
|
self.track_childless = other.track_childless
|
|
if other.track_childless:
|
|
self.childless = other.childless.copy()
|
|
|
|
self._track_flops = other._track_flops
|
|
if other._track_flops:
|
|
self._flops = other._flops
|
|
|
|
self._track_write = other._track_write
|
|
if other._track_write:
|
|
self._write = other._write
|
|
|
|
self._track_size = other._track_size
|
|
if other._track_size:
|
|
self._sizes = other._sizes.copy()
|
|
|
|
def copy(self):
|
|
"""Create a copy of this ``ContractionTree``."""
|
|
tree = object.__new__(self.__class__)
|
|
tree.set_state_from(self)
|
|
return tree
|
|
|
|
def set_default_objective(self, objective):
|
|
"""Set the objective function for this tree."""
|
|
self._default_objective = get_score_fn(objective)
|
|
|
|
def get_default_objective(self):
|
|
"""Get the objective function for this tree."""
|
|
if self._default_objective is None:
|
|
self._default_objective = get_score_fn("flops")
|
|
return self._default_objective
|
|
|
|
def get_default_combo_factor(self):
|
|
"""Get the default combo factor for this tree."""
|
|
objective = self.get_default_objective()
|
|
try:
|
|
return objective.factor
|
|
except AttributeError:
|
|
return DEFAULT_COMBO_FACTOR
|
|
|
|
def get_score(self, objective=None):
|
|
"""Score this tree using the default objective function."""
|
|
from .scoring import get_score_fn
|
|
|
|
if objective is None:
|
|
objective = self.get_default_objective()
|
|
|
|
objective = get_score_fn(objective)
|
|
|
|
return objective({"tree": self})
|
|
|
|
@property
|
|
def nslices(self):
|
|
"""Simple alias for how many independent contractions this tree
|
|
represents overall.
|
|
"""
|
|
return self.multiplicity
|
|
|
|
@property
|
|
def nchunks(self):
|
|
"""The number of 'chunks' - determined by the number of sliced output
|
|
indices.
|
|
"""
|
|
return prod(
|
|
si.size for si in self.sliced_inds.values() if not si.inner
|
|
)
|
|
|
|
def node_to_terms(self, node):
|
|
"""Turn a node -- a frozen set of ints -- into the corresponding terms
|
|
-- a sequence of sets of str corresponding to input indices.
|
|
"""
|
|
return (self.get_legs(node_from_single(i)) for i in node)
|
|
|
|
def gen_leaves(self):
|
|
"""Generate the nodes representing leaves of the contraction tree, i.e.
|
|
of size 1 each corresponding to a single input tensor.
|
|
"""
|
|
return map(node_from_single, range(self.N))
|
|
|
|
def get_incomplete_nodes(self):
|
|
"""Get the set of current nodes that have no children and the set of
|
|
nodes that have no parents. These are the 'childless' and 'parentless'
|
|
nodes respectively, that need to be contracted to complete the tree.
|
|
The parentless nodes are grouped into the childless nodes that contain
|
|
them as subgraphs.
|
|
|
|
Returns
|
|
-------
|
|
groups : dict[frozenet[int], list[frozenset[int]]]
|
|
A mapping of childless nodes to the list of parentless nodes are
|
|
beneath them.
|
|
|
|
See Also
|
|
--------
|
|
autocomplete
|
|
"""
|
|
childless = dict.fromkeys(
|
|
node
|
|
for node in self.info
|
|
# start wth all but leaves
|
|
if len(node) != 1
|
|
)
|
|
parentless = dict.fromkeys(
|
|
node
|
|
for node in self.info
|
|
# start with all but root
|
|
if len(node) != self.N
|
|
)
|
|
for p, (l, r) in self.children.items():
|
|
parentless.pop(l)
|
|
parentless.pop(r)
|
|
childless.pop(p)
|
|
|
|
groups = {node: [] for node in childless}
|
|
for node in parentless:
|
|
# get the smallest node that contains this node
|
|
ancestor = min(
|
|
filter(node.issubset, childless),
|
|
key=len,
|
|
)
|
|
groups[ancestor].append(node)
|
|
|
|
return groups
|
|
|
|
def autocomplete(self, **contract_opts):
|
|
"""Contract all remaining node groups (as computed by
|
|
``tree.get_incomplete_nodes``) in the tree to complete it.
|
|
|
|
Parameters
|
|
----------
|
|
contract_opts
|
|
Options to pass to ``tree.contract_nodes``.
|
|
|
|
See Also
|
|
--------
|
|
get_incomplete_nodes, contract_nodes
|
|
"""
|
|
groups = self.get_incomplete_nodes()
|
|
for _, parentless_subnodes in groups.items():
|
|
self.contract_nodes(parentless_subnodes, **contract_opts)
|
|
|
|
@classmethod
|
|
def from_path(
|
|
cls,
|
|
inputs,
|
|
output,
|
|
size_dict,
|
|
*,
|
|
path=None,
|
|
ssa_path=None,
|
|
edge_path=None,
|
|
optimize="auto-hq",
|
|
autocomplete="auto",
|
|
check=False,
|
|
**kwargs,
|
|
):
|
|
"""Create a (completed) ``ContractionTree`` from the usual inputs plus
|
|
a standard contraction path or 'ssa_path' - you need to supply one.
|
|
|
|
Parameters
|
|
----------
|
|
inputs : Sequence[Sequence[str]]
|
|
The input indices of each tensor, as single unicode characters.
|
|
output : Sequence[str]
|
|
The output indices.
|
|
size_dict : dict[str, int]
|
|
The size of each index.
|
|
path : Sequence[Sequence[int]], optional
|
|
The contraction path, a sequence of pairs of tensor ids to
|
|
contract. The ids are linear indices into the list of temporary
|
|
tensors, which are recycled as each contraction pops a pair and
|
|
appends the result. One of ``path``, ``ssa_path`` or ``edge_path``
|
|
must be supplied.
|
|
ssa_path : Sequence[Sequence[int]], optional
|
|
The contraction path, a sequence of pairs of indices to contract.
|
|
The indices are single use, as if the result of each contraction is
|
|
appended to the end of the list of temporary tensors without
|
|
popping. One of ``path``, ``ssa_path`` or ``edge_path`` must be
|
|
supplied.
|
|
edge_path : Sequence[str], optional
|
|
The contraction path, a sequence of indices to contract in order.
|
|
One of ``path``, ``ssa_path`` or ``edge_path`` must be supplied.
|
|
optimize : str, optional
|
|
If a contraction within the path contains 3 or more tensors, how to
|
|
optimize this subcontraction into a binary tree.
|
|
autocomplete : "auto" or bool, optional
|
|
Whether to automatically complete the path, i.e. contract all
|
|
remaining nodes. If "auto" then a warning is issued if the path is
|
|
not complete.
|
|
check : bool, optional
|
|
Whether to perform some basic checks while creating the contraction
|
|
nodes.
|
|
|
|
Returns
|
|
-------
|
|
ContractionTree
|
|
"""
|
|
if (path is None) + (ssa_path is None) + (edge_path is None) != 2:
|
|
raise ValueError(
|
|
"Exactly one of ``path`` or ``ssa_path`` must be supplied."
|
|
)
|
|
|
|
contract_opts = {"optimize": optimize, "check": check}
|
|
|
|
if edge_path is not None:
|
|
from .pathfinders.path_basic import edge_path_to_ssa
|
|
|
|
ssa_path = edge_path_to_ssa(edge_path, inputs)
|
|
|
|
if ssa_path is not None:
|
|
path = ssa_path
|
|
|
|
tree = cls(inputs, output, size_dict, **kwargs)
|
|
|
|
if ssa_path is not None:
|
|
# ssa path (single use ids)
|
|
nodes = dict(enumerate(tree.gen_leaves()))
|
|
ssa = len(nodes)
|
|
for p in path:
|
|
merge = [nodes.pop(i) for i in p]
|
|
nodes[ssa] = tree.contract_nodes(merge, **contract_opts)
|
|
ssa += 1
|
|
nodes = nodes.values()
|
|
else:
|
|
# regular path ('recycled' ids)
|
|
nodes = list(tree.gen_leaves())
|
|
for p in path:
|
|
merge = [nodes.pop(i) for i in sorted(p, reverse=True)]
|
|
nodes.append(tree.contract_nodes(merge, **contract_opts))
|
|
|
|
if len(nodes) > 1 and autocomplete:
|
|
if autocomplete == "auto":
|
|
# warn that we are completing
|
|
warnings.warn(
|
|
"Path was not complete - contracting all remaining. "
|
|
"You can silence this warning with `autocomplete=True`."
|
|
"Or produce an incomplete tree with `autocomplete=False`."
|
|
)
|
|
|
|
tree.contract_nodes(nodes, **contract_opts)
|
|
|
|
return tree
|
|
|
|
@classmethod
|
|
def from_info(cls, info, **kwargs):
|
|
"""Create a ``ContractionTree`` from an ``opt_einsum.PathInfo`` object."""
|
|
return cls.from_path(
|
|
inputs=info.input_subscripts.split(","),
|
|
output=info.output_subscript,
|
|
size_dict=info.size_dict,
|
|
path=info.path,
|
|
**kwargs,
|
|
)
|
|
|
|
@classmethod
|
|
def from_eq(cls, eq, size_dict, **kwargs):
|
|
"""Create a empty ``ContractionTree`` directly from an equation and set
|
|
of shapes.
|
|
|
|
Parameters
|
|
----------
|
|
eq : str
|
|
The einsum string equation.
|
|
size_dict : dict[str, int]
|
|
The size of each index.
|
|
"""
|
|
lhs, output = eq.split("->")
|
|
inputs = lhs.split(",")
|
|
return cls(inputs, output, size_dict, **kwargs)
|
|
|
|
def get_eq(self):
|
|
"""Get the einsum equation corresponding to this tree. Note that this
|
|
is the total (or original) equation, so includes indices which have
|
|
been sliced.
|
|
|
|
Returns
|
|
-------
|
|
eq : str
|
|
"""
|
|
return inputs_output_to_eq(self.inputs, self.output)
|
|
|
|
def get_shapes(self):
|
|
"""Get the shapes of the input tensors corresponding to this tree.
|
|
|
|
Returns
|
|
-------
|
|
shapes : tuple[tuple[int]]
|
|
"""
|
|
return tuple(
|
|
tuple(self.size_dict[ix] for ix in term) for term in self.inputs
|
|
)
|
|
|
|
def get_inputs_sliced(self):
|
|
"""Get the input indices corresponding to a single slice of this tree,
|
|
i.e. with sliced indices removed.
|
|
|
|
Returns
|
|
-------
|
|
inputs : tuple[tuple[str]]
|
|
"""
|
|
return tuple(
|
|
tuple(ix for ix in term if ix not in self.sliced_inds)
|
|
for term in self.inputs
|
|
)
|
|
|
|
def get_output_sliced(self):
|
|
"""Get the output indices corresponding to a single slice of this tree,
|
|
i.e. with sliced indices removed.
|
|
|
|
Returns
|
|
-------
|
|
output : tuple[str]
|
|
"""
|
|
return tuple(ix for ix in self.output if ix not in self.sliced_inds)
|
|
|
|
def get_eq_sliced(self):
|
|
"""Get the einsum equation corresponding to a single slice of this
|
|
tree, i.e. with sliced indices removed.
|
|
|
|
Returns
|
|
-------
|
|
eq : str
|
|
"""
|
|
return inputs_output_to_eq(
|
|
self.get_inputs_sliced(), self.get_output_sliced()
|
|
)
|
|
|
|
def get_shapes_sliced(self):
|
|
"""Get the shapes of the input tensors corresponding to a single slice
|
|
of this tree, i.e. with sliced indices removed.
|
|
|
|
Returns
|
|
-------
|
|
shapes : tuple[tuple[int]]
|
|
"""
|
|
return tuple(
|
|
tuple(
|
|
self.size_dict[ix] for ix in term if ix not in self.sliced_inds
|
|
)
|
|
for term in self.inputs
|
|
)
|
|
|
|
@classmethod
|
|
def from_edge_path(
|
|
cls,
|
|
edge_path,
|
|
inputs,
|
|
output,
|
|
size_dict,
|
|
optimize="auto-hq",
|
|
autocomplete="auto",
|
|
check=False,
|
|
**kwargs,
|
|
):
|
|
"""Create a ``ContractionTree`` from an edge elimination ordering."""
|
|
warnings.warn(
|
|
"ContractionTree.from_edge_path(edge_path, ...) is deprecated. Use"
|
|
" ContractionTree.from_path(edge_path=edge_path, ...) instead."
|
|
)
|
|
return cls.from_path(
|
|
inputs,
|
|
output,
|
|
size_dict,
|
|
edge_path=edge_path,
|
|
optimize=optimize,
|
|
autocomplete=autocomplete,
|
|
check=check,
|
|
**kwargs,
|
|
)
|
|
|
|
def _add_node(self, node, check=False):
|
|
if check:
|
|
if len(self.info) > 2 * self.N - 1:
|
|
raise ValueError("There are too many children already.")
|
|
if len(self.children) > self.N - 1:
|
|
raise ValueError("There are too many branches already.")
|
|
if not is_valid_node(node):
|
|
raise ValueError("{} is not a valid node.".format(node))
|
|
|
|
self.info.setdefault(node, dict())
|
|
|
|
def _remove_node(self, node):
|
|
"""Remove ``node`` from this tree and update the flops and maximum size
|
|
if tracking them respectively, as well as input pre-processing.
|
|
"""
|
|
node_extent = len(node)
|
|
|
|
if node_extent == 1:
|
|
# leaf nodes should always exist
|
|
self.info[node].clear()
|
|
# input: remove any associated preprocessing
|
|
self.preprocessing.pop(node_get_single_el(node), None)
|
|
else:
|
|
# only non-leaf nodes contribute to size, flops and write
|
|
if self._track_size:
|
|
self._sizes.discard(self.get_size(node))
|
|
|
|
if self._track_flops:
|
|
self._flops -= self.get_flops(node)
|
|
|
|
if self._track_write:
|
|
self._write -= self.get_size(node)
|
|
|
|
del self.children[node]
|
|
if node_extent == self.N:
|
|
# root node should always exist
|
|
self.info[node].clear()
|
|
else:
|
|
del self.info[node]
|
|
|
|
def compute_leaf_legs(self, i):
|
|
"""Compute the effective 'outer' indices for the ith input tensor. This
|
|
is not always simply the ith input indices, due to A) potential slicing
|
|
and B) potential preprocessing.
|
|
"""
|
|
# indices of input tensor (after slicing which is done immediately)
|
|
if self.sliced_inds:
|
|
term = tuple(
|
|
ix for ix in self.inputs[i] if ix not in self.sliced_inds
|
|
)
|
|
else:
|
|
term = self.inputs[i]
|
|
|
|
legs = {}
|
|
for ix in term:
|
|
legs[ix] = legs.get(ix, 0) + 1
|
|
|
|
# check for single term simplifications, these are treated as a simple
|
|
# preprocessing step that only is taken into account during actual
|
|
# contraction, and are not represented in the binary tree
|
|
# N.B. need to compute simplifiability *after* slicing
|
|
is_simplifiable = (
|
|
# repeated indices (diag or traces)
|
|
(len(term) != len(legs))
|
|
or
|
|
# reduced indices (are summed immediately)
|
|
any(
|
|
ix_count == self.appearances[ix]
|
|
for ix, ix_count in legs.items()
|
|
)
|
|
)
|
|
|
|
if is_simplifiable:
|
|
# compute the simplified legs -> the new effective input legs
|
|
legs = {
|
|
ix: ix_count
|
|
for ix, ix_count in legs.items()
|
|
if ix_count != self.appearances[ix]
|
|
}
|
|
# add a preprocessing step to the list of contractions
|
|
eq = inputs_output_to_eq((term,), legs, canonicalize=True)
|
|
self.preprocessing[i] = eq
|
|
|
|
return legs
|
|
|
|
def has_preprocessing(self):
|
|
# touch all inputs legs, since preprocessing is lazily computed
|
|
for node in self.gen_leaves():
|
|
self.get_legs(node)
|
|
return bool(self.preprocessing)
|
|
|
|
def has_hyper_indices(self):
|
|
"""Check if there are any 'hyper' indices in the contraction, i.e.
|
|
indices that don't appear exactly twice, when considering the inputs
|
|
and output.
|
|
"""
|
|
return any(ix_count != 2 for ix_count in self.appearances.values())
|
|
|
|
@cached_node_property("legs")
|
|
def get_legs(self, node):
|
|
"""Get the effective 'outer' indices for the collection of tensors
|
|
in ``node``.
|
|
"""
|
|
node_extent = len(node)
|
|
|
|
if node_extent == 1:
|
|
# leaf legs are inputs
|
|
return self.compute_leaf_legs(node_get_single_el(node))
|
|
elif node_extent == self.N:
|
|
# root legs are output, after slicing
|
|
# n.b. the index counts are irrelevant for the output
|
|
return {ix: 0 for ix in self.output if ix not in self.sliced_inds}
|
|
|
|
try:
|
|
involved = self.get_involved(node)
|
|
except KeyError:
|
|
involved = legs_union(self.node_to_terms(node))
|
|
|
|
return {
|
|
ix: ix_count
|
|
for ix, ix_count in involved.items()
|
|
if ix_count < self.appearances[ix]
|
|
}
|
|
|
|
@cached_node_property("involved")
|
|
def get_involved(self, node):
|
|
"""Get all the indices involved in the formation of subgraph ``node``."""
|
|
if len(node) == 1:
|
|
return {}
|
|
sub_legs = map(self.get_legs, self.children[node])
|
|
return legs_union(sub_legs)
|
|
|
|
@cached_node_property("size")
|
|
def get_size(self, node):
|
|
"""Get the tensor size of ``node``."""
|
|
return compute_size_by_dict(self.get_legs(node), self.size_dict)
|
|
|
|
@cached_node_property("flops")
|
|
def get_flops(self, node):
|
|
"""Get the FLOPs for the pairwise contraction that will create
|
|
``node``.
|
|
"""
|
|
if len(node) == 1:
|
|
return 0
|
|
involved = self.get_involved(node)
|
|
return compute_size_by_dict(involved, self.size_dict)
|
|
|
|
@cached_node_property("can_dot")
|
|
def get_can_dot(self, node):
|
|
"""Get whether this contraction can be performed as a dot product (i.e.
|
|
with ``tensordot``), or else requires ``einsum``, as it has indices
|
|
that don't appear exactly twice in either the inputs or the output.
|
|
"""
|
|
l, r = self.children[node]
|
|
sp, sl, sr = map(self.get_legs, (node, l, r))
|
|
return set(sp) == set(sl).symmetric_difference(sr)
|
|
|
|
@cached_node_property("inds")
|
|
def get_inds(self, node):
|
|
"""Get the indices of this node - an ordered string version of
|
|
``get_legs`` that starts with ``tree.inputs`` and maintains the order
|
|
they appear in each contraction 'ABC,abc->ABCabc', to match tensordot.
|
|
"""
|
|
# NB: self.inputs and self.output contain the full (unsliced) indices
|
|
# thus we filter even the input legs and output legs
|
|
|
|
if len(node) in (1, self.N):
|
|
return "".join(self.get_legs(node))
|
|
|
|
legs = self.get_legs(node)
|
|
l_inds, r_inds = map(self.get_inds, self.children[node])
|
|
default_inds = "".join(
|
|
unique(filter(legs.__contains__, itertools.chain(l_inds, r_inds)))
|
|
)
|
|
|
|
# The default ordering can put right-only output indices between
|
|
# batches of left-only indices. On torch CPU this tends to create
|
|
# high-rank permuted views that have to be cloned before GEMM/BMM.
|
|
# Group output indices by how this contraction naturally produces
|
|
# them: shared kept indices first, then left-only, then right-only.
|
|
r_inds_set = set(r_inds)
|
|
l_inds_set = set(l_inds)
|
|
|
|
batch = [
|
|
ix for ix in l_inds
|
|
if (ix in legs) and (ix in r_inds_set)
|
|
]
|
|
left_keep = [
|
|
ix for ix in l_inds
|
|
if (ix in legs) and (ix not in r_inds_set)
|
|
]
|
|
right_keep = [
|
|
ix for ix in r_inds
|
|
if (ix in legs) and (ix not in l_inds_set)
|
|
]
|
|
|
|
ordered = list(unique(itertools.chain(batch, left_keep, right_keep)))
|
|
seen = set(ordered)
|
|
ordered.extend(ix for ix in default_inds if ix not in seen)
|
|
return "".join(ordered)
|
|
|
|
@cached_node_property("tensordot_axes")
|
|
def get_tensordot_axes(self, node):
|
|
"""Get the ``axes`` arg for a tensordot ocontraction that produces
|
|
``node``. The pairs are sorted in order of appearance on the left
|
|
input.
|
|
"""
|
|
l_inds, r_inds = map(self.get_inds, self.children[node])
|
|
l_axes, r_axes = [], []
|
|
for i, ind in enumerate(l_inds):
|
|
j = r_inds.find(ind)
|
|
if j != -1:
|
|
l_axes.append(i)
|
|
r_axes.append(j)
|
|
return tuple(l_axes), tuple(r_axes)
|
|
|
|
@cached_node_property("tensordot_perm")
|
|
def get_tensordot_perm(self, node):
|
|
"""Get the permutation required, if any, to bring the tensordot output
|
|
of this nodes contraction into line with ``self.get_inds(node)``.
|
|
"""
|
|
l_inds, r_inds = map(self.get_inds, self.children[node])
|
|
# the target output inds
|
|
p_inds = self.get_inds(node)
|
|
# the tensordot output inds
|
|
td_inds = "".join(sorted(p_inds, key=f"{l_inds}{r_inds}".find))
|
|
if td_inds == p_inds:
|
|
return None
|
|
return tuple(map(td_inds.find, p_inds))
|
|
|
|
@cached_node_property("einsum_eq")
|
|
def get_einsum_eq(self, node):
|
|
"""Get the einsum string describing the contraction that produces
|
|
``node``, unlike ``get_inds`` the characters are mapped into [a-zA-Z],
|
|
for compatibility with ``numpy.einsum`` for example.
|
|
"""
|
|
l, r = self.children[node]
|
|
l_inds, r_inds, p_inds = map(self.get_inds, (l, r, node))
|
|
# we need to map any extended unicode characters into ascii
|
|
char_mapping = {
|
|
ord(ix): get_symbol(i)
|
|
for i, ix in enumerate(unique(itertools.chain(l_inds, r_inds)))
|
|
}
|
|
return f"{l_inds},{r_inds}->{p_inds}".translate(char_mapping)
|
|
|
|
def get_centrality(self, node):
|
|
try:
|
|
return self.info[node]["centrality"]
|
|
except KeyError:
|
|
self.compute_centralities()
|
|
return self.info[node]["centrality"]
|
|
|
|
def total_flops(self, dtype=None, log=None):
|
|
"""Sum the flops contribution from every node in the tree.
|
|
|
|
Parameters
|
|
----------
|
|
dtype : {'float', 'complex', None}, optional
|
|
Scale the answer depending on the assumed data type.
|
|
"""
|
|
if self._track_flops:
|
|
C = self.multiplicity * self._flops
|
|
|
|
else:
|
|
self._flops = 0
|
|
for node, _, _ in self.traverse():
|
|
self._flops += self.get_flops(node)
|
|
|
|
self._track_flops = True
|
|
C = self.multiplicity * self._flops
|
|
|
|
if dtype is None:
|
|
pass
|
|
elif "float" in dtype:
|
|
C *= 2
|
|
elif "complex" in dtype:
|
|
C *= 4
|
|
else:
|
|
raise ValueError(f"Unknown dtype {dtype}")
|
|
|
|
if log is not None:
|
|
C = math.log(C, log)
|
|
|
|
return C
|
|
|
|
def total_write(self):
|
|
"""Sum the total amount of memory that will be created and operated on."""
|
|
if not self._track_write:
|
|
self._write = 0
|
|
for node, _, _ in self.traverse():
|
|
self._write += self.get_size(node)
|
|
|
|
self._track_write = True
|
|
|
|
return self.multiplicity * self._write
|
|
|
|
def combo_cost(self, factor=DEFAULT_COMBO_FACTOR, combine=sum, log=None):
|
|
t = 0
|
|
for p in self.children:
|
|
f = self.get_flops(p)
|
|
w = self.get_size(p)
|
|
t += combine((f, factor * w))
|
|
|
|
t *= self.multiplicity
|
|
|
|
if log is not None:
|
|
t = math.log(t, log)
|
|
|
|
return t
|
|
|
|
total_cost = combo_cost
|
|
|
|
def max_size(self, log=None):
|
|
"""The size of the largest intermediate tensor."""
|
|
if self.N == 1:
|
|
return self.get_size(self.root)
|
|
|
|
if not self._track_size:
|
|
self._sizes = MaxCounter()
|
|
for node, _, _ in self.traverse():
|
|
self._sizes.add(self.get_size(node))
|
|
self._track_size = True
|
|
|
|
size = self._sizes.max()
|
|
|
|
if log is not None:
|
|
size = math.log(size, log)
|
|
|
|
return size
|
|
|
|
def peak_size(self, order=None, log=None):
|
|
"""Get the peak concurrent size of tensors needed - this depends on the
|
|
traversal order, i.e. the exact contraction path, not just the
|
|
contraction tree.
|
|
"""
|
|
tot_size = sum(self.get_size(node) for node in self.gen_leaves())
|
|
peak = tot_size
|
|
for p, l, r in self.traverse(order=order):
|
|
tot_size += self.get_size(p)
|
|
# measure peak assuming we need both inputs and output
|
|
peak = max(peak, tot_size)
|
|
tot_size -= self.get_size(l)
|
|
tot_size -= self.get_size(r)
|
|
|
|
if log is not None:
|
|
peak = math.log(peak, log)
|
|
|
|
return peak
|
|
|
|
def contract_stats(self, force=False):
|
|
"""Simulteneously compute the total flops, write and size of the
|
|
contraction tree. This is more efficient than calling each of the
|
|
individual methods separately. Once computed, each quantity is then
|
|
automatically tracked.
|
|
|
|
Returns
|
|
-------
|
|
stats : dict[str, int]
|
|
The total flops, write and size.
|
|
"""
|
|
if force or not (
|
|
self._track_flops and self._track_write and self._track_size
|
|
):
|
|
self._flops = self._write = 0
|
|
self._sizes = MaxCounter()
|
|
|
|
for node, _, _ in self.traverse():
|
|
self._flops += self.get_flops(node)
|
|
node_size = self.get_size(node)
|
|
self._write += node_size
|
|
self._sizes.add(node_size)
|
|
|
|
self._track_flops = self._track_write = self._track_size = True
|
|
|
|
return {
|
|
"flops": self.multiplicity * self._flops,
|
|
"write": self.multiplicity * self._write,
|
|
"size": self._sizes.max(),
|
|
}
|
|
|
|
def arithmetic_intensity(self):
|
|
"""The ratio of total flops to total write - the higher the better for
|
|
extracting good computational performance.
|
|
"""
|
|
return self.total_flops(dtype=None) / self.total_write()
|
|
|
|
def contraction_scaling(self):
|
|
"""This is computed simply as the maximum number of indices involved
|
|
in any single contraction, which will match the scaling assuming that
|
|
all dimensions are equal.
|
|
"""
|
|
return max(len(self.get_involved(node)) for node in self.info)
|
|
|
|
def contraction_cost(self, log=None):
|
|
"""Get the total number of scalar operations ~ time complexity."""
|
|
return self.total_flops(dtype=None, log=log)
|
|
|
|
def contraction_width(self, log=2):
|
|
"""Get log2 of the size of the largest tensor."""
|
|
return self.max_size(log=log)
|
|
|
|
def compressed_contract_stats(
|
|
self,
|
|
chi=None,
|
|
order="surface_order",
|
|
compress_late=None,
|
|
):
|
|
if chi is None:
|
|
chi = self.get_default_chi()
|
|
|
|
if compress_late is None:
|
|
compress_late = self.get_default_compress_late()
|
|
|
|
hg = self.get_hypergraph(accel="auto")
|
|
|
|
# conversion between tree nodes <-> hypergraph nodes during contraction
|
|
tree_map = dict(zip(self.gen_leaves(), range(hg.get_num_nodes())))
|
|
|
|
tracker = CompressedStatsTracker(hg, chi)
|
|
|
|
for p, l, r in self.traverse(order):
|
|
li = tree_map[l]
|
|
ri = tree_map[r]
|
|
|
|
tracker.update_pre_step()
|
|
|
|
if compress_late:
|
|
tracker.update_pre_compress(hg, li, ri)
|
|
# compress just before we contract tensors
|
|
hg.compress(chi=chi, edges=hg.get_node(li))
|
|
hg.compress(chi=chi, edges=hg.get_node(ri))
|
|
tracker.update_post_compress(hg, li, ri)
|
|
|
|
tracker.update_pre_contract(hg, li, ri)
|
|
pi = tree_map[p] = hg.contract(li, ri)
|
|
tracker.update_post_contract(hg, pi)
|
|
|
|
if not compress_late:
|
|
# compress as soon as we can after contracting tensors
|
|
tracker.update_pre_compress(hg, pi)
|
|
hg.compress(chi=chi, edges=hg.get_node(pi))
|
|
tracker.update_post_compress(hg, pi)
|
|
|
|
tracker.update_post_step()
|
|
|
|
return tracker
|
|
|
|
def total_flops_compressed(
|
|
self,
|
|
chi=None,
|
|
order="surface_order",
|
|
compress_late=None,
|
|
dtype=None,
|
|
log=None,
|
|
):
|
|
"""Estimate the total flops for a compressed contraction of this tree
|
|
with maximum bond size ``chi``. This includes basic estimates of the
|
|
ops to perform contractions, QRs and SVDs.
|
|
"""
|
|
if dtype is not None:
|
|
raise ValueError(
|
|
"Can only estimate cost in terms of "
|
|
"number of abstract scalar ops."
|
|
)
|
|
|
|
F = self.compressed_contract_stats(
|
|
chi=chi,
|
|
order=order,
|
|
compress_late=compress_late,
|
|
).flops
|
|
|
|
if log is not None:
|
|
F = math.log(F, log)
|
|
|
|
return F
|
|
|
|
contraction_cost_compressed = total_flops_compressed
|
|
|
|
def total_write_compressed(
|
|
self,
|
|
chi=None,
|
|
order="surface_order",
|
|
compress_late=None,
|
|
accel="auto",
|
|
log=None,
|
|
):
|
|
"""Compute the total size of all intermediate tensors when a
|
|
compressed contraction is performed with maximum bond size ``chi``,
|
|
ordered by ``order``. This is relevant maybe for time complexity and
|
|
e.g. autodiff space complexity (since every intermediate is kept).
|
|
"""
|
|
W = self.compressed_contract_stats(
|
|
chi=chi,
|
|
order=order,
|
|
compress_late=compress_late,
|
|
).write
|
|
|
|
if log is not None:
|
|
W = math.log(W, log)
|
|
|
|
return W
|
|
|
|
def combo_cost_compressed(
|
|
self,
|
|
chi=None,
|
|
order="surface_order",
|
|
compress_late=None,
|
|
factor=None,
|
|
log=None,
|
|
):
|
|
if factor is None:
|
|
factor = self.get_default_combo_factor()
|
|
|
|
C = self.total_flops_compressed(
|
|
chi=chi, order=order, compress_late=compress_late
|
|
) + factor * self.total_write_compressed(
|
|
chi=chi, order=order, compress_late=compress_late
|
|
)
|
|
|
|
if log is not None:
|
|
C = math.log(C, log)
|
|
|
|
return C
|
|
|
|
total_cost_compressed = combo_cost_compressed
|
|
|
|
def max_size_compressed(
|
|
self, chi=None, order="surface_order", compress_late=None, log=None
|
|
):
|
|
"""Compute the maximum sized tensor produced when a compressed
|
|
contraction is performed with maximum bond size ``chi``, ordered by
|
|
``order``. This is close to the ideal space complexity if only
|
|
tensors that are being directly operated on are kept in memory.
|
|
"""
|
|
S = self.compressed_contract_stats(
|
|
chi=chi,
|
|
order=order,
|
|
compress_late=compress_late,
|
|
).max_size
|
|
|
|
if log is not None:
|
|
S = math.log(S, log)
|
|
|
|
return S
|
|
|
|
def peak_size_compressed(
|
|
self,
|
|
chi=None,
|
|
order="surface_order",
|
|
compress_late=None,
|
|
accel="auto",
|
|
log=None,
|
|
):
|
|
"""Compute the peak size of combined intermediate tensors when a
|
|
compressed contraction is performed with maximum bond size ``chi``,
|
|
ordered by ``order``. This is the practical space complexity if one is
|
|
not swapping intermediates in and out of memory.
|
|
"""
|
|
P = self.compressed_contract_stats(
|
|
chi=chi,
|
|
order=order,
|
|
compress_late=compress_late,
|
|
).peak_size
|
|
|
|
if log is not None:
|
|
P = math.log(P, log)
|
|
|
|
return P
|
|
|
|
def contraction_width_compressed(
|
|
self, chi=None, order="surface_order", compress_late=None, log=2
|
|
):
|
|
"""Compute log2 of the maximum sized tensor produced when a compressed
|
|
contraction is performed with maximum bond size ``chi``, ordered by
|
|
``order``.
|
|
"""
|
|
return self.max_size_compressed(chi, order, compress_late, log=log)
|
|
|
|
def _update_tracked(self, node):
|
|
if self._track_flops:
|
|
self._flops += self.get_flops(node)
|
|
if self._track_write:
|
|
self._write += self.get_size(node)
|
|
if self._track_size:
|
|
self._sizes.add(self.get_size(node))
|
|
|
|
def contract_nodes_pair(
|
|
self,
|
|
x,
|
|
y,
|
|
legs=None,
|
|
cost=None,
|
|
size=None,
|
|
check=False,
|
|
):
|
|
"""Contract node ``x`` with node ``y`` in the tree to create a new
|
|
parent node, which is returned.
|
|
|
|
Parameters
|
|
----------
|
|
x : frozenset[int]
|
|
The first node to contract.
|
|
y : frozenset[int]
|
|
The second node to contract.
|
|
legs : dict[str, int], optional
|
|
The effective 'legs' of the new node if already known. If not
|
|
given, this is computed from the inputs of ``x`` and ``y``.
|
|
cost : int, optional
|
|
The cost of the contraction if already known. If not given, this is
|
|
computed from the inputs of ``x`` and ``y``.
|
|
size : int, optional
|
|
The size of the new node if already known. If not given, this is
|
|
computed from the inputs of ``x`` and ``y``.
|
|
check : bool, optional
|
|
Whether to check the inputs are valid.
|
|
|
|
Returns
|
|
-------
|
|
parent : frozenset[int]
|
|
The new parent node of ``x`` and ``y``.
|
|
"""
|
|
parent = x.union(y)
|
|
|
|
# make sure info entries exist for all (default dict)
|
|
for node in (x, y, parent):
|
|
self._add_node(node, check=check)
|
|
|
|
# enforce left ordering of 'heaviest' subtrees
|
|
nx, ny = len(x), len(y)
|
|
if nx == ny:
|
|
# deterministically break ties
|
|
sortx = -min(x)
|
|
sorty = -min(y)
|
|
else:
|
|
sortx = nx
|
|
sorty = ny
|
|
|
|
if sortx > sorty:
|
|
lr = (x, y)
|
|
else:
|
|
lr = (y, x)
|
|
|
|
self.children[parent] = lr
|
|
|
|
if self.track_childless:
|
|
self.childless.discard(parent)
|
|
if x not in self.children and nx > 1:
|
|
self.childless.add(x)
|
|
if y not in self.children and ny > 1:
|
|
self.childless.add(y)
|
|
|
|
# pre-computed information
|
|
if legs is not None:
|
|
self.info[parent]["legs"] = legs
|
|
if cost is not None:
|
|
self.info[parent]["flops"] = cost
|
|
if size is not None:
|
|
self.info[parent]["size"] = size
|
|
|
|
self._update_tracked(parent)
|
|
|
|
return parent
|
|
|
|
def contract_nodes(
|
|
self,
|
|
nodes,
|
|
optimize="auto-hq",
|
|
check=False,
|
|
extra_opts=None,
|
|
):
|
|
"""Contract an arbitrary number of ``nodes`` in the tree to build up a
|
|
subtree. The root of this subtree (a new intermediate) is returned.
|
|
"""
|
|
if len(nodes) == 1:
|
|
return next(iter(nodes))
|
|
|
|
if len(nodes) == 2:
|
|
return self.contract_nodes_pair(*nodes, check=check)
|
|
|
|
from .interface import find_path
|
|
|
|
# create the bottom and top nodes
|
|
grandparent = union_it(nodes)
|
|
self._add_node(grandparent, check=check)
|
|
for node in nodes:
|
|
self._add_node(node, check=check)
|
|
|
|
# if more than two nodes need to find the path to fill in between
|
|
# \
|
|
# GN <- 'grandparent'
|
|
# / \
|
|
# ?????????
|
|
# ????????????? <- to be filled with 'temp nodes'
|
|
# / \ / / \
|
|
# N0 N1 N2 N3 N4 <- ``nodes``, or, subgraphs
|
|
# / \ / / \
|
|
path_inputs = [tuple(self.get_legs(x)) for x in nodes]
|
|
path_output = tuple(self.get_legs(grandparent))
|
|
|
|
path = find_path(
|
|
path_inputs,
|
|
path_output,
|
|
self.size_dict,
|
|
optimize=optimize,
|
|
**(extra_opts or {}),
|
|
)
|
|
|
|
# now we have path create the nodes in between
|
|
temp_nodes = list(nodes)
|
|
for p in path:
|
|
to_contract = [temp_nodes.pop(i) for i in sorted(p, reverse=True)]
|
|
temp_nodes.append(self.contract_nodes(to_contract, check=check))
|
|
|
|
(parent,) = temp_nodes
|
|
|
|
if check:
|
|
# final remaining temp input should be the 'grandparent'
|
|
assert parent == grandparent
|
|
|
|
return parent
|
|
|
|
def is_complete(self):
|
|
"""Check every node has two children, unless it is a leaf."""
|
|
too_many_nodes = len(self.info) > 2 * self.N - 1
|
|
too_many_branches = len(self.children) > self.N - 1
|
|
|
|
if too_many_nodes or too_many_branches:
|
|
raise ValueError("Contraction tree seems to be over complete!")
|
|
|
|
queue = [self.root]
|
|
while queue:
|
|
x = queue.pop()
|
|
if len(x) == 1:
|
|
continue
|
|
try:
|
|
queue.extend(self.children[x])
|
|
except KeyError:
|
|
return False
|
|
|
|
return True
|
|
|
|
def get_default_order(self):
|
|
return "dfs"
|
|
|
|
def _traverse_dfs(self):
|
|
"""Traverse the tree in a depth first, non-recursive, order."""
|
|
ready = set(self.gen_leaves())
|
|
queue = [self.root]
|
|
|
|
while queue:
|
|
node = queue[-1]
|
|
l, r = self.children[node]
|
|
|
|
# both node's children are ready -> we can yield this contraction
|
|
if (l in ready) and (r in ready):
|
|
ready.add(queue.pop())
|
|
yield node, l, r
|
|
continue
|
|
|
|
if r not in ready:
|
|
queue.append(r)
|
|
if l not in ready:
|
|
queue.append(l)
|
|
|
|
def _traverse_ordered(self, order):
|
|
"""Traverse the tree in the order that minimizes ``order(node)``, but
|
|
still constrained to produce children before parents.
|
|
"""
|
|
from bisect import bisect
|
|
|
|
if order == "surface_order":
|
|
order = self.surface_order
|
|
|
|
seen = set()
|
|
queue = [self.root]
|
|
scores = [order(self.root)]
|
|
|
|
while len(seen) != len(self.children):
|
|
i = 0
|
|
while i < len(queue):
|
|
node = queue[i]
|
|
if node not in seen:
|
|
for child in self.children[node]:
|
|
if len(child) > 1:
|
|
# insert child into queue by score + before parent
|
|
score = order(child)
|
|
ci = bisect(scores[:i], score)
|
|
scores.insert(ci, score)
|
|
queue.insert(ci, child)
|
|
# parent moves extra place to right
|
|
i += 1
|
|
seen.add(node)
|
|
i += 1
|
|
|
|
for node in queue:
|
|
yield (node, *self.children[node])
|
|
|
|
def traverse(self, order=None):
|
|
"""Generate, in order, all the node merges in this tree. Non-recursive!
|
|
This ensures children are always visited before their parent.
|
|
|
|
Parameters
|
|
----------
|
|
order : None, "dfs", or callable, optional
|
|
How to order the contractions within the tree. If a callable is
|
|
given (which should take a node as its argument), try to contract
|
|
nodes that minimize this function first.
|
|
|
|
Returns
|
|
-------
|
|
generator[tuple[node]]
|
|
The bottom up ordered sequence of tree merges, each a
|
|
tuple of ``(parent, left_child, right_child)``.
|
|
|
|
See Also
|
|
--------
|
|
descend
|
|
"""
|
|
if self.N == 1:
|
|
return
|
|
|
|
if order is None:
|
|
order = self.get_default_order()
|
|
|
|
if order == "dfs":
|
|
yield from self._traverse_dfs()
|
|
else:
|
|
yield from self._traverse_ordered(order=order)
|
|
|
|
def descend(self, mode="dfs"):
|
|
"""Generate, from root to leaves, all the node merges in this tree.
|
|
Non-recursive! This ensures parents are visited before their children.
|
|
|
|
Parameters
|
|
----------
|
|
mode : {'dfs', bfs}, optional
|
|
How expand from a parent.
|
|
|
|
Returns
|
|
-------
|
|
generator[tuple[node]
|
|
The top down ordered sequence of tree merges, each a
|
|
tuple of ``(parent, left_child, right_child)``.
|
|
|
|
See Also
|
|
--------
|
|
traverse
|
|
"""
|
|
queue = [self.root]
|
|
while queue:
|
|
if mode == "dfs":
|
|
parent = queue.pop(-1)
|
|
elif mode == "bfs":
|
|
parent = queue.pop(0)
|
|
l, r = self.children[parent]
|
|
yield parent, l, r
|
|
if len(l) > 1:
|
|
queue.append(l)
|
|
if len(r) > 1:
|
|
queue.append(r)
|
|
|
|
def get_subtree(self, node, size, search="bfs", seed=None):
|
|
"""Get a subtree spanning down from ``node`` which will have ``size``
|
|
leaves (themselves not necessarily leaves of the actual tree).
|
|
|
|
Parameters
|
|
----------
|
|
node : node
|
|
The node of the tree to start with.
|
|
size : int
|
|
How many subtree leaves to aim for.
|
|
search : {'bfs', 'dfs', 'random'}, optional
|
|
How to build the tree:
|
|
|
|
- 'bfs': breadth first expansion
|
|
- 'dfs': depth first expansion (largest nodes first)
|
|
- 'random': random expansion
|
|
|
|
seed : None, int or random.Random, optional
|
|
Random number generator seed, if ``search`` is 'random'.
|
|
|
|
Returns
|
|
-------
|
|
sub_leaves : tuple[node]
|
|
Nodes which are subtree leaves.
|
|
branches : tuple[node]
|
|
Nodes which are between the subtree leaves and root.
|
|
"""
|
|
# nodes which are subtree leaves
|
|
branches = []
|
|
|
|
# actual tree leaves - can't expand
|
|
real_leaves = []
|
|
|
|
# nodes to expand
|
|
queue = [node]
|
|
|
|
if search == "random":
|
|
rng = get_rng(seed)
|
|
else:
|
|
rng = None
|
|
if search == "bfs":
|
|
i = 0
|
|
elif search == "dfs":
|
|
i = -1
|
|
|
|
while (len(queue) + len(real_leaves) < size) and queue:
|
|
if rng is not None:
|
|
i = rng.randint(0, len(queue) - 1)
|
|
|
|
p = queue.pop(i)
|
|
if len(p) == 1:
|
|
real_leaves.append(p)
|
|
continue
|
|
|
|
# the left child is always >= in weight that right child
|
|
# if we append it last then ``.pop(-1)`` above perform the
|
|
# depth first search sorting by node subgraph size
|
|
l, r = self.children[p]
|
|
|
|
queue.append(r)
|
|
queue.append(l)
|
|
branches.append(p)
|
|
|
|
# nodes at the bottom of the subtree
|
|
sub_leaves = queue + real_leaves
|
|
|
|
return tuple(sub_leaves), tuple(branches)
|
|
|
|
def remove_ind(self, ind, project=None, inplace=False):
|
|
"""Remove (i.e. by default slice) index ``ind`` from this contraction
|
|
tree, taking care to update all relevant information about each node.
|
|
"""
|
|
tree = self if inplace else self.copy()
|
|
|
|
if ind in tree.sliced_inds:
|
|
raise ValueError(f"Index {ind} already sliced.")
|
|
|
|
# make sure all flops and size information has been populated
|
|
tree.contract_stats()
|
|
|
|
d = tree.size_dict[ind]
|
|
if project is None:
|
|
# we are slicing the index
|
|
si = SliceInfo(ind not in tree.output, ind, d, None)
|
|
tree.multiplicity = tree.multiplicity * d
|
|
else:
|
|
si = SliceInfo(ind not in tree.output, ind, 1, project)
|
|
|
|
# update the ordered slice information dictionary, but maintain the
|
|
# order such that output sliced indices always appear first ->
|
|
# enforced by the dataclass SliceInfo ordering
|
|
tree.sliced_inds = {
|
|
si.ind: si for si in sorted((*tree.sliced_inds.values(), si))
|
|
}
|
|
|
|
for node, node_info in tree.info.items():
|
|
if len(node) == 1:
|
|
# handle leaves separately
|
|
i = node_get_single_el(node)
|
|
term = tree.inputs[i]
|
|
if ind in term:
|
|
# n.b. leaves don't contribute to size, flops or write
|
|
# simply recalculate all information, incl. preprocessing
|
|
tree._remove_node(node)
|
|
tree.sliced_inputs = tree.sliced_inputs | frozenset([i])
|
|
else:
|
|
involved = tree.get_involved(node)
|
|
if ind not in involved:
|
|
# if ind doesn't feature in this node (contraction)
|
|
# -> nothing to do
|
|
continue
|
|
|
|
# else update all the relevant information about this node
|
|
# -> flops changes for all involved indices
|
|
node_info["involved"] = legs_without(involved, ind)
|
|
old_flops = tree.get_flops(node)
|
|
new_flops = old_flops // d
|
|
node_info["flops"] = new_flops
|
|
tree._flops += new_flops - old_flops
|
|
|
|
# -> size and write only changes for node legs (output) indices
|
|
legs = tree.get_legs(node)
|
|
if ind in legs:
|
|
node_info["legs"] = legs_without(legs, ind)
|
|
old_size = tree.get_size(node)
|
|
tree._sizes.discard(old_size)
|
|
new_size = old_size // d
|
|
tree._sizes.add(new_size)
|
|
node_info["size"] = new_size
|
|
tree._write += new_size - old_size
|
|
|
|
# delete info we can't change
|
|
for k in (
|
|
"inds",
|
|
"einsum_eq",
|
|
"can_dot",
|
|
"tensordot_axes",
|
|
"tensordot_perm",
|
|
):
|
|
tree.info[node].pop(k, None)
|
|
|
|
tree.already_optimized.clear()
|
|
tree.contraction_cores.clear()
|
|
|
|
return tree
|
|
|
|
remove_ind_ = functools.partialmethod(remove_ind, inplace=True)
|
|
|
|
def restore_ind(self, ind, inplace=False):
|
|
"""Restore (unslice or un-project) index ``ind`` to this contraction
|
|
tree, taking care to update all relevant information about each node.
|
|
|
|
Parameters
|
|
----------
|
|
ind : str
|
|
The index to restore.
|
|
inplace : bool, optional
|
|
Whether to perform the restoration inplace or not.
|
|
|
|
Returns
|
|
-------
|
|
ContractionTree
|
|
"""
|
|
tree = self if inplace else self.copy()
|
|
|
|
# pop sliced index info
|
|
si = tree.sliced_inds.pop(ind)
|
|
|
|
# make sure all flops and size information has been populated
|
|
tree.contract_stats()
|
|
tree.multiplicity //= si.size
|
|
|
|
# handle inputs
|
|
for i, term in enumerate(tree.inputs):
|
|
# this is the original term with all indices
|
|
if ind in term:
|
|
tree._remove_node(node_from_single(i))
|
|
if all(ix not in tree.sliced_inds for ix in term):
|
|
# mark this input as not sliced
|
|
tree.sliced_inputs = tree.sliced_inputs - frozenset([i])
|
|
|
|
# delete and re-add dependent intermediates
|
|
for p, l, r in tree.traverse():
|
|
if ind in tree.get_legs(l) or ind in tree.get_legs(r):
|
|
tree._remove_node(p)
|
|
tree.contract_nodes_pair(l, r)
|
|
|
|
# reset caches
|
|
tree.already_optimized.clear()
|
|
tree.contraction_cores.clear()
|
|
|
|
return tree
|
|
|
|
restore_ind_ = functools.partialmethod(restore_ind, inplace=True)
|
|
|
|
def unslice_rand(self, seed=None, inplace=False):
|
|
"""Unslice (restore) a random index from this contraction tree.
|
|
|
|
Parameters
|
|
----------
|
|
seed : None, int or random.Random, optional
|
|
Random number generator seed.
|
|
inplace : bool, optional
|
|
Whether to perform the unslicing inplace or not.
|
|
|
|
Returns
|
|
-------
|
|
ContractionTree
|
|
"""
|
|
rng = get_rng(seed)
|
|
ix = rng.choice(tuple(self.sliced_inds))
|
|
return self.restore_ind(ix, inplace=inplace)
|
|
|
|
unslice_rand_ = functools.partialmethod(unslice_rand, inplace=True)
|
|
|
|
def unslice_all(self, inplace=False):
|
|
"""Unslice (restore) all sliced indices from this contraction tree.
|
|
|
|
Parameters
|
|
----------
|
|
inplace : bool, optional
|
|
Whether to perform the unslicing inplace or not.
|
|
|
|
Returns
|
|
-------
|
|
ContractionTree
|
|
"""
|
|
tree = self if inplace else self.copy()
|
|
|
|
for ind in tuple(tree.sliced_inds):
|
|
tree.restore_ind_(ind)
|
|
|
|
return tree
|
|
|
|
unslice_all_ = functools.partialmethod(unslice_all, inplace=True)
|
|
|
|
def calc_subtree_candidates(self, pwr=2, what="flops"):
|
|
candidates = list(self.children)
|
|
|
|
if what == "size":
|
|
weights = [self.get_size(x) for x in candidates]
|
|
|
|
elif what == "flops":
|
|
weights = [self.get_flops(x) for x in candidates]
|
|
|
|
max_weight = max(weights)
|
|
|
|
# can be bigger than numpy int/float allows
|
|
weights = [float(w / max_weight) ** (1 / pwr) for w in weights]
|
|
|
|
# sort by descending score
|
|
candidates, weights = zip(
|
|
*sorted(zip(candidates, weights), key=lambda x: -x[1])
|
|
)
|
|
|
|
return list(candidates), list(weights)
|
|
|
|
def subtree_reconfigure(
|
|
self,
|
|
subtree_size=8,
|
|
subtree_search="bfs",
|
|
weight_what="flops",
|
|
weight_pwr=2,
|
|
select="max",
|
|
maxiter=500,
|
|
seed=None,
|
|
minimize=None,
|
|
optimize=None,
|
|
inplace=False,
|
|
progbar=False,
|
|
):
|
|
"""Reconfigure subtrees of this tree with locally optimal paths.
|
|
|
|
Parameters
|
|
----------
|
|
subtree_size : int, optional
|
|
The size of subtree to consider. Cost is exponential in this.
|
|
subtree_search : {'bfs', 'dfs', 'random'}, optional
|
|
How to build the subtrees:
|
|
|
|
- 'bfs': breadth-first-search creating balanced subtrees
|
|
- 'dfs': depth-first-search creating imbalanced subtrees
|
|
- 'random': random subtree building
|
|
|
|
weight_what : {'flops', 'size'}, optional
|
|
When assessing nodes to build and optimize subtrees from whether to
|
|
score them by the (local) contraction cost, or tensor size.
|
|
weight_pwr : int, optional
|
|
When assessing nodes to build and optimize subtrees from, how to
|
|
scale their score into a probability: ``score**(1 / weight_pwr)``.
|
|
The larger this is the more explorative the algorithm is when
|
|
``select='random'``.
|
|
select : {'max', 'min', 'random'}, optional
|
|
What order to select node subtrees to optimize:
|
|
|
|
- 'max': choose the highest score first
|
|
- 'min': choose the lowest score first
|
|
- 'random': choose randomly weighted on score -- see
|
|
``weight_pwr``.
|
|
|
|
maxiter : int, optional
|
|
How many subtree optimizations to perform, the algorithm can
|
|
terminate before this if all subtrees have been optimized.
|
|
seed : int, optional
|
|
A random seed (seeds python system random module).
|
|
minimize : {'flops', 'size'}, optional
|
|
Whether to minimize with respect to contraction flops or size.
|
|
inplace : bool, optional
|
|
Whether to perform the reconfiguration inplace or not.
|
|
progbar : bool, optional
|
|
Whether to show live progress of the reconfiguration.
|
|
|
|
Returns
|
|
-------
|
|
ContractionTree
|
|
"""
|
|
tree = self if inplace else self.copy()
|
|
|
|
# ensure these have been computed and thus are being tracked
|
|
tree.contract_stats()
|
|
|
|
if minimize is None:
|
|
minimize = self.get_default_objective()
|
|
scorer = get_score_fn(minimize)
|
|
|
|
if optimize is None:
|
|
from .pathfinders.path_basic import OptimalOptimizer
|
|
|
|
opt = OptimalOptimizer(
|
|
minimize=scorer.get_dynamic_programming_minimize()
|
|
)
|
|
else:
|
|
opt = optimize
|
|
|
|
node_cost = getattr(scorer, "cost_local_tree_node", lambda _: 2)
|
|
|
|
# different caches as we might want to reconfigure one before other
|
|
tree.already_optimized.setdefault(minimize, set())
|
|
already_optimized = tree.already_optimized[minimize]
|
|
|
|
if select == "random":
|
|
rng = get_rng(seed)
|
|
else:
|
|
if select == "max":
|
|
i = 0
|
|
elif select == "min":
|
|
i = -1
|
|
rng = None
|
|
|
|
candidates, weights = tree.calc_subtree_candidates(
|
|
pwr=weight_pwr, what=weight_what
|
|
)
|
|
|
|
if progbar:
|
|
import tqdm
|
|
|
|
pbar = tqdm.tqdm()
|
|
pbar.set_description(_describe_tree(tree), refresh=False)
|
|
|
|
r = 0
|
|
try:
|
|
while candidates and r < maxiter:
|
|
if rng is not None:
|
|
(i,) = rng.choices(range(len(candidates)), weights=weights)
|
|
|
|
weights.pop(i)
|
|
sub_root = candidates.pop(i)
|
|
|
|
# get a subtree to possibly reconfigure
|
|
sub_leaves, sub_branches = tree.get_subtree(
|
|
sub_root, size=subtree_size, search=subtree_search
|
|
)
|
|
|
|
sub_leaves = frozenset(sub_leaves)
|
|
|
|
# check if its already been optimized
|
|
if sub_leaves in already_optimized:
|
|
continue
|
|
|
|
# else remove the branches, keeping track of current cost
|
|
current_cost = node_cost(tree, sub_root)
|
|
for node in sub_branches:
|
|
if minimize == "size":
|
|
current_cost = max(current_cost, node_cost(tree, node))
|
|
else:
|
|
current_cost += node_cost(tree, node)
|
|
tree._remove_node(node)
|
|
|
|
# make the optimizer more efficient by supplying accurate cap
|
|
opt.cost_cap = max(2, current_cost)
|
|
|
|
# and reoptimize the leaves
|
|
tree.contract_nodes(sub_leaves, optimize=opt)
|
|
already_optimized.add(sub_leaves)
|
|
|
|
r += 1
|
|
|
|
if progbar:
|
|
pbar.update()
|
|
pbar.set_description(_describe_tree(tree), refresh=False)
|
|
|
|
# if we have reconfigured simply re-add all candidates
|
|
candidates, weights = tree.calc_subtree_candidates(
|
|
pwr=weight_pwr, what=weight_what
|
|
)
|
|
finally:
|
|
if progbar:
|
|
pbar.close()
|
|
|
|
# invalidate any compiled contractions
|
|
tree.contraction_cores.clear()
|
|
|
|
return tree
|
|
|
|
subtree_reconfigure_ = functools.partialmethod(
|
|
subtree_reconfigure, inplace=True
|
|
)
|
|
|
|
def subtree_reconfigure_forest(
|
|
self,
|
|
num_trees=8,
|
|
num_restarts=10,
|
|
restart_fraction=0.5,
|
|
subtree_maxiter=100,
|
|
subtree_size=10,
|
|
subtree_search=("random", "bfs"),
|
|
subtree_select=("random",),
|
|
subtree_weight_what=("flops", "size"),
|
|
subtree_weight_pwr=(2,),
|
|
parallel="auto",
|
|
parallel_maxiter_steps=4,
|
|
minimize=None,
|
|
seed=None,
|
|
progbar=False,
|
|
inplace=False,
|
|
):
|
|
"""'Forested' version of ``subtree_reconfigure`` which is more
|
|
explorative and can be parallelized. It stochastically generates
|
|
a 'forest' reconfigured trees, then only keeps some fraction of these
|
|
to generate the next forest.
|
|
|
|
Parameters
|
|
----------
|
|
num_trees : int, optional
|
|
The number of trees to reconfigure at each stage.
|
|
num_restarts : int, optional
|
|
The number of times to halt, prune and then restart the
|
|
tree reconfigurations.
|
|
restart_fraction : float, optional
|
|
The fraction of trees to keep at each stage and generate the next
|
|
forest from.
|
|
subtree_maxiter : int, optional
|
|
Number of subtree reconfigurations per step.
|
|
``num_restarts * subtree_maxiter`` is the max number of total
|
|
subtree reconfigurations for the final tree produced.
|
|
subtree_size : int, optional
|
|
The size of subtrees to search for and reconfigure.
|
|
subtree_search : tuple[{'random', 'bfs', 'dfs'}], optional
|
|
Tuple of options for the ``search`` kwarg of
|
|
:meth:`ContractionTree.subtree_reconfigure` to randomly sample.
|
|
subtree_select : tuple[{'random', 'max', 'min'}], optional
|
|
Tuple of options for the ``select`` kwarg of
|
|
:meth:`ContractionTree.subtree_reconfigure` to randomly sample.
|
|
subtree_weight_what : tuple[{'flops', 'size'}], optional
|
|
Tuple of options for the ``weight_what`` kwarg of
|
|
:meth:`ContractionTree.subtree_reconfigure` to randomly sample.
|
|
subtree_weight_pwr : tuple[int], optional
|
|
Tuple of options for the ``weight_pwr`` kwarg of
|
|
:meth:`ContractionTree.subtree_reconfigure` to randomly sample.
|
|
parallel : 'auto', False, True, int, or distributed.Client
|
|
Whether to parallelize the search.
|
|
parallel_maxiter_steps : int, optional
|
|
If parallelizing, how many steps to break each reconfiguration into
|
|
in order to evenly saturate many processes.
|
|
minimize : {'flops', 'size', ..., Objective}, optional
|
|
Whether to minimize the total flops or maximum size of the
|
|
contraction tree.
|
|
seed : None, int or random.Random, optional
|
|
A random seed to use.
|
|
progbar : bool, optional
|
|
Whether to show live progress.
|
|
inplace : bool, optional
|
|
Whether to perform the subtree reconfiguration inplace.
|
|
|
|
Returns
|
|
-------
|
|
ContractionTree
|
|
"""
|
|
tree = self if inplace else self.copy()
|
|
|
|
# some of these might be unpicklable
|
|
tree.contraction_cores.clear()
|
|
|
|
# candidate trees
|
|
num_keep = max(1, int(num_trees * restart_fraction))
|
|
|
|
# how to rank the trees
|
|
if minimize is None:
|
|
minimize = self.get_default_objective()
|
|
score = get_score_fn(minimize)
|
|
|
|
rng = get_rng(seed)
|
|
|
|
# set up the initial 'forest' and parallel machinery
|
|
pool = parse_parallel_arg(parallel)
|
|
is_scatter_pool = can_scatter(pool)
|
|
if is_scatter_pool:
|
|
is_worker = maybe_leave_pool(pool)
|
|
# store the trees as futures for the entire process
|
|
forest = [scatter(pool, tree)]
|
|
maxiter = subtree_maxiter // parallel_maxiter_steps
|
|
else:
|
|
forest = [tree]
|
|
maxiter = subtree_maxiter
|
|
|
|
if progbar:
|
|
import tqdm
|
|
|
|
pbar = tqdm.tqdm(total=num_restarts)
|
|
pbar.set_description(_describe_tree(tree), refresh=False)
|
|
|
|
try:
|
|
for _ in range(num_restarts):
|
|
# on the next round take only the best trees
|
|
forest = itertools.cycle(forest[:num_keep])
|
|
|
|
# select some random configurations
|
|
saplings = [
|
|
{
|
|
"tree": next(forest),
|
|
"maxiter": maxiter,
|
|
"minimize": minimize,
|
|
"subtree_size": subtree_size,
|
|
"subtree_search": rng.choice(subtree_search),
|
|
"select": rng.choice(subtree_select),
|
|
"weight_pwr": rng.choice(subtree_weight_pwr),
|
|
"weight_what": rng.choice(subtree_weight_what),
|
|
}
|
|
for _ in range(num_trees)
|
|
]
|
|
|
|
if pool is None:
|
|
forest = [_reconfigure_tree(**s) for s in saplings]
|
|
res = [{"tree": t, **_get_tree_info(t)} for t in forest]
|
|
elif not is_scatter_pool:
|
|
forest_futures = [
|
|
submit(pool, _reconfigure_tree, **s) for s in saplings
|
|
]
|
|
forest = [f.result() for f in forest_futures]
|
|
res = [{"tree": t, **_get_tree_info(t)} for t in forest]
|
|
else:
|
|
# submit in smaller steps to saturate processes
|
|
for _ in range(parallel_maxiter_steps):
|
|
for s in saplings:
|
|
s["tree"] = submit(pool, _reconfigure_tree, **s)
|
|
|
|
# compute scores remotely then gather
|
|
forest_futures = [s["tree"] for s in saplings]
|
|
res_futures = [
|
|
submit(pool, _get_tree_info, t) for t in forest_futures
|
|
]
|
|
res = [
|
|
{"tree": tree_future, **res_future.result()}
|
|
for tree_future, res_future in zip(
|
|
forest_futures, res_futures
|
|
)
|
|
]
|
|
|
|
# update the order of the new forest
|
|
res.sort(key=score)
|
|
forest = [r["tree"] for r in res]
|
|
|
|
if progbar:
|
|
pbar.update()
|
|
if pool is None:
|
|
d = _describe_tree(forest[0])
|
|
else:
|
|
d = submit(pool, _describe_tree, forest[0]).result()
|
|
pbar.set_description(d, refresh=False)
|
|
|
|
finally:
|
|
if progbar:
|
|
pbar.close()
|
|
|
|
if is_scatter_pool:
|
|
tree.set_state_from(forest[0].result())
|
|
maybe_rejoin_pool(is_worker, pool)
|
|
else:
|
|
tree.set_state_from(forest[0])
|
|
|
|
return tree
|
|
|
|
subtree_reconfigure_forest_ = functools.partialmethod(
|
|
subtree_reconfigure_forest, inplace=True
|
|
)
|
|
|
|
simulated_anneal = simulated_anneal_tree
|
|
simulated_anneal_ = functools.partialmethod(simulated_anneal, inplace=True)
|
|
parallel_temper = parallel_temper_tree
|
|
parallel_temper_ = functools.partialmethod(parallel_temper, inplace=True)
|
|
|
|
def slice(
|
|
self,
|
|
target_size=None,
|
|
target_overhead=None,
|
|
target_slices=None,
|
|
temperature=0.01,
|
|
minimize=None,
|
|
allow_outer=True,
|
|
max_repeats=16,
|
|
reslice=False,
|
|
seed=None,
|
|
inplace=False,
|
|
):
|
|
"""Slice this tree (turn some indices into indices which are explicitly
|
|
summed over rather than being part of contractions). The indices are
|
|
stored in ``tree.sliced_inds``, and the contraction width updated to
|
|
take account of the slicing. Calling ``tree.contract(arrays)`` moreover
|
|
which automatically perform the slicing and summation.
|
|
|
|
Parameters
|
|
----------
|
|
target_size : int, optional
|
|
The target number of entries in the largest tensor of the sliced
|
|
contraction. The search algorithm will terminate after this is
|
|
reached.
|
|
target_slices : int, optional
|
|
The target or minimum number of 'slices' to consider - individual
|
|
contractions after slicing indices. The search algorithm will
|
|
terminate after this is breached. This is on top of the current
|
|
number of slices.
|
|
target_overhead : float, optional
|
|
The target increase in total number of floating point operations.
|
|
For example, a value of ``2.0`` will terminate the search just
|
|
before the cost of computing all the slices individually breaches
|
|
twice that of computing the original contraction all at once.
|
|
temperature : float, optional
|
|
How much to randomize the repeated search.
|
|
minimize : {'flops', 'size', ..., Objective}, optional
|
|
Which metric to score the overhead increase against.
|
|
allow_outer : bool, optional
|
|
Whether to allow slicing of outer indices.
|
|
max_repeats : int, optional
|
|
How many times to repeat the search with a slight randomization.
|
|
reslice : bool, optional
|
|
Whether to reslice the tree, i.e. first remove all currently
|
|
sliced indices and start the search again. Generally any 'good'
|
|
sliced indices will be easily found again.
|
|
seed : None, int or random.Random, optional
|
|
A random seed or generator to use for the search.
|
|
inplace : bool, optional
|
|
Whether the remove the indices from this tree inplace or not.
|
|
|
|
Returns
|
|
-------
|
|
ContractionTree
|
|
|
|
See Also
|
|
--------
|
|
SliceFinder, ContractionTree.slice_and_reconfigure
|
|
"""
|
|
from .slicer import SliceFinder
|
|
|
|
if minimize is None:
|
|
minimize = self.get_default_objective()
|
|
|
|
tree = self if inplace else self.copy()
|
|
|
|
if reslice:
|
|
if target_slices is not None:
|
|
target_slices *= tree.nslices
|
|
tree.unslice_all_()
|
|
|
|
sf = SliceFinder(
|
|
tree,
|
|
target_size=target_size,
|
|
target_overhead=target_overhead,
|
|
target_slices=target_slices,
|
|
temperature=temperature,
|
|
minimize=minimize,
|
|
allow_outer=allow_outer,
|
|
seed=seed,
|
|
)
|
|
|
|
ix_sl, _ = sf.search(max_repeats)
|
|
for ix in ix_sl:
|
|
tree.remove_ind_(ix)
|
|
|
|
return tree
|
|
|
|
slice_ = functools.partialmethod(slice, inplace=True)
|
|
|
|
def slice_and_reconfigure(
|
|
self,
|
|
target_size,
|
|
step_size=2,
|
|
temperature=0.01,
|
|
minimize=None,
|
|
allow_outer=True,
|
|
max_repeats=16,
|
|
reslice=False,
|
|
reconf_opts=None,
|
|
progbar=False,
|
|
inplace=False,
|
|
):
|
|
"""Interleave slicing (removing indices into an exterior sum) with
|
|
subtree reconfiguration to minimize the overhead induced by this
|
|
slicing.
|
|
|
|
Parameters
|
|
----------
|
|
target_size : int
|
|
Slice the tree until the maximum intermediate size is this or
|
|
smaller.
|
|
step_size : int, optional
|
|
The minimum size reduction to try and achieve before switching to a
|
|
round of subtree reconfiguration.
|
|
temperature : float, optional
|
|
The temperature to supply to ``SliceFinder`` for searching for
|
|
indices.
|
|
minimize : {'flops', 'size', ..., Objective}, optional
|
|
The metric to minimize when slicing and reconfiguring subtrees.
|
|
max_repeats : int, optional
|
|
The number of slicing attempts to perform per search.
|
|
progbar : bool, optional
|
|
Whether to show live progress.
|
|
inplace : bool, optional
|
|
Whether to perform the slicing and reconfiguration inplace.
|
|
reconf_opts : None or dict, optional
|
|
Supplied to
|
|
:meth:`ContractionTree.subtree_reconfigure` or
|
|
:meth:`ContractionTree.subtree_reconfigure_forest`, depending on
|
|
`'forested'` key value.
|
|
"""
|
|
tree = self if inplace else self.copy()
|
|
|
|
reconf_opts = {} if reconf_opts is None else dict(reconf_opts)
|
|
|
|
if minimize is None:
|
|
minimize = self.get_default_objective()
|
|
minimize = get_score_fn(minimize)
|
|
|
|
reconf_opts.setdefault("minimize", minimize)
|
|
forested_reconf = reconf_opts.pop("forested", False)
|
|
|
|
if progbar:
|
|
import tqdm
|
|
|
|
pbar = tqdm.tqdm()
|
|
pbar.set_description(_describe_tree(tree), refresh=False)
|
|
|
|
try:
|
|
while tree.max_size() > target_size:
|
|
tree.slice_(
|
|
temperature=temperature,
|
|
target_slices=step_size,
|
|
minimize=minimize,
|
|
allow_outer=allow_outer,
|
|
max_repeats=max_repeats,
|
|
reslice=reslice,
|
|
)
|
|
if forested_reconf:
|
|
tree.subtree_reconfigure_forest_(**reconf_opts)
|
|
else:
|
|
tree.subtree_reconfigure_(**reconf_opts)
|
|
|
|
if progbar:
|
|
pbar.update()
|
|
pbar.set_description(_describe_tree(tree), refresh=False)
|
|
finally:
|
|
if progbar:
|
|
pbar.close()
|
|
|
|
return tree
|
|
|
|
slice_and_reconfigure_ = functools.partialmethod(
|
|
slice_and_reconfigure, inplace=True
|
|
)
|
|
|
|
def slice_and_reconfigure_forest(
|
|
self,
|
|
target_size,
|
|
step_size=2,
|
|
num_trees=8,
|
|
restart_fraction=0.5,
|
|
temperature=0.02,
|
|
max_repeats=32,
|
|
reslice=False,
|
|
minimize=None,
|
|
allow_outer=True,
|
|
parallel="auto",
|
|
progbar=False,
|
|
inplace=False,
|
|
reconf_opts=None,
|
|
):
|
|
"""'Forested' version of :meth:`ContractionTree.slice_and_reconfigure`.
|
|
This maintains a 'forest' of trees with different slicing and subtree
|
|
reconfiguration attempts, pruning the worst at each step and generating
|
|
a new forest from the best.
|
|
|
|
Parameters
|
|
----------
|
|
target_size : int
|
|
Slice the tree until the maximum intermediate size is this or
|
|
smaller.
|
|
step_size : int, optional
|
|
The minimum size reduction to try and achieve before switching to a
|
|
round of subtree reconfiguration.
|
|
num_restarts : int, optional
|
|
The number of times to halt, prune and then restart the
|
|
tree reconfigurations.
|
|
restart_fraction : float, optional
|
|
The fraction of trees to keep at each stage and generate the next
|
|
forest from.
|
|
temperature : float, optional
|
|
The temperature at which to randomize the sliced index search.
|
|
max_repeats : int, optional
|
|
The number of slicing attempts to perform per search.
|
|
parallel : 'auto', False, True, int, or distributed.Client
|
|
Whether to parallelize the search.
|
|
progbar : bool, optional
|
|
Whether to show live progress.
|
|
inplace : bool, optional
|
|
Whether to perform the slicing and reconfiguration inplace.
|
|
reconf_opts : None or dict, optional
|
|
Supplied to
|
|
:meth:`ContractionTree.slice_and_reconfigure`.
|
|
|
|
Returns
|
|
-------
|
|
ContractionTree
|
|
"""
|
|
tree = self if inplace else self.copy()
|
|
|
|
# some of these might be unpicklable
|
|
tree.contraction_cores.clear()
|
|
|
|
# candidate trees
|
|
num_keep = max(1, int(num_trees * restart_fraction))
|
|
|
|
# how to rank the trees
|
|
if minimize is None:
|
|
minimize = self.get_default_objective()
|
|
score = get_score_fn(minimize)
|
|
|
|
# set up the initial 'forest' and parallel machinery
|
|
pool = parse_parallel_arg(parallel)
|
|
is_scatter_pool = can_scatter(pool)
|
|
if is_scatter_pool:
|
|
is_worker = maybe_leave_pool(pool)
|
|
# store the trees as futures for the entire process
|
|
forest = [scatter(pool, tree)]
|
|
else:
|
|
forest = [tree]
|
|
|
|
if progbar:
|
|
import tqdm
|
|
|
|
pbar = tqdm.tqdm()
|
|
pbar.set_description(_describe_tree(tree), refresh=False)
|
|
|
|
next_size = tree.max_size()
|
|
|
|
try:
|
|
while True:
|
|
next_size //= step_size
|
|
|
|
# on the next round take only the best trees
|
|
forest = itertools.cycle(forest[:num_keep])
|
|
|
|
saplings = [
|
|
{
|
|
"tree": next(forest),
|
|
"target_size": next_size,
|
|
"step_size": step_size,
|
|
"temperature": temperature,
|
|
"max_repeats": max_repeats,
|
|
"reconf_opts": reconf_opts,
|
|
"allow_outer": allow_outer,
|
|
"reslice": reslice,
|
|
}
|
|
for _ in range(num_trees)
|
|
]
|
|
|
|
if pool is None:
|
|
forest = [
|
|
_slice_and_reconfigure_tree(**s) for s in saplings
|
|
]
|
|
res = [{"tree": t, **_get_tree_info(t)} for t in forest]
|
|
|
|
elif not is_scatter_pool:
|
|
# simple pool with no pass by reference
|
|
forest_futures = [
|
|
submit(pool, _slice_and_reconfigure_tree, **s)
|
|
for s in saplings
|
|
]
|
|
forest = [f.result() for f in forest_futures]
|
|
res = [{"tree": t, **_get_tree_info(t)} for t in forest]
|
|
|
|
else:
|
|
forest_futures = [
|
|
submit(pool, _slice_and_reconfigure_tree, **s)
|
|
for s in saplings
|
|
]
|
|
|
|
# compute scores remotely then gather
|
|
res_futures = [
|
|
submit(pool, _get_tree_info, t) for t in forest_futures
|
|
]
|
|
res = [
|
|
{"tree": tree_future, **res_future.result()}
|
|
for tree_future, res_future in zip(
|
|
forest_futures, res_futures
|
|
)
|
|
]
|
|
|
|
# we want to sort by flops, but also favour sampling as
|
|
# many different sliced index combos as possible
|
|
# ~ [1, 1, 1, 2, 2, 3] -> [1, 2, 3, 1, 2, 1]
|
|
res.sort(key=score)
|
|
res = list(
|
|
interleave(
|
|
groupby(lambda r: r["sliced_ind_set"], res).values()
|
|
)
|
|
)
|
|
|
|
# update the order of the new forest
|
|
forest = [r["tree"] for r in res]
|
|
|
|
if progbar:
|
|
pbar.update()
|
|
if pool is None:
|
|
d = _describe_tree(forest[0])
|
|
else:
|
|
d = submit(pool, _describe_tree, forest[0]).result()
|
|
pbar.set_description(d, refresh=False)
|
|
|
|
if res[0]["size"] <= target_size:
|
|
break
|
|
|
|
finally:
|
|
if progbar:
|
|
pbar.close()
|
|
|
|
if is_scatter_pool:
|
|
tree.set_state_from(forest[0].result())
|
|
maybe_rejoin_pool(is_worker, pool)
|
|
else:
|
|
tree.set_state_from(forest[0])
|
|
|
|
return tree
|
|
|
|
slice_and_reconfigure_forest_ = functools.partialmethod(
|
|
slice_and_reconfigure_forest, inplace=True
|
|
)
|
|
|
|
def compressed_reconfigure(
|
|
self,
|
|
minimize=None,
|
|
order_only=False,
|
|
max_nodes="auto",
|
|
max_time=None,
|
|
local_score=None,
|
|
exploration_power=0,
|
|
best_score=None,
|
|
progbar=False,
|
|
inplace=False,
|
|
):
|
|
"""Reconfigure this tree according to ``peak_size_compressed``.
|
|
|
|
Parameters
|
|
----------
|
|
chi : int
|
|
The maximum bond dimension to consider.
|
|
order_only : bool, optional
|
|
Whether to only consider the ordering of the current tree
|
|
contractions, or all possible contractions, starting with the
|
|
current.
|
|
max_nodes : int, optional
|
|
Set the maximum number of contraction steps to consider.
|
|
max_time : float, optional
|
|
Set the maximum time to spend on the search.
|
|
local_score : callable, optional
|
|
A function that assigns a score to a potential contraction, with a
|
|
lower score giving more priority to explore that contraction
|
|
earlier. It should have signature::
|
|
|
|
local_score(step, new_score, dsize, new_size)
|
|
|
|
where ``step`` is the number of steps so far, ``new_score`` is the
|
|
score of the contraction so far, ``dsize`` is the change in memory
|
|
by the current step, and ``new_size`` is the new memory size after
|
|
contraction.
|
|
exploration_power : float, optional
|
|
If not ``0.0``, the inverse power to which the step is raised in
|
|
the default local score function. Higher values favor exploring
|
|
more promising branches early on - at the cost of increased memory.
|
|
Ignored if ``local_score`` is supplied.
|
|
best_score : float, optional
|
|
Manually specify an upper bound for best score found so far.
|
|
progbar : bool, optional
|
|
If ``True``, display a progress bar.
|
|
inplace : bool, optional
|
|
Whether to perform the reconfiguration inplace on this tree.
|
|
|
|
Returns
|
|
-------
|
|
ContractionTree
|
|
"""
|
|
from .experimental.path_compressed_branchbound import (
|
|
CompressedExhaustive,
|
|
)
|
|
|
|
if minimize is None:
|
|
minimize = self.get_default_objective()
|
|
|
|
if max_nodes == "auto":
|
|
if max_time is None:
|
|
max_nodes = max(10_000, self.N**2)
|
|
else:
|
|
max_nodes = float("inf")
|
|
|
|
opt = CompressedExhaustive(
|
|
minimize=minimize,
|
|
local_score=local_score,
|
|
max_nodes=max_nodes,
|
|
max_time=max_time,
|
|
exploration_power=exploration_power,
|
|
best_score=best_score,
|
|
progbar=progbar,
|
|
)
|
|
opt.setup(self.inputs, self.output, self.size_dict)
|
|
opt.explore_path(self.get_path_surface(), restrict=order_only)
|
|
|
|
# rtree = opt.search(self.inputs, self.output, self.size_dict)
|
|
|
|
opt.run(self.inputs, self.output, self.size_dict)
|
|
ssa_path = opt.ssa_path
|
|
# ssa_path = opt(self.inputs, self.output, self.size_dict)
|
|
rtree = self.__class__.from_path(
|
|
self.inputs,
|
|
self.output,
|
|
self.size_dict,
|
|
ssa_path=ssa_path,
|
|
objective=minimize,
|
|
)
|
|
if inplace:
|
|
self.set_state_from(rtree)
|
|
rtree = self
|
|
|
|
rtree.contraction_cores.clear()
|
|
return rtree
|
|
|
|
compressed_reconfigure_ = functools.partialmethod(
|
|
compressed_reconfigure, inplace=True
|
|
)
|
|
|
|
def windowed_reconfigure(
|
|
self,
|
|
minimize=None,
|
|
order_only=False,
|
|
window_size=20,
|
|
max_iterations=100,
|
|
max_window_tries=1000,
|
|
score_temperature=0.0,
|
|
queue_temperature=1.0,
|
|
scorer=None,
|
|
queue_scorer=None,
|
|
seed=None,
|
|
inplace=False,
|
|
progbar=False,
|
|
**kwargs,
|
|
):
|
|
from .pathfinders.path_compressed import WindowedOptimizer
|
|
|
|
if minimize is None:
|
|
minimize = self.get_default_objective()
|
|
|
|
wo = WindowedOptimizer(
|
|
self.inputs,
|
|
self.output,
|
|
self.size_dict,
|
|
minimize=minimize,
|
|
ssa_path=self.get_ssa_path(),
|
|
seed=seed,
|
|
)
|
|
|
|
wo.refine(
|
|
window_size=window_size,
|
|
max_iterations=max_iterations,
|
|
order_only=order_only,
|
|
max_window_tries=max_window_tries,
|
|
score_temperature=score_temperature,
|
|
queue_temperature=queue_temperature,
|
|
scorer=scorer,
|
|
queue_scorer=queue_scorer,
|
|
progbar=progbar,
|
|
**kwargs,
|
|
)
|
|
ssa_path = wo.get_ssa_path()
|
|
|
|
rtree = self.__class__.from_path(
|
|
self.inputs,
|
|
self.output,
|
|
self.size_dict,
|
|
ssa_path=ssa_path,
|
|
objective=minimize,
|
|
)
|
|
|
|
if inplace:
|
|
self.set_state_from(rtree)
|
|
rtree = self
|
|
|
|
rtree.contraction_cores.clear()
|
|
return rtree
|
|
|
|
windowed_reconfigure_ = functools.partialmethod(
|
|
windowed_reconfigure, inplace=True
|
|
)
|
|
|
|
def flat_tree(self, order=None):
|
|
"""Create a nested tuple representation of the contraction tree like::
|
|
|
|
((0, (1, 2)), ((3, 4), ((5, (6, 7)), (8, 9))))
|
|
|
|
Such that the contraction will progress like::
|
|
|
|
((0, (1, 2)), ((3, 4), ((5, (6, 7)), (8, 9))))
|
|
((0, 12), (34, ((5, 67), 89)))
|
|
(012, (34, (567, 89)))
|
|
(012, (34, 56789))
|
|
(012, 3456789)
|
|
0123456789
|
|
|
|
Where each integer represents a leaf (i.e. single element node).
|
|
"""
|
|
tups = dict(zip(self.gen_leaves(), range(self.N)))
|
|
|
|
for parent, l, r in self.traverse(order=order):
|
|
tups[parent] = tups[l], tups[r]
|
|
|
|
return tups[self.root]
|
|
|
|
def get_leaves_ordered(self):
|
|
"""Return the list of leaves as ordered by the contraction tree.
|
|
|
|
Returns
|
|
-------
|
|
tuple[frozenset[str]]
|
|
"""
|
|
if not self.is_complete():
|
|
raise ValueError("Can't order the leaves until tree is complete.")
|
|
|
|
return tuple(
|
|
nd
|
|
for nd in itertools.chain.from_iterable(self.traverse())
|
|
if len(nd) == 1
|
|
)
|
|
|
|
def get_path(self, order=None):
|
|
"""Generate a standard path (with linear recycled ids) from the
|
|
contraction tree.
|
|
|
|
Parameters
|
|
----------
|
|
order : None, "dfs", or callable, optional
|
|
How to order the contractions within the tree. If a callable is
|
|
given (which should take a node as its argument), try to contract
|
|
nodes that minimize this function first.
|
|
|
|
Returns
|
|
-------
|
|
path: tuple[tuple[int, int]]
|
|
"""
|
|
from bisect import bisect_left
|
|
|
|
ssa = self.N
|
|
ssas = list(range(ssa))
|
|
node_to_ssa = dict(zip(self.gen_leaves(), ssas))
|
|
path = []
|
|
|
|
for parent, left, right in self.traverse(order=order):
|
|
# map nodes to ssas
|
|
lssa = node_to_ssa[left]
|
|
rssa = node_to_ssa[right]
|
|
# map ssas to linear indices, using bisection
|
|
i, j = sorted((bisect_left(ssas, lssa), bisect_left(ssas, rssa)))
|
|
# 'contract' nodes
|
|
ssas.pop(j)
|
|
ssas.pop(i)
|
|
path.append((i, j))
|
|
ssas.append(ssa)
|
|
# update mapping
|
|
node_to_ssa[parent] = ssa
|
|
ssa += 1
|
|
|
|
return tuple(path)
|
|
|
|
path = deprecated(get_path, "path", "get_path")
|
|
|
|
def get_numpy_path(self, order=None):
|
|
"""Generate a path compatible with the `optimize` kwarg of
|
|
`numpy.einsum`.
|
|
"""
|
|
return ["einsum_path", *self.get_path(order=order)]
|
|
|
|
def get_ssa_path(self, order=None):
|
|
"""Generate a single static assignment path from the contraction tree.
|
|
|
|
Parameters
|
|
----------
|
|
order : None, "dfs", or callable, optional
|
|
How to order the contractions within the tree. If a callable is
|
|
given (which should take a node as its argument), try to contract
|
|
nodes that minimize this function first.
|
|
|
|
Returns
|
|
-------
|
|
ssa_path: tuple[tuple[int, int]]
|
|
"""
|
|
ssa_path = []
|
|
pos = dict(zip(self.gen_leaves(), range(self.N)))
|
|
|
|
for parent, l, r in self.traverse(order=order):
|
|
i, j = sorted((pos[l], pos[r]))
|
|
ssa_path.append((i, j))
|
|
pos[parent] = len(ssa_path) + self.N - 1
|
|
|
|
return tuple(ssa_path)
|
|
|
|
ssa_path = deprecated(get_ssa_path, "ssa_path", "get_ssa_path")
|
|
|
|
def surface_order(self, node):
|
|
return (len(node), self.get_centrality(node))
|
|
|
|
def set_surface_order_from_path(self, ssa_path):
|
|
o = {}
|
|
nodes = list(self.gen_leaves())
|
|
for j, p in enumerate(ssa_path):
|
|
l, r = (nodes[i] for i in p)
|
|
p = l.union(r)
|
|
nodes.append(p)
|
|
o[p] = j
|
|
|
|
self.surface_order = functools.partial(
|
|
get_with_default, obj=o, default=float("inf")
|
|
)
|
|
|
|
def get_path_surface(self):
|
|
return self.get_path(order=self.surface_order)
|
|
|
|
path_surface = deprecated(
|
|
get_path_surface, "path_surface", "get_path_surface"
|
|
)
|
|
|
|
def get_ssa_path_surface(self):
|
|
return self.get_ssa_path(order=self.surface_order)
|
|
|
|
ssa_path_surface = deprecated(
|
|
get_ssa_path_surface, "ssa_path_surface", "get_ssa_path_surface"
|
|
)
|
|
|
|
def get_spans(self):
|
|
"""Get all (which could mean none) potential embeddings of this
|
|
contraction tree into a spanning tree of the original graph.
|
|
|
|
Returns
|
|
-------
|
|
tuple[dict[frozenset[int], frozenset[int]]]
|
|
"""
|
|
ind_to_term = collections.defaultdict(set)
|
|
for i, term in enumerate(self.inputs):
|
|
for ix in term:
|
|
ind_to_term[ix].add(i)
|
|
|
|
def boundary_pairs(node):
|
|
"""Get nodes along the boundary of the bipartition represented by
|
|
``node``.
|
|
"""
|
|
pairs = set()
|
|
involved = self.get_involved(node)
|
|
legs = self.get_legs(node)
|
|
removed = [ix for ix in involved if ix not in legs]
|
|
for ix in removed:
|
|
# for every index across the contraction
|
|
l1, l2 = ind_to_term[ix]
|
|
|
|
# can either span from left to right or right to left
|
|
pairs.add((l1, l2))
|
|
pairs.add((l2, l1))
|
|
|
|
return pairs
|
|
|
|
# first span choice is any nodes across the top level bipart
|
|
candidates = [
|
|
{
|
|
# which intermedate nodes map to which leaf nodes
|
|
"map": {self.root: node_from_single(l2)},
|
|
# the leaf nodes in the spanning tree
|
|
"spine": {l1, l2},
|
|
}
|
|
for l1, l2 in boundary_pairs(self.root)
|
|
]
|
|
|
|
for _, l, r in self.descend():
|
|
for child in (r, l):
|
|
# for each current candidate check all the possible extensions
|
|
for _ in range(len(candidates)):
|
|
cand = candidates.pop(0)
|
|
|
|
# don't need to do anything for
|
|
if len(child) == 1:
|
|
candidates.append(
|
|
{
|
|
"map": {child: child, **cand["map"]},
|
|
"spine": cand["spine"].copy(),
|
|
}
|
|
)
|
|
|
|
for l1, l2 in boundary_pairs(child):
|
|
if (l1 in cand["spine"]) or (l2 not in cand["spine"]):
|
|
# pair does not merge inwards into spine
|
|
continue
|
|
|
|
# valid extension of spanning tree
|
|
candidates.append(
|
|
{
|
|
"map": {
|
|
child: node_from_single(l2),
|
|
**cand["map"],
|
|
},
|
|
"spine": cand["spine"] | {l1, l2},
|
|
}
|
|
)
|
|
|
|
return tuple(c["map"] for c in candidates)
|
|
|
|
def compute_centralities(self, combine="mean"):
|
|
"""Compute a centrality for every node in this contraction tree."""
|
|
hg = self.get_hypergraph(accel="auto")
|
|
cents = hg.simple_centrality()
|
|
|
|
for i, leaf in enumerate(self.gen_leaves()):
|
|
self.info[leaf]["centrality"] = cents[i]
|
|
|
|
combine = {
|
|
"mean": lambda x, y: (x + y) / 2,
|
|
"sum": lambda x, y: (x + y),
|
|
"max": max,
|
|
"min": min,
|
|
}.get(combine, combine)
|
|
|
|
for p, l, r in self.traverse("dfs"):
|
|
self.info[p]["centrality"] = combine(
|
|
self.info[l]["centrality"], self.info[r]["centrality"]
|
|
)
|
|
|
|
def get_hypergraph(self, accel=False):
|
|
"""Get a hypergraph representing the uncontracted network (i.e. the
|
|
leaves).
|
|
"""
|
|
return get_hypergraph(self.inputs, self.output, self.size_dict, accel)
|
|
|
|
def reset_contraction_indices(self):
|
|
"""Reset all information regarding the explicit contraction indices
|
|
ordering.
|
|
"""
|
|
# delete all derived information
|
|
for node in self.children:
|
|
for k in (
|
|
"inds",
|
|
"einsum_eq",
|
|
"can_dot",
|
|
"tensordot_axes",
|
|
"tensordot_perm",
|
|
):
|
|
self.info[node].pop(k, None)
|
|
|
|
# invalidate any compiled contractions
|
|
self.contraction_cores.clear()
|
|
|
|
def sort_contraction_indices(
|
|
self,
|
|
priority="flops",
|
|
make_output_contig=True,
|
|
make_contracted_contig=True,
|
|
reset=True,
|
|
):
|
|
"""Set explicit orders for the contraction indices of this self to
|
|
optimize for one of two things: contiguity in contracted ('k') indices,
|
|
or contiguity of left and right output ('m' and 'n') indices.
|
|
|
|
Parameters
|
|
----------
|
|
priority : {'flops', 'size', 'root', 'leaves'}, optional
|
|
Which order to process the intermediate nodes in. Later nodes
|
|
re-sort previous nodes so are more likely to keep their ordering.
|
|
E.g. for 'flops' the mostly costly contracton will be process last
|
|
and thus will be guaranteed to have its indices exactly sorted.
|
|
make_output_contig : bool, optional
|
|
When processing a pairwise contraction, sort the parent contraction
|
|
indices so that the order of indices is the order they appear
|
|
from left to right in the two child (input) tensors.
|
|
make_contracted_contig : bool, optional
|
|
When processing a pairwise contraction, sort the child (input)
|
|
tensor indices so that all contracted indices appear contiguously.
|
|
reset : bool, optional
|
|
Reset all indices to the default order before sorting.
|
|
"""
|
|
if reset:
|
|
self.reset_contraction_indices()
|
|
|
|
if priority == "flops":
|
|
nodes = sorted(
|
|
self.children.items(), key=lambda x: self.get_flops(x[0])
|
|
)
|
|
elif priority == "size":
|
|
nodes = sorted(
|
|
self.children.items(), key=lambda x: self.get_size(x[0])
|
|
)
|
|
elif priority == "root":
|
|
nodes = ((p, (l, r)) for p, l, r in self.traverse())
|
|
elif priority == "leaves":
|
|
nodes = ((p, (l, r)) for p, l, r in self.descend())
|
|
else:
|
|
raise ValueError(priority)
|
|
|
|
for p, (l, r) in nodes:
|
|
p_inds, l_inds, r_inds = map(self.get_inds, (p, l, r))
|
|
|
|
if make_output_contig and len(p) != self.N:
|
|
# sort indices by whether they appear in the left or right
|
|
# whether this happens before or after the sort below depends
|
|
# on the order we are processing the nodes
|
|
# (avoid root as don't want to modify output)
|
|
|
|
def psort(ix):
|
|
# group by whether in left or right input
|
|
return (r_inds.find(ix), l_inds.find(ix))
|
|
|
|
p_inds = "".join(sorted(p_inds, key=psort))
|
|
self.info[p]["inds"] = p_inds
|
|
|
|
if make_contracted_contig:
|
|
# sort indices by:
|
|
# 1. if they are going to be contracted
|
|
# 2. what order they appear in the parent indices
|
|
# (but ignore leaf indices)
|
|
if len(l) != 1:
|
|
|
|
def lsort(ix):
|
|
return (r_inds.find(ix), p_inds.find(ix))
|
|
|
|
l_inds = "".join(sorted(self.get_legs(l), key=lsort))
|
|
self.info[l]["inds"] = l_inds
|
|
|
|
if len(r) != 1:
|
|
|
|
def rsort(ix):
|
|
return (p_inds.find(ix), l_inds.find(ix))
|
|
|
|
r_inds = "".join(sorted(self.get_legs(r), key=rsort))
|
|
self.info[r]["inds"] = r_inds
|
|
|
|
# invalidate any compiled contractions
|
|
self.contraction_cores.clear()
|
|
|
|
def print_contractions(self, sort=None, show_brackets=True):
|
|
"""Print each pairwise contraction, with colorized indices (if
|
|
`colorama` is installed), and other information. The color codes are:
|
|
|
|
- blue: index appears on left and is kept
|
|
- green: index appears on right and is kept
|
|
- red: contracted index: appears on both sides and is removed
|
|
- pink: batch index: appears on both sides and is kept
|
|
|
|
Any trivial indices that appear only on one term and not in the output
|
|
are removed and shown by the preprocessing steps.
|
|
|
|
Parameters
|
|
----------
|
|
sort : {'flops', 'size'}, optional
|
|
Sort the contractions by either the number of floating point
|
|
operations or the size of the intermediate tensor. By default the
|
|
contraction are show in the order they are performed.
|
|
show_brackets : bool, optional
|
|
Whether to show the brackets around contiguous sections of the same
|
|
type of indices.
|
|
"""
|
|
try:
|
|
from colorama import Fore
|
|
|
|
RESET = Fore.RESET
|
|
GREY = Fore.WHITE
|
|
PINK = Fore.MAGENTA
|
|
RED = Fore.RED
|
|
BLUE = Fore.BLUE
|
|
GREEN = Fore.GREEN
|
|
except ImportError:
|
|
RESET = GREY = PINK = RED = BLUE = GREEN = ""
|
|
|
|
entries = []
|
|
|
|
if self.has_preprocessing():
|
|
for pi, eq in self.preprocessing.items():
|
|
# eq is with canonical indices, reinsert original inputs
|
|
replacer = dict(zip(eq.split("->")[0], self.inputs[pi]))
|
|
eq = "".join(replacer.get(c, c) for c in eq)
|
|
print(f"{GREY}preprocess input {pi}: {RESET}{eq}")
|
|
print()
|
|
|
|
for i, (p, l, r) in enumerate(self.traverse()):
|
|
p_legs, l_legs, r_legs = map(self.get_legs, [p, l, r])
|
|
p_inds, l_inds, r_inds = map(self.get_inds, [p, l, r])
|
|
|
|
# print sizes and flops
|
|
p_flops = self.get_flops(p)
|
|
p_sz, l_sz, r_sz = (
|
|
math.log2(self.get_size(node)) for node in [p, l, r]
|
|
)
|
|
# print whether tensordottable
|
|
if self.get_can_dot(p):
|
|
type_msg = "tensordot"
|
|
perm = self.get_tensordot_perm(p)
|
|
if perm is not None:
|
|
# and whether indices match tensordot
|
|
type_msg += "+perm"
|
|
else:
|
|
type_msg = "einsum"
|
|
|
|
kpt_brck_l = "(" if show_brackets else ""
|
|
kpt_brck_r = ")" if show_brackets else ""
|
|
con_brck_l = "[" if show_brackets else ""
|
|
con_brck_r = "]" if show_brackets else ""
|
|
hyp_brck_l = "{" if show_brackets else ""
|
|
hyp_brck_r = "}" if show_brackets else ""
|
|
|
|
pa = (
|
|
"".join(
|
|
PINK + f"{hyp_brck_l}{ix}{hyp_brck_r}"
|
|
if (ix in l_legs) and (ix in r_legs)
|
|
else GREEN + f"{kpt_brck_l}{ix}{kpt_brck_r}"
|
|
if ix in r_legs
|
|
else BLUE + ix
|
|
for ix in p_inds
|
|
)
|
|
.replace(f"){GREEN}(", "")
|
|
.replace(f"}}{PINK}{{", "")
|
|
)
|
|
la = (
|
|
"".join(
|
|
PINK + f"{hyp_brck_l}{ix}{hyp_brck_r}"
|
|
if (ix in p_legs) and (ix in r_legs)
|
|
else RED + f"{con_brck_l}{ix}{con_brck_r}"
|
|
if ix in r_legs
|
|
else BLUE + ix
|
|
for ix in l_inds
|
|
)
|
|
.replace(f"]{RED}[", "")
|
|
.replace(f"}}{PINK}{{", "")
|
|
)
|
|
ra = (
|
|
"".join(
|
|
PINK + f"{hyp_brck_l}{ix}{hyp_brck_r}"
|
|
if (ix in p_legs) and (ix in l_legs)
|
|
else RED + f"{con_brck_l}{ix}{con_brck_r}"
|
|
if ix in l_legs
|
|
else GREEN + ix
|
|
for ix in r_inds
|
|
)
|
|
.replace(f"]{RED}[", "")
|
|
.replace(f"}}{PINK}{{", "")
|
|
)
|
|
|
|
entries.append(
|
|
(
|
|
p,
|
|
f"{GREY}({i}) cost: {RESET}{p_flops:.1e} "
|
|
f"{GREY}widths: {RESET}{l_sz:.1f},{r_sz:.1f}->{p_sz:.1f} "
|
|
f"{GREY}type: {RESET}{type_msg}\n"
|
|
f"{GREY}inputs: {la},{ra}{RESET}->\n"
|
|
f"{GREY}output: {pa}\n",
|
|
)
|
|
)
|
|
|
|
if sort == "flops":
|
|
entries.sort(key=lambda x: self.get_flops(x[0]), reverse=True)
|
|
if sort == "size":
|
|
entries.sort(key=lambda x: self.get_size(x[0]), reverse=True)
|
|
|
|
entries.append((None, f"{RESET}"))
|
|
|
|
o = "\n".join(entry for _, entry in entries)
|
|
print(o)
|
|
|
|
# --------------------- Performing the Contraction ---------------------- #
|
|
|
|
def get_contractor(
|
|
self,
|
|
order=None,
|
|
prefer_einsum=False,
|
|
strip_exponent=False,
|
|
check_zero=False,
|
|
implementation=None,
|
|
autojit=False,
|
|
progbar=False,
|
|
):
|
|
"""Get a reusable function which performs the contraction corresponding
|
|
to this tree, cached.
|
|
|
|
Parameters
|
|
----------
|
|
tree : ContractionTree
|
|
The contraction tree.
|
|
order : str or callable, optional
|
|
Supplied to :meth:`ContractionTree.traverse`, the order in which
|
|
to perform the pairwise contractions given by the tree.
|
|
prefer_einsum : bool, optional
|
|
Prefer to use ``einsum`` for pairwise contractions, even if
|
|
``tensordot`` can perform the contraction.
|
|
strip_exponent : bool, optional
|
|
If ``True``, the function will eagerly strip the exponent (in
|
|
log10) from intermediate tensors to control numerical problems from
|
|
leaving the range of the datatype. This method then returns the
|
|
scaled 'mantissa' output array and the exponent separately.
|
|
check_zero : bool, optional
|
|
If ``True``, when ``strip_exponent=True``, explicitly check for
|
|
zero-valued intermediates that would otherwise produce ``nan``,
|
|
instead terminating early if encountered and returning
|
|
``(0.0, 0.0)``.
|
|
implementation : str or tuple[callable, callable], optional
|
|
What library to use to actually perform the contractions. Options
|
|
are:
|
|
|
|
- None: let cotengra choose.
|
|
- "autoray": dispatch with autoray, using the ``tensordot`` and
|
|
``einsum`` implementation of the backend.
|
|
- "cotengra": use the ``tensordot`` and ``einsum`` implementation
|
|
of cotengra, which is based on batch matrix multiplication. This
|
|
is faster for some backends like numpy, and also enables
|
|
libraries which don't yet provide ``tensordot`` and ``einsum`` to
|
|
be used.
|
|
- "cuquantum": use the cuquantum library to perform the whole
|
|
contraction (not just individual contractions).
|
|
- tuple[callable, callable]: manually supply the ``tensordot`` and
|
|
``einsum`` implementations to use.
|
|
|
|
autojit : bool, optional
|
|
If ``True``, use :func:`autoray.autojit` to compile the contraction
|
|
function.
|
|
progbar : bool, optional
|
|
Whether to show progress through the contraction by default.
|
|
|
|
Returns
|
|
-------
|
|
fn : callable
|
|
The contraction function, with signature ``fn(*arrays)``.
|
|
"""
|
|
key = (
|
|
autojit,
|
|
order,
|
|
prefer_einsum,
|
|
strip_exponent,
|
|
check_zero,
|
|
implementation,
|
|
progbar,
|
|
)
|
|
try:
|
|
fn = self.contraction_cores[key]
|
|
except KeyError:
|
|
fn = self.contraction_cores[key] = make_contractor(
|
|
tree=self,
|
|
order=order,
|
|
prefer_einsum=prefer_einsum,
|
|
strip_exponent=strip_exponent,
|
|
check_zero=check_zero,
|
|
implementation=implementation,
|
|
autojit=autojit,
|
|
progbar=progbar,
|
|
)
|
|
|
|
return fn
|
|
|
|
def contract_core(
|
|
self,
|
|
arrays,
|
|
order=None,
|
|
prefer_einsum=False,
|
|
strip_exponent=False,
|
|
check_zero=False,
|
|
backend=None,
|
|
implementation=None,
|
|
autojit="auto",
|
|
progbar=False,
|
|
):
|
|
"""Contract ``arrays`` with this tree. The order of the axes and
|
|
output is assumed to be that of ``tree.inputs`` and ``tree.output``,
|
|
but with sliced indices removed. This functon contracts the core tree
|
|
and thus if indices have been sliced the arrays supplied need to be
|
|
sliced as well.
|
|
|
|
Parameters
|
|
----------
|
|
arrays : sequence of array
|
|
The arrays to contract.
|
|
order : str or callable, optional
|
|
Supplied to :meth:`ContractionTree.traverse`.
|
|
prefer_einsum : bool, optional
|
|
Prefer to use ``einsum`` for pairwise contractions, even if
|
|
``tensordot`` can perform the contraction.
|
|
backend : str, optional
|
|
What library to use for ``einsum`` and ``transpose``, will be
|
|
automatically inferred from the arrays if not given.
|
|
autojit : "auto" or bool, optional
|
|
Whether to use ``autoray.autojit`` to jit compile the expression.
|
|
If "auto", then let ``cotengra`` choose.
|
|
progbar : bool, optional
|
|
Show progress through the contraction.
|
|
"""
|
|
if autojit == "auto":
|
|
# choose for the user
|
|
autojit = backend == "jax"
|
|
|
|
fn = self.get_contractor(
|
|
order=order,
|
|
prefer_einsum=prefer_einsum,
|
|
strip_exponent=strip_exponent is not False,
|
|
implementation=implementation,
|
|
autojit=autojit,
|
|
check_zero=check_zero,
|
|
progbar=progbar,
|
|
)
|
|
return fn(*arrays, backend=backend)
|
|
|
|
def slice_key(self, i, strides=None):
|
|
"""Get the combination of sliced index values for overall slice ``i``.
|
|
|
|
Parameters
|
|
----------
|
|
i : int
|
|
The overall slice index.
|
|
|
|
Returns
|
|
-------
|
|
key : dict[str, int]
|
|
The value each sliced index takes for slice ``i``.
|
|
"""
|
|
if strides is None:
|
|
strides = get_slice_strides(self.sliced_inds)
|
|
|
|
key = {}
|
|
for (ind, info), stride in zip(self.sliced_inds.items(), strides):
|
|
if info.project is None:
|
|
key[ind] = i // stride
|
|
i %= stride
|
|
else:
|
|
# size is 1 and i doesn't change
|
|
key[ind] = info.project
|
|
|
|
return key
|
|
|
|
def slice_arrays(self, arrays, i):
|
|
"""Take ``arrays`` and slice the relevant inputs according to
|
|
``tree.sliced_inds`` and the dynary representation of ``i``.
|
|
"""
|
|
temp_arrays = list(arrays)
|
|
|
|
# e.g. {'a': 2, 'd': 7, 'z': 0}
|
|
locations = self.slice_key(i)
|
|
|
|
for c in self.sliced_inputs:
|
|
# the indexing object, e.g. [:, :, 7, :, 2, :, :, 0]
|
|
selector = tuple(
|
|
locations.get(ix, slice(None)) for ix in self.inputs[c]
|
|
)
|
|
# re-insert the sliced array
|
|
temp_arrays[c] = temp_arrays[c][selector]
|
|
|
|
return temp_arrays
|
|
|
|
def contract_slice(self, arrays, i, **kwargs):
|
|
"""Get slices ``i`` of ``arrays`` and then contract them."""
|
|
return self.contract_core(self.slice_arrays(arrays, i), **kwargs)
|
|
|
|
def gather_slices(self, slices, backend=None, progbar=False):
|
|
"""Gather all the output contracted slices into a single full result.
|
|
If none of the sliced indices appear in the output, then this is a
|
|
simple sum - otherwise the slices need to be partially summed and
|
|
partially stacked.
|
|
"""
|
|
if progbar:
|
|
import tqdm
|
|
|
|
slices = tqdm.tqdm(slices, total=self.multiplicity)
|
|
|
|
output_pos = {
|
|
ix: i for i, ix in enumerate(self.output) if ix in self.sliced_inds
|
|
}
|
|
|
|
if not output_pos:
|
|
# we can just sum everything
|
|
return functools.reduce(add_maybe_exponent_stripped, slices)
|
|
|
|
# first we sum over non-output sliced indices
|
|
chunks = {}
|
|
for i, s in enumerate(slices):
|
|
key_slice = self.slice_key(i)
|
|
key = tuple(key_slice[ix] for ix in output_pos)
|
|
try:
|
|
chunks[key] = add_maybe_exponent_stripped(chunks[key], s)
|
|
except KeyError:
|
|
chunks[key] = s
|
|
|
|
if isinstance(next(iter(chunks.values())), tuple):
|
|
# have stripped exponents, need to scale to largest
|
|
emax = max(v[1] for v in chunks.values())
|
|
chunks = {
|
|
k: mi * 10 ** (ei - emax) for k, (mi, ei) in chunks.items()
|
|
}
|
|
else:
|
|
emax = None
|
|
|
|
# then we stack these summed chunks over output sliced indices
|
|
def recursively_stack_chunks(loc, remaining):
|
|
if not remaining:
|
|
return chunks[loc]
|
|
arrays = [
|
|
recursively_stack_chunks(loc + (d,), remaining[1:])
|
|
for d in self.sliced_inds[remaining[0]].sliced_range
|
|
]
|
|
axes = output_pos[remaining[0]] - len(loc)
|
|
return do("stack", arrays, axes, like=backend)
|
|
|
|
result = recursively_stack_chunks((), tuple(output_pos))
|
|
|
|
if emax is not None:
|
|
# strip_exponent was True, return the exponent separately
|
|
return result, emax
|
|
|
|
return result
|
|
|
|
def gen_output_chunks(
|
|
self, arrays, with_key=False, progbar=False, **contract_opts
|
|
):
|
|
"""Generate each output chunk of the contraction - i.e. take care of
|
|
summing internally sliced indices only first. This assumes that the
|
|
``sliced_inds`` are sorted by whether they appear in the output or not
|
|
(the default order). Useful for performing some kind of reduction over
|
|
the final tensor object like ``fn(x).sum()`` without constructing the
|
|
entire thing.
|
|
|
|
Parameters
|
|
----------
|
|
arrays : sequence of array
|
|
The arrays to contract.
|
|
with_key : bool, optional
|
|
Whether to yield the output index configuration key along with the
|
|
chunk.
|
|
progbar : bool, optional
|
|
Show progress through the contraction chunks.
|
|
|
|
Yields
|
|
------
|
|
chunk : array
|
|
A chunk of the contracted result.
|
|
key : dict[str, int]
|
|
The value each sliced output index takes for this chunk.
|
|
"""
|
|
# consecutive slices of size ``stepsize`` all belong to the same output
|
|
# block because the sliced indices are sorted output first
|
|
stepsize = prod(
|
|
si.size for si in self.sliced_inds.values() if si.inner
|
|
)
|
|
|
|
if progbar:
|
|
import tqdm
|
|
|
|
it = tqdm.trange(self.nslices // stepsize)
|
|
else:
|
|
it = range(self.nslices // stepsize)
|
|
|
|
for o in it:
|
|
chunk = self.contract_slice(arrays, o * stepsize, **contract_opts)
|
|
|
|
if with_key:
|
|
output_key = {
|
|
ix: x
|
|
for ix, x in self.slice_key(o * stepsize).items()
|
|
if ix in self.output
|
|
}
|
|
|
|
for j in range(1, stepsize):
|
|
i = o * stepsize + j
|
|
chunk = chunk + self.contract_slice(arrays, i, **contract_opts)
|
|
|
|
if with_key:
|
|
yield chunk, output_key
|
|
else:
|
|
yield chunk
|
|
|
|
def contract(
|
|
self,
|
|
arrays,
|
|
order=None,
|
|
prefer_einsum=False,
|
|
strip_exponent=False,
|
|
check_zero=False,
|
|
backend=None,
|
|
implementation=None,
|
|
autojit="auto",
|
|
progbar=False,
|
|
):
|
|
"""Contract ``arrays`` with this tree. This function takes *unsliced*
|
|
arrays and handles the slicing, contractions and gathering. The order
|
|
of the axes and output is assumed to match that of ``tree.inputs`` and
|
|
``tree.output``.
|
|
|
|
Parameters
|
|
----------
|
|
arrays : sequence of array
|
|
The arrays to contract.
|
|
order : str or callable, optional
|
|
Supplied to :meth:`ContractionTree.traverse`.
|
|
prefer_einsum : bool, optional
|
|
Prefer to use ``einsum`` for pairwise contractions, even if
|
|
``tensordot`` can perform the contraction.
|
|
strip_exponent : bool, optional
|
|
If ``True``, eagerly strip the exponent (in log10) from
|
|
intermediate tensors to control numerical problems from leaving the
|
|
range of the datatype. This method then returns the scaled
|
|
'mantissa' output array and the exponent separately.
|
|
check_zero : bool, optional
|
|
If ``True``, when ``strip_exponent=True``, explicitly check for
|
|
zero-valued intermediates that would otherwise produce ``nan``,
|
|
instead terminating early if encountered and returning
|
|
``(0.0, 0.0)``.
|
|
backend : str, optional
|
|
What library to use for ``tensordot``, ``einsum`` and
|
|
``transpose``, it will be automatically inferred from the input
|
|
arrays if not given.
|
|
autojit : bool, optional
|
|
Whether to use the 'autojit' feature of `autoray` to compile the
|
|
contraction expression.
|
|
progbar : bool, optional
|
|
Whether to show a progress bar.
|
|
|
|
Returns
|
|
-------
|
|
output : array
|
|
The contracted output, it will be scaled if
|
|
``strip_exponent==True``.
|
|
exponent : float
|
|
The exponent of the output in base 10, returned only if
|
|
``strip_exponent==True``.
|
|
|
|
See Also
|
|
--------
|
|
contract_core, contract_slice, slice_arrays, gather_slices
|
|
"""
|
|
if not self.sliced_inds:
|
|
return self.contract_core(
|
|
arrays,
|
|
order=order,
|
|
prefer_einsum=prefer_einsum,
|
|
strip_exponent=strip_exponent,
|
|
check_zero=check_zero,
|
|
backend=backend,
|
|
implementation=implementation,
|
|
autojit=autojit,
|
|
progbar=progbar,
|
|
)
|
|
|
|
slices = (
|
|
self.contract_slice(
|
|
arrays,
|
|
i,
|
|
order=order,
|
|
prefer_einsum=prefer_einsum,
|
|
strip_exponent=strip_exponent,
|
|
check_zero=check_zero,
|
|
backend=backend,
|
|
implementation=implementation,
|
|
autojit=autojit,
|
|
)
|
|
for i in range(self.multiplicity)
|
|
)
|
|
|
|
return self.gather_slices(slices, backend=backend, progbar=progbar)
|
|
|
|
def contract_mpi(self, arrays, comm=None, root=None, **kwargs):
|
|
"""Contract the slices of this tree and sum them in parallel -
|
|
*assuming* we are already running under MPI.
|
|
|
|
Parameters
|
|
----------
|
|
arrays : sequence of array
|
|
The input (unsliced arrays)
|
|
comm : None or mpi4py communicator
|
|
Defaults to ``mpi4py.MPI.COMM_WORLD`` if not given.
|
|
root : None or int, optional
|
|
If ``root=None``, an ``Allreduce`` will be performed such that
|
|
every process has the resulting tensor, else if an integer e.g.
|
|
``root=0``, the result will be exclusively gathered to that
|
|
process using ``Reduce``, with every other process returning
|
|
``None``.
|
|
kwargs
|
|
Supplied to :meth:`~cotengra.ContractionTree.contract_slice`.
|
|
"""
|
|
if not set(self.sliced_inds).isdisjoint(set(self.output)):
|
|
raise NotImplementedError(
|
|
"Sliced and output indices overlap - currently only a simple "
|
|
"sum of result slices is supported currently."
|
|
)
|
|
|
|
if comm is None:
|
|
from mpi4py import MPI
|
|
|
|
comm = MPI.COMM_WORLD
|
|
|
|
if self.multiplicity < comm.size:
|
|
raise ValueError(
|
|
f"Need to have more slices than MPI processes, but have "
|
|
f"{self.multiplicity} and {comm.size} respectively."
|
|
)
|
|
|
|
# round robin compute each slice, eagerly summing
|
|
result_i = None
|
|
for i in range(comm.rank, self.multiplicity, comm.size):
|
|
# note: fortran ordering is needed for the MPI reduce
|
|
x = do("asfortranarray", self.contract_slice(arrays, i, **kwargs))
|
|
if result_i is None:
|
|
result_i = x
|
|
else:
|
|
result_i += x
|
|
|
|
if root is None:
|
|
# everyone gets the summed result
|
|
result = do("empty_like", result_i)
|
|
comm.Allreduce(result_i, result)
|
|
return result
|
|
|
|
# else we only sum reduce the result to process ``root``
|
|
if comm.rank == root:
|
|
result = do("empty_like", result_i)
|
|
else:
|
|
result = None
|
|
comm.Reduce(result_i, result, root=root)
|
|
return result
|
|
|
|
def benchmark(
|
|
self,
|
|
dtype,
|
|
max_time=60,
|
|
min_reps=3,
|
|
max_reps=100,
|
|
warmup=True,
|
|
**contract_opts,
|
|
):
|
|
"""Benchmark the contraction of this tree.
|
|
|
|
Parameters
|
|
----------
|
|
dtype : {"float32", "float64", "complex64", "complex128"}
|
|
The datatype to use.
|
|
max_time : float, optional
|
|
The maximum time to spend benchmarking in seconds.
|
|
min_reps : int, optional
|
|
The minimum number of repetitions to perform, regardless of time.
|
|
max_reps : int, optional
|
|
The maximum number of repetitions to perform, regardless of time.
|
|
warmup : bool or int, optional
|
|
Whether to perform a warmup run before the benchmark. If an int,
|
|
the number of warmup runs to perform.
|
|
contract_opts
|
|
Supplied to :meth:`~cotengra.ContractionTree.contract_slice`.
|
|
|
|
Returns
|
|
-------
|
|
dict
|
|
A dictionary of benchmarking results. The keys are:
|
|
|
|
- "time_per_slice" : float
|
|
The average time to contract a single slice.
|
|
- "est_time_total" : float
|
|
The estimated total time to contract all slices.
|
|
- "est_gigaflops" : float
|
|
The estimated gigaflops of the contraction.
|
|
|
|
See Also
|
|
--------
|
|
contract_slice
|
|
"""
|
|
import time
|
|
|
|
from .utils import make_arrays_from_inputs
|
|
|
|
arrays = make_arrays_from_inputs(
|
|
self.inputs, self.size_dict, dtype=dtype
|
|
)
|
|
|
|
for i in range(int(warmup)):
|
|
self.contract_slice(arrays, i % self.nslices, **contract_opts)
|
|
|
|
t0 = time.time()
|
|
ti = t0
|
|
i = 0
|
|
while (ti - t0 < max_time) or (i < min_reps):
|
|
self.contract_slice(arrays, i % self.nslices, **contract_opts)
|
|
ti = time.time()
|
|
i += 1
|
|
if i >= max_reps:
|
|
break
|
|
|
|
time_per_slice = (ti - t0) / i
|
|
est_time_total = time_per_slice * self.nslices
|
|
est_gigaflops = self.total_flops(dtype=dtype) / (1e9 * est_time_total)
|
|
|
|
return {
|
|
"time_per_slice": time_per_slice,
|
|
"est_time_total": est_time_total,
|
|
"est_gigaflops": est_gigaflops,
|
|
}
|
|
|
|
plot_ring = plot_tree_ring
|
|
plot_tent = plot_tree_tent
|
|
plot_span = plot_tree_span
|
|
plot_flat = plot_tree_flat
|
|
plot_circuit = plot_tree_circuit
|
|
plot_rubberband = plot_tree_rubberband
|
|
plot_contractions = plot_contractions
|
|
plot_contractions_alt = plot_contractions_alt
|
|
|
|
@functools.wraps(plot_hypergraph)
|
|
def plot_hypergraph(self, **kwargs):
|
|
hg = self.get_hypergraph(accel=False)
|
|
hg.plot(**kwargs)
|
|
|
|
def describe(self, info="normal", join=" "):
|
|
"""Return a string describing the contraction tree."""
|
|
self.contract_stats()
|
|
if info == "normal":
|
|
return join.join(
|
|
(
|
|
f"log10[FLOPs]={self.total_flops(log=10):.2f}",
|
|
f"log2[SIZE]={self.max_size(log=2):.2f}",
|
|
)
|
|
)
|
|
|
|
elif info == "full":
|
|
s = [
|
|
f"log10[FLOPS]={self.total_flops(log=10):.2f}",
|
|
f"log10[COMBO]={self.combo_cost(log=10):.2f}",
|
|
f"log2[SIZE]={self.max_size(log=2):.2f}",
|
|
f"log2[PEAK]={self.peak_size(log=2):.2f}",
|
|
]
|
|
if self.sliced_inds:
|
|
s.append(f"NSLICES={self.multiplicity:.2f}")
|
|
return join.join(s)
|
|
|
|
elif info == "concise":
|
|
s = [
|
|
f"F={self.total_flops(log=10):.2f}",
|
|
f"C={self.combo_cost(log=10):.2f}",
|
|
f"S={self.max_size(log=2):.2f}",
|
|
f"P={self.peak_size(log=2):.2f}",
|
|
]
|
|
if self.sliced_inds:
|
|
s.append(f"$={self.multiplicity:.2f}")
|
|
return join.join(s)
|
|
|
|
def __repr__(self):
|
|
if self.is_complete():
|
|
return f"<{self.__class__.__name__}(N={self.N})>"
|
|
else:
|
|
s = "<{}(N={}, branches={}, complete={})>"
|
|
return s.format(
|
|
self.__class__.__name__,
|
|
self.N,
|
|
len(self.children),
|
|
self.is_complete(),
|
|
)
|
|
|
|
def __str__(self):
|
|
if not self.is_complete():
|
|
return self.__repr__()
|
|
else:
|
|
d = self.describe("concise", join=", ")
|
|
return f"<{self.__class__.__name__}(N={self.N}, {d})>"
|
|
|
|
|
|
def _reconfigure_tree(tree, *args, **kwargs):
|
|
return tree.subtree_reconfigure(*args, **kwargs)
|
|
|
|
|
|
def _slice_and_reconfigure_tree(tree, *args, **kwargs):
|
|
return tree.slice_and_reconfigure(*args, **kwargs)
|
|
|
|
|
|
def _get_tree_info(tree):
|
|
stats = tree.contract_stats()
|
|
stats["sliced_ind_set"] = frozenset(tree.sliced_inds)
|
|
return stats
|
|
|
|
|
|
def _describe_tree(tree, info="normal"):
|
|
return tree.describe(info=info)
|
|
|
|
|
|
class ContractionTreeCompressed(ContractionTree):
|
|
"""A contraction tree for compressed contractions. Currently the only
|
|
difference is that this defaults to the 'surface' traversal ordering.
|
|
"""
|
|
|
|
def set_state_from(self, other):
|
|
super().set_state_from(other)
|
|
self.set_surface_order_from_path(other.get_ssa_path())
|
|
|
|
@classmethod
|
|
def from_path(
|
|
cls,
|
|
inputs,
|
|
output,
|
|
size_dict,
|
|
*,
|
|
path=None,
|
|
ssa_path=None,
|
|
autocomplete="auto",
|
|
check=False,
|
|
**kwargs,
|
|
):
|
|
"""Create a (completed) ``ContractionTreeCompressed`` from the usual
|
|
inputs plus a standard contraction path or 'ssa_path' - you need to
|
|
supply one. This also set the default 'surface' traversal ordering to
|
|
be the initial path.
|
|
"""
|
|
if int(path is None) + int(ssa_path is None) != 1:
|
|
raise ValueError(
|
|
"Exactly one of ``path`` or ``ssa_path`` must be supplied."
|
|
)
|
|
|
|
if path is not None:
|
|
from .pathfinders.path_basic import linear_to_ssa
|
|
|
|
ssa_path = linear_to_ssa(path)
|
|
|
|
tree = cls(inputs, output, size_dict, **kwargs)
|
|
terms = list(tree.gen_leaves())
|
|
|
|
for p in ssa_path:
|
|
merge = [terms[i] for i in p]
|
|
terms.append(tree.contract_nodes(merge, check=check))
|
|
|
|
tree.set_surface_order_from_path(ssa_path)
|
|
|
|
if (len(tree.children) < tree.N - 1) and autocomplete:
|
|
if autocomplete == "auto":
|
|
# warn that we are completing
|
|
warnings.warn(
|
|
"Path was not complete - contracting all remaining. "
|
|
"You can silence this warning with `autocomplete=True`."
|
|
"Or produce an incomplete tree with `autocomplete=False`."
|
|
)
|
|
|
|
tree.autocomplete(optimize="greedy-compressed")
|
|
|
|
return tree
|
|
|
|
def get_default_order(self):
|
|
return "surface_order"
|
|
|
|
def get_default_objective(self):
|
|
if self._default_objective is None:
|
|
self._default_objective = get_score_fn("peak-compressed")
|
|
return self._default_objective
|
|
|
|
def get_default_chi(self):
|
|
objective = self.get_default_objective()
|
|
try:
|
|
chi = objective.chi
|
|
except AttributeError:
|
|
chi = "auto"
|
|
|
|
if chi == "auto":
|
|
chi = max(self.size_dict.values()) ** 2
|
|
|
|
return chi
|
|
|
|
def get_default_compress_late(self):
|
|
objective = self.get_default_objective()
|
|
try:
|
|
return objective.compress_late
|
|
except AttributeError:
|
|
return False
|
|
|
|
total_flops = ContractionTree.total_flops_compressed
|
|
total_write = ContractionTree.total_write_compressed
|
|
combo_cost = ContractionTree.combo_cost_compressed
|
|
total_cost = ContractionTree.total_cost_compressed
|
|
max_size = ContractionTree.max_size_compressed
|
|
peak_size = ContractionTree.peak_size_compressed
|
|
contraction_cost = ContractionTree.contraction_cost_compressed
|
|
contraction_width = ContractionTree.contraction_width_compressed
|
|
|
|
total_flops_exact = ContractionTree.total_flops
|
|
total_write_exact = ContractionTree.total_write
|
|
combo_cost_exact = ContractionTree.combo_cost
|
|
total_cost_exact = ContractionTree.total_cost
|
|
max_size_exact = ContractionTree.max_size
|
|
peak_size_exact = ContractionTree.peak_size
|
|
|
|
def get_contractor(self, *_, **__):
|
|
raise NotImplementedError(
|
|
"`cotengra` doesn't implement compressed contraction itself. "
|
|
"If you want to use compressed contractions, you need to use "
|
|
"`quimb` and the `TensorNetwork.contract_compressed` method, "
|
|
"with e.g. `optimize=tree.get_path()`."
|
|
)
|
|
|
|
def simulated_anneal(
|
|
self,
|
|
minimize=None,
|
|
tfinal=0.0001,
|
|
tstart=0.01,
|
|
tsteps=50,
|
|
numiter=50,
|
|
seed=None,
|
|
inplace=False,
|
|
progbar=False,
|
|
**kwargs,
|
|
):
|
|
"""Perform simulated annealing refinement of this *compressed*
|
|
contraction tree.
|
|
"""
|
|
from .pathfinders.path_compressed import WindowedOptimizer
|
|
|
|
if minimize is None:
|
|
minimize = self.get_default_objective()
|
|
|
|
wo = WindowedOptimizer(
|
|
self.inputs,
|
|
self.output,
|
|
self.size_dict,
|
|
minimize=minimize,
|
|
ssa_path=self.get_ssa_path(),
|
|
seed=seed,
|
|
)
|
|
|
|
wo.simulated_anneal(
|
|
tfinal=tfinal,
|
|
tstart=tstart,
|
|
tsteps=tsteps,
|
|
numiter=numiter,
|
|
progbar=progbar,
|
|
**kwargs,
|
|
)
|
|
ssa_path = wo.get_ssa_path()
|
|
|
|
rtree = self.__class__.from_path(
|
|
self.inputs,
|
|
self.output,
|
|
self.size_dict,
|
|
ssa_path=ssa_path,
|
|
objective=minimize,
|
|
)
|
|
|
|
if inplace:
|
|
self.set_state_from(rtree)
|
|
rtree = self
|
|
|
|
rtree.contraction_cores.clear()
|
|
return rtree
|
|
|
|
simulated_anneal_ = functools.partialmethod(simulated_anneal, inplace=True)
|
|
|
|
|
|
class PartitionTreeBuilder:
|
|
"""Function wrapper that takes a function that partitions graphs and
|
|
uses it to build a contraction tree. ``partition_fn`` should have
|
|
signature:
|
|
|
|
def partition_fn(inputs, output, size_dict,
|
|
weight_nodes, weight_edges, **kwargs):
|
|
...
|
|
return membership
|
|
|
|
Where ``weight_nodes`` and ``weight_edges`` decsribe how to weight the
|
|
nodes and edges of the graph respectively and ``membership`` should be a
|
|
list of integers of length ``len(inputs)`` labelling which partition
|
|
each input node should be put it.
|
|
"""
|
|
|
|
def __init__(self, partition_fn):
|
|
self.partition_fn = partition_fn
|
|
|
|
def build_divide(
|
|
self,
|
|
inputs,
|
|
output,
|
|
size_dict,
|
|
random_strength=0.01,
|
|
cutoff=10,
|
|
parts=2,
|
|
parts_decay=0.5,
|
|
sub_optimize="greedy",
|
|
super_optimize="auto-hq",
|
|
check=False,
|
|
seed=None,
|
|
**partition_opts,
|
|
):
|
|
tree = ContractionTree(inputs, output, size_dict, track_childless=True)
|
|
|
|
rng = get_rng(seed)
|
|
rand_size_dict = jitter_dict(size_dict, random_strength, rng)
|
|
|
|
dynamic_imbalance = ("imbalance" in partition_opts) and (
|
|
"imbalance_decay" in partition_opts
|
|
)
|
|
if dynamic_imbalance:
|
|
imbalance = partition_opts.pop("imbalance")
|
|
imbalance_decay = partition_opts.pop("imbalance_decay")
|
|
else:
|
|
imbalance = imbalance_decay = None
|
|
|
|
dynamic_fix = partition_opts.get("fix_output_nodes", None) == "auto"
|
|
|
|
while tree.childless:
|
|
tree_node = next(iter(tree.childless))
|
|
subgraph = tuple(tree_node)
|
|
subsize = len(subgraph)
|
|
|
|
# skip straight to better method
|
|
if subsize <= cutoff:
|
|
tree.contract_nodes(
|
|
[node_from_single(x) for x in subgraph],
|
|
optimize=sub_optimize,
|
|
check=check,
|
|
)
|
|
continue
|
|
|
|
# relative subgraph size
|
|
s = subsize / tree.N
|
|
|
|
# let the target number of communities depend on subgraph size
|
|
parts_s = max(int(s**parts_decay * parts), 2)
|
|
|
|
# let the imbalance either rise or fall
|
|
if dynamic_imbalance:
|
|
if imbalance_decay >= 0:
|
|
imbalance_s = s**imbalance_decay * imbalance
|
|
else:
|
|
imbalance_s = 1 - s**-imbalance_decay * (1 - imbalance)
|
|
partition_opts["imbalance"] = imbalance_s
|
|
|
|
if dynamic_fix:
|
|
# for the top level subtree (s==1.0) we partition the outputs
|
|
# nodes first into their own bi-partition
|
|
parts_s = 2
|
|
partition_opts["fix_output_nodes"] = s == 1.0
|
|
|
|
# partition! get community membership list e.g.
|
|
# [0, 0, 1, 0, 1, 0, 0, 2, 2, ...]
|
|
inputs = tuple(map(tuple, tree.node_to_terms(subgraph)))
|
|
output = tuple(tree.get_legs(tree_node))
|
|
membership = self.partition_fn(
|
|
inputs,
|
|
output,
|
|
rand_size_dict,
|
|
parts=parts_s,
|
|
seed=rng,
|
|
**partition_opts,
|
|
)
|
|
|
|
# divide subgraph up e.g. if we enumerate the subgraph index sets
|
|
# (0, 1, 2, 3, 4, 5, 6, 7, 8, ...) ->
|
|
# ({0, 1, 3, 5, 6}, {2, 4}, {7, 8})
|
|
new_subgs = tuple(
|
|
map(node_from_seq, separate(subgraph, membership))
|
|
)
|
|
|
|
if len(new_subgs) == 1:
|
|
# no communities found - contract all remaining
|
|
tree.contract_nodes(
|
|
tuple(map(node_from_single, subgraph)),
|
|
optimize=sub_optimize,
|
|
check=check,
|
|
)
|
|
continue
|
|
|
|
# update tree structure with newly contracted subgraphs
|
|
tree.contract_nodes(
|
|
new_subgs, optimize=super_optimize, check=check
|
|
)
|
|
|
|
if check:
|
|
assert tree.is_complete()
|
|
|
|
return tree
|
|
|
|
def build_agglom(
|
|
self,
|
|
inputs,
|
|
output,
|
|
size_dict,
|
|
random_strength=0.01,
|
|
groupsize=4,
|
|
check=False,
|
|
sub_optimize="greedy",
|
|
seed=None,
|
|
**partition_opts,
|
|
):
|
|
tree = ContractionTree(inputs, output, size_dict, track_childless=True)
|
|
rand_size_dict = jitter_dict(size_dict, random_strength, seed)
|
|
leaves = tuple(tree.gen_leaves())
|
|
for node in leaves:
|
|
tree._add_node(node, check=check)
|
|
output = tuple(tree.output)
|
|
|
|
while len(leaves) > groupsize:
|
|
parts = max(2, len(leaves) // groupsize)
|
|
|
|
inputs = [tuple(tree.get_legs(node)) for node in leaves]
|
|
membership = self.partition_fn(
|
|
inputs,
|
|
output,
|
|
rand_size_dict,
|
|
parts=parts,
|
|
**partition_opts,
|
|
)
|
|
leaves = [
|
|
tree.contract_nodes(group, check=check, optimize=sub_optimize)
|
|
for group in separate(leaves, membership)
|
|
]
|
|
|
|
if len(leaves) > 1:
|
|
tree.contract_nodes(leaves, check=check, optimize=sub_optimize)
|
|
|
|
if check:
|
|
assert tree.is_complete()
|
|
|
|
return tree
|
|
|
|
def trial_fn(self, inputs, output, size_dict, **partition_opts):
|
|
return self.build_divide(inputs, output, size_dict, **partition_opts)
|
|
|
|
def trial_fn_agglom(self, inputs, output, size_dict, **partition_opts):
|
|
return self.build_agglom(inputs, output, size_dict, **partition_opts)
|
|
|
|
|
|
def jitter(x, strength, rng):
|
|
return x * (1 + strength * rng.expovariate(1.0))
|
|
|
|
|
|
def jitter_dict(d, strength, seed=None):
|
|
rng = get_rng(seed)
|
|
return {k: jitter(v, strength, rng) for k, v in d.items()}
|
|
|
|
|
|
def separate(xs, blocks):
|
|
"""Partition ``xs`` into ``n`` different list based on the corresponding
|
|
labels in ``blocks``.
|
|
"""
|
|
sorter = collections.defaultdict(list)
|
|
for x, b in zip(xs, blocks):
|
|
sorter[b].append(x)
|
|
x_b = list(sorter.items())
|
|
x_b.sort()
|
|
return [x[1] for x in x_b]
|