65 lines
893 B
YAML
65 lines
893 B
YAML
defaults:
|
|
- _self_
|
|
- data: pusht
|
|
|
|
output_model_name: lewm
|
|
subdir: ${hydra:job.id}
|
|
|
|
num_workers: 6
|
|
train_split: 0.9
|
|
seed: 3072
|
|
img_size: 224
|
|
patch_size: 14
|
|
encoder_scale: tiny
|
|
dump_object: True
|
|
|
|
trainer:
|
|
max_epochs: 100
|
|
devices: auto
|
|
accelerator: gpu
|
|
precision: bf16
|
|
gradient_clip_val: 1.0
|
|
|
|
loader:
|
|
batch_size: 128
|
|
num_workers: ${num_workers}
|
|
persistent_workers: True
|
|
prefetch_factor: 3
|
|
pin_memory: True
|
|
|
|
optimizer:
|
|
type: AdamW
|
|
lr: 5e-5
|
|
weight_decay: 1e-3
|
|
|
|
wandb:
|
|
enabled: True
|
|
config:
|
|
entity: lewm
|
|
project: lewm
|
|
name: ${output_model_name}
|
|
id: ${subdir}
|
|
resume: allow
|
|
log_model: False
|
|
|
|
wm:
|
|
type: lewm
|
|
history_size: 3
|
|
num_preds: 1
|
|
embed_dim: 192
|
|
|
|
predictor:
|
|
depth: 6
|
|
heads: 16
|
|
mlp_dim: 2048
|
|
dim_head: 64
|
|
dropout: 0.1
|
|
emb_dropout: 0.0
|
|
|
|
loss:
|
|
sigreg:
|
|
weight: 0.09
|
|
kwargs:
|
|
knots: 17
|
|
num_proj: 1024
|