Skip to content

Commit 519481d

Browse files
committed
Respect caller timestamp options in batched transcribe
1 parent ed9a06c commit 519481d

2 files changed

Lines changed: 96 additions & 6 deletions

File tree

faster_whisper/transcribe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,14 +359,14 @@ def transcribe(
359359
condition_on_previous_text: If True, the previous output of the model is provided
360360
as a prompt for the next window; disabling may make the text inconsistent across
361361
windows, but the model becomes less prone to getting stuck in a failure loop,
362-
such as repetition looping or timestamps going out of sync. Set as False
362+
such as repetition looping or timestamps going out of sync.
363363
prompt_reset_on_temperature: Resets prompt if temperature is above this value.
364364
Arg has effect only if condition_on_previous_text is True. Set at 0.5
365365
prefix: Optional text to provide as a prefix at the beginning of each window.
366-
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
366+
max_initial_timestamp: The initial timestamp cannot be later than this.
367367
hallucination_silence_threshold: Optional[float]
368368
When word_timestamps is True, skip silent periods longer than this threshold
369-
(in seconds) when a possible hallucination is detected. set as None.
369+
(in seconds) when a possible hallucination is detected. Set as None.
370370
Returns:
371371
A tuple with:
372372
@@ -544,12 +544,12 @@ def transcribe(
544544
hotwords=hotwords,
545545
word_timestamps=word_timestamps,
546546
hallucination_silence_threshold=None,
547-
condition_on_previous_text=False,
547+
condition_on_previous_text=condition_on_previous_text,
548548
clip_timestamps=clip_timestamps,
549549
prompt_reset_on_temperature=0.5,
550550
multilingual=multilingual,
551551
without_timestamps=without_timestamps,
552-
max_initial_timestamp=0.0,
552+
max_initial_timestamp=max_initial_timestamp,
553553
)
554554

555555
info = TranscriptionInfo(

tests/test_transcribe.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,31 @@
11
import inspect
2+
import logging
23
import os
4+
from types import SimpleNamespace
35

46
import numpy as np
57

6-
from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio
8+
from faster_whisper import BatchedInferencePipeline, decode_audio, WhisperModel
9+
10+
11+
class _DummyFeatureExtractor:
12+
sampling_rate = 16000
13+
chunk_length = 30
14+
15+
def __call__(self, audio, chunk_length=None):
16+
return np.zeros((80, 4), dtype="float32")
17+
18+
19+
def _make_dummy_batched_model():
20+
logger = logging.getLogger("test.batched_options")
21+
22+
return SimpleNamespace(
23+
feature_extractor=_DummyFeatureExtractor(),
24+
frames_per_second=50,
25+
hf_tokenizer=object(),
26+
logger=logger,
27+
model=SimpleNamespace(is_multilingual=False),
28+
)
729

830

931
def test_supported_languages():
@@ -313,3 +335,71 @@ def test_cliptimestamps_timings(physcisworks_path):
313335
assert clip["start"] == segment.start
314336
assert clip["end"] == segment.end
315337
assert segment.text == transcript
338+
339+
340+
def test_batched_transcribe_respects_condition_on_previous_text(monkeypatch):
341+
pipeline = BatchedInferencePipeline(_make_dummy_batched_model())
342+
captured_options = []
343+
344+
monkeypatch.setattr(
345+
"faster_whisper.transcribe.Tokenizer",
346+
lambda *args, **kwargs: object(),
347+
)
348+
monkeypatch.setattr(
349+
pipeline,
350+
"_batched_segments_generator",
351+
lambda *args: captured_options.append(args[4]) or iter(()),
352+
)
353+
354+
audio = np.zeros(1600, dtype="float32")
355+
clip_timestamps = [{"start": 0.0, "end": 0.1}]
356+
357+
_, info = pipeline.transcribe(
358+
audio,
359+
language="en",
360+
clip_timestamps=clip_timestamps,
361+
condition_on_previous_text=True,
362+
suppress_tokens=[],
363+
)
364+
365+
assert info.transcription_options.condition_on_previous_text is True
366+
assert captured_options[0].condition_on_previous_text is True
367+
368+
captured_options.clear()
369+
370+
_, info = pipeline.transcribe(
371+
audio,
372+
language="en",
373+
clip_timestamps=clip_timestamps,
374+
condition_on_previous_text=False,
375+
suppress_tokens=[],
376+
)
377+
378+
assert info.transcription_options.condition_on_previous_text is False
379+
assert captured_options[0].condition_on_previous_text is False
380+
381+
382+
def test_batched_transcribe_respects_max_initial_timestamp(monkeypatch):
383+
pipeline = BatchedInferencePipeline(_make_dummy_batched_model())
384+
captured_options = []
385+
386+
monkeypatch.setattr(
387+
"faster_whisper.transcribe.Tokenizer",
388+
lambda *args, **kwargs: object(),
389+
)
390+
monkeypatch.setattr(
391+
pipeline,
392+
"_batched_segments_generator",
393+
lambda *args: captured_options.append(args[4]) or iter(()),
394+
)
395+
396+
_, info = pipeline.transcribe(
397+
np.zeros(1600, dtype="float32"),
398+
language="en",
399+
clip_timestamps=[{"start": 0.0, "end": 0.1}],
400+
max_initial_timestamp=1.7,
401+
suppress_tokens=[],
402+
)
403+
404+
assert info.transcription_options.max_initial_timestamp == 1.7
405+
assert captured_options[0].max_initial_timestamp == 1.7

0 commit comments

Comments
 (0)