302 lines
12 KiB
Python
302 lines
12 KiB
Python
"""JEPA Implementation"""
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
class JEPA(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
encoder,
|
|
predictor,
|
|
action_encoder,
|
|
projector=None,
|
|
pred_proj=None,
|
|
):
|
|
super().__init__()
|
|
|
|
self.encoder = encoder
|
|
self.predictor = predictor
|
|
self.action_encoder = action_encoder
|
|
self.projector = projector or nn.Identity()
|
|
self.pred_proj = pred_proj or nn.Identity()
|
|
self._cached_device_tensors = {}
|
|
self._cached_init_signature = None
|
|
self._cached_init_emb = None
|
|
self._cached_goal_signature = None
|
|
self._cached_goal_emb = None
|
|
|
|
def _ensure_runtime_caches(self):
|
|
if not hasattr(self, "_cached_device_tensors"):
|
|
self._cached_device_tensors = {}
|
|
if not hasattr(self, "_cached_init_signature"):
|
|
self._cached_init_signature = None
|
|
if not hasattr(self, "_cached_init_emb"):
|
|
self._cached_init_emb = None
|
|
if not hasattr(self, "_cached_goal_signature"):
|
|
self._cached_goal_signature = None
|
|
if not hasattr(self, "_cached_goal_emb"):
|
|
self._cached_goal_emb = None
|
|
|
|
@staticmethod
|
|
def _tensor_signature(tensor: torch.Tensor):
|
|
try:
|
|
version = tensor._version
|
|
except RuntimeError:
|
|
version = None
|
|
return (
|
|
str(tensor.device),
|
|
tensor.dtype,
|
|
tuple(tensor.shape),
|
|
tuple(tensor.stride()),
|
|
tensor.storage_offset(),
|
|
tensor.data_ptr(),
|
|
version,
|
|
)
|
|
|
|
def _get_cached_device_tensor(
|
|
self,
|
|
key: str,
|
|
tensor: torch.Tensor,
|
|
device: torch.device,
|
|
*,
|
|
ensure_contiguous: bool = False,
|
|
):
|
|
self._ensure_runtime_caches()
|
|
if tensor.device == device and (not ensure_contiguous or tensor.is_contiguous()):
|
|
return tensor
|
|
|
|
signature = (self._tensor_signature(tensor), str(device), ensure_contiguous)
|
|
cached = self._cached_device_tensors.get(key)
|
|
if cached is None or cached[0] != signature:
|
|
prepared = tensor.to(device, non_blocking=True)
|
|
if ensure_contiguous and not prepared.is_contiguous():
|
|
prepared = prepared.contiguous()
|
|
self._cached_device_tensors[key] = (
|
|
signature,
|
|
prepared,
|
|
)
|
|
return self._cached_device_tensors[key][1]
|
|
|
|
def _ensure_info_device(self, info_dict: dict, device: torch.device):
|
|
for key, value in list(info_dict.items()):
|
|
if key.startswith("_lewm_"):
|
|
continue
|
|
if torch.is_tensor(value):
|
|
info_dict[key] = self._get_cached_device_tensor(
|
|
key,
|
|
value,
|
|
device,
|
|
ensure_contiguous=True,
|
|
)
|
|
return info_dict
|
|
|
|
def _get_cached_init_emb(self, info_dict: dict):
|
|
self._ensure_runtime_caches()
|
|
pixels = info_dict["pixels"]
|
|
signature = self._tensor_signature(pixels)
|
|
if self._cached_init_signature != signature:
|
|
init_info = {"pixels": pixels[:, 0]}
|
|
self._cached_init_emb = self.encode(init_info)["emb"].detach()
|
|
self._cached_init_signature = signature
|
|
return self._cached_init_emb
|
|
|
|
def _get_cached_goal_emb(self, info_dict: dict):
|
|
self._ensure_runtime_caches()
|
|
goal = info_dict["goal"]
|
|
signature = self._tensor_signature(goal)
|
|
if self._cached_goal_signature != signature:
|
|
goal_info = {"pixels": goal[:, 0]}
|
|
self._cached_goal_emb = self.encode(goal_info)["emb"][:, -1:, :].detach()
|
|
self._cached_goal_signature = signature
|
|
return self._cached_goal_emb
|
|
|
|
def encode(self, info):
|
|
"""Encode observations and actions into embeddings.
|
|
info: dict with pixels and action keys
|
|
"""
|
|
with torch.profiler.record_function("lewm.encode"):
|
|
pixels = info['pixels'].float()
|
|
b, t = pixels.shape[:2]
|
|
pixels = pixels.reshape(b * t, *pixels.shape[2:]) # flatten for encoding
|
|
output = self.encoder(pixels, interpolate_pos_encoding=True)
|
|
pixels_emb = output.last_hidden_state[:, 0] # cls token
|
|
emb = self.projector(pixels_emb)
|
|
info["emb"] = emb.reshape(b, t, -1)
|
|
|
|
if "action" in info:
|
|
info["act_emb"] = self.action_encoder(info["action"])
|
|
|
|
return info
|
|
|
|
def predict(self, emb, act_emb):
|
|
"""Predict next state embedding
|
|
emb: (B, T, D)
|
|
act_emb: (B, T, A_emb)
|
|
"""
|
|
with torch.profiler.record_function("lewm.predict"):
|
|
preds = self.predictor(emb, act_emb)
|
|
preds = self.pred_proj(preds)
|
|
return preds
|
|
|
|
####################
|
|
## Inference only ##
|
|
####################
|
|
|
|
def rollout(self, info, action_sequence, history_size: int = 3):
|
|
"""Rollout the model given an initial info dict and action sequence.
|
|
pixels: (B, S, T, C, H, W)
|
|
action_sequence: (B, S, T, action_dim)
|
|
- S is the number of action plan samples
|
|
- T is the time horizon
|
|
"""
|
|
with torch.profiler.record_function("lewm.rollout"):
|
|
assert "pixels" in info, "pixels not in info_dict"
|
|
if history_size < 1:
|
|
raise ValueError("history_size must be >= 1")
|
|
|
|
H = info["pixels"].size(2)
|
|
B, S, T = action_sequence.shape[:3]
|
|
if T < H:
|
|
raise ValueError(
|
|
f"action_sequence horizon ({T}) must be >= history length ({H})"
|
|
)
|
|
|
|
# Cache the encoded initial state across solver iterations.
|
|
init_emb = self._get_cached_init_emb(info)
|
|
HS = history_size
|
|
hist_len = min(HS, init_emb.size(1), H)
|
|
if hist_len < 1:
|
|
raise ValueError("rollout requires at least one history step")
|
|
|
|
init_hist = init_emb[:, -hist_len:]
|
|
init_hist = init_hist.unsqueeze(1).expand(-1, S, -1, -1)
|
|
init_hist = init_hist.reshape(B * S, hist_len, init_hist.size(-1)).contiguous()
|
|
|
|
flat_actions = action_sequence.contiguous().view(B * S, T, -1)
|
|
action_emb = self.action_encoder(flat_actions)
|
|
act_hist = action_emb[:, H - hist_len : H]
|
|
act_future = action_emb[:, H:]
|
|
|
|
if HS == 1:
|
|
emb_hist = init_hist[:, -1:]
|
|
act_emb_hist = act_hist[:, -1:]
|
|
|
|
for t in range(act_future.size(1)):
|
|
emb_hist = self.predict(emb_hist, act_emb_hist)[:, -1:]
|
|
act_emb_hist = act_future[:, t : t + 1]
|
|
|
|
pred_rollout = self.predict(emb_hist, act_emb_hist)[:, -1:]
|
|
else:
|
|
if torch.is_grad_enabled() and action_sequence.requires_grad:
|
|
emb_slots = init_hist.split(1, dim=1)
|
|
act_slots = act_hist.split(1, dim=1)
|
|
|
|
for t in range(act_future.size(1)):
|
|
emb_view = torch.cat(emb_slots[-HS:], dim=1)
|
|
act_view = torch.cat(act_slots[-HS:], dim=1)
|
|
pred_emb = self.predict(emb_view, act_view)[:, -1:]
|
|
next_act_emb = act_future[:, t : t + 1]
|
|
|
|
emb_slots = (*emb_slots[-(HS - 1) :], pred_emb)
|
|
act_slots = (*act_slots[-(HS - 1) :], next_act_emb)
|
|
|
|
emb_view = torch.cat(emb_slots[-HS:], dim=1)
|
|
act_view = torch.cat(act_slots[-HS:], dim=1)
|
|
pred_rollout = self.predict(emb_view, act_view)[:, -1:]
|
|
info["predicted_emb"] = pred_rollout.reshape(
|
|
B, S, *pred_rollout.shape[1:]
|
|
)
|
|
return info
|
|
|
|
emb_hist = init_hist.new_empty((B * S, HS, init_hist.size(-1)))
|
|
act_emb_hist = action_emb.new_empty((B * S, HS, action_emb.size(-1)))
|
|
emb_hist[:, :hist_len].copy_(init_hist)
|
|
act_emb_hist[:, :hist_len].copy_(act_hist)
|
|
|
|
history_order = torch.stack(
|
|
[
|
|
(torch.arange(HS, device=action_emb.device) + offset) % HS
|
|
for offset in range(HS)
|
|
]
|
|
)
|
|
filled = hist_len
|
|
next_slot = hist_len % HS
|
|
|
|
for t in range(act_future.size(1)):
|
|
if filled < HS:
|
|
emb_view = emb_hist[:, :filled]
|
|
act_view = act_emb_hist[:, :filled]
|
|
elif next_slot == 0:
|
|
emb_view = emb_hist
|
|
act_view = act_emb_hist
|
|
else:
|
|
order = history_order[next_slot]
|
|
emb_view = emb_hist.index_select(1, order)
|
|
act_view = act_emb_hist.index_select(1, order)
|
|
|
|
pred_emb = self.predict(emb_view, act_view)[:, -1:]
|
|
next_act_emb = act_future[:, t : t + 1]
|
|
emb_hist[:, next_slot : next_slot + 1].copy_(pred_emb)
|
|
act_emb_hist[:, next_slot : next_slot + 1].copy_(next_act_emb)
|
|
|
|
if filled < HS:
|
|
filled += 1
|
|
next_slot = (next_slot + 1) % HS
|
|
|
|
if filled < HS:
|
|
emb_view = emb_hist[:, :filled]
|
|
act_view = act_emb_hist[:, :filled]
|
|
elif next_slot == 0:
|
|
emb_view = emb_hist
|
|
act_view = act_emb_hist
|
|
else:
|
|
order = history_order[next_slot]
|
|
emb_view = emb_hist.index_select(1, order)
|
|
act_view = act_emb_hist.index_select(1, order)
|
|
|
|
pred_rollout = self.predict(emb_view, act_view)[:, -1:]
|
|
info["predicted_emb"] = pred_rollout.reshape(B, S, *pred_rollout.shape[1:])
|
|
|
|
return info
|
|
|
|
def criterion(self, info_dict: dict):
|
|
"""Compute the cost between predicted embeddings and goal embeddings."""
|
|
with torch.profiler.record_function("lewm.criterion"):
|
|
pred_emb = info_dict["predicted_emb"] # (B,S, T-1, dim)
|
|
goal_emb = info_dict["goal_emb"] # (B, S, T, dim)
|
|
if goal_emb.ndim == pred_emb.ndim - 1:
|
|
goal_emb = goal_emb.unsqueeze(1)
|
|
|
|
# return last-step cost per action candidate
|
|
cost = F.mse_loss(
|
|
pred_emb[..., -1:, :],
|
|
goal_emb[..., -1:, :].detach(),
|
|
reduction="none",
|
|
).sum(dim=tuple(range(2, pred_emb.ndim))) # (B, S)
|
|
|
|
return cost
|
|
|
|
def get_cost(self, info_dict: dict, action_candidates: torch.Tensor):
|
|
""" Compute the cost of action candidates given an info dict with goal and initial state."""
|
|
with torch.profiler.record_function("lewm.get_cost"):
|
|
assert "goal" in info_dict, "goal not in info_dict"
|
|
|
|
self._ensure_runtime_caches()
|
|
device = next(self.parameters()).device
|
|
info_dict = self._ensure_info_device(info_dict, device)
|
|
action_candidates = self._get_cached_device_tensor(
|
|
"_lewm_action_candidates",
|
|
action_candidates,
|
|
device,
|
|
ensure_contiguous=True,
|
|
)
|
|
|
|
info_dict["goal_emb"] = self._get_cached_goal_emb(info_dict)
|
|
info_dict = self.rollout(info_dict, action_candidates)
|
|
|
|
cost = self.criterion(info_dict)
|
|
|
|
return cost
|