Files
qibotn/tools/inspect_contraction_tree.py
jaunatisblue 915c24dc7b
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
赛前稳定版
2026-05-15 09:32:26 +08:00

209 lines
6.4 KiB
Python

"""Inspect cotengra contraction trees for dominant torch matmul shapes."""
from __future__ import annotations
import argparse
import importlib
import math
import pickle
from collections import Counter, defaultdict
from pathlib import Path
def _prod(values):
out = 1
for value in values:
out *= int(value)
return out
def _broadcast_batch(a_batch, b_batch):
if a_batch == b_batch:
return _prod(a_batch)
if not a_batch:
return _prod(b_batch)
if not b_batch:
return _prod(a_batch)
ndim = max(len(a_batch), len(b_batch))
a_batch = (1,) * (ndim - len(a_batch)) + tuple(a_batch)
b_batch = (1,) * (ndim - len(b_batch)) + tuple(b_batch)
return _prod(max(a, b) for a, b in zip(a_batch, b_batch))
def _load_tree(path, index):
with Path(path).open("rb") as f:
payload = pickle.load(f)
trees = payload["trees"] if isinstance(payload, dict) else payload
if not isinstance(trees, (list, tuple)):
trees = [trees]
return trees[index]
def _analyze_tree(tree):
contract_mod = importlib.import_module("cotengra.contract")
contractions = contract_mod.extract_contractions(tree)
size_dict = tree.size_dict
ops = []
counts = Counter()
for op_index, (parent, left, right, tdot, arg, perm) in enumerate(contractions):
if left is None and right is None:
counts["preprocess"] += 1
continue
left_inds = tree.get_inds(left)
right_inds = tree.get_inds(right)
parent_inds = tree.get_inds(parent)
left_shape = tuple(size_dict[ix] for ix in left_inds)
right_shape = tuple(size_dict[ix] for ix in right_inds)
if tdot:
parsed = contract_mod._parse_tensordot_axes_to_matmul(
arg,
left_shape,
right_shape,
)
else:
parsed = contract_mod._parse_eq_to_batch_matmul(
arg,
left_shape,
right_shape,
)
(
_eq_a,
_eq_b,
new_shape_a,
new_shape_b,
_new_shape_ab,
_perm_ab,
pure_multiplication,
) = parsed
matmul_shape = None
matmul_flops = 0
if pure_multiplication:
kind = "mul"
else:
a_shape = tuple(new_shape_a or left_shape)
b_shape = tuple(new_shape_b or right_shape)
batch = _broadcast_batch(a_shape[:-2], b_shape[:-2])
m, k, n = int(a_shape[-2]), int(a_shape[-1]), int(b_shape[-1])
kind = "mm" if batch == 1 else "bmm"
matmul_shape = (batch, m, k, n)
matmul_flops = batch * m * k * n
tree_flops = int(tree.get_flops(parent))
out_size = int(tree.get_size(parent))
ops.append(
{
"index": op_index,
"kind": kind,
"matmul_shape": matmul_shape,
"matmul_flops": matmul_flops,
"tree_flops": tree_flops,
"out_size": out_size,
"left_shape": left_shape,
"right_shape": right_shape,
"left_rank": len(left_inds),
"right_rank": len(right_inds),
"out_rank": len(parent_inds),
"perm": perm,
}
)
counts[kind] += 1
return contractions, ops, counts
def _format_log(value, base):
return "-inf" if value <= 0 else f"{math.log(value, base):.3f}"
def main():
parser = argparse.ArgumentParser()
parser.add_argument("tree", help="Pickle file containing one tree or {'trees': [...]}.")
parser.add_argument("--index", type=int, default=0, help="Tree index in the file.")
parser.add_argument("--top", type=int, default=20, help="Number of top ops to print.")
parser.add_argument(
"--dtype-bytes",
type=int,
default=8,
help="Bytes per element for memory estimates, for example 8 for complex64.",
)
args = parser.parse_args()
tree = _load_tree(args.tree, args.index)
contractions, ops, counts = _analyze_tree(tree)
nslices = int(getattr(tree, "multiplicity", 1))
per_slice_flops = sum(op["tree_flops"] for op in ops)
per_slice_write = sum(op["out_size"] for op in ops)
max_out = max((op["out_size"] for op in ops), default=0)
all_flops = per_slice_flops * nslices
all_write = per_slice_write * nslices
print(f"tree={args.tree} index={args.index}")
print(
"summary "
f"slices={nslices} contractions={len(contractions)} "
f"counts={dict(counts)}"
)
print(
"per_slice "
f"log10_flops={_format_log(per_slice_flops, 10)} "
f"log10_write={_format_log(per_slice_write, 10)} "
f"log2_max_output={_format_log(max_out, 2)} "
f"max_output_gib={max_out * args.dtype_bytes / 1024**3:.6g}"
)
print(
"all_slices "
f"log10_flops={_format_log(all_flops, 10)} "
f"log10_write={_format_log(all_write, 10)}"
)
print(f"\ntop_{args.top}_ops_by_flops")
for op in sorted(ops, key=lambda item: item["tree_flops"], reverse=True)[: args.top]:
print(
f"op={op['index']} kind={op['kind']} "
f"flops={op['tree_flops']:.6e} out={op['out_size']:.6e} "
f"matmul={op['matmul_shape']} "
f"ranks=({op['left_rank']},{op['right_rank']}->{op['out_rank']}) "
f"lhs={op['left_shape']} rhs={op['right_shape']}"
)
by_shape = defaultdict(lambda: [0, 0, 0])
for op in ops:
shape = op["matmul_shape"]
if shape is None:
continue
by_shape[shape][0] += 1
by_shape[shape][1] += op["tree_flops"]
by_shape[shape][2] += op["out_size"]
print(f"\ntop_{args.top}_matmul_shapes_by_flops")
for shape, (count, flops, out_size) in sorted(
by_shape.items(),
key=lambda item: item[1][1],
reverse=True,
)[: args.top]:
print(
f"shape={shape} count={count} "
f"flops={flops:.6e} output={out_size:.6e}"
)
print(f"\ntop_{args.top}_matmul_shapes_by_count")
for shape, (count, flops, out_size) in sorted(
by_shape.items(),
key=lambda item: item[1][0],
reverse=True,
)[: args.top]:
print(
f"shape={shape} count={count} "
f"flops={flops:.6e} output={out_size:.6e}"
)
if __name__ == "__main__":
main()