Skip to content

IterableDataset with CORRECT length causes validation loop to be skipped #19624

@mattcleigh

Description

@mattcleigh

Bug description

This is related to this issue:
#10290

Whereby an IterableDataset with a length defined wont trigger a validation epoch, even if the defined length is correct so long as the following conditions met:

  1. Accurate length of IterableDataset defined
  2. Dataset accurately split between multiple workers with no overlap
  3. Drop last = True for the dataloader
  4. Dataset size does not evenly divide into the batches

In this instance multiple workers may be left with an incomplete batch right at the end of the training epoch. So the number of "dropped batches" exceeds 1. Then the dataloader will raise a StopIteration before the length is reached, causing the validation epoch to be skipped.

This is standard PyTorch behavior as the collation function is called per worker in an IterableDataset.
pytorch/pytorch#33413

I am having this issue right now my current fix is artificially subtract from the length of my IterableDataset to account for this. Unfortunately I really would like the length to be defined, so can't set it to inf which was the hotfix in the previous thread.
The progress bar is useful for me to judge which partition I need to run certain jobs on plus I use the dataset length to sync up my cyclic learning rate with the number of steps in an epoch.

What version are you seeing the problem on?

master

How to reproduce the bug

import torch as T
import numpy as np
from torch.utils.data import DataLoader, Dataset, IterableDataset, get_worker_info

from lightning import LightningModule, Trainer, LightningDataModule

nwrkrs = 4
drop = True

class Data(IterableDataset):
    def __init__(self) -> None:
        super().__init__()
        self.data = np.random.rand(100, 10).astype(np.float32)

    def __len__(self) -> int:
        return len(self.data)

    def __iter__(self):
        worker_info = get_worker_info()
        worker_id = 0 if worker_info is None else worker_info.id
        num_workers = 1 if worker_info is None else worker_info.num_workers
        worker_samples = np.array_split(self.data, num_workers)[worker_id]

        for i in worker_samples:
            yield i

class Model(LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.layer = T.nn.Linear(10, 1)
        self.did_validation = False

    def forward(self, x: T.Tensor) -> T.Tensor:
        return self.layer(x)

    def training_step(self, batch):
        return self(batch).mean()

    def validation_step(self, batch):
        self.did_validation = True
        return self(batch).mean()

    def configure_optimizers(self):
        return T.optim.Adam(self.parameters())

model = Model()
trainer = Trainer(logger=False, max_epochs=2, num_sanity_val_steps=0)
train_loader = DataLoader(Data(), num_workers=nwrkrs, batch_size=32, drop_last=drop)
valid_loader = DataLoader(Data(), num_workers=nwrkrs, batch_size=32, drop_last=drop)
trainer.fit(model, train_loader, valid_loader)
print("Performed validation:", model.did_validation)

Setting up the code above and running it with the following settings gives these results:

nwrkrs = 0, drop = True

Performed validation: True

nwrkrs = 4, drop = False

Performed validation: True

nwrkrs = 4, drop = True

Performed validation: False

cc @justusschock @awaelchli

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions