-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathutils.py
More file actions
77 lines (67 loc) · 2.42 KB
/
utils.py
File metadata and controls
77 lines (67 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
import random
import numpy as np
import cv2
from tqdm import tqdm
import torch
from sklearn.utils import shuffle
""" Seeding the randomness. """
def seeding(seed):
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
""" Create a directory """
def create_dir(path):
if not os.path.exists(path):
os.makedirs(path)
""" Shuffle the dataset. """
def shuffling(x, y):
x, y = shuffle(x, y, random_state=42)
return x, y
def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
def load_checkpoint(model, optimizer, checkpoint_path):
if os.path.exists(checkpoint_path):
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
print(f"Checkpoint loaded from epoch {epoch}")
return model, optimizer, epoch
else:
print("No checkpoint found, starting from scratch.")
return model, optimizer, 0
def save_checkpoint(model, optimizer, epoch, loss, checkpoint_path):
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, checkpoint_path)
print(f"Checkpoint saved at epoch {epoch}")
class EarlyStopping:
def __init__(self, patience=5, verbose=False, delta=0):
self.patience = patience
self.verbose = verbose
self.delta = delta
self.counter = 0
self.best_loss = float('inf')
self.early_stop = False
def __call__(self, val_loss, model, optimizer, epoch, checkpoint_path):
if self.best_loss - val_loss > self.delta:
self.best_loss = val_loss
self.counter = 0
save_checkpoint(model, optimizer, epoch, val_loss, checkpoint_path)
else:
self.counter += 1
if self.verbose:
print(f"EarlyStopping counter: {self.counter} of {self.patience}")
if self.counter >= self.patience:
self.early_stop = True
print("Early stopping triggered")