|
1 | 1 | import inspect |
| 2 | +import logging |
2 | 3 | import os |
| 4 | +from types import SimpleNamespace |
3 | 5 |
|
4 | 6 | import numpy as np |
5 | 7 |
|
6 | | -from faster_whisper import BatchedInferencePipeline, WhisperModel, decode_audio |
| 8 | +from faster_whisper import BatchedInferencePipeline, decode_audio, WhisperModel |
| 9 | + |
| 10 | + |
| 11 | +class _DummyFeatureExtractor: |
| 12 | + sampling_rate = 16000 |
| 13 | + chunk_length = 30 |
| 14 | + |
| 15 | + def __call__(self, audio, chunk_length=None): |
| 16 | + return np.zeros((80, 4), dtype="float32") |
| 17 | + |
| 18 | + |
| 19 | +def _make_dummy_batched_model(): |
| 20 | + logger = logging.getLogger("test.batched_options") |
| 21 | + |
| 22 | + return SimpleNamespace( |
| 23 | + feature_extractor=_DummyFeatureExtractor(), |
| 24 | + frames_per_second=50, |
| 25 | + hf_tokenizer=object(), |
| 26 | + logger=logger, |
| 27 | + model=SimpleNamespace(is_multilingual=False), |
| 28 | + ) |
7 | 29 |
|
8 | 30 |
|
9 | 31 | def test_supported_languages(): |
@@ -313,3 +335,71 @@ def test_cliptimestamps_timings(physcisworks_path): |
313 | 335 | assert clip["start"] == segment.start |
314 | 336 | assert clip["end"] == segment.end |
315 | 337 | assert segment.text == transcript |
| 338 | + |
| 339 | + |
| 340 | +def test_batched_transcribe_respects_condition_on_previous_text(monkeypatch): |
| 341 | + pipeline = BatchedInferencePipeline(_make_dummy_batched_model()) |
| 342 | + captured_options = [] |
| 343 | + |
| 344 | + monkeypatch.setattr( |
| 345 | + "faster_whisper.transcribe.Tokenizer", |
| 346 | + lambda *args, **kwargs: object(), |
| 347 | + ) |
| 348 | + monkeypatch.setattr( |
| 349 | + pipeline, |
| 350 | + "_batched_segments_generator", |
| 351 | + lambda *args: captured_options.append(args[4]) or iter(()), |
| 352 | + ) |
| 353 | + |
| 354 | + audio = np.zeros(1600, dtype="float32") |
| 355 | + clip_timestamps = [{"start": 0.0, "end": 0.1}] |
| 356 | + |
| 357 | + _, info = pipeline.transcribe( |
| 358 | + audio, |
| 359 | + language="en", |
| 360 | + clip_timestamps=clip_timestamps, |
| 361 | + condition_on_previous_text=True, |
| 362 | + suppress_tokens=[], |
| 363 | + ) |
| 364 | + |
| 365 | + assert info.transcription_options.condition_on_previous_text is True |
| 366 | + assert captured_options[0].condition_on_previous_text is True |
| 367 | + |
| 368 | + captured_options.clear() |
| 369 | + |
| 370 | + _, info = pipeline.transcribe( |
| 371 | + audio, |
| 372 | + language="en", |
| 373 | + clip_timestamps=clip_timestamps, |
| 374 | + condition_on_previous_text=False, |
| 375 | + suppress_tokens=[], |
| 376 | + ) |
| 377 | + |
| 378 | + assert info.transcription_options.condition_on_previous_text is False |
| 379 | + assert captured_options[0].condition_on_previous_text is False |
| 380 | + |
| 381 | + |
| 382 | +def test_batched_transcribe_respects_max_initial_timestamp(monkeypatch): |
| 383 | + pipeline = BatchedInferencePipeline(_make_dummy_batched_model()) |
| 384 | + captured_options = [] |
| 385 | + |
| 386 | + monkeypatch.setattr( |
| 387 | + "faster_whisper.transcribe.Tokenizer", |
| 388 | + lambda *args, **kwargs: object(), |
| 389 | + ) |
| 390 | + monkeypatch.setattr( |
| 391 | + pipeline, |
| 392 | + "_batched_segments_generator", |
| 393 | + lambda *args: captured_options.append(args[4]) or iter(()), |
| 394 | + ) |
| 395 | + |
| 396 | + _, info = pipeline.transcribe( |
| 397 | + np.zeros(1600, dtype="float32"), |
| 398 | + language="en", |
| 399 | + clip_timestamps=[{"start": 0.0, "end": 0.1}], |
| 400 | + max_initial_timestamp=1.7, |
| 401 | + suppress_tokens=[], |
| 402 | + ) |
| 403 | + |
| 404 | + assert info.transcription_options.max_initial_timestamp == 1.7 |
| 405 | + assert captured_options[0].max_initial_timestamp == 1.7 |
0 commit comments