Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
209 lines
6.4 KiB
Python
209 lines
6.4 KiB
Python
"""Inspect cotengra contraction trees for dominant torch matmul shapes."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import importlib
|
|
import math
|
|
import pickle
|
|
from collections import Counter, defaultdict
|
|
from pathlib import Path
|
|
|
|
|
|
def _prod(values):
|
|
out = 1
|
|
for value in values:
|
|
out *= int(value)
|
|
return out
|
|
|
|
|
|
def _broadcast_batch(a_batch, b_batch):
|
|
if a_batch == b_batch:
|
|
return _prod(a_batch)
|
|
if not a_batch:
|
|
return _prod(b_batch)
|
|
if not b_batch:
|
|
return _prod(a_batch)
|
|
|
|
ndim = max(len(a_batch), len(b_batch))
|
|
a_batch = (1,) * (ndim - len(a_batch)) + tuple(a_batch)
|
|
b_batch = (1,) * (ndim - len(b_batch)) + tuple(b_batch)
|
|
return _prod(max(a, b) for a, b in zip(a_batch, b_batch))
|
|
|
|
|
|
def _load_tree(path, index):
|
|
with Path(path).open("rb") as f:
|
|
payload = pickle.load(f)
|
|
trees = payload["trees"] if isinstance(payload, dict) else payload
|
|
if not isinstance(trees, (list, tuple)):
|
|
trees = [trees]
|
|
return trees[index]
|
|
|
|
|
|
def _analyze_tree(tree):
|
|
contract_mod = importlib.import_module("cotengra.contract")
|
|
contractions = contract_mod.extract_contractions(tree)
|
|
size_dict = tree.size_dict
|
|
ops = []
|
|
counts = Counter()
|
|
|
|
for op_index, (parent, left, right, tdot, arg, perm) in enumerate(contractions):
|
|
if left is None and right is None:
|
|
counts["preprocess"] += 1
|
|
continue
|
|
|
|
left_inds = tree.get_inds(left)
|
|
right_inds = tree.get_inds(right)
|
|
parent_inds = tree.get_inds(parent)
|
|
left_shape = tuple(size_dict[ix] for ix in left_inds)
|
|
right_shape = tuple(size_dict[ix] for ix in right_inds)
|
|
|
|
if tdot:
|
|
parsed = contract_mod._parse_tensordot_axes_to_matmul(
|
|
arg,
|
|
left_shape,
|
|
right_shape,
|
|
)
|
|
else:
|
|
parsed = contract_mod._parse_eq_to_batch_matmul(
|
|
arg,
|
|
left_shape,
|
|
right_shape,
|
|
)
|
|
|
|
(
|
|
_eq_a,
|
|
_eq_b,
|
|
new_shape_a,
|
|
new_shape_b,
|
|
_new_shape_ab,
|
|
_perm_ab,
|
|
pure_multiplication,
|
|
) = parsed
|
|
|
|
matmul_shape = None
|
|
matmul_flops = 0
|
|
if pure_multiplication:
|
|
kind = "mul"
|
|
else:
|
|
a_shape = tuple(new_shape_a or left_shape)
|
|
b_shape = tuple(new_shape_b or right_shape)
|
|
batch = _broadcast_batch(a_shape[:-2], b_shape[:-2])
|
|
m, k, n = int(a_shape[-2]), int(a_shape[-1]), int(b_shape[-1])
|
|
kind = "mm" if batch == 1 else "bmm"
|
|
matmul_shape = (batch, m, k, n)
|
|
matmul_flops = batch * m * k * n
|
|
|
|
tree_flops = int(tree.get_flops(parent))
|
|
out_size = int(tree.get_size(parent))
|
|
ops.append(
|
|
{
|
|
"index": op_index,
|
|
"kind": kind,
|
|
"matmul_shape": matmul_shape,
|
|
"matmul_flops": matmul_flops,
|
|
"tree_flops": tree_flops,
|
|
"out_size": out_size,
|
|
"left_shape": left_shape,
|
|
"right_shape": right_shape,
|
|
"left_rank": len(left_inds),
|
|
"right_rank": len(right_inds),
|
|
"out_rank": len(parent_inds),
|
|
"perm": perm,
|
|
}
|
|
)
|
|
counts[kind] += 1
|
|
|
|
return contractions, ops, counts
|
|
|
|
|
|
def _format_log(value, base):
|
|
return "-inf" if value <= 0 else f"{math.log(value, base):.3f}"
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("tree", help="Pickle file containing one tree or {'trees': [...]}.")
|
|
parser.add_argument("--index", type=int, default=0, help="Tree index in the file.")
|
|
parser.add_argument("--top", type=int, default=20, help="Number of top ops to print.")
|
|
parser.add_argument(
|
|
"--dtype-bytes",
|
|
type=int,
|
|
default=8,
|
|
help="Bytes per element for memory estimates, for example 8 for complex64.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
tree = _load_tree(args.tree, args.index)
|
|
contractions, ops, counts = _analyze_tree(tree)
|
|
nslices = int(getattr(tree, "multiplicity", 1))
|
|
per_slice_flops = sum(op["tree_flops"] for op in ops)
|
|
per_slice_write = sum(op["out_size"] for op in ops)
|
|
max_out = max((op["out_size"] for op in ops), default=0)
|
|
all_flops = per_slice_flops * nslices
|
|
all_write = per_slice_write * nslices
|
|
|
|
print(f"tree={args.tree} index={args.index}")
|
|
print(
|
|
"summary "
|
|
f"slices={nslices} contractions={len(contractions)} "
|
|
f"counts={dict(counts)}"
|
|
)
|
|
print(
|
|
"per_slice "
|
|
f"log10_flops={_format_log(per_slice_flops, 10)} "
|
|
f"log10_write={_format_log(per_slice_write, 10)} "
|
|
f"log2_max_output={_format_log(max_out, 2)} "
|
|
f"max_output_gib={max_out * args.dtype_bytes / 1024**3:.6g}"
|
|
)
|
|
print(
|
|
"all_slices "
|
|
f"log10_flops={_format_log(all_flops, 10)} "
|
|
f"log10_write={_format_log(all_write, 10)}"
|
|
)
|
|
|
|
print(f"\ntop_{args.top}_ops_by_flops")
|
|
for op in sorted(ops, key=lambda item: item["tree_flops"], reverse=True)[: args.top]:
|
|
print(
|
|
f"op={op['index']} kind={op['kind']} "
|
|
f"flops={op['tree_flops']:.6e} out={op['out_size']:.6e} "
|
|
f"matmul={op['matmul_shape']} "
|
|
f"ranks=({op['left_rank']},{op['right_rank']}->{op['out_rank']}) "
|
|
f"lhs={op['left_shape']} rhs={op['right_shape']}"
|
|
)
|
|
|
|
by_shape = defaultdict(lambda: [0, 0, 0])
|
|
for op in ops:
|
|
shape = op["matmul_shape"]
|
|
if shape is None:
|
|
continue
|
|
by_shape[shape][0] += 1
|
|
by_shape[shape][1] += op["tree_flops"]
|
|
by_shape[shape][2] += op["out_size"]
|
|
|
|
print(f"\ntop_{args.top}_matmul_shapes_by_flops")
|
|
for shape, (count, flops, out_size) in sorted(
|
|
by_shape.items(),
|
|
key=lambda item: item[1][1],
|
|
reverse=True,
|
|
)[: args.top]:
|
|
print(
|
|
f"shape={shape} count={count} "
|
|
f"flops={flops:.6e} output={out_size:.6e}"
|
|
)
|
|
|
|
print(f"\ntop_{args.top}_matmul_shapes_by_count")
|
|
for shape, (count, flops, out_size) in sorted(
|
|
by_shape.items(),
|
|
key=lambda item: item[1][0],
|
|
reverse=True,
|
|
)[: args.top]:
|
|
print(
|
|
f"shape={shape} count={count} "
|
|
f"flops={flops:.6e} output={out_size:.6e}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|