Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion notebooks/minicpm-o-4.5/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ For details, please refer to [Installation Guide](../../README.md).

⚠️ **EXPERIMENTAL NOTEBOOK**

This notebook demonstrates a model that has not been fully validated with OpenVINO and is using a custom branch of optimum-intel. It may be fully supported and validated in the future.
This notebook demonstrates a model that has not been fully validated with OpenVINO. It may be fully supported and validated in the future.

<img referrerpolicy="no-referrer-when-downgrade" src="https://static.scarf.sh/a.png?x-pxid=5b5a4db0-7875-4bfb-bdbd-01698b5b1a77&file=notebooks/minicpm-o-4.5/README.md" />
169 changes: 134 additions & 35 deletions notebooks/minicpm-o-4.5/minicpm-o-4.5.ipynb

Large diffs are not rendered by default.

177 changes: 138 additions & 39 deletions notebooks/minicpm-o-4.5/minicpm_o_4_5_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,10 +1956,21 @@ def convert_hift(token2wav_model, output_path: Path):

hift = token2wav_model.hift

# source_cache_len: 8 mel frames * 480 samples/frame = 3840 samples
SOURCE_CACHE_LEN = 3840

class HiFTWrapper(nn.Module):
"""
HiFT forward wrapper that outputs raw spectral features before istft.
This allows istft to be performed in Python with better compatibility.
HiFT forward wrapper that outputs raw spectral features before istft,
plus the excitation source signal for streaming cache continuity.

Inputs:
x: mel spectrogram [B, 80, mel_len]
cache_source: cached excitation source [B, 1, SOURCE_CACHE_LEN] from previous chunk
(pass zeros for first chunk)
Outputs:
spectral: raw spectral features [B, n_fft+2, time] before istft
source: excitation source signal [B, 1, T_source] for next chunk's cache
"""

def __init__(self, hift_model):
Expand Down Expand Up @@ -1992,19 +2003,27 @@ def _stft(self, x):
spec = torch.view_as_real(spec) # [B, F, TT, 2]
return spec[..., 0], spec[..., 1]

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, cache_source: torch.Tensor) -> tuple:
"""
HiFT forward that outputs raw conv_post output before istft.
HiFT forward with source cache for streaming continuity.

Input: mel spectrogram (batch, 80, mel_len)
Output: raw spectral features (batch, n_fft+2, time) before istft
The cache_source from the previous chunk is spliced into the head of
the newly generated excitation source, maintaining phase continuity
across streaming chunks (matching the original PyTorch HiFTGenerator).
"""
# mel -> f0
f0 = self.f0_predictor(x)
# f0 -> source
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
# f0 -> source excitation signal
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # [B, T, 1]
s, _, _ = self.m_source(s)
s = s.transpose(1, 2)
s = s.transpose(1, 2) # [B, 1, T_source]

# Splice cached source into head for phase continuity
# cache_source shape: [B, 1, SOURCE_CACHE_LEN]
# For first chunk, cache_source is zeros — the first 3840 audio samples
# are replaced with silence in stream() anyway, so this is harmless.
cache_len = cache_source.shape[2]
s[:, :, :cache_len] = cache_source

# stft of source
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
Expand Down Expand Up @@ -2034,21 +2053,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:

x = F.leaky_relu(x)
x = self.conv_post(x)
# Return here without exp, sin, istft, clamp
return x
# Return spectral features + source signal for next chunk's cache
return x, s

hift_wrapper = HiFTWrapper(hift)
hift_wrapper.eval()

__make_16bit_traceable(hift_wrapper)

# Example input: mel spectrogram [batch, 80, mel_len]
# Example inputs: mel spectrogram + source cache
mel_bins = 80
example_input = torch.randn([1, mel_bins, 200], dtype=torch.float32)
example_mel = torch.randn([1, mel_bins, 200], dtype=torch.float32)
example_cache = torch.zeros([1, 1, SOURCE_CACHE_LEN], dtype=torch.float32)
example_input = (example_mel, example_cache)

# Input shapes: fix mel_bins=80, allow dynamic batch and mel_len
# Input shapes: mel has dynamic length, cache_source is fixed size
input_shapes = [
ov.PartialShape([-1, mel_bins, -1]), # [batch, mel_bins=80, mel_len]
ov.PartialShape([-1, mel_bins, -1]), # [batch, 80, mel_len]
ov.PartialShape([-1, 1, SOURCE_CACHE_LEN]), # [batch, 1, 3840]
]

with torch.no_grad():
Expand All @@ -2058,7 +2080,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
del ov_model
cleanup_torchscript_cache()
gc.collect()
print(f"✅ HiFT model saved to {output_path}")
print(f"✅ HiFT model saved to {output_path} (with source cache support)")


