Skip to content

Commit 76b5cfe

Browse files
fix: prevent division by zero in sampling when temperature is 0.0
1 parent d16b705 commit 76b5cfe

2 files changed

Lines changed: 14 additions & 13 deletions

File tree

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,6 @@ poetry.lock
2727
# Ignore generated docs
2828
docs/_build
2929
docs/api
30+
31+
# virtual environments
32+
.venv/

gemma/gm/text/_sampling.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ class RandomSampling(SamplingMethod):
5858

5959
@typechecked
6060
def get_next_tokens(self, logits: Float['*B V'], rng: PRNGKey) -> Int['*B']:
61-
return jax.random.categorical(rng, logits / self.temperature, axis=-1)
61+
scaled_logits = logits if self.temperature < 1e-6 else logits / self.temperature
62+
return jax.random.categorical(rng, scaled_logits, axis=-1)
6263

6364

6465
@dataclasses.dataclass(frozen=True, kw_only=True)
@@ -74,9 +75,8 @@ def get_next_tokens(self, logits: Float['*B V'], rng: PRNGKey) -> Int['*B']:
7475

7576
batch_size = logits.shape[0]
7677
topk_values, topk_indices = jax.lax.top_k(logits, self.k)
77-
sampled_topk_indices = jax.random.categorical(
78-
rng, topk_values / self.temperature, axis=-1
79-
)
78+
scaled_topk_values = topk_values if self.temperature < 1e-6 else topk_values / self.temperature
79+
sampled_topk_indices = jax.random.categorical(rng, scaled_topk_values, axis=-1)
8080
batch_indices = jnp.arange(batch_size)
8181
topk_indices = topk_indices[batch_indices, sampled_topk_indices]
8282
return enp.unflatten(topk_indices, batch_shape, '...')
@@ -91,11 +91,10 @@ class TopPSampling(SamplingMethod):
9191

9292
@typechecked
9393
def get_next_tokens(self, logits: Float['... V'], rng: PRNGKey) -> Int['...']:
94-
# temperature scaling
95-
logits = logits / self.temperature
94+
scaled_logits = logits if self.temperature < 1e-6 else logits / self.temperature
9695

9796
if self.p < 1.0:
98-
sorted_logits = jnp.sort(logits, axis=-1, descending=True)
97+
sorted_logits = jnp.sort(scaled_logits, axis=-1, descending=True)
9998

10099
cumulative_probs = jnp.cumsum(
101100
jax.nn.softmax(sorted_logits, axis=-1), axis=-1
@@ -108,11 +107,10 @@ def get_next_tokens(self, logits: Float['... V'], rng: PRNGKey) -> Int['...']:
108107
cutoff_logit = jnp.take_along_axis(sorted_logits, cutoff_index, axis=-1)
109108

110109
# select logit values that are smaller than the cutoff logit.
111-
logits = jnp.where(
112-
logits < cutoff_logit,
113-
jnp.finfo(logits.dtype).min,
114-
logits,
110+
scaled_logits = jnp.where(
111+
scaled_logits < cutoff_logit,
112+
jnp.finfo(scaled_logits.dtype).min,
113+
scaled_logits,
115114
)
116115

117-
return jax.random.categorical(rng, logits, axis=-1)
118-
116+
return jax.random.categorical(rng, scaled_logits, axis=-1)

0 commit comments

Comments
 (0)