Skip to content

feat: PEFT fine-tuning pipeline (LoRA/DoRA, multi-GPU) for TimesFM 2.5#398

Merged
siriuz42 merged 17 commits intogoogle-research:masterfrom
darkpowerxo:feat/peft-finetuning-pipeline-2.5
Apr 15, 2026
Merged

feat: PEFT fine-tuning pipeline (LoRA/DoRA, multi-GPU) for TimesFM 2.5#398
siriuz42 merged 17 commits intogoogle-research:masterfrom
darkpowerxo:feat/peft-finetuning-pipeline-2.5

Conversation

@darkpowerxo
Copy link
Copy Markdown
Contributor

Summary

Adds a full PEFT (Parameter-Efficient Fine-Tuning) pipeline for TimesFM 2.5 with LoRA and DoRA adapters, multi-GPU support via PyTorch DDP, and mixed-precision training. Also includes unit tests, bug fixes, documentation updates, and several cherry-picked community fixes.

What this adds

PEFT Pipeline (peft/)

peft/
├── __init__.py       ← public API exports
├── config.py         ← PEFTConfig dataclass (adapter type, rank, target modules, etc.)
├── adapters.py       ← LoRA/DoRA layers, inject/merge/save/load adapter weights
├── data.py           ← TimeSeriesDataset (sliding-window from arrays or DataFrames)
├── trainer.py        ← PEFTTrainer: DDP, AMP, cosine-warmup, early stopping, W&B logging
├── finetune.py       ← CLI entry-point (argparse)
├── finetune.sh       ← torchrun launch script with NUM_GPUS control
└── README.md         ← full documentation and usage examples

Key features:

  • LoRA and DoRA adapters targeting qkv_proj, output_proj, and ff_layer modules
  • num_adapter_layers parameter to freeze all but the last N transformer layers
  • Multi-GPU training via torchrun and DistributedDataParallel
  • Mixed-precision (FP16/BF16) via torch.amp
  • safetensors adapter-only checkpoints (~4 MB at rank-4)
  • Gradient checkpointing for memory-constrained setups

Unit Tests (tests/)

58 tests covering configs, torch layers, torch utilities, and base utilities:

tests/
├── __init__.py
├── test_base_utils.py
├── test_configs.py
├── test_torch_layers.py
└── test_torch_utils.py

Bug Fixes & Improvements

  • fix: per-input ridge regression in xreg_lib.py (both src/ and v1/) to prevent data leakage across batch items
  • fix: batch_size iteration in v1/src/timesfm/data_loader.py when permute=False
  • fix: correct SKILL.md link in README (treeblob path)
  • fix: typo compliedcompiled in configs.py
  • docs: fix swapped xreg_mode descriptions and covaraites typo in timesfm_2p5_base.py
  • refactor: printlogging.info in timesfm_2p5_torch.py
  • ci: upgrade GitHub Actions (checkout / setup-python v2 → v6)
  • docs: update README — replace "under construction" section with completed feature checklist

No breaking changes

All source-level modifications are backward-compatible:

  • configs.py, timesfm_2p5_base.py: docstring/typo fixes only
  • timesfm_2p5_torch.py: printlogging.info
  • xreg_lib.py: per-input regression refactor preserves the existing 5-tuple return signature

Testing

pytest tests/ -q
# 58 passed in 1.46s

All 58 tests pass. PEFT imports, import timesfm, and ForecastConfig all verified working.

Incorporated community PRs

PR Description
#396 fix: correct typo + replace print with logging
#394 test: add unit tests for configs, torch layers, utils
#393 fix: per-input ridge regression in xreg
#391 fix: batch_size in v1 data_loader
#390 fix: SKILL.md link in README
#367 ci: upgrade GitHub Actions to v6
#366 docs: fix swapped xreg_mode descriptions

- Replace timesfm_jax.egg-info/ with generic *.egg-info/ glob
- Add uv.lock (lockfile is environment-specific)
- Add peft_checkpoints/ (training output directory)
Implement production-grade PEFT adapters targeting the 2.5 architecture:

- LoRALinear: low-rank A/B decomposition with scaling (alpha/rank)
- DoRALinear: weight-decomposed LoRA (magnitude + direction)
- inject_adapters(): freezes base weights, wraps target nn.Linear modules
  - Supports fused QKV (qkv_proj), attention output, and FFN layers
  - num_adapter_layers controls how many top layers get adapters (0=all)
  - target_modules selects 'all', 'attention', or 'ffn'