def convert_token2wav(
Expand Down Expand Up @@ -2178,12 +2200,22 @@ def __init__(self, flow, hift):
else:
print(f"✅ Flow Estimator model already exists at {flow_est_path}")

# Convert HiFT
if not hift_path.exists():
# Convert HiFT (with cache_source support for streaming audio continuity)
_need_hift_export = True
if hift_path.exists():
try:
_hift_model = core.read_model(str(hift_path))
if any("cache_source" in inp.any_name for inp in _hift_model.inputs):
_need_hift_export = False
print(f"✅ HiFT model already exists at {hift_path} (with cache_source)")
else:
print(f"⚠️ Legacy HiFT model found (no cache_source), re-exporting...")
del _hift_model
except Exception:
pass
if _need_hift_export:
print("⌛ Converting HiFT model...")
convert_hift(token2wav, hift_path)
else:
print(f"✅ HiFT model already exists at {hift_path}")

# Copy ONNX models (s3tokenizer and campplus)
# These models are small and kept as ONNX for compatibility with s3tokenizer library
Expand Down Expand Up @@ -3376,6 +3408,10 @@ def init_tts(self, streaming=False, model_dir=None, enable_float16=False, n_time
model_dir: Path to token2wav model directory
enable_float16: Whether to use float16
n_timesteps: Number of diffusion steps
hift_input_len: Fixed mel length for HiFT (0 = auto).
When 0 and tts_device is GPU, auto-set to 64
(25 tokens → 56 mel frames + 8 cache = 64) to avoid
dynamic shape recompilation stalls.

Returns:
The audio tokenizer (OVToken2wav)
Expand All @@ -3395,12 +3431,18 @@ def init_tts(self, streaming=False, model_dir=None, enable_float16=False, n_time
print(f"⚠️ token2wav model directory not found: {model_dir}")
return None

# Auto-enable fixed HiFT shape on GPU to prevent dynamic shape recompilation
# Streaming mel: 25 audio tokens → 56 mel frames + 8 mel cache = 64 max
if hift_input_len == 0 and self.tts_device.upper() not in ("CPU",):
hift_input_len = 64
print(f" 📐 Auto-setting hift_input_len={hift_input_len} for {self.tts_device} (avoids dynamic shape recompilation)")

self.token2wav = OVToken2wav(
model_dir=model_dir,
device=self.tts_device,
float16=enable_float16,
n_timesteps=n_timesteps,
hift_input_len=hift_input_len, # dynamic shape for streaming mel+cache (50+8=58 frames)
hift_input_len=hift_input_len,
# flow_emb_token_len=50,
# flow_emb_prompt_len=200,
)
Expand Down Expand Up @@ -7693,12 +7735,32 @@ def __init__(self, model_path: str, device: str = "CPU", hift_input_len: int = 0
self.ov_device = device
self.hift_input_len = hift_input_len

# Source cache length (must match export-time SOURCE_CACHE_LEN)
self.source_cache_len_fixed = 3840

# Load OpenVINO model with optional reshape
print(f"⌛ Loading OpenVINO HiFT model from {model_path}...")
model = core.read_model(str(model_path))

# Detect whether model has cache_source input (new) or not (legacy)
self._has_cache_source = len(model.inputs) >= 2 and any("cache_source" in inp.any_name for inp in model.inputs)

if self.hift_input_len > 0:
model.reshape([1, 80, self.hift_input_len])
print(f" 📐 Reshaped HiFT to fixed input: [1, 80, {self.hift_input_len}]")
if self._has_cache_source:
# Reshape both inputs to fully static shapes for GPU efficiency
model.reshape(
{
model.inputs[0].any_name: [1, 80, self.hift_input_len],
model.inputs[1].any_name: [1, 1, self.source_cache_len_fixed],
}
)
print(f" 📐 Reshaped HiFT to fixed: mel=[1,80,{self.hift_input_len}], cache=[1,1,{self.source_cache_len_fixed}]")
else:
model.reshape([1, 80, self.hift_input_len])
print(f" 📐 Reshaped HiFT to fixed input: [1, 80, {self.hift_input_len}]")

if not self._has_cache_source:
print(" ⚠️ Legacy HiFT model (no cache_source). Re-export with reexport_hift.py for streaming audio continuity.")
self.hift = core.compile_model(model, device)
print(f"✅ HiFT model loaded on {device}")

Expand Down Expand Up @@ -7729,22 +7791,45 @@ def _istft(self, magnitude: torch.Tensor, phase: torch.Tensor) -> torch.Tensor:
)
return inverse_transform

def inference(self, speech_feat: torch.Tensor) -> torch.Tensor:
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = None):
"""
Run HiFT inference using OpenVINO.

Args:
speech_feat: Mel spectrogram (batch, 80, mel_len)
cache_source: Cached excitation source from previous chunk [B, 1, 3840].
Pass None or zeros for first chunk.

Returns:
speech: Generated waveform (batch, samples)
source_out: Excitation source signal [B, 1, T] for next chunk's cache
"""
# Convert to numpy for potential padding
# Convert mel to numpy for potential padding
if isinstance(speech_feat, torch.Tensor):
mel_input = speech_feat.cpu().numpy()
else:
mel_input = speech_feat

# Prepare cache_source input (fixed size: SOURCE_CACHE_LEN=3840)
source_cache_len = 3840
if cache_source is None:
cache_source_np = np.zeros((1, 1, source_cache_len), dtype=np.float32)
elif isinstance(cache_source, torch.Tensor):
cache_source_np = cache_source.cpu().numpy().astype(np.float32)
# Pad or trim to fixed size
cur_len = cache_source_np.shape[2]
if cur_len < source_cache_len:
cache_source_np = np.pad(
cache_source_np,
((0, 0), (0, 0), (0, source_cache_len - cur_len)),
mode="constant",
constant_values=0,
)
elif cur_len > source_cache_len:
cache_source_np = cache_source_np[:, :, -source_cache_len:]
else:
cache_source_np = cache_source

# Fixed shape mode: pad input to target length, trim output after
original_len = mel_input.shape[2]
if self.hift_input_len > 0:
Expand All @@ -7758,11 +7843,19 @@ def inference(self, speech_feat: torch.Tensor) -> torch.Tensor:
mel_input = mel_input[:, :, :target_len]
original_len = target_len

# Run OpenVINO inference - output is (batch, n_fft+2, time)
# Run OpenVINO inference
start_time = time.time()
result = self.hift(mel_input)
if self._has_cache_source:
# New model: 2 inputs (mel + cache_source), 2 outputs (spectral + source)
result = self.hift([mel_input, cache_source_np])
x = torch.from_numpy(result[0].copy())
source_out = torch.from_numpy(result[1].copy())
else:
# Legacy model: 1 input (mel), 1 output (spectral), no source cache
result = self.hift(mel_input)
x = torch.from_numpy(result[0].copy())
source_out = torch.zeros(1, 1, 0)
elapsed = time.time() - start_time
x = torch.from_numpy(result[0].copy())

# Post-processing: exp, sin, istft, clamp
n_fft = self.istft_params["n_fft"]
Expand All @@ -7776,8 +7869,9 @@ def inference(self, speech_feat: torch.Tensor) -> torch.Tensor:
if self.hift_input_len > 0 and original_len < self.hift_input_len:
original_samples = original_len * self.mel_to_samples_ratio
speech = speech[:, :original_samples]
source_out = source_out[:, :, :original_samples]

return speech
return speech, source_out


class OVToken2wav:
Expand Down Expand Up @@ -7937,8 +8031,8 @@ def __call__(self, generated_speech_tokens, prompt_wav):
self.n_timesteps,
)

# Run HiFT vocoder
wav = self.hift.inference(speech_feat=mel)
# Run HiFT vocoder (non-streaming: no source cache needed)
wav, _ = self.hift.inference(speech_feat=mel)

output = io.BytesIO()
torchaudio.save(output, wav.cpu(), sample_rate=24000, format="wav")
Expand Down Expand Up @@ -8013,32 +8107,37 @@ def stream(self, tokens, prompt_wav, last_chunk=False, return_waveform=False):
else:
mel_combined = mel

# Run HiFT inference
speech = self.hift.inference(speech_feat=mel_combined)
# Run HiFT inference with source cache for phase continuity
hift_cache_source = self.hift_cache_dict.get("source", None)
speech, source_out = self.hift.inference(speech_feat=mel_combined, cache_source=hift_cache_source)

# Speech overlap smoothing with cached speech
# Aligned with original fade_in_out(speech_new, speech_old, window):
# new_head * window[:half] + old_tail * window[half:]
# where hamming window[:half] = rising (0→1), window[half:] = falling (1→0)
if not is_first_chunk and hift_cache_speech.shape[-1] > 0:
# Fade overlap between cached tail and new speech head
overlap_len = min(self.source_cache_len, speech.shape[-1], hift_cache_speech.shape[-1])
if overlap_len > 0:
fade_window = (
self.speech_window[: 2 * overlap_len]
if 2 * overlap_len <= self.speech_window.shape[0]
else torch.from_numpy(np.hamming(2 * overlap_len).astype(np.float32))
)
fade_in = fade_window[overlap_len:] # second half: 0→1
fade_out = fade_window[:overlap_len] # first half: 1→0
# window[:half] = rising (0→1) → applied to new speech head (fade IN)
# window[half:] = falling (1→0) → applied to old speech tail (fade OUT)
w_new = fade_window[:overlap_len] # rising: 0→1
w_old = fade_window[overlap_len:] # falling: 1→0

overlap_new = speech[:, :overlap_len]
overlap_old = hift_cache_speech[:, -overlap_len:]

blended = overlap_old * fade_out.unsqueeze(0) + overlap_new * fade_in.unsqueeze(0)
blended = overlap_new * w_new.unsqueeze(0) + overlap_old * w_old.unsqueeze(0)
speech = torch.cat([blended, speech[:, overlap_len:]], dim=1)

# Update HiFT cache
# Update HiFT cache — source_out now properly carries the excitation signal
self.hift_cache_dict = {
"mel": mel_combined[:, :, -self.mel_cache_len :].clone() if mel_combined.shape[2] >= self.mel_cache_len else mel_combined.clone(),
"source": torch.zeros(1, 1, 0),
"source": source_out[:, :, -self.source_cache_len :].clone() if source_out.shape[2] >= self.source_cache_len else source_out.clone(),
"speech": speech[:, -self.source_cache_len :].clone() if speech.shape[-1] >= self.source_cache_len else speech.clone(),
}

Expand Down
Loading