"""Inspect cotengra contraction trees for dominant torch matmul shapes.""" from __future__ import annotations import argparse import importlib import math import pickle from collections import Counter, defaultdict from pathlib import Path def _prod(values): out = 1 for value in values: out *= int(value) return out def _broadcast_batch(a_batch, b_batch): if a_batch == b_batch: return _prod(a_batch) if not a_batch: return _prod(b_batch) if not b_batch: return _prod(a_batch) ndim = max(len(a_batch), len(b_batch)) a_batch = (1,) * (ndim - len(a_batch)) + tuple(a_batch) b_batch = (1,) * (ndim - len(b_batch)) + tuple(b_batch) return _prod(max(a, b) for a, b in zip(a_batch, b_batch)) def _load_tree(path, index): with Path(path).open("rb") as f: payload = pickle.load(f) trees = payload["trees"] if isinstance(payload, dict) else payload if not isinstance(trees, (list, tuple)): trees = [trees] return trees[index] def _analyze_tree(tree): contract_mod = importlib.import_module("cotengra.contract") contractions = contract_mod.extract_contractions(tree) size_dict = tree.size_dict ops = [] counts = Counter() for op_index, (parent, left, right, tdot, arg, perm) in enumerate(contractions): if left is None and right is None: counts["preprocess"] += 1 continue left_inds = tree.get_inds(left) right_inds = tree.get_inds(right) parent_inds = tree.get_inds(parent) left_shape = tuple(size_dict[ix] for ix in left_inds) right_shape = tuple(size_dict[ix] for ix in right_inds) if tdot: parsed = contract_mod._parse_tensordot_axes_to_matmul( arg, left_shape, right_shape, ) else: parsed = contract_mod._parse_eq_to_batch_matmul( arg, left_shape, right_shape, ) ( _eq_a, _eq_b, new_shape_a, new_shape_b, _new_shape_ab, _perm_ab, pure_multiplication, ) = parsed matmul_shape = None matmul_flops = 0 if pure_multiplication: kind = "mul" else: a_shape = tuple(new_shape_a or left_shape) b_shape = tuple(new_shape_b or right_shape) batch = _broadcast_batch(a_shape[:-2], b_shape[:-2]) m, k, n = int(a_shape[-2]), int(a_shape[-1]), int(b_shape[-1]) kind = "mm" if batch == 1 else "bmm" matmul_shape = (batch, m, k, n) matmul_flops = batch * m * k * n tree_flops = int(tree.get_flops(parent)) out_size = int(tree.get_size(parent)) ops.append( { "index": op_index, "kind": kind, "matmul_shape": matmul_shape, "matmul_flops": matmul_flops, "tree_flops": tree_flops, "out_size": out_size, "left_shape": left_shape, "right_shape": right_shape, "left_rank": len(left_inds), "right_rank": len(right_inds), "out_rank": len(parent_inds), "perm": perm, } ) counts[kind] += 1 return contractions, ops, counts def _format_log(value, base): return "-inf" if value <= 0 else f"{math.log(value, base):.3f}" def main(): parser = argparse.ArgumentParser() parser.add_argument("tree", help="Pickle file containing one tree or {'trees': [...]}.") parser.add_argument("--index", type=int, default=0, help="Tree index in the file.") parser.add_argument("--top", type=int, default=20, help="Number of top ops to print.") parser.add_argument( "--dtype-bytes", type=int, default=8, help="Bytes per element for memory estimates, for example 8 for complex64.", ) args = parser.parse_args() tree = _load_tree(args.tree, args.index) contractions, ops, counts = _analyze_tree(tree) nslices = int(getattr(tree, "multiplicity", 1)) per_slice_flops = sum(op["tree_flops"] for op in ops) per_slice_write = sum(op["out_size"] for op in ops) max_out = max((op["out_size"] for op in ops), default=0) all_flops = per_slice_flops * nslices all_write = per_slice_write * nslices print(f"tree={args.tree} index={args.index}") print( "summary " f"slices={nslices} contractions={len(contractions)} " f"counts={dict(counts)}" ) print( "per_slice " f"log10_flops={_format_log(per_slice_flops, 10)} " f"log10_write={_format_log(per_slice_write, 10)} " f"log2_max_output={_format_log(max_out, 2)} " f"max_output_gib={max_out * args.dtype_bytes / 1024**3:.6g}" ) print( "all_slices " f"log10_flops={_format_log(all_flops, 10)} " f"log10_write={_format_log(all_write, 10)}" ) print(f"\ntop_{args.top}_ops_by_flops") for op in sorted(ops, key=lambda item: item["tree_flops"], reverse=True)[: args.top]: print( f"op={op['index']} kind={op['kind']} " f"flops={op['tree_flops']:.6e} out={op['out_size']:.6e} " f"matmul={op['matmul_shape']} " f"ranks=({op['left_rank']},{op['right_rank']}->{op['out_rank']}) " f"lhs={op['left_shape']} rhs={op['right_shape']}" ) by_shape = defaultdict(lambda: [0, 0, 0]) for op in ops: shape = op["matmul_shape"] if shape is None: continue by_shape[shape][0] += 1 by_shape[shape][1] += op["tree_flops"] by_shape[shape][2] += op["out_size"] print(f"\ntop_{args.top}_matmul_shapes_by_flops") for shape, (count, flops, out_size) in sorted( by_shape.items(), key=lambda item: item[1][1], reverse=True, )[: args.top]: print( f"shape={shape} count={count} " f"flops={flops:.6e} output={out_size:.6e}" ) print(f"\ntop_{args.top}_matmul_shapes_by_count") for shape, (count, flops, out_size) in sorted( by_shape.items(), key=lambda item: item[1][0], reverse=True, )[: args.top]: print( f"shape={shape} count={count} " f"flops={flops:.6e} output={out_size:.6e}" ) if __name__ == "__main__": main()