Skip to content

Commit b790cdc

Browse files
Etienne PotThe gemma Authors
authored andcommitted
Support dialog in the standard sampler
PiperOrigin-RevId: 869246989
1 parent 1ded5e5 commit b790cdc

1 file changed

Lines changed: 26 additions & 11 deletions

File tree

gemma/gm/text/_sampler.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import typing
2222
from typing import Literal
2323

24+
import dialog
2425
from etils import enp
2526
from gemma.gm.data import _functional
2627
from gemma.gm.nn import _transformer_like
@@ -43,6 +44,14 @@
4344
# * Mode which queue the prompts and compute them asynchronously ?
4445
# * Mode which yields tokens as they get predicted ?
4546

47+
type _Prompt = (
48+
# Prompts can be:
49+
str
50+
| Sequence[str]
51+
| dialog.Conversation
52+
| Sequence[dialog.Conversation]
53+
)
54+
4655

4756
@dataclasses.dataclass(frozen=True, kw_only=True)
4857
class SamplerOutput:
@@ -158,7 +167,7 @@ def __post_init__(self):
158167
@typing.overload
159168
def sample(
160169
self,
161-
prompt: str,
170+
prompt: str | dialog.Conversation,
162171
*,
163172
images: UInt8['N? H W C'] | None = ...,
164173
max_new_tokens: int | None = ...,
@@ -175,7 +184,7 @@ def sample(
175184
@typing.overload
176185
def sample(
177186
self,
178-
prompt: Sequence[str],
187+
prompt: Sequence[str | dialog.Conversation],
179188
*,
180189
images: Sequence[UInt8['N H W C']] | None = ...,
181190
max_new_tokens: int | None = ...,
@@ -193,7 +202,7 @@ def sample(
193202
@typing.overload
194203
def sample(
195204
self,
196-
prompt: str | Sequence[str],
205+
prompt: _Prompt,
197206
*,
198207
images: UInt8['B? N? H W C'] | None = ...,
199208
max_new_tokens: int | None = ...,
@@ -258,8 +267,8 @@ def sample(
258267
```
259268
260269
Args:
261-
prompt: Prompt to sample from. Can be a single string or a list of
262-
strings.
270+
prompt: Prompt(s) to sample from. Can be a single string or
271+
`dialog.Conversation` or a list of those.
263272
images: Images for the prompt. The position where the image should be
264273
inserted in the prompt is determined by the `<start_of_image>` token in
265274
the prompt.
@@ -403,13 +412,13 @@ def _get_inputs(
403412

404413
def _tokenize_prompts(
405414
self,
406-
prompt: str | Sequence[str],
415+
prompt: _Prompt,
407416
*,
408417
add_bos: bool,
409418
pad_length: int | None = None,
410419
) -> Float['B L']:
411420
"""Encode the prompts."""
412-
prompt = _normalize_prompt(prompt)
421+
prompt = _normalize_prompt(prompt, format=self.tokenizer.FORMAT)
413422
tokens = [self.tokenizer.encode(p, add_bos=add_bos) for p in prompt]
414423

415424
# Notice that if pad_length exceeds the maximum length of the prompts,
@@ -500,9 +509,9 @@ def _normalize_tokens(
500509
return tuple(_normalize_token(self.tokenizer, t) for t in tokens)
501510

502511

503-
def _get_has_batch_dim(prompt: str | Sequence[str]) -> bool:
512+
def _get_has_batch_dim(prompt: _Prompt) -> bool:
504513
"""Returns whether the prompt batched or not."""
505-
if isinstance(prompt, str):
514+
if isinstance(prompt, str | dialog.Conversation):
506515
return False
507516
elif _is_str_array(prompt): # Scalar str array.
508517
assert isinstance(prompt, np.ndarray)
@@ -511,17 +520,23 @@ def _get_has_batch_dim(prompt: str | Sequence[str]) -> bool:
511520
return True
512521

513522

514-
def _normalize_prompt(prompt: str | Sequence[str]) -> list[str]:
523+
def _normalize_prompt(prompt: _Prompt, format: dialog.Format) -> list[str]: # pylint: disable=redefined-builtin
515524
"""Normalize the inputs."""
516525
if _is_str_array(prompt): # Supports batched input array
517526
assert isinstance(prompt, np.ndarray)
518527
prompt = prompt.tolist()
519528

520-
if isinstance(prompt, str):
529+
if isinstance(prompt, str | dialog.Conversation):
521530
prompt = [prompt]
522531
else:
523532
prompt = list(prompt)
524533

534+
# Normalize the prompt to strings.
535+
prompt = [
536+
c.as_text(format=format) if isinstance(c, dialog.Conversation) else c
537+
for c in prompt
538+
]
539+
525540
return prompt
526541

527542

0 commit comments

Comments
 (0)