Skip to content

Commit b342f70

Browse files
committed
refactor(mtmd): introduce omni-modal media pipeline with experimental audio support
This commit significantly overhauls the media parsing and loading pipeline in `MTMDChatHandler` to gracefully handle both vision and audio inputs, establishing a true omni-modal architecture. Key structural changes: - Hardware Capability Sniffing: `_init_mtmd_context` now actively probes the C++ backend for `ctx_v` (vision) and `ctx_a` (audio) encoders, enabling proactive fail-fast validation before media processing. - Unified Media Extraction: Replaced `get_image_urls` and `split_text_on_image_urls` with a robust `_get_media_items` method. This safely parses `image_url`, `input_audio`, and `audio_url` while strictly maintaining the chronological order of user prompts and enforcing OpenAI format specs. - Media Dispatcher & Magic Bytes: Introduced a unified `load_media` dispatcher. Added a new `_load_audio` method and a rigorous `detect_audio_format` static method that accurately mimics `llama.cpp`'s C++ magic bytes sniffing (RIFF/WAVE, ID3/MPEG, fLaC) to prevent fatal backend crashes. - Concurrent Omni-Decoding: The ThreadPoolExecutor in `_process_mtmd_prompt` has been upgraded to concurrently fetch and decode both image and audio payloads into unified `mtmd_bitmap` structures. Note: Audio processing capabilities in the underlying llama.cpp engine are currently in an experimental stage. Signed-off-by: JamePeng <jame_peng@sina.com>
1 parent f5f76a1 commit b342f70

1 file changed

Lines changed: 183 additions & 68 deletions

File tree

llama_cpp/llama_chat_format.py

Lines changed: 183 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from __future__ import annotations
22

3-
import os
4-
import sys
5-
import json
3+
import base64
64
import ctypes
75
import dataclasses
86
import datetime
7+
import json
8+
import os
99
import random
1010
import string
11+
import sys
1112

