diff --git a/.gitignore b/.gitignore
index 7dbdfee6a..3076b7d65 100644
--- a/.gitignore
+++ b/.gitignore
@@ -168,4 +168,7 @@ cython_debug/
#.idea/
# PyPI configuration file
-.pypirc
\ No newline at end of file
+.pypirc
+
+# Models files added for offline diarization
+models/*.bin
diff --git a/README.md b/README.md
index 08eb0757e..0f009330f 100644
--- a/README.md
+++ b/README.md
@@ -54,6 +54,7 @@ This repository provides fast automatic speech recognition (70x realtime with la
Newπ¨
+- Offline diarization support added! Use local models without Hugging Face API access.
- 1st place at [Ego4d transcription challenge](https://eval.ai/web/challenges/challenge-page/1637/leaderboard/3931/WER) π
- _WhisperX_ accepted at INTERSPEECH 2023
- v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization
@@ -110,6 +111,19 @@ To **enable Speaker Diarization**, include your Hugging Face access token (read)
> **Note**
> As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds.
+#### Offline Diarization
+
+You can now use offline diarization without needing a Hugging Face token. This requires:
+
+1. Download the necessary model files to a local `models` directory
+2. Configure the diarization pipeline using the provided `models/pyannote_diarization_config.yaml` file
+
+To use offline diarization with the command line:
+
+```bash
+whisperx path/to/audio.wav --model large-v2 --diarize --diarize_offline --diarize_config models/pyannote_diarization_config.yaml
+```
+
Usage π¬ (command line)
### English
@@ -137,6 +151,10 @@ To label the transcript with speaker ID's (set number of speakers if known e.g.
whisperx path/to/audio.wav --model large-v2 --diarize --highlight_words True
+To use offline diarization mode:
+
+ whisperx path/to/audio.wav --model large-v2 --diarize --diarize_offline --diarize_config models/pyannote_diarization_config.yaml
+
To run on CPU instead of GPU (and for running on Mac OS X):
whisperx path/to/audio.wav --compute_type int8
@@ -192,8 +210,15 @@ print(result["segments"]) # after alignment
# import gc; gc.collect(); torch.cuda.empty_cache(); del model_a
# 3. Assign speaker labels
+# Option 1: HF-token based diarization (online)
diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device)
+# Option 2: Local/offline diarization
+# diarize_model = whisperx.OfflineDiarizationPipeline(
+# config_path="models/pyannote_diarization_config.yaml",
+# device=device
+# )
+
# add min/max number of speakers if known
diarize_segments = diarize_model(audio)
# diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
@@ -267,6 +292,8 @@ Bug finding and pull requests are also highly appreciated to keep this project g
* [x] Allow silero-vad as alternative VAD option
+* [x] Add offline diarization support
+
* [ ] Improve diarization (word level). *Harder than first thought...*
diff --git a/models/pyannote_diarization_config.yaml b/models/pyannote_diarization_config.yaml
new file mode 100644
index 000000000..d54512b4e
--- /dev/null
+++ b/models/pyannote_diarization_config.yaml
@@ -0,0 +1,19 @@
+version: 3.1.0
+
+pipeline:
+ name: pyannote.audio.pipelines.SpeakerDiarization
+ params:
+ clustering: AgglomerativeClustering
+ embedding: models/pyannote_model_wespeaker-voxceleb-resnet34-LM.bin
+ embedding_batch_size: 32
+ embedding_exclude_overlap: true
+ segmentation: models/pyannote_model_segmentation-3.0.bin
+ segmentation_batch_size: 32
+
+params:
+ clustering:
+ method: centroid
+ min_cluster_size: 12
+ threshold: 0.7045654963945799
+ segmentation:
+ min_duration_off: 0.0
diff --git a/whisperx/offline_diarize.py b/whisperx/offline_diarize.py
new file mode 100644
index 000000000..cd6d201f4
--- /dev/null
+++ b/whisperx/offline_diarize.py
@@ -0,0 +1,65 @@
+import os
+from pathlib import Path
+import numpy as np
+import pandas as pd
+import torch
+from pyannote.audio import Pipeline
+
+from .audio import load_audio, SAMPLE_RATE
+
+class OfflineDiarizationPipeline:
+ def __init__(
+ self,
+ config_path,
+ device="cpu",
+ ):
+ if isinstance(device, str):
+ device = torch.device(device)
+
+ # Load the pipeline with local config
+ self.model = self._load_pipeline_from_pretrained(config_path).to(device)
+
+ def _load_pipeline_from_pretrained(self, path_to_config):
+ path_to_config = Path(path_to_config)
+
+ if not path_to_config.exists():
+ raise FileNotFoundError(f"Config file not found: {path_to_config}")
+
+ print(f"Loading pyannote pipeline from {path_to_config}...")
+ # the paths in the config are relative to the current working directory
+ # so we need to change the working directory to the model path
+ # and then change it back
+
+ cwd = Path.cwd().resolve() # store current working directory
+
+ # first .parent is the folder of the config, second .parent is the folder containing the 'models' folder
+ cd_to = path_to_config.parent.parent.resolve()
+
+ print(f"Changing working directory to {cd_to}")
+ os.chdir(cd_to)
+
+ pipeline = Pipeline.from_pretrained(path_to_config)
+
+ print(f"Changing working directory back to {cwd}")
+ os.chdir(cwd)
+
+ return pipeline
+
+ def __call__(
+ self,
+ audio,
+ num_speakers=None,
+ min_speakers=None,
+ max_speakers=None,
+ ):
+ if isinstance(audio, str):
+ audio = load_audio(audio)
+ audio_data = {
+ 'waveform': torch.from_numpy(audio[None, :]),
+ 'sample_rate': SAMPLE_RATE
+ }
+ segments = self.model(audio_data, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
+ diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
+ diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
+ diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
+ return diarize_df
diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py
index a67787c1b..f256a9843 100644
--- a/whisperx/transcribe.py
+++ b/whisperx/transcribe.py
@@ -10,6 +10,11 @@
from whisperx.asr import load_model
from whisperx.audio import load_audio
from whisperx.diarize import DiarizationPipeline, assign_word_speakers
+# Import the offline diarization pipeline
+try:
+ from .offline_diarize import OfflineDiarizationPipeline
+except ImportError:
+ OfflineDiarizationPipeline = None
from whisperx.types import AlignedTranscriptionResult, TranscriptionResult
from whisperx.utils import (
LANGUAGES,
@@ -56,6 +61,9 @@ def cli():
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file")
parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file")
+ # Add offline diarization params
+ parser.add_argument("--diarize_offline", action="store_true", help="Use offline diarization models instead of downloading from HF")
+ parser.add_argument("--diarize_config", type=str, default="models/pyannote_diarization_config.yaml", help="Path to the diarization config file for offline mode")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
@@ -120,6 +128,8 @@ def cli():
chunk_size: int = args.pop("chunk_size")
diarize: bool = args.pop("diarize")
+ diarize_offline: bool = args.pop("diarize_offline")
+ diarize_config: str = args.pop("diarize_config")
min_speakers: int = args.pop("min_speakers")
max_speakers: int = args.pop("max_speakers")
print_progress: bool = args.pop("print_progress")
@@ -238,12 +248,24 @@ def cli():
# >> Diarize
if diarize:
- if hf_token is None:
- print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
tmp_results = results
print(">>Performing diarization...")
results = []
- diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
+
+ if diarize_offline:
+ if OfflineDiarizationPipeline is None:
+ raise ImportError("offline_diarize.py must be in the same directory as transcribe.py. Please ensure it's properly installed.")
+
+ print(f"Using offline diarization with config: {diarize_config}")
+ diarize_model = OfflineDiarizationPipeline(
+ config_path=diarize_config,
+ device=device
+ )
+ else:
+ if hf_token is None:
+ print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
+ diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device)
+
for result, input_audio_path in tmp_results:
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
result = assign_word_speakers(diarize_segments, result)