2121import typing
2222from typing import Literal
2323
24+ import dialog
2425from etils import enp
2526from gemma .gm .data import _functional
2627from gemma .gm .nn import _transformer_like
4344# * Mode which queue the prompts and compute them asynchronously ?
4445# * Mode which yields tokens as they get predicted ?
4546
47+ type _Prompt = (
48+ # Prompts can be:
49+ str
50+ | Sequence [str ]
51+ | dialog .Conversation
52+ | Sequence [dialog .Conversation ]
53+ )
54+
4655
4756@dataclasses .dataclass (frozen = True , kw_only = True )
4857class SamplerOutput :
@@ -158,7 +167,7 @@ def __post_init__(self):
158167 @typing .overload
159168 def sample (
160169 self ,
161- prompt : str ,
170+ prompt : str | dialog . Conversation ,
162171 * ,
163172 images : UInt8 ['N? H W C' ] | None = ...,
164173 max_new_tokens : int | None = ...,
@@ -175,7 +184,7 @@ def sample(
175184 @typing .overload
176185 def sample (
177186 self ,
178- prompt : Sequence [str ],
187+ prompt : Sequence [str | dialog . Conversation ],
179188 * ,
180189 images : Sequence [UInt8 ['N H W C' ]] | None = ...,
181190 max_new_tokens : int | None = ...,
@@ -193,7 +202,7 @@ def sample(
193202 @typing .overload
194203 def sample (
195204 self ,
196- prompt : str | Sequence [ str ] ,
205+ prompt : _Prompt ,
197206 * ,
198207 images : UInt8 ['B? N? H W C' ] | None = ...,
199208 max_new_tokens : int | None = ...,
@@ -258,8 +267,8 @@ def sample(
258267 ```
259268
260269 Args:
261- prompt: Prompt to sample from. Can be a single string or a list of
262- strings .
270+ prompt: Prompt(s) to sample from. Can be a single string or
271+ `dialog.Conversation` or a list of those .
263272 images: Images for the prompt. The position where the image should be
264273 inserted in the prompt is determined by the `<start_of_image>` token in
265274 the prompt.
@@ -403,13 +412,13 @@ def _get_inputs(
403412
404413 def _tokenize_prompts (
405414 self ,
406- prompt : str | Sequence [ str ] ,
415+ prompt : _Prompt ,
407416 * ,
408417 add_bos : bool ,
409418 pad_length : int | None = None ,
410419 ) -> Float ['B L' ]:
411420 """Encode the prompts."""
412- prompt = _normalize_prompt (prompt )
421+ prompt = _normalize_prompt (prompt , format = self . tokenizer . FORMAT )
413422 tokens = [self .tokenizer .encode (p , add_bos = add_bos ) for p in prompt ]
414423
415424 # Notice that if pad_length exceeds the maximum length of the prompts,
@@ -500,9 +509,9 @@ def _normalize_tokens(
500509 return tuple (_normalize_token (self .tokenizer , t ) for t in tokens )
501510
502511
503- def _get_has_batch_dim (prompt : str | Sequence [ str ] ) -> bool :
512+ def _get_has_batch_dim (prompt : _Prompt ) -> bool :
504513 """Returns whether the prompt batched or not."""
505- if isinstance (prompt , str ):
514+ if isinstance (prompt , str | dialog . Conversation ):
506515 return False
507516 elif _is_str_array (prompt ): # Scalar str array.
508517 assert isinstance (prompt , np .ndarray )
@@ -511,17 +520,23 @@ def _get_has_batch_dim(prompt: str | Sequence[str]) -> bool:
511520 return True
512521
513522
514- def _normalize_prompt (prompt : str | Sequence [ str ] ) -> list [str ]:
523+ def _normalize_prompt (prompt : _Prompt , format : dialog . Format ) -> list [str ]: # pylint: disable=redefined-builtin
515524 """Normalize the inputs."""
516525 if _is_str_array (prompt ): # Supports batched input array
517526 assert isinstance (prompt , np .ndarray )
518527 prompt = prompt .tolist ()
519528
520- if isinstance (prompt , str ):
529+ if isinstance (prompt , str | dialog . Conversation ):
521530 prompt = [prompt ]
522531 else :
523532 prompt = list (prompt )
524533
534+ # Normalize the prompt to strings.
535+ prompt = [
536+ c .as_text (format = format ) if isinstance (c , dialog .Conversation ) else c
537+ for c in prompt
538+ ]
539+
525540 return prompt
526541
527542
0 commit comments