- merge_adapters(): folds adapter deltas back into base nn.Linear
- save/load_adapter_weights(): safetensors adapter-only checkpoints
- PEFTConfig dataclass with all hyperparameters

References:
  LoRA — https://arxiv.org/abs/2106.09685
  DoRA — https://arxiv.org/abs/2402.09353
Sliding-window dataset that produces (context, mask, target) tuples:
- Accepts list of arrays, long-format, or wide-format DataFrames
- Context length auto-rounded to multiple of patch_len (32)
- Left-pads short series with proper masking
- Configurable stride for window overlap
PEFTTrainer with production-grade training loop:
- PyTorch DDP multi-GPU via torchrun
- Mixed-precision training (fp16/bf16) with GradScaler
- Gradient checkpointing for long contexts
- Cosine-with-warmup LR schedule
- MSE loss + optional pinball quantile loss (9 channels)
- Early stopping on validation loss
- Adapter-only checkpointing (safetensors)
- W&B logging (rank-0 only)
- Differentiable training forward that replicates the 2.5
  patch -> RevIN -> transformer -> output-head -> un-RevIN path
- finetune.py: argparse CLI with all config flags, CSV data loading,
  chronological train/val split, and full training pipeline
  Usage: python -m peft.finetune --data_path data.csv --value_col y
  Multi-GPU: torchrun --nproc_per_node=4 -m peft.finetune ...

- finetune.sh: env-var driven launch script for single/multi-GPU
  Usage: DATA_PATH=data.csv NUM_GPUS=4 bash peft/finetune.sh
Covers quick start, Python API, CLI reference, adapter loading/merging,
architecture overview with parameter counts, and file layout.
Apply changes from PR google-research#396 by @shahrukhx01:
- Fix typo 'complied' -> 'compiled' in ForecastConfig docstrings
- Replace bare print() with logging.info() in load_checkpoint()
Apply changes from PR google-research#394 by @cj-wong:
- tests/__init__.py: package marker
- tests/test_base_utils.py: strip_leading_nans + linear_interpolation tests
- tests/test_configs.py: frozen dataclass, defaults, replace, equality tests
- tests/test_torch_layers.py: ResidualBlock, RMSNorm, RandomFourierFeatures
- tests/test_torch_utils.py: update_running_stats, revin, DecodeCache tests
Apply changes from PR google-research#393 by @MarcoGorworworelli:
- Normalize covariates per-input instead of batch-wide to prevent
  each input's result from depending on batch composition
- Fit separate ridge regressions per time series instead of a single
  batched regression, preventing cross-series data leakage
- Applied to both src/timesfm/utils/xreg_lib.py and v1/src/timesfm/xreg_lib.py
Apply changes from PR google-research#391 by @MarcoGorworworelli:
- Fix train_gen() to iterate in proper batch_size chunks instead of
  yielding all time series at once when permute=False
- Add test_data_loader.py to verify batch boundaries
Apply changes from PR google-research#390 by @amansinghbais:
- Fix link to point to the actual SKILL.md file instead of the directory
Apply changes from PR google-research#367 by @Copilot:
- actions/checkout v2 -> v6
- actions/setup-python v2 -> v6
Apply changes from PR google-research#366 by @cj-wong:
- Correct xreg_mode docstring: descriptions for 'xreg + timesfm' and
  'timesfm + xreg' were swapped
- Fix 'covaraites' -> 'covariates' typo in error message
- Add Apr. 2026 update entry for PEFT pipeline, unit tests, and community fixes
- Replace 'under construction' numbered list with checklist of completed items:
  Flax model, covariate support, docs/examples, PEFT pipeline, unit tests
@google-cla
Copy link
Copy Markdown

google-cla bot commented Apr 8, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@darkpowerxo darkpowerxo force-pushed the feat/peft-finetuning-pipeline-2.5 branch from fb9b380 to ad192b7 Compare April 8, 2026 22:37
- Initialize LoRA parameters on the same device as the base linear layer
- Load adapter weights directly to the model device instead of hardcoded CPU
- Slice XReg linear regression outputs to match the specified sequence lengths

- Replace batch-wide covariate normalization with per-input normalization
  in create_covariate_matrix to prevent cross-input scale leakage.
- Refactor BatchedInContextXRegLinear.fit to solve ridge regression
  per instance rather than as a single global matrix solve, avoiding
  cross-contamination between batched inputs.
