Skip to content

Commit d3b25f8

Browse files
Add log_key_prefix parameter to LearningRateMonitor (#21612)
* Add `log_key_prefix` parameter to `LearningRateMonitor` callback Allow users to prepend a configurable prefix to all metric names logged by `LearningRateMonitor`. This is useful for grouping learning rate metrics in loggers like TensorBoard (e.g., `optim/lr-Adam` instead of `lr-Adam`). Fixes #21590 * Update src/lightning/pytorch/CHANGELOG.md --------- Co-authored-by: Deependu <deependujha21@gmail.com>
1 parent 612ab08 commit d3b25f8

File tree

3 files changed

+119
-2
lines changed

3 files changed

+119
-2
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
### Added
1212

13-
-
13+
- Added `log_key_prefix` parameter to `LearningRateMonitor` callback for prefixing logged metric names ([#21612](https://github.com/Lightning-AI/pytorch-lightning/issues/21612))
1414

1515
### Changed
1616

src/lightning/pytorch/callbacks/lr_monitor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ class LearningRateMonitor(Callback):
4646
has the ``momentum`` or ``betas`` attribute. Defaults to ``False``.
4747
log_weight_decay: option to also log the weight decay values of the optimizer. Defaults to
4848
``False``.
49+
log_key_prefix: optional string prefix to prepend to all logged metric names. Useful for
50+
grouping learning rate metrics in loggers like TensorBoard. For example, setting
51+
``log_key_prefix="optim/"`` would log ``optim/lr-Adam`` instead of ``lr-Adam``.
52+
Defaults to ``None`` (no prefix).
4953
5054
Raises:
5155
MisconfigurationException:
@@ -96,13 +100,15 @@ def __init__(
96100
logging_interval: Optional[Literal["step", "epoch"]] = None,
97101
log_momentum: bool = False,
98102
log_weight_decay: bool = False,
103+
log_key_prefix: Optional[str] = None,
99104
) -> None:
100105
if logging_interval not in (None, "step", "epoch"):
101106
raise MisconfigurationException("logging_interval should be `step` or `epoch` or `None`.")
102107

103108
self.logging_interval = logging_interval
104109
self.log_momentum = log_momentum
105110
self.log_weight_decay = log_weight_decay
111+
self.log_key_prefix = log_key_prefix or ""
106112

107113
self.lrs: dict[str, list[float]] = {}
108114
self.last_momentum_values: dict[str, Optional[list[float]]] = {}
@@ -361,4 +367,7 @@ def _check_duplicates_and_update_name(
361367
)
362368

363369
name = self._add_prefix(name, optimizer_cls, seen_optimizer_types)
364-
return [self._add_suffix(name, param_groups, i) for i in range(len(param_groups))]
370+
names = [self._add_suffix(name, param_groups, i) for i in range(len(param_groups))]
371+
if self.log_key_prefix:
372+
names = [f"{self.log_key_prefix}{n}" for n in names]
373+
return names

tests/tests_pytorch/callbacks/test_lr_monitor.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,3 +709,111 @@ def configure_optimizers(self):
709709
# Verify the callback metric tensor was created successfully
710710
assert "lr-SGD" in trainer.callback_metrics
711711
assert isinstance(trainer.callback_metrics["lr-SGD"], torch.Tensor)
712+
713+
714+
def test_lr_monitor_log_key_prefix(tmp_path):
715+
"""Test that learning rate metric names are correctly prefixed when log_key_prefix is set."""
716+
model = BoringModel()
717+
718+
lr_monitor = LearningRateMonitor(log_key_prefix="optim/")
719+
trainer = Trainer(
720+
default_root_dir=tmp_path,
721+
max_epochs=2,
722+
limit_val_batches=0.1,
723+
limit_train_batches=0.5,
724+
callbacks=[lr_monitor],
725+
logger=CSVLogger(tmp_path),
726+
)
727+
trainer.fit(model)
728+
729+
assert lr_monitor.lrs, "No learning rates logged"
730+
assert list(lr_monitor.lrs) == ["optim/lr-SGD"]
731+
assert "optim/lr-SGD" in trainer.callback_metrics
732+
733+
734+
def test_lr_monitor_log_key_prefix_with_momentum_and_weight_decay(tmp_path):
735+
"""Test that prefix is applied to momentum and weight decay metric names as well."""
736+
737+
class CustomModel(BoringModel):
738+
def configure_optimizers(self):
739+
optimizer = optim.Adam(self.parameters(), lr=1e-2, betas=(0.9, 0.999), weight_decay=0.01)
740+
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1)
741+
return [optimizer], [lr_scheduler]
742+
743+
model = CustomModel()
744+
lr_monitor = LearningRateMonitor(log_momentum=True, log_weight_decay=True, log_key_prefix="train/")
745+
trainer = Trainer(
746+
default_root_dir=tmp_path,
747+
max_epochs=2,
748+
limit_val_batches=2,
749+
limit_train_batches=5,
750+
log_every_n_steps=1,
751+
callbacks=[lr_monitor],
752+
logger=CSVLogger(tmp_path),
753+
)
754+
trainer.fit(model)
755+
756+
assert list(lr_monitor.lrs) == ["train/lr-Adam"]
757+
assert all(k == "train/lr-Adam-momentum" for k in lr_monitor.last_momentum_values)
758+
assert all(k == "train/lr-Adam-weight_decay" for k in lr_monitor.last_weight_decay_values)
759+
760+
761+
def test_lr_monitor_log_key_prefix_multi_optimizers(tmp_path):
762+
"""Test that prefix is applied correctly with multiple optimizers."""
763+
764+
class MultiOptModel(BoringModel):
765+
def __init__(self):
766+
super().__init__()
767+
self.automatic_optimization = False
768+
769+
def training_step(self, batch, batch_idx):
770+
opt1, opt2 = self.optimizers()
771+
772+
loss = self.loss(self.step(batch))
773+
opt1.zero_grad()
774+
self.manual_backward(loss)
775+
opt1.step()
776+
777+
loss = self.loss(self.step(batch))
778+
opt2.zero_grad()
779+
self.manual_backward(loss)
780+
opt2.step()
781+
782+
def configure_optimizers(self):
783+
optimizer1 = optim.Adam(self.parameters(), lr=1e-2)
784+
optimizer2 = optim.SGD(self.parameters(), lr=1e-2)
785+
return [optimizer1, optimizer2]
786+
787+
model = MultiOptModel()
788+
lr_monitor = LearningRateMonitor(log_key_prefix="hparams/")
789+
trainer = Trainer(
790+
default_root_dir=tmp_path,
791+
max_epochs=2,
792+
limit_val_batches=0.1,
793+
limit_train_batches=5,
794+
log_every_n_steps=1,
795+
callbacks=[lr_monitor],
796+
logger=CSVLogger(tmp_path),
797+
)
798+
trainer.fit(model)
799+
800+
assert lr_monitor.lrs, "No learning rates logged"
801+
assert list(lr_monitor.lrs) == ["hparams/lr-Adam", "hparams/lr-SGD"]
802+
803+
804+
def test_lr_monitor_log_key_prefix_none(tmp_path):
805+
"""Test that when log_key_prefix is None (default), metric names are unchanged."""
806+
model = BoringModel()
807+
808+
lr_monitor = LearningRateMonitor(log_key_prefix=None)
809+
trainer = Trainer(
810+
default_root_dir=tmp_path,
811+
max_epochs=2,
812+
limit_val_batches=0.1,
813+
limit_train_batches=0.5,
814+
callbacks=[lr_monitor],
815+
logger=CSVLogger(tmp_path),
816+
)
817+
trainer.fit(model)
818+
819+
assert list(lr_monitor.lrs) == ["lr-SGD"]

0 commit comments

Comments
 (0)