Skip to content

Commit bd0dabf

Browse files
ConchylicultorThe gemma Authors
authored andcommitted
Rename AbstractPartialLoader to InitTransform
PiperOrigin-RevId: 869688526
1 parent b790cdc commit bd0dabf

3 files changed

Lines changed: 6 additions & 6 deletions

File tree

gemma/gm/ckpts/_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646

4747
# TODO(epot): Should be part of core Kauldron
4848
@dataclasses.dataclass(frozen=True)
49-
class LoadCheckpoint(kd.ckpts.AbstractPartialLoader):
49+
class LoadCheckpoint(kd.ckpts.InitTransform):
5050
"""Loads weights from a Gemma checkpoint.
5151
5252
Note: The checpoint only contains the Gemma transformer weights, not the

gemma/gm/ckpts/_lora.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434

3535

3636
@dataclasses.dataclass(frozen=True)
37-
class SkipLoRA(kd.ckpts.AbstractPartialLoader):
37+
class SkipLoRA(kd.ckpts.InitTransform):
3838
"""Wraps a partial loader to not restore the LoRA weights."""
3939

40-
wrapped: kd.ckpts.AbstractPartialLoader
40+
wrapped: kd.ckpts.InitTransform
4141

4242
def transform(self, state: _StateT) -> _StateT: # pytype: disable=signature-mismatch
4343
# Remove the LoRA weights from the params structure so it can be restored

gemma/gm/ckpts/_policy.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
@dataclasses.dataclass(frozen=True, kw_only=True)
26-
class AnchoredPolicyLoader(kd.ckpts.AbstractPartialLoader):
26+
class AnchoredPolicyLoader(kd.ckpts.InitTransform):
2727
"""Loader for `gm.nn.AnchoredPolicy` models.
2828
2929
Loaded load policy and anchor separately by providing
@@ -33,8 +33,8 @@ class AnchoredPolicyLoader(kd.ckpts.AbstractPartialLoader):
3333
modifying the rest of the state.
3434
"""
3535

36-
policy: kd.ckpts.AbstractPartialLoader
37-
anchor: kd.ckpts.AbstractPartialLoader | None = None
36+
policy: kd.ckpts.InitTransform
37+
anchor: kd.ckpts.InitTransform | None = None
3838

3939
def transform(self, state: kd.train.TrainState) -> kd.train.TrainState:
4040
if set(state.params.keys()) != {'policy', 'anchor'}:

0 commit comments

Comments
 (0)