diff --git a/notebooks/minicpm-o-4.5/README.md b/notebooks/minicpm-o-4.5/README.md index d086e4992d1..82c5bd747f4 100644 --- a/notebooks/minicpm-o-4.5/README.md +++ b/notebooks/minicpm-o-4.5/README.md @@ -58,6 +58,6 @@ For details, please refer to [Installation Guide](../../README.md). ⚠️ **EXPERIMENTAL NOTEBOOK** -This notebook demonstrates a model that has not been fully validated with OpenVINO and is using a custom branch of optimum-intel. It may be fully supported and validated in the future. +This notebook demonstrates a model that has not been fully validated with OpenVINO. It may be fully supported and validated in the future. diff --git a/notebooks/minicpm-o-4.5/minicpm-o-4.5.ipynb b/notebooks/minicpm-o-4.5/minicpm-o-4.5.ipynb index e1225820d26..60efa6dfedf 100644 --- a/notebooks/minicpm-o-4.5/minicpm-o-4.5.ipynb +++ b/notebooks/minicpm-o-4.5/minicpm-o-4.5.ipynb @@ -43,7 +43,7 @@ "\n", "⚠️ **EXPERIMENTAL NOTEBOOK**\n", "\n", - "This notebook demonstrates a model that has not been fully validated with OpenVINO and is using a custom branch of optimum-intel. It may be fully supported and validated in the future.\n", + "This notebook demonstrates a model that has not been fully validated with OpenVINO. It may be fully supported and validated in the future.\n", "\n", "\n", "\n", @@ -190,7 +190,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "0999436d", "metadata": {}, "outputs": [], @@ -206,7 +206,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "8db2e50e", "metadata": {}, "outputs": [ @@ -251,7 +251,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "1b640aac", "metadata": {}, "outputs": [ @@ -266,7 +266,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "D:\\openvino_notebooks\\minicpm_test\\Lib\\site-packages\\huggingface_hub\\file_download.py:986: UserWarning: `local_dir_use_symlinks` parameter is deprecated and will be ignored. The process to download files to a local folder has been updated and do not rely on symlinks anymore. You only need to pass a destination folder as`local_dir`.\n", + "D:\\openvino_notebooks\\py_env\\Lib\\site-packages\\huggingface_hub\\file_download.py:979: UserWarning: `local_dir_use_symlinks` parameter is deprecated and will be ignored. The process to download files to a local folder has been updated and do not rely on symlinks anymore. You only need to pass a destination folder as`local_dir`.\n", "For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.\n", " warnings.warn(\n" ] @@ -274,7 +274,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "99fb056681c240e08ccab07f12270607", + "model_id": "3666882db2304cee8f6dab9b25912149", "version_major": 2, "version_minor": 0 }, @@ -347,14 +347,14 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "ff859149", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "2ca7425b6c3547eb91817b9a1119a791", + "model_id": "b94a2c078a264b0c938fc6018cb52246", "version_major": 2, "version_minor": 0 }, @@ -362,7 +362,7 @@ "Dropdown(description='Device:', options=('CPU', 'GPU', 'AUTO'), value='CPU')" ] }, - "execution_count": 5, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -377,14 +377,14 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "60e68498", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "72062ca0687d463d86088c74983a72ec", + "model_id": "b392389b89184e67acc4a3e967c65727", "version_major": 2, "version_minor": 0 }, @@ -392,7 +392,7 @@ "Dropdown(description='Device:', options=('CPU', 'GPU', 'AUTO'), value='CPU')" ] }, - "execution_count": 6, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -416,7 +416,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "243deffa", "metadata": {}, "outputs": [ @@ -424,9 +424,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "D:\\openvino_notebooks\\minicpm_test\\Lib\\site-packages\\librosa\\util\\files.py:10: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", + "D:\\openvino_notebooks\\py_env\\Lib\\site-packages\\librosa\\util\\files.py:10: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n", " from pkg_resources import resource_filename\n", - "D:\\openvino_notebooks\\minicpm_test\\Lib\\site-packages\\transformers\\models\\auto\\image_processing_auto.py:604: FutureWarning: The image_processor_class argument is deprecated and will be removed in v4.42. Please use `slow_image_processor_class`, or `fast_image_processor_class` instead\n", + "D:\\openvino_notebooks\\py_env\\Lib\\site-packages\\transformers\\models\\auto\\image_processing_auto.py:604: FutureWarning: The image_processor_class argument is deprecated and will be removed in v4.42. Please use `slow_image_processor_class`, or `fast_image_processor_class` instead\n", " warnings.warn(\n", "Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.\n" ] @@ -475,7 +475,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "492d8be8", "metadata": {}, "outputs": [ @@ -509,7 +509,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "257693f5", "metadata": {}, "outputs": [ @@ -517,7 +517,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "The image shows a relaxed, gray tabby cat lying on its back inside a plain cardboard box. The cat appears comfortable and content, with its eyes partially closed and paws slightly curled. It’s resting on a soft, light-colored carpet in a bright, cozy indoor setting. In the background, part of a light-colored sofa is visible, suggesting a living room environment. The overall mood is peaceful and charming — a classic “cat in a box” scene that many pet lovers find endearing.\n" + "The image shows a relaxed, gray tabby cat lying comfortably on its back inside a plain cardboard box. The cat appears to be sleeping or dozing peacefully, with its eyes partially closed and paws slightly curled. Its tail is stretched out behind it, adding to the cozy, playful posture.\n", + "\n", + "The scene takes place indoors on a soft, light-colored carpet. In the background, part of a light-colored sofa or piece of furniture is visible, along with bright, diffused natural light coming from a nearby window — suggesting a calm, sunny day.\n", + "\n", + "This is a classic “cat in a box” moment — cats often enjoy boxes for their enclosed, secure feeling, even if they’re just lounging in them.\n", + "\n", + "Overall, the image conveys warmth, comfort, and the charming, quirky behavior typical of domestic cats.\n" ] } ], @@ -545,7 +551,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "ea07fc20", "metadata": {}, "outputs": [ @@ -553,7 +559,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "Based on the bright, soft lighting and lack of heavy clothing or seasonal decorations, this picture might have been taken during spring or summer. The indoor setting with light curtains suggests a pleasant day outside, but without direct evidence like windows showing foliage or weather conditions, it's hard to pinpoint exactly\n" + "Based on the bright, soft natural light coming through the window and the absence of any heavy clothing or seasonal decorations, this picture was likely taken during spring or summer.\n", + "\n", + "The indoor setting is clean and comfortable, with no signs of cold weather gear like blankets piled high or visible heaters. The cat’s relaxed posture also suggests a warm environment — cats often seek out cozy spots when it’s comfortably cool indoors, which fits well with warmer seasons.\n", + "\n", + "So while we can’t be certain without more context, the lighting and overall ambiance strongly suggest that this photo captures a peaceful moment in late spring or summer.\n" ] } ], @@ -600,7 +610,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "e6535fc1", "metadata": {}, "outputs": [ @@ -633,7 +643,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "ASR Result: for a while we built a hyperloop test track adjacent to spacex just for a student competition to encourage innovative ideas in transport it actually ends up being the biggest vacuum chamber in the world after the large hadron collider\n" + "ASR Result: for a while we built a hyperloop test track adjacent to spacex just for a student competition to encourage innovative ideas in transport it actually ends up being the biggest vacuum chamber in the world after the large hadron collider\n" ] } ], @@ -682,7 +692,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "9812e6b1", "metadata": {}, "outputs": [ @@ -690,7 +700,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Speaker Analysis: According to the content of this speaker, it seams that the gender of the speaker is male, and the speaker is in the state of sigh. What's more, I deduce that the speaker is a Youth. The health condition is: Healthy.\n", + "Speaker Analysis: According to the content of this speaker, it seams that the gender of the speaker is male, and the speaker is in the state of sigh. What's more, I gauss that the speaker is a Youth. The health condition is: Healthy.\n", "Sound Scene Tag: Speech\n" ] } @@ -762,7 +772,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "5441dcc7", "metadata": {}, "outputs": [ @@ -770,11 +780,13 @@ "name": "stdout", "output_type": "stream", "text": [ + " 📐 Auto-setting hift_input_len=64 for GPU (avoids dynamic shape recompilation)\n", "⌛ Loading OpenVINO Flow embeddings model from MiniCPM-o-4_5-OV\\openvino_flow_embeddings_model.xml...\n", "✅ Flow embeddings model loaded\n", "⌛ Loading OpenVINO Flow estimator model from MiniCPM-o-4_5-OV\\openvino_flow_estimator_model.xml...\n", "✅ Flow estimator model loaded\n", "⌛ Loading OpenVINO HiFT model from MiniCPM-o-4_5-OV\\openvino_hift_model.xml...\n", + " 📐 Reshaped HiFT to fixed: mel=[1,80,64], cache=[1,1,3840]\n", "✅ HiFT model loaded on GPU\n", "⌛ Loading s3tokenizer model...\n", "✅ s3tokenizer model loaded\n", @@ -841,7 +853,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "id": "14c6f089", "metadata": {}, "outputs": [ @@ -854,7 +866,7 @@ "⌛ Loading OpenVINO Flow estimator model from MiniCPM-o-4_5-OV\\openvino_flow_estimator_model.xml...\n", "✅ Flow estimator model loaded\n", "⌛ Loading OpenVINO HiFT model from MiniCPM-o-4_5-OV\\openvino_hift_model.xml...\n", - " 📐 Reshaped HiFT to fixed input: [1, 80, 100]\n", + " 📐 Reshaped HiFT to fixed: mel=[1,80,100], cache=[1,1,3840]\n", "✅ HiFT model loaded on GPU\n", "⌛ Loading s3tokenizer model...\n", "✅ s3tokenizer model loaded\n", @@ -866,7 +878,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "D:\\openvino_notebooks\\minicpm_test\\Lib\\site-packages\\torchaudio\\_backend\\utils.py:213: UserWarning: In 2.9, this function's implementation will be changed to use torchaudio.load_with_torchcodec` under the hood. Some parameters like ``normalize``, ``format``, ``buffer_size``, and ``backend`` will be ignored. We recommend that you port your code to rely directly on TorchCodec's decoder instead: https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder.\n", + "D:\\openvino_notebooks\\py_env\\Lib\\site-packages\\torchaudio\\_backend\\utils.py:213: UserWarning: In 2.9, this function's implementation will be changed to use torchaudio.load_with_torchcodec` under the hood. Some parameters like ``normalize``, ``format``, ``buffer_size``, and ``backend`` will be ignored. We recommend that you port your code to rely directly on TorchCodec's decoder instead: https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder.\n", " warnings.warn(\n" ] }, @@ -884,7 +896,7 @@ "text/html": [ "\n", " \n", " " @@ -991,7 +1003,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "id": "e068b4bb", "metadata": {}, "outputs": [ @@ -999,7 +1011,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "D:\\openvino_notebooks\\minicpm_test\\Lib\\site-packages\\torchaudio\\_backend\\utils.py:213: UserWarning: In 2.9, this function's implementation will be changed to use torchaudio.load_with_torchcodec` under the hood. Some parameters like ``normalize``, ``format``, ``buffer_size``, and ``backend`` will be ignored. We recommend that you port your code to rely directly on TorchCodec's decoder instead: https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder.\n", + "D:\\openvino_notebooks\\py_env\\Lib\\site-packages\\torchaudio\\_backend\\utils.py:213: UserWarning: In 2.9, this function's implementation will be changed to use torchaudio.load_with_torchcodec` under the hood. Some parameters like ``normalize``, ``format``, ``buffer_size``, and ``backend`` will be ignored. We recommend that you port your code to rely directly on TorchCodec's decoder instead: https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder.\n", " warnings.warn(\n" ] }, @@ -1008,7 +1020,7 @@ "output_type": "stream", "text": [ "fixed s3 encode\n", - "Text: Okay, so for a while we built a Hyperloop test track adjacent to SpaceX just for a student competition to encourage innovative ideas in transport. It actually ends up being the biggest vacuum chamber in the world after the Large Hadron Collider.<|tts_eos|>\n", + "Text: For a while we built a hyperloop test track adjacent to SpaceX just for a student competition to encourage innovative ideas in transport. It actually ends up being the biggest vacuum chamber in the world after the Large Hadron Collider.<|tts_eos|>\n", "Audio saved to output_realtime.wav\n" ] }, @@ -1017,7 +1029,7 @@ "text/html": [ "\n", " \n", " " @@ -1152,10 +1164,97 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "id": "de3a13c8", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "✅ en_video downloaded: assets/omni_duplex1.mp4\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "D:\\openvino_notebooks\\py_env\\Lib\\site-packages\\torchaudio\\_backend\\utils.py:213: UserWarning: In 2.9, this function's implementation will be changed to use torchaudio.load_with_torchcodec` under the hood. Some parameters like ``normalize``, ``format``, ``buffer_size``, and ``backend`` will be ignored. We recommend that you port your code to rely directly on TorchCodec's decoder instead: https://docs.pytorch.org/torchcodec/stable/generated/torchcodec.decoders.AudioDecoder.html#torchcodec.decoders.AudioDecoder.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "fixed s3 encode\n", + "✅ Duplex mode prepared: sliding_window=off, ref_audio=yes\n", + "listen...\n", + "listen...\n", + "listen...\n", + "listen...\n", + "listen...\n", + "speak> Sure, it's\n", + "speak> currently on the\n", + "speak> 20th floor.\n", + "speak> \n", + "listen...\n", + "listen...\n", + "speak> 23rd\n", + "speak> floor now.\n", + "speak> \n", + "listen...\n", + "listen...\n", + "listen...\n", + "listen...\n", + "listen...\n", + "listen...\n", + "listen...\n", + "listen...\n", + "listen...\n", + "listen...\n", + "listen...\n", + "listen...\n", + "listen...\n", + "listen...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:audio cache length 1507 exceed 1500, reset.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "listen...\n", + "listen...\n", + "speak> Yes you did just\n", + "speak> now.\n", + "speak> \n", + "listen...\n", + "listen...\n" + ] + }, + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "import librosa\n", "import torch\n", diff --git a/notebooks/minicpm-o-4.5/minicpm_o_4_5_helper.py b/notebooks/minicpm-o-4.5/minicpm_o_4_5_helper.py index d80f9a8ac50..b15cc1fa55e 100644 --- a/notebooks/minicpm-o-4.5/minicpm_o_4_5_helper.py +++ b/notebooks/minicpm-o-4.5/minicpm_o_4_5_helper.py @@ -1956,10 +1956,21 @@ def convert_hift(token2wav_model, output_path: Path): hift = token2wav_model.hift + # source_cache_len: 8 mel frames * 480 samples/frame = 3840 samples + SOURCE_CACHE_LEN = 3840 + class HiFTWrapper(nn.Module): """ - HiFT forward wrapper that outputs raw spectral features before istft. - This allows istft to be performed in Python with better compatibility. + HiFT forward wrapper that outputs raw spectral features before istft, + plus the excitation source signal for streaming cache continuity. + + Inputs: + x: mel spectrogram [B, 80, mel_len] + cache_source: cached excitation source [B, 1, SOURCE_CACHE_LEN] from previous chunk + (pass zeros for first chunk) + Outputs: + spectral: raw spectral features [B, n_fft+2, time] before istft + source: excitation source signal [B, 1, T_source] for next chunk's cache """ def __init__(self, hift_model): @@ -1992,19 +2003,27 @@ def _stft(self, x): spec = torch.view_as_real(spec) # [B, F, TT, 2] return spec[..., 0], spec[..., 1] - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, cache_source: torch.Tensor) -> tuple: """ - HiFT forward that outputs raw conv_post output before istft. + HiFT forward with source cache for streaming continuity. - Input: mel spectrogram (batch, 80, mel_len) - Output: raw spectral features (batch, n_fft+2, time) before istft + The cache_source from the previous chunk is spliced into the head of + the newly generated excitation source, maintaining phase continuity + across streaming chunks (matching the original PyTorch HiFTGenerator). """ # mel -> f0 f0 = self.f0_predictor(x) - # f0 -> source - s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t + # f0 -> source excitation signal + s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # [B, T, 1] s, _, _ = self.m_source(s) - s = s.transpose(1, 2) + s = s.transpose(1, 2) # [B, 1, T_source] + + # Splice cached source into head for phase continuity + # cache_source shape: [B, 1, SOURCE_CACHE_LEN] + # For first chunk, cache_source is zeros — the first 3840 audio samples + # are replaced with silence in stream() anyway, so this is harmless. + cache_len = cache_source.shape[2] + s[:, :, :cache_len] = cache_source # stft of source s_stft_real, s_stft_imag = self._stft(s.squeeze(1)) @@ -2034,21 +2053,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.leaky_relu(x) x = self.conv_post(x) - # Return here without exp, sin, istft, clamp - return x + # Return spectral features + source signal for next chunk's cache + return x, s hift_wrapper = HiFTWrapper(hift) hift_wrapper.eval() __make_16bit_traceable(hift_wrapper) - # Example input: mel spectrogram [batch, 80, mel_len] + # Example inputs: mel spectrogram + source cache mel_bins = 80 - example_input = torch.randn([1, mel_bins, 200], dtype=torch.float32) + example_mel = torch.randn([1, mel_bins, 200], dtype=torch.float32) + example_cache = torch.zeros([1, 1, SOURCE_CACHE_LEN], dtype=torch.float32) + example_input = (example_mel, example_cache) - # Input shapes: fix mel_bins=80, allow dynamic batch and mel_len + # Input shapes: mel has dynamic length, cache_source is fixed size input_shapes = [ - ov.PartialShape([-1, mel_bins, -1]), # [batch, mel_bins=80, mel_len] + ov.PartialShape([-1, mel_bins, -1]), # [batch, 80, mel_len] + ov.PartialShape([-1, 1, SOURCE_CACHE_LEN]), # [batch, 1, 3840] ] with torch.no_grad(): @@ -2058,7 +2080,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: del ov_model cleanup_torchscript_cache() gc.collect() - print(f"✅ HiFT model saved to {output_path}") + print(f"✅ HiFT model saved to {output_path} (with source cache support)") def convert_token2wav( @@ -2178,12 +2200,22 @@ def __init__(self, flow, hift): else: print(f"✅ Flow Estimator model already exists at {flow_est_path}") - # Convert HiFT - if not hift_path.exists(): + # Convert HiFT (with cache_source support for streaming audio continuity) + _need_hift_export = True + if hift_path.exists(): + try: + _hift_model = core.read_model(str(hift_path)) + if any("cache_source" in inp.any_name for inp in _hift_model.inputs): + _need_hift_export = False + print(f"✅ HiFT model already exists at {hift_path} (with cache_source)") + else: + print(f"⚠️ Legacy HiFT model found (no cache_source), re-exporting...") + del _hift_model + except Exception: + pass + if _need_hift_export: print("⌛ Converting HiFT model...") convert_hift(token2wav, hift_path) - else: - print(f"✅ HiFT model already exists at {hift_path}") # Copy ONNX models (s3tokenizer and campplus) # These models are small and kept as ONNX for compatibility with s3tokenizer library @@ -3376,6 +3408,10 @@ def init_tts(self, streaming=False, model_dir=None, enable_float16=False, n_time model_dir: Path to token2wav model directory enable_float16: Whether to use float16 n_timesteps: Number of diffusion steps + hift_input_len: Fixed mel length for HiFT (0 = auto). + When 0 and tts_device is GPU, auto-set to 64 + (25 tokens → 56 mel frames + 8 cache = 64) to avoid + dynamic shape recompilation stalls. Returns: The audio tokenizer (OVToken2wav) @@ -3395,12 +3431,18 @@ def init_tts(self, streaming=False, model_dir=None, enable_float16=False, n_time print(f"⚠️ token2wav model directory not found: {model_dir}") return None + # Auto-enable fixed HiFT shape on GPU to prevent dynamic shape recompilation + # Streaming mel: 25 audio tokens → 56 mel frames + 8 mel cache = 64 max + if hift_input_len == 0 and self.tts_device.upper() not in ("CPU",): + hift_input_len = 64 + print(f" 📐 Auto-setting hift_input_len={hift_input_len} for {self.tts_device} (avoids dynamic shape recompilation)") + self.token2wav = OVToken2wav( model_dir=model_dir, device=self.tts_device, float16=enable_float16, n_timesteps=n_timesteps, - hift_input_len=hift_input_len, # dynamic shape for streaming mel+cache (50+8=58 frames) + hift_input_len=hift_input_len, # flow_emb_token_len=50, # flow_emb_prompt_len=200, ) @@ -7693,12 +7735,32 @@ def __init__(self, model_path: str, device: str = "CPU", hift_input_len: int = 0 self.ov_device = device self.hift_input_len = hift_input_len + # Source cache length (must match export-time SOURCE_CACHE_LEN) + self.source_cache_len_fixed = 3840 + # Load OpenVINO model with optional reshape print(f"⌛ Loading OpenVINO HiFT model from {model_path}...") model = core.read_model(str(model_path)) + + # Detect whether model has cache_source input (new) or not (legacy) + self._has_cache_source = len(model.inputs) >= 2 and any("cache_source" in inp.any_name for inp in model.inputs) + if self.hift_input_len > 0: - model.reshape([1, 80, self.hift_input_len]) - print(f" 📐 Reshaped HiFT to fixed input: [1, 80, {self.hift_input_len}]") + if self._has_cache_source: + # Reshape both inputs to fully static shapes for GPU efficiency + model.reshape( + { + model.inputs[0].any_name: [1, 80, self.hift_input_len], + model.inputs[1].any_name: [1, 1, self.source_cache_len_fixed], + } + ) + print(f" 📐 Reshaped HiFT to fixed: mel=[1,80,{self.hift_input_len}], cache=[1,1,{self.source_cache_len_fixed}]") + else: + model.reshape([1, 80, self.hift_input_len]) + print(f" 📐 Reshaped HiFT to fixed input: [1, 80, {self.hift_input_len}]") + + if not self._has_cache_source: + print(" ⚠️ Legacy HiFT model (no cache_source). Re-export with reexport_hift.py for streaming audio continuity.") self.hift = core.compile_model(model, device) print(f"✅ HiFT model loaded on {device}") @@ -7729,22 +7791,45 @@ def _istft(self, magnitude: torch.Tensor, phase: torch.Tensor) -> torch.Tensor: ) return inverse_transform - def inference(self, speech_feat: torch.Tensor) -> torch.Tensor: + def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = None): """ Run HiFT inference using OpenVINO. Args: speech_feat: Mel spectrogram (batch, 80, mel_len) + cache_source: Cached excitation source from previous chunk [B, 1, 3840]. + Pass None or zeros for first chunk. Returns: speech: Generated waveform (batch, samples) + source_out: Excitation source signal [B, 1, T] for next chunk's cache """ - # Convert to numpy for potential padding + # Convert mel to numpy for potential padding if isinstance(speech_feat, torch.Tensor): mel_input = speech_feat.cpu().numpy() else: mel_input = speech_feat + # Prepare cache_source input (fixed size: SOURCE_CACHE_LEN=3840) + source_cache_len = 3840 + if cache_source is None: + cache_source_np = np.zeros((1, 1, source_cache_len), dtype=np.float32) + elif isinstance(cache_source, torch.Tensor): + cache_source_np = cache_source.cpu().numpy().astype(np.float32) + # Pad or trim to fixed size + cur_len = cache_source_np.shape[2] + if cur_len < source_cache_len: + cache_source_np = np.pad( + cache_source_np, + ((0, 0), (0, 0), (0, source_cache_len - cur_len)), + mode="constant", + constant_values=0, + ) + elif cur_len > source_cache_len: + cache_source_np = cache_source_np[:, :, -source_cache_len:] + else: + cache_source_np = cache_source + # Fixed shape mode: pad input to target length, trim output after original_len = mel_input.shape[2] if self.hift_input_len > 0: @@ -7758,11 +7843,19 @@ def inference(self, speech_feat: torch.Tensor) -> torch.Tensor: mel_input = mel_input[:, :, :target_len] original_len = target_len - # Run OpenVINO inference - output is (batch, n_fft+2, time) + # Run OpenVINO inference start_time = time.time() - result = self.hift(mel_input) + if self._has_cache_source: + # New model: 2 inputs (mel + cache_source), 2 outputs (spectral + source) + result = self.hift([mel_input, cache_source_np]) + x = torch.from_numpy(result[0].copy()) + source_out = torch.from_numpy(result[1].copy()) + else: + # Legacy model: 1 input (mel), 1 output (spectral), no source cache + result = self.hift(mel_input) + x = torch.from_numpy(result[0].copy()) + source_out = torch.zeros(1, 1, 0) elapsed = time.time() - start_time - x = torch.from_numpy(result[0].copy()) # Post-processing: exp, sin, istft, clamp n_fft = self.istft_params["n_fft"] @@ -7776,8 +7869,9 @@ def inference(self, speech_feat: torch.Tensor) -> torch.Tensor: if self.hift_input_len > 0 and original_len < self.hift_input_len: original_samples = original_len * self.mel_to_samples_ratio speech = speech[:, :original_samples] + source_out = source_out[:, :, :original_samples] - return speech + return speech, source_out class OVToken2wav: @@ -7937,8 +8031,8 @@ def __call__(self, generated_speech_tokens, prompt_wav): self.n_timesteps, ) - # Run HiFT vocoder - wav = self.hift.inference(speech_feat=mel) + # Run HiFT vocoder (non-streaming: no source cache needed) + wav, _ = self.hift.inference(speech_feat=mel) output = io.BytesIO() torchaudio.save(output, wav.cpu(), sample_rate=24000, format="wav") @@ -8013,12 +8107,15 @@ def stream(self, tokens, prompt_wav, last_chunk=False, return_waveform=False): else: mel_combined = mel - # Run HiFT inference - speech = self.hift.inference(speech_feat=mel_combined) + # Run HiFT inference with source cache for phase continuity + hift_cache_source = self.hift_cache_dict.get("source", None) + speech, source_out = self.hift.inference(speech_feat=mel_combined, cache_source=hift_cache_source) # Speech overlap smoothing with cached speech + # Aligned with original fade_in_out(speech_new, speech_old, window): + # new_head * window[:half] + old_tail * window[half:] + # where hamming window[:half] = rising (0→1), window[half:] = falling (1→0) if not is_first_chunk and hift_cache_speech.shape[-1] > 0: - # Fade overlap between cached tail and new speech head overlap_len = min(self.source_cache_len, speech.shape[-1], hift_cache_speech.shape[-1]) if overlap_len > 0: fade_window = ( @@ -8026,19 +8123,21 @@ def stream(self, tokens, prompt_wav, last_chunk=False, return_waveform=False): if 2 * overlap_len <= self.speech_window.shape[0] else torch.from_numpy(np.hamming(2 * overlap_len).astype(np.float32)) ) - fade_in = fade_window[overlap_len:] # second half: 0→1 - fade_out = fade_window[:overlap_len] # first half: 1→0 + # window[:half] = rising (0→1) → applied to new speech head (fade IN) + # window[half:] = falling (1→0) → applied to old speech tail (fade OUT) + w_new = fade_window[:overlap_len] # rising: 0→1 + w_old = fade_window[overlap_len:] # falling: 1→0 overlap_new = speech[:, :overlap_len] overlap_old = hift_cache_speech[:, -overlap_len:] - blended = overlap_old * fade_out.unsqueeze(0) + overlap_new * fade_in.unsqueeze(0) + blended = overlap_new * w_new.unsqueeze(0) + overlap_old * w_old.unsqueeze(0) speech = torch.cat([blended, speech[:, overlap_len:]], dim=1) - # Update HiFT cache + # Update HiFT cache — source_out now properly carries the excitation signal self.hift_cache_dict = { "mel": mel_combined[:, :, -self.mel_cache_len :].clone() if mel_combined.shape[2] >= self.mel_cache_len else mel_combined.clone(), - "source": torch.zeros(1, 1, 0), + "source": source_out[:, :, -self.source_cache_len :].clone() if source_out.shape[2] >= self.source_cache_len else source_out.clone(), "speech": speech[:, -self.source_cache_len :].clone() if speech.shape[-1] >= self.source_cache_len else speech.clone(), }