File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 4646
4747# TODO(epot): Should be part of core Kauldron
4848@dataclasses .dataclass (frozen = True )
49- class LoadCheckpoint (kd .ckpts .AbstractPartialLoader ):
49+ class LoadCheckpoint (kd .ckpts .InitTransform ):
5050 """Loads weights from a Gemma checkpoint.
5151
5252 Note: The checpoint only contains the Gemma transformer weights, not the
Original file line number Diff line number Diff line change 3434
3535
3636@dataclasses .dataclass (frozen = True )
37- class SkipLoRA (kd .ckpts .AbstractPartialLoader ):
37+ class SkipLoRA (kd .ckpts .InitTransform ):
3838 """Wraps a partial loader to not restore the LoRA weights."""
3939
40- wrapped : kd .ckpts .AbstractPartialLoader
40+ wrapped : kd .ckpts .InitTransform
4141
4242 def transform (self , state : _StateT ) -> _StateT : # pytype: disable=signature-mismatch
4343 # Remove the LoRA weights from the params structure so it can be restored
Original file line number Diff line number Diff line change 2323
2424
2525@dataclasses .dataclass (frozen = True , kw_only = True )
26- class AnchoredPolicyLoader (kd .ckpts .AbstractPartialLoader ):
26+ class AnchoredPolicyLoader (kd .ckpts .InitTransform ):
2727 """Loader for `gm.nn.AnchoredPolicy` models.
2828
2929 Loaded load policy and anchor separately by providing
@@ -33,8 +33,8 @@ class AnchoredPolicyLoader(kd.ckpts.AbstractPartialLoader):
3333 modifying the rest of the state.
3434 """
3535
36- policy : kd .ckpts .AbstractPartialLoader
37- anchor : kd .ckpts .AbstractPartialLoader | None = None
36+ policy : kd .ckpts .InitTransform
37+ anchor : kd .ckpts .InitTransform | None = None
3838
3939 def transform (self , state : kd .train .TrainState ) -> kd .train .TrainState :
4040 if set (state .params .keys ()) != {'policy' , 'anchor' }:
You can’t perform that action at this time.
0 commit comments