-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloader.py
More file actions
16 lines (16 loc) · 727 Bytes
/
loader.py
File metadata and controls
16 lines (16 loc) · 727 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torchvision
import torch
img_size=64
transform = torchvision.transforms.Compose([
torchvision.transforms.Resize((img_size, img_size)),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
def train_loader_fn(batch_size):
'''
It loads the training dataset. Takes one parameter:
batch_size: The batch size to be used during training
'''
train_dataset = torchvision.datasets.ImageFolder(root="~/torch_datasets/img_align_celeba", transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
return train_loader