1213
from contextlib import ExitStack
1314
from typing import (
@@ -29,6 +30,9 @@
2930
import numpy as np
3031
import numpy.typing as npt
3132

33+
import urllib.request
34+
from urllib.error import URLError, HTTPError
35+
3236
import llama_cpp.llama_cpp as llama_cpp
3337
import llama_cpp.llama as llama
3438
import llama_cpp.llama_types as llama_types
@@ -2900,16 +2904,22 @@ def _init_mtmd_context(self, llama_model: llama.Llama):
29002904
raise ValueError(f"{self.log_prefix}(_init_mtmd_context): Failed to load mtmd context from: {self.clip_model_path}")
29012905

29022906
# Check if vision is supported
2903-
if self._mtmd_cpp.mtmd_support_vision(self.mtmd_ctx):
2907+
self.is_support_vision = self._mtmd_cpp.mtmd_support_vision(self.mtmd_ctx)
2908+
if self.is_support_vision:
29042909
if self.verbose:
29052910
print(f"{self.log_prefix}(_init_mtmd_context): Vision support detected.", file=sys.stderr)
29062911
else:
2907-
raise ValueError(f"{self.log_prefix}(_init_mtmd_context): Vision is not supported by this model")
2912+
if self.verbose:
2913+
print(f"{self.log_prefix}(_init_mtmd_context): Vision is NOT supported by this mmproj model backend.", file=sys.stderr)
29082914

29092915
# Check if audio is supported
2910-
if self._mtmd_cpp.mtmd_support_audio(self.mtmd_ctx):
2916+
self.is_support_audio = self._mtmd_cpp.mtmd_support_audio(self.mtmd_ctx)
2917+
if self.is_support_audio:
29112918
if self.verbose:
29122919
print(f"{self.log_prefix}(_init_mtmd_context): Audio support detected.", file=sys.stderr)
2920+
else:
2921+
if self.verbose:
2922+
print(f"{self.log_prefix}(_init_mtmd_context): Audio is NOT supported by this mmproj model backend.", file=sys.stderr)
29132923

29142924
def close(self) -> None:
29152925
"""Explicitly free the mtmd context and vision model resources."""
@@ -2930,6 +2940,72 @@ def close(self) -> None:
29302940
def __del__(self) -> None:
29312941
self.close()
29322942

2943+
def _get_media_items(self, messages: List[llama_types.ChatCompletionRequestMessage]) -> List[Dict[str, str]]:
2944+
"""
2945+
Extracts all media payloads (images, audio) sequentially to maintain exact chronological order.
2946+
Strictly enforces capability checks, raising exceptions if unsupported media is passed.
2947+
2948+
Returns:
2949+
media_items: A list of dictionaries containing the media 'url' and its 'type' (image or audio).
2950+
"""
2951+
media_items: List[Dict[str, str]] = []
2952+
for message in messages:
2953+
if isinstance(message.get("content"), list):
2954+
for content in message["content"]:
2955+
content_type = content.get("type", "")
2956+
2957+
# 1. Vision Processing
2958+
if content_type == "image_url":
2959+
if not self.is_support_vision:
2960+
raise ValueError(f"{self.log_prefix}: This mmproj model instance does not support image inputs.")
2961+
2962+
url = content["image_url"] if isinstance(content["image_url"], str) else content["image_url"]["url"]
2963+
media_items.append({"url": url, "type": "image"})
2964+
2965+
# 2. Audio Processing
2966+
elif content_type in ["audio_url", "input_audio"]:
2967+
if not self.is_support_audio:
2968+
raise ValueError(f"{self.log_prefix}: This mmproj model instance does not support audio inputs.")
2969+
2970+
# Case A: Handle custom/forward-compatible audio_url format
2971+
if content == "audio_url":
2972+
url = content["audio_url"] if isinstance(content["audio_url"], str) else content["audio_url"]["url"]
2973+
media_items.append({"url": url, "type": "audio"})
2974+
# Case B: Handle OpenAI standard input_audio format
2975+
else:
2976+
input_audio = content.get("input_audio", {})
2977+
if isinstance(input_audio, dict) and "data" in input_audio:
2978+
# It might just be raw base64 data, we can format it as a data URI to reuse load_audio logic
2979+
# input_audio: {
2980+
# data: audio.base64Data,
2981+
# format: audio.mimeType.includes('wav') ? 'wav' : 'mp3'
2982+
# }
2983+
audio_data = input_audio.get("data", "")
2984+
audio_format = input_audio.get("format", "")
2985+
2986+
# Strictly align with llama.cpp (require wav/mp3)
2987+
if audio_format not in ["wav", "mp3"]:
2988+
raise ValueError(f"{self.log_prefix}: input_audio.format must be either 'wav' or 'mp3'")
2989+
2990+
# Format as a Data URI to reuse the unified load_media logic
2991+
media_items.append({
2992+
"url": f"data:audio/{audio_format};base64,{audio_data}",
2993+
"type": "audio"
2994+
})
2995+
else:
2996+
# Just a raw base64 data
2997+
url = input_audio if isinstance(input_audio, str) else ""
2998+
if url:
2999+
media_items.append({"url": url, "type": "audio"})
3000+
3001+
# 3. Text & Unknown Types
3002+
elif content_type == "text":
3003+
continue
3004+
else:
3005+
if self.verbose:
3006+
print(f"{self.log_prefix}: Ignored unknown content type '{content_type}'.", file=sys.stderr)
3007+
return media_items
3008+
29333009
def _create_bitmap_from_bytes(self, media_bytes: bytes):
29343010
"""
29353011
Constructs an mtmd_bitmap structure from a raw byte buffer containing media data.
@@ -2992,7 +3068,7 @@ def _process_mtmd_prompt(
29923068
if system_prompt == "" and self.DEFAULT_SYSTEM_MESSAGE is not None:
29933069
messages = [{"role": "system", "content": self.DEFAULT_SYSTEM_MESSAGE}] + messages
29943070

2995-
image_urls = self.get_image_urls(messages)
3071+
media_items = self._get_media_items(messages)
29963072
media_marker = self.media_marker
29973073

29983074
# 2. Render the chat template and replace actual URLs with C++ media markers
@@ -3004,31 +3080,31 @@ def _process_mtmd_prompt(
30043080
**getattr(self, 'extra_template_arguments', {})
30053081
)
30063082
# Replace image_url by media_marker in text
3007-
for url in image_urls:
3008-
text = text.replace(url, media_marker)
3083+
for item in media_items:
3084+
text = text.replace(item["url"], media_marker)
30093085

30103086
if self.verbose:
3011-
print(f"{self.log_prefix}(_process_mtmd_prompt): Rendered prompt length: {len(text)} chars, Image count: {len(image_urls)}.", file=sys.stderr)
3087+
print(f"{self.log_prefix}(_process_mtmd_prompt): Rendered prompt length: {len(text)} chars, Media count: {len(media_items)}.", file=sys.stderr)
30123088
print(f"{self.log_prefix}(_process_mtmd_prompt): Rendered prompt: {text}", file=sys.stderr)
30133089

30143090
# 3. Pre-allocate bitmap array to guarantee chronological order during concurrent decoding
3015-
bitmaps = [None] * len(image_urls)
3091+
bitmaps = [None] * len(media_items)
30163092
bitmap_cleanup = []
30173093
chunks = None
30183094

30193095
try:
30203096
# Concurrent Media Decoding
30213097
import concurrent.futures
3022-
if image_urls:
3023-
def _create_bitmap_func(idx: int, url: str):
3024-
media_bytes = self.load_image(url)
3098+
if media_items:
3099+
def _create_bitmap_func(idx: int, item: str):
3100+
media_bytes = self.load_media(item["url"], item["type"])
30253101
bitmap = self._create_bitmap_from_bytes(media_bytes)
30263102
return idx, bitmap
3027-
# This method uses multi-threaded parallel processing to convert images to bitmaps,
3103+
# This method uses multi-threaded parallel processing to convert images or audio to bitmaps,
30283104
# which can be used in the future to process large numbers of video frames.
3029-
max_workers = min(llama.n_threads, len(image_urls))
3105+
max_workers = min(llama.n_threads, len(media_items))
30303106
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
3031-
futures = [executor.submit(_create_bitmap_func, i, url) for i, url in enumerate(image_urls)]
3107+
futures = [executor.submit(_create_bitmap_func, i, item) for i, item in enumerate(media_items)]
30323108

30333109
for future in concurrent.futures.as_completed(futures):
30343110
idx, bitmap = future.result()
@@ -3040,7 +3116,7 @@ def _create_bitmap_func(idx: int, url: str):
30403116
raise RuntimeError(f"{self.log_prefix}(_create_bitmap_func): Failed to decode one or more media files.")
30413117
else:
30423118
if self.verbose:
3043-
print(f"{self.log_prefix}(_create_bitmap_func with {max_workers} threads): {len(image_urls)} bitmaps were successfully created.")
3119+
print(f"{self.log_prefix}(_create_bitmap_func with {max_workers} threads): {len(media_items)} bitmaps were successfully created.")
30443120
else:
30453121
# If there are no images, set the bitmaps to empty.
30463122
bitmaps = []
@@ -3423,8 +3499,95 @@ def __call__(
34233499
)
34243500
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
34253501

3426-
def load_image(self, image_url: str) -> bytes:
3427-
return self._load_image(image_url)
3502+
def load_media(self, media_url: str, media_type: str) -> bytes:
3503+
"""
3504+
Unified dispatcher for loading media payloads.
3505+
Routes the URL/URI to the specific image or audio processor based on the media_type.
3506+
"""
3507+
if media_type == "image":
3508+
return self._load_image(media_url)
3509+
elif media_type == "audio":
3510+
audio_bytes = self._load_audio(media_url)
3511+
# Apply ironclad magic bytes validation before returning
3512+
try:
3513+
self.detect_audio_format(audio_bytes)
3514+
except ValueError as e:
3515+
raise ValueError(f"{self.log_prefix}(load_media): {e}")
3516+
return audio_bytes
3517+
else:
3518+
raise ValueError(f"{self.log_prefix}(load_media): Unknown media type '{media_type}'")
3519+
3520+
@staticmethod
3521+
def detect_audio_format(audio_bytes: bytes) -> str:
3522+
"""
3523+
Pure utility function: Detects the audio format from magic bytes.
3524+
Strictly translated from llama.cpp's `is_audio_file` to ensure 100% compatibility
3525+
and avoid false positives (e.g., AVI files disguised as RIFF).
3526+
"""
3527+
length = len(audio_bytes)
3528+
3529+
if length < 12:
3530+
raise ValueError("Audio data is corrupted or too small (less than 12 bytes).")
3531+
3532+
# RIFF & WAVE magic bytes verification
3533+
is_wav = audio_bytes.startswith(b"RIFF") and audio_bytes[8:12] == b"WAVE"
3534+
3535+
# ID3 metadata or MPEG sync word verification
3536+
is_mp3 = length >= 3 and (
3537+
audio_bytes.startswith(b"ID3") or
3538+
(audio_bytes[0] == 0xFF and (audio_bytes[1] & 0xE0) == 0xE0)
3539+
)
3540+
3541+
# FLAC magic bytes verification
3542+
is_flac = audio_bytes.startswith(b"fLaC")
3543+
3544+
if is_wav:
3545+
return "wav"
3546+
elif is_mp3:
3547+
return "mp3"
3548+
elif is_flac:
3549+
return "flac"
3550+
else:
3551+
raise ValueError(
3552+
"Unsupported audio format detected via magic bytes. "
3553+
"The underlying C++ miniaudio backend ONLY supports WAV, MP3, and FLAC."
3554+
)
3555+
3556+
@staticmethod
3557+
def _load_audio(audio_url: str) -> bytes:
3558+
"""
3559+
Load audio from either a URL, local path, or a data URI and return raw bytes.
3560+
"""
3561+
3562+
audio_bytes = b""
3563+
3564+
# 1. Handle data URI (base64)
3565+
if audio_url.strip().startswith("data:"):
3566+
comma_pos = audio_url.find(",")
3567+
if comma_pos == -1:
3568+
raise ValueError("Invalid data URI: missing comma separator")
3569+
base64_data = audio_url[comma_pos + 1 :]
3570+
audio_bytes = base64.b64decode(base64_data)
3571+
3572+
# 2. Handle local file path
3573+
elif os.path.exists(audio_url):
3574+
with open(audio_url, "rb") as f:
3575+
audio_bytes = f.read()
3576+
3577+
# 3. Handle remote URL via HTTP/HTTPS
3578+
else:
3579+
headers = {"User-Agent": "Mozilla/5.0"}
3580+
req = urllib.request.Request(audio_url, headers=headers)
3581+
try:
3582+
with urllib.request.urlopen(req, timeout=15) as f:
3583+
audio_bytes = f.read()
3584+
except (URLError, HTTPError) as e:
3585+
raise ConnectionError(f"Failed to download audio from {audio_url}: {e}")
3586+
3587+
if not audio_bytes:
3588+
raise ValueError("Empty audio data received")
3589+
3590+
return audio_bytes
34283591

34293592
@staticmethod
34303593
def _load_image(image_url: str) -> bytes:
@@ -3444,7 +3607,6 @@ def _load_image(image_url: str) -> bytes:
34443607

34453608
# 1. Handle data URI (base64)
34463609
if image_url.strip().startswith("data:"):
3447-
import base64
34483610
# Split only once from the right to correctly handle mime types containing commas
34493611
comma_pos = image_url.find(",")
34503612
if comma_pos == -1:
@@ -3454,9 +3616,6 @@ def _load_image(image_url: str) -> bytes:
34543616

34553617
# 2. Handle local/remote URL
34563618
else:
3457-
import urllib.request
3458-
from urllib.error import URLError, HTTPError
3459-
34603619
headers = {"User-Agent": "Mozilla/5.0"}
34613620
req = urllib.request.Request(image_url, headers=headers)
34623621

@@ -3506,50 +3665,6 @@ def _load_image(image_url: str) -> bytes:
35063665
image.save(output, format="JPEG", quality=95, optimize=True, progressive=True)
35073666
return output.getvalue()
35083667

3509-
@staticmethod
3510-
def get_image_urls(messages: List[llama_types.ChatCompletionRequestMessage]):
3511-
image_urls: List[str] = []
3512-
for message in messages:
3513-
if message["role"] == "user":
3514-
if message["content"] is None:
3515-
continue
3516-
for content in message["content"]:
3517-
if isinstance(content, dict) and "type" in content:
3518-
if content["type"] == "image_url":
3519-
if (
3520-
isinstance(content["image_url"], dict)
3521-
and "url" in content["image_url"]
3522-
):
3523-
image_urls.append(content["image_url"]["url"])
3524-
else:
3525-
image_urls.append(content["image_url"])
3526-
return image_urls
3527-
3528-
@staticmethod
3529-
def split_text_on_image_urls(text: str, image_urls: List[str]):
3530-
"""This method is no longer used in the new implementation."""
3531-
def find_first(s: str, substrs: List[str]):
3532-
for i, substr in enumerate(substrs):
3533-
pos = s.find(substr)
3534-
if pos != -1:
3535-
return pos, i
3536-
return None, None
3537-
3538-
split_text: List[Tuple[Literal["text", "image_url"], str]] = []
3539-
remaining = text
3540-
while remaining:
3541-
# Find first image_url
3542-
pos, i = find_first(remaining, image_urls)
3543-
if pos is not None and i is not None:
3544-
if pos > 0:
3545-
split_text.append(("text", remaining[:pos]))
3546-
split_text.append(("image_url", image_urls[i]))
3547-
remaining = remaining[pos + len(image_urls[i]) :]
3548-
else:
3549-
split_text.append(("text", remaining))
3550-
remaining = ""
3551-
return split_text
3552-
35533668
@classmethod
35543669
def from_pretrained(
35553670
cls,

0 commit comments

Comments
 (0)