Initial commit
This commit is contained in:
57
utils.py
Normal file
57
utils.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from stable_pretraining import data as dt
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
|
||||
def get_img_preprocessor(source: str, target: str, img_size: int = 224):
|
||||
imagenet_stats = dt.dataset_stats.ImageNet
|
||||
to_image = dt.transforms.ToImage(**imagenet_stats, source=source, target=target)
|
||||
resize = dt.transforms.Resize(img_size, source=source, target=target)
|
||||
return dt.transforms.Compose(to_image, resize)
|
||||
|
||||
|
||||
def get_column_normalizer(dataset, source: str, target: str):
|
||||
"""Get normalizer for a specific column in the dataset."""
|
||||
col_data = dataset.get_col_data(source)
|
||||
data = torch.from_numpy(np.array(col_data))
|
||||
data = data[~torch.isnan(data).any(dim=1)]
|
||||
mean = data.mean(0, keepdim=True).clone()
|
||||
std = data.std(0, keepdim=True).clone()
|
||||
|
||||
def norm_fn(x):
|
||||
return ((x - mean) / std).float()
|
||||
|
||||
normalizer = dt.transforms.WrapTorchTransform(norm_fn, source=source, target=target)
|
||||
return normalizer
|
||||
|
||||
class ModelObjectCallBack(Callback):
|
||||
"""Callback to pickle model object after each epoch."""
|
||||
|
||||
def __init__(self, dirpath, filename="model_object", epoch_interval: int = 1):
|
||||
super().__init__()
|
||||
self.dirpath = Path(dirpath)
|
||||
self.filename = filename
|
||||
self.epoch_interval = epoch_interval
|
||||
|
||||
def on_train_epoch_end(self, trainer, pl_module):
|
||||
super().on_train_epoch_end(trainer, pl_module)
|
||||
|
||||
output_path = (
|
||||
self.dirpath
|
||||
/ f"{self.filename}_epoch_{trainer.current_epoch + 1}_object.ckpt"
|
||||
)
|
||||
|
||||
if trainer.is_global_zero:
|
||||
if (trainer.current_epoch + 1) % self.epoch_interval == 0:
|
||||
self._dump_model(pl_module.model, output_path)
|
||||
|
||||
# save final epoch
|
||||
if (trainer.current_epoch + 1) == trainer.max_epochs:
|
||||
self._dump_model(pl_module.model, output_path)
|
||||
|
||||
def _dump_model(self, model, path):
|
||||
try:
|
||||
torch.save(model, path)
|
||||
except Exception as e:
|
||||
print(f"Error saving model object: {e}")
|
||||
Reference in New Issue
Block a user