@@ -58,7 +58,8 @@ class RandomSampling(SamplingMethod):
5858
5959 @typechecked
6060 def get_next_tokens (self , logits : Float ['*B V' ], rng : PRNGKey ) -> Int ['*B' ]:
61- return jax .random .categorical (rng , logits / self .temperature , axis = - 1 )
61+ scaled_logits = logits if self .temperature < 1e-6 else logits / self .temperature
62+ return jax .random .categorical (rng , scaled_logits , axis = - 1 )
6263
6364
6465@dataclasses .dataclass (frozen = True , kw_only = True )
@@ -74,9 +75,8 @@ def get_next_tokens(self, logits: Float['*B V'], rng: PRNGKey) -> Int['*B']:
7475
7576 batch_size = logits .shape [0 ]
7677 topk_values , topk_indices = jax .lax .top_k (logits , self .k )
77- sampled_topk_indices = jax .random .categorical (
78- rng , topk_values / self .temperature , axis = - 1
79- )
78+ scaled_topk_values = topk_values if self .temperature < 1e-6 else topk_values / self .temperature
79+ sampled_topk_indices = jax .random .categorical (rng , scaled_topk_values , axis = - 1 )
8080 batch_indices = jnp .arange (batch_size )
8181 topk_indices = topk_indices [batch_indices , sampled_topk_indices ]
8282 return enp .unflatten (topk_indices , batch_shape , '...' )
@@ -91,11 +91,10 @@ class TopPSampling(SamplingMethod):
9191
9292 @typechecked
9393 def get_next_tokens (self , logits : Float ['... V' ], rng : PRNGKey ) -> Int ['...' ]:
94- # temperature scaling
95- logits = logits / self .temperature
94+ scaled_logits = logits if self .temperature < 1e-6 else logits / self .temperature
9695
9796 if self .p < 1.0 :
98- sorted_logits = jnp .sort (logits , axis = - 1 , descending = True )
97+ sorted_logits = jnp .sort (scaled_logits , axis = - 1 , descending = True )
9998
10099 cumulative_probs = jnp .cumsum (
101100 jax .nn .softmax (sorted_logits , axis = - 1 ), axis = - 1
@@ -108,11 +107,10 @@ def get_next_tokens(self, logits: Float['... V'], rng: PRNGKey) -> Int['...']:
108107 cutoff_logit = jnp .take_along_axis (sorted_logits , cutoff_index , axis = - 1 )
109108
110109 # select logit values that are smaller than the cutoff logit.
111- logits = jnp .where (
112- logits < cutoff_logit ,
113- jnp .finfo (logits .dtype ).min ,
114- logits ,
110+ scaled_logits = jnp .where (
111+ scaled_logits < cutoff_logit ,
112+ jnp .finfo (scaled_logits .dtype ).min ,
113+ scaled_logits ,
115114 )
116115
117- return jax .random .categorical (rng , logits , axis = - 1 )
118-
116+ return jax .random .categorical (rng , scaled_logits , axis = - 1 )
0 commit comments