diff --git a/.gitignore b/.gitignore index 7dbdfee6a..3076b7d65 100644 --- a/.gitignore +++ b/.gitignore @@ -168,4 +168,7 @@ cython_debug/ #.idea/ # PyPI configuration file -.pypirc \ No newline at end of file +.pypirc + +# Models files added for offline diarization +models/*.bin diff --git a/README.md b/README.md index 08eb0757e..0f009330f 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ This repository provides fast automatic speech recognition (70x realtime with la

New🚨

+- Offline diarization support added! Use local models without Hugging Face API access. - 1st place at [Ego4d transcription challenge](https://eval.ai/web/challenges/challenge-page/1637/leaderboard/3931/WER) πŸ† - _WhisperX_ accepted at INTERSPEECH 2023 - v3 transcript segment-per-sentence: using nltk sent_tokenize for better subtitlting & better diarization @@ -110,6 +111,19 @@ To **enable Speaker Diarization**, include your Hugging Face access token (read) > **Note**
> As of Oct 11, 2023, there is a known issue regarding slow performance with pyannote/Speaker-Diarization-3.0 in whisperX. It is due to dependency conflicts between faster-whisper and pyannote-audio 3.0.0. Please see [this issue](https://github.com/m-bain/whisperX/issues/499) for more details and potential workarounds. +#### Offline Diarization + +You can now use offline diarization without needing a Hugging Face token. This requires: + +1. Download the necessary model files to a local `models` directory +2. Configure the diarization pipeline using the provided `models/pyannote_diarization_config.yaml` file + +To use offline diarization with the command line: + +```bash +whisperx path/to/audio.wav --model large-v2 --diarize --diarize_offline --diarize_config models/pyannote_diarization_config.yaml +``` +

Usage πŸ’¬ (command line)

### English @@ -137,6 +151,10 @@ To label the transcript with speaker ID's (set number of speakers if known e.g. whisperx path/to/audio.wav --model large-v2 --diarize --highlight_words True +To use offline diarization mode: + + whisperx path/to/audio.wav --model large-v2 --diarize --diarize_offline --diarize_config models/pyannote_diarization_config.yaml + To run on CPU instead of GPU (and for running on Mac OS X): whisperx path/to/audio.wav --compute_type int8 @@ -192,8 +210,15 @@ print(result["segments"]) # after alignment # import gc; gc.collect(); torch.cuda.empty_cache(); del model_a # 3. Assign speaker labels +# Option 1: HF-token based diarization (online) diarize_model = whisperx.DiarizationPipeline(use_auth_token=YOUR_HF_TOKEN, device=device) +# Option 2: Local/offline diarization +# diarize_model = whisperx.OfflineDiarizationPipeline( +# config_path="models/pyannote_diarization_config.yaml", +# device=device +# ) + # add min/max number of speakers if known diarize_segments = diarize_model(audio) # diarize_model(audio, min_speakers=min_speakers, max_speakers=max_speakers) @@ -267,6 +292,8 @@ Bug finding and pull requests are also highly appreciated to keep this project g * [x] Allow silero-vad as alternative VAD option +* [x] Add offline diarization support + * [ ] Improve diarization (word level). *Harder than first thought...* diff --git a/models/pyannote_diarization_config.yaml b/models/pyannote_diarization_config.yaml new file mode 100644 index 000000000..d54512b4e --- /dev/null +++ b/models/pyannote_diarization_config.yaml @@ -0,0 +1,19 @@ +version: 3.1.0 + +pipeline: + name: pyannote.audio.pipelines.SpeakerDiarization + params: + clustering: AgglomerativeClustering + embedding: models/pyannote_model_wespeaker-voxceleb-resnet34-LM.bin + embedding_batch_size: 32 + embedding_exclude_overlap: true + segmentation: models/pyannote_model_segmentation-3.0.bin + segmentation_batch_size: 32 + +params: + clustering: + method: centroid + min_cluster_size: 12 + threshold: 0.7045654963945799 + segmentation: + min_duration_off: 0.0 diff --git a/whisperx/offline_diarize.py b/whisperx/offline_diarize.py new file mode 100644 index 000000000..cd6d201f4 --- /dev/null +++ b/whisperx/offline_diarize.py @@ -0,0 +1,65 @@ +import os +from pathlib import Path +import numpy as np +import pandas as pd +import torch +from pyannote.audio import Pipeline + +from .audio import load_audio, SAMPLE_RATE + +class OfflineDiarizationPipeline: + def __init__( + self, + config_path, + device="cpu", + ): + if isinstance(device, str): + device = torch.device(device) + + # Load the pipeline with local config + self.model = self._load_pipeline_from_pretrained(config_path).to(device) + + def _load_pipeline_from_pretrained(self, path_to_config): + path_to_config = Path(path_to_config) + + if not path_to_config.exists(): + raise FileNotFoundError(f"Config file not found: {path_to_config}") + + print(f"Loading pyannote pipeline from {path_to_config}...") + # the paths in the config are relative to the current working directory + # so we need to change the working directory to the model path + # and then change it back + + cwd = Path.cwd().resolve() # store current working directory + + # first .parent is the folder of the config, second .parent is the folder containing the 'models' folder + cd_to = path_to_config.parent.parent.resolve() + + print(f"Changing working directory to {cd_to}") + os.chdir(cd_to) + + pipeline = Pipeline.from_pretrained(path_to_config) + + print(f"Changing working directory back to {cwd}") + os.chdir(cwd) + + return pipeline + + def __call__( + self, + audio, + num_speakers=None, + min_speakers=None, + max_speakers=None, + ): + if isinstance(audio, str): + audio = load_audio(audio) + audio_data = { + 'waveform': torch.from_numpy(audio[None, :]), + 'sample_rate': SAMPLE_RATE + } + segments = self.model(audio_data, num_speakers=num_speakers, min_speakers=min_speakers, max_speakers=max_speakers) + diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker']) + diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start) + diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end) + return diarize_df diff --git a/whisperx/transcribe.py b/whisperx/transcribe.py index a67787c1b..f256a9843 100644 --- a/whisperx/transcribe.py +++ b/whisperx/transcribe.py @@ -10,6 +10,11 @@ from whisperx.asr import load_model from whisperx.audio import load_audio from whisperx.diarize import DiarizationPipeline, assign_word_speakers +# Import the offline diarization pipeline +try: + from .offline_diarize import OfflineDiarizationPipeline +except ImportError: + OfflineDiarizationPipeline = None from whisperx.types import AlignedTranscriptionResult, TranscriptionResult from whisperx.utils import ( LANGUAGES, @@ -56,6 +61,9 @@ def cli(): parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word") parser.add_argument("--min_speakers", default=None, type=int, help="Minimum number of speakers to in audio file") parser.add_argument("--max_speakers", default=None, type=int, help="Maximum number of speakers to in audio file") + # Add offline diarization params + parser.add_argument("--diarize_offline", action="store_true", help="Use offline diarization models instead of downloading from HF") + parser.add_argument("--diarize_config", type=str, default="models/pyannote_diarization_config.yaml", help="Path to the diarization config file for offline mode") parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling") parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature") @@ -120,6 +128,8 @@ def cli(): chunk_size: int = args.pop("chunk_size") diarize: bool = args.pop("diarize") + diarize_offline: bool = args.pop("diarize_offline") + diarize_config: str = args.pop("diarize_config") min_speakers: int = args.pop("min_speakers") max_speakers: int = args.pop("max_speakers") print_progress: bool = args.pop("print_progress") @@ -238,12 +248,24 @@ def cli(): # >> Diarize if diarize: - if hf_token is None: - print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...") tmp_results = results print(">>Performing diarization...") results = [] - diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) + + if diarize_offline: + if OfflineDiarizationPipeline is None: + raise ImportError("offline_diarize.py must be in the same directory as transcribe.py. Please ensure it's properly installed.") + + print(f"Using offline diarization with config: {diarize_config}") + diarize_model = OfflineDiarizationPipeline( + config_path=diarize_config, + device=device + ) + else: + if hf_token is None: + print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...") + diarize_model = DiarizationPipeline(use_auth_token=hf_token, device=device) + for result, input_audio_path in tmp_results: diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers) result = assign_word_speakers(diarize_segments, result)