feat: PEFT fine-tuning pipeline (LoRA/DoRA, multi-GPU) for TimesFM 2.5#398
Conversation
- Replace timesfm_jax.egg-info/ with generic *.egg-info/ glob - Add uv.lock (lockfile is environment-specific) - Add peft_checkpoints/ (training output directory)
Implement production-grade PEFT adapters targeting the 2.5 architecture: - LoRALinear: low-rank A/B decomposition with scaling (alpha/rank) - DoRALinear: weight-decomposed LoRA (magnitude + direction) - inject_adapters(): freezes base weights, wraps target nn.Linear modules - Supports fused QKV (qkv_proj), attention output, and FFN layers - num_adapter_layers controls how many top layers get adapters (0=all) - target_modules selects 'all', 'attention', or 'ffn' - merge_adapters(): folds adapter deltas back into base nn.Linear - save/load_adapter_weights(): safetensors adapter-only checkpoints - PEFTConfig dataclass with all hyperparameters References: LoRA — https://arxiv.org/abs/2106.09685 DoRA — https://arxiv.org/abs/2402.09353
Sliding-window dataset that produces (context, mask, target) tuples: - Accepts list of arrays, long-format, or wide-format DataFrames - Context length auto-rounded to multiple of patch_len (32) - Left-pads short series with proper masking - Configurable stride for window overlap
PEFTTrainer with production-grade training loop: - PyTorch DDP multi-GPU via torchrun - Mixed-precision training (fp16/bf16) with GradScaler - Gradient checkpointing for long contexts - Cosine-with-warmup LR schedule - MSE loss + optional pinball quantile loss (9 channels) - Early stopping on validation loss - Adapter-only checkpointing (safetensors) - W&B logging (rank-0 only) - Differentiable training forward that replicates the 2.5 patch -> RevIN -> transformer -> output-head -> un-RevIN path
- finetune.py: argparse CLI with all config flags, CSV data loading, chronological train/val split, and full training pipeline Usage: python -m peft.finetune --data_path data.csv --value_col y Multi-GPU: torchrun --nproc_per_node=4 -m peft.finetune ... - finetune.sh: env-var driven launch script for single/multi-GPU Usage: DATA_PATH=data.csv NUM_GPUS=4 bash peft/finetune.sh
Covers quick start, Python API, CLI reference, adapter loading/merging, architecture overview with parameter counts, and file layout.
Apply changes from PR google-research#396 by @shahrukhx01: - Fix typo 'complied' -> 'compiled' in ForecastConfig docstrings - Replace bare print() with logging.info() in load_checkpoint()
Apply changes from PR google-research#394 by @cj-wong: - tests/__init__.py: package marker - tests/test_base_utils.py: strip_leading_nans + linear_interpolation tests - tests/test_configs.py: frozen dataclass, defaults, replace, equality tests - tests/test_torch_layers.py: ResidualBlock, RMSNorm, RandomFourierFeatures - tests/test_torch_utils.py: update_running_stats, revin, DecodeCache tests
Apply changes from PR google-research#393 by @MarcoGorworworelli: - Normalize covariates per-input instead of batch-wide to prevent each input's result from depending on batch composition - Fit separate ridge regressions per time series instead of a single batched regression, preventing cross-series data leakage - Applied to both src/timesfm/utils/xreg_lib.py and v1/src/timesfm/xreg_lib.py
Apply changes from PR google-research#391 by @MarcoGorworworelli: - Fix train_gen() to iterate in proper batch_size chunks instead of yielding all time series at once when permute=False - Add test_data_loader.py to verify batch boundaries
Apply changes from PR google-research#390 by @amansinghbais: - Fix link to point to the actual SKILL.md file instead of the directory
Apply changes from PR google-research#367 by @Copilot: - actions/checkout v2 -> v6 - actions/setup-python v2 -> v6
Apply changes from PR google-research#366 by @cj-wong: - Correct xreg_mode docstring: descriptions for 'xreg + timesfm' and 'timesfm + xreg' were swapped - Fix 'covaraites' -> 'covariates' typo in error message
- Add Apr. 2026 update entry for PEFT pipeline, unit tests, and community fixes - Replace 'under construction' numbered list with checklist of completed items: Flax model, covariate support, docs/examples, PEFT pipeline, unit tests
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
fb9b380 to
ad192b7
Compare
- Initialize LoRA parameters on the same device as the base linear layer - Load adapter weights directly to the model device instead of hardcoded CPU - Slice XReg linear regression outputs to match the specified sequence lengths - Replace batch-wide covariate normalization with per-input normalization in create_covariate_matrix to prevent cross-input scale leakage. - Refactor BatchedInContextXRegLinear.fit to solve ridge regression per instance rather than as a single global matrix solve, avoiding cross-contamination between batched inputs. - Truncate JAX regression outputs to the actual train/test lengths after the padded matrix multiply, fixing shape mismatches for non-power-of-2 horizons (e.g. horizon=24 was returning 32 elements).
9da7093 to
18d5eb2
Compare
|
@siriuz42 @kashif @heidmotron @sitic @abhidas Can we get an idea of how long would it take to have the pr approved? |
|
its already supported see: https://github.com/huggingface/notebooks/blob/main/examples/timesfm2_5.ipynb |
Remove the custom peft/ directory (LoRA/DoRA adapters, trainer, data pipeline) in favor of a lightweight fine-tuning example that uses the standard HuggingFace Transformers + PEFT ecosystem. The new example at timesfm-forecasting/examples/finetuning/ demonstrates LoRA fine-tuning via TimesFm2_5ModelForPrediction and the peft library, based on the approach by @kashif at HuggingFace. - Remove peft/ (8 files) - Add timesfm-forecasting/examples/finetuning/finetune_lora.py - Add timesfm-forecasting/examples/finetuning/README.md - Update README.md to reference new example - Clean up .gitignore (remove peft_checkpoints/)
|
Thanks @kashif — you're absolutely right, and great work on the Transformers integration! I wasn't aware of your notebook when I opened this PR. I've reworked the PR accordingly: the custom PEFT pipeline is gone. Instead, this now ships a lightweight Python script finetune_lora.py that uses the standard Transformers + PEFT workflow — same dataset, same LoRA config, same training API as your notebook. Think of it as a runnable, CLI-friendly version of your example that lives alongside the other TimesFM examples in this repo. The notebook is excellent for interactive exploration, but having a standalone script here makes fine-tuning more discoverable for users who start from this repo rather than the HuggingFace notebooks collection. The rest of the PR includes bug fixes, CI upgrades, and documentation — all independent of the fine-tuning example. Happy to adjust anything based on your feedback. |
|
Thanks @darkpowerxo! Working on it now. |
|
Some initial comments:
|
|
@kashif Can we get a LGTM from you regarding the HF transformers part? |
…le-research#390 (SKILL.md link) per maintainer feedback
|
Thanks for the review @siriuz42! Done — I've reverted both: #393 (per-input ridge regression in xreg) — understood, batch behavior is intended and users can run batch_size=1 for per-input semantics. Fine-tuning example (finetuning) — a CLI-friendly Python script using the standard HuggingFace Transformers + PEFT workflow (based on @kashif's notebook) |
| torch_compile = kwargs["torch_compile"] | ||
| if torch_compile: | ||
| print("Compiling model...") | ||
| logging.info("Compiling model...") |
There was a problem hiding this comment.
Let's keep both. print can come handy in a notebook
There was a problem hiding this comment.
In hindsight let's merge first.
|
Thanks for the contribution, darkpowerxo! |
Summary
Adds a full PEFT (Parameter-Efficient Fine-Tuning) pipeline for TimesFM 2.5 with LoRA and DoRA adapters, multi-GPU support via PyTorch DDP, and mixed-precision training. Also includes unit tests, bug fixes, documentation updates, and several cherry-picked community fixes.
What this adds
PEFT Pipeline (
peft/)Key features:
qkv_proj,output_proj, andff_layermodulesnum_adapter_layersparameter to freeze all but the last N transformer layerstorchrunandDistributedDataParalleltorch.ampUnit Tests (
tests/)58 tests covering configs, torch layers, torch utilities, and base utilities:
Bug Fixes & Improvements
xreg_lib.py(bothsrc/andv1/) to prevent data leakage across batch itemsbatch_sizeiteration inv1/src/timesfm/data_loader.pywhenpermute=Falsetree→blobpath)complied→compiledinconfigs.pyxreg_modedescriptions andcovaraitestypo intimesfm_2p5_base.pyprint→logging.infointimesfm_2p5_torch.pycheckout/setup-pythonv2 → v6)No breaking changes
All source-level modifications are backward-compatible:
configs.py,timesfm_2p5_base.py: docstring/typo fixes onlytimesfm_2p5_torch.py:print→logging.infoxreg_lib.py: per-input regression refactor preserves the existing 5-tuple return signatureTesting
All 58 tests pass. PEFT imports,
import timesfm, andForecastConfigall verified working.Incorporated community PRs