-
Notifications
You must be signed in to change notification settings - Fork 3.7k
IterableDataset with CORRECT length causes validation loop to be skipped #19624
Description
Bug description
This is related to this issue:
#10290
Whereby an IterableDataset with a length defined wont trigger a validation epoch, even if the defined length is correct so long as the following conditions met:
- Accurate length of IterableDataset defined
- Dataset accurately split between multiple workers with no overlap
- Drop last = True for the dataloader
- Dataset size does not evenly divide into the batches
In this instance multiple workers may be left with an incomplete batch right at the end of the training epoch. So the number of "dropped batches" exceeds 1. Then the dataloader will raise a StopIteration before the length is reached, causing the validation epoch to be skipped.
This is standard PyTorch behavior as the collation function is called per worker in an IterableDataset.
pytorch/pytorch#33413
I am having this issue right now my current fix is artificially subtract from the length of my IterableDataset to account for this. Unfortunately I really would like the length to be defined, so can't set it to inf which was the hotfix in the previous thread.
The progress bar is useful for me to judge which partition I need to run certain jobs on plus I use the dataset length to sync up my cyclic learning rate with the number of steps in an epoch.
What version are you seeing the problem on?
master
How to reproduce the bug
import torch as T
import numpy as np
from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info
from lightning import LightningModule, Trainer, LightningDataModule
nwrkrs = 4
drop = True
class Data(IterableDataset):
def __init__(self) -> None:
super().__init__()
self.data = np.random.rand(100, 10).astype(np.float32)
def __len__(self) -> int:
return len(self.data)
def __iter__(self):
worker_info = get_worker_info()
worker_id = 0 if worker_info is None else worker_info.id
num_workers = 1 if worker_info is None else worker_info.num_workers
worker_samples = np.array_split(self.data, num_workers)[worker_id]
for i in worker_samples:
yield i
class Model(LightningModule):
def __init__(self) -> None:
super().__init__()
self.layer = T.nn.Linear(10, 1)
self.did_validation = False
def forward(self, x: T.Tensor) -> T.Tensor:
return self.layer(x)
def training_step(self, batch):
return self(batch).mean()
def validation_step(self, batch):
self.did_validation = True
return self(batch).mean()
def configure_optimizers(self):
return T.optim.Adam(self.parameters())
model = Model()
trainer = Trainer(logger=False, max_epochs=2, num_sanity_val_steps=0)
train_loader = DataLoader(Data(), num_workers=nwrkrs, batch_size=32, drop_last=drop)
valid_loader = DataLoader(Data(), num_workers=nwrkrs, batch_size=32, drop_last=drop)
trainer.fit(model, train_loader, valid_loader)
print("Performed validation:", model.did_validation)Setting up the code above and running it with the following settings gives these results:
nwrkrs = 0, drop = TruePerformed validation: True
nwrkrs = 4, drop = FalsePerformed validation: True
nwrkrs = 4, drop = TruePerformed validation: False