forked from Delta-R/CBFAN-Bearing-RUL-Prediction
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__init__.py
More file actions
57 lines (48 loc) · 2.15 KB
/
__init__.py
File metadata and controls
57 lines (48 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
from dataset.dataset import DGM4_Dataset
from dataset.randaugment import RandomAugment
def create_dataset(config):
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
train_transform = transforms.Compose([
RandomAugment(2, 7, isPIL=True, augs=['Identity', 'AutoContrast', 'Equalize', 'Brightness', 'Sharpness']),
transforms.ToTensor(),
normalize,
])
test_transform = transforms.Compose([
transforms.Resize((config['image_res'],config['image_res']),interpolation=Image.BICUBIC),
transforms.ToTensor(),
normalize,
])
train_dataset = DGM4_Dataset(config=config, ann_file=config['train_file'], transform=train_transform, max_words=config['max_words'], is_train=True)
val_dataset = DGM4_Dataset(config=config, ann_file=config['val_file'], transform=test_transform, max_words=config['max_words'], is_train=False)
return train_dataset, val_dataset
def create_sampler(datasets, shuffles, num_tasks, global_rank):
samplers = []
for dataset,shuffle in zip(datasets,shuffles):
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
samplers.append(sampler)
return samplers
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
loaders = []
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
if is_train:
shuffle = (sampler is None)
drop_last = True
else:
shuffle = False
drop_last = False
loader = DataLoader(
dataset,
batch_size=bs,
num_workers=n_worker,
pin_memory=True,
sampler=sampler,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last,
)
loaders.append(loader)
return loaders