- Truncate JAX regression outputs to the actual train/test lengths
  after the padded matrix multiply, fixing shape mismatches for
  non-power-of-2 horizons (e.g. horizon=24 was returning 32 elements).
@darkpowerxo darkpowerxo force-pushed the feat/peft-finetuning-pipeline-2.5 branch from 9da7093 to 18d5eb2 Compare April 9, 2026 01:44
@darkpowerxo
Copy link
Copy Markdown
Contributor Author

@siriuz42 @kashif @heidmotron @sitic @abhidas

Can we get an idea of how long would it take to have the pr approved?
so i can start working on other PRs

@kashif
Copy link
Copy Markdown
Contributor

kashif commented Apr 9, 2026

Remove the custom peft/ directory (LoRA/DoRA adapters, trainer, data
pipeline) in favor of a lightweight fine-tuning example that uses the
standard HuggingFace Transformers + PEFT ecosystem.

The new example at timesfm-forecasting/examples/finetuning/ demonstrates
LoRA fine-tuning via TimesFm2_5ModelForPrediction and the peft library,
based on the approach by @kashif at HuggingFace.

- Remove peft/ (8 files)
- Add timesfm-forecasting/examples/finetuning/finetune_lora.py
- Add timesfm-forecasting/examples/finetuning/README.md
- Update README.md to reference new example
- Clean up .gitignore (remove peft_checkpoints/)
@darkpowerxo
Copy link
Copy Markdown
Contributor Author

Thanks @kashif — you're absolutely right, and great work on the Transformers integration! I wasn't aware of your notebook when I opened this PR.

I've reworked the PR accordingly: the custom PEFT pipeline is gone. Instead, this now ships a lightweight Python script finetune_lora.py that uses the standard Transformers + PEFT workflow — same dataset, same LoRA config, same training API as your notebook. Think of it as a runnable, CLI-friendly version of your example that lives alongside the other TimesFM examples in this repo.

The notebook is excellent for interactive exploration, but having a standalone script here makes fine-tuning more discoverable for users who start from this repo rather than the HuggingFace notebooks collection.

The rest of the PR includes bug fixes, CI upgrades, and documentation — all independent of the fine-tuning example. Happy to adjust anything based on your feedback.

@siriuz42
Copy link
Copy Markdown
Collaborator

siriuz42 commented Apr 9, 2026

Thanks @darkpowerxo! Working on it now.

@siriuz42
Copy link
Copy Markdown
Collaborator

siriuz42 commented Apr 10, 2026

Some initial comments:

  1. Let's not merge fix: per-input regression in forecast_with_covariates to prevent data leakage #393. The batch behavior is intended. One can run batch_size=1 inference to not cross the examples.

  2. Let's not merge fix(v1): honor batch size in train_gen without permutation #391. Let's keep the deprecated v1 code as is.

  3. Let's not merge Fix link to SKILL.md in README #390. It's intended to link to that page.

@siriuz42 siriuz42 self-requested a review April 10, 2026 00:30
@siriuz42
Copy link
Copy Markdown
Collaborator

@kashif Can we get a LGTM from you regarding the HF transformers part?

@darkpowerxo
Copy link
Copy Markdown
Contributor Author

Thanks for the review @siriuz42!

Done — I've reverted both:

#393 (per-input ridge regression in xreg) — understood, batch behavior is intended and users can run batch_size=1 for per-input semantics.
#390 (SKILL.md link) — reverted to the original directory link.
The remaining changes in this PR are:

Fine-tuning example (finetuning) — a CLI-friendly Python script using the standard HuggingFace Transformers + PEFT workflow (based on @kashif's notebook)
Unit tests (tests) for core layers, configs, and utilities
Bug fixes in configs and model code (e.g., missing freq attribute, batch dim assertions)
CI workflow upgrades (actions v6)
Documentation updates

torch_compile = kwargs["torch_compile"]
if torch_compile:
print("Compiling model...")
logging.info("Compiling model...")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep both. print can come handy in a notebook

Copy link
Copy Markdown
Collaborator

@siriuz42 siriuz42 Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In hindsight let's merge first.

@siriuz42 siriuz42 merged commit eacf761 into google-research:master Apr 15, 2026
1 check passed
@siriuz42
Copy link
Copy Markdown
Collaborator

Thanks for the contribution, darkpowerxo!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants