11from __future__ import annotations
22
3- import os
4- import sys
5- import json
3+ import base64
64import ctypes
75import dataclasses
86import datetime
7+ import json
8+ import os
99import random
1010import string
11+ import sys
1112
1213from contextlib import ExitStack
1314from typing import (
2930import numpy as np
3031import numpy .typing as npt
3132
33+ import urllib .request
34+ from urllib .error import URLError , HTTPError
35+
3236import llama_cpp .llama_cpp as llama_cpp
3337import llama_cpp .llama as llama
3438import llama_cpp .llama_types as llama_types
@@ -2900,16 +2904,22 @@ def _init_mtmd_context(self, llama_model: llama.Llama):
29002904 raise ValueError (f"{ self .log_prefix } (_init_mtmd_context): Failed to load mtmd context from: { self .clip_model_path } " )
29012905
29022906 # Check if vision is supported
2903- if self ._mtmd_cpp .mtmd_support_vision (self .mtmd_ctx ):
2907+ self .is_support_vision = self ._mtmd_cpp .mtmd_support_vision (self .mtmd_ctx )
2908+ if self .is_support_vision :
29042909 if self .verbose :
29052910 print (f"{ self .log_prefix } (_init_mtmd_context): Vision support detected." , file = sys .stderr )
29062911 else :
2907- raise ValueError (f"{ self .log_prefix } (_init_mtmd_context): Vision is not supported by this model" )
2912+ if self .verbose :
2913+ print (f"{ self .log_prefix } (_init_mtmd_context): Vision is NOT supported by this mmproj model backend." , file = sys .stderr )
29082914
29092915 # Check if audio is supported
2910- if self ._mtmd_cpp .mtmd_support_audio (self .mtmd_ctx ):
2916+ self .is_support_audio = self ._mtmd_cpp .mtmd_support_audio (self .mtmd_ctx )
2917+ if self .is_support_audio :
29112918 if self .verbose :
29122919 print (f"{ self .log_prefix } (_init_mtmd_context): Audio support detected." , file = sys .stderr )
2920+ else :
2921+ if self .verbose :
2922+ print (f"{ self .log_prefix } (_init_mtmd_context): Audio is NOT supported by this mmproj model backend." , file = sys .stderr )
29132923
29142924 def close (self ) -> None :
29152925 """Explicitly free the mtmd context and vision model resources."""
@@ -2930,6 +2940,72 @@ def close(self) -> None:
29302940 def __del__ (self ) -> None :
29312941 self .close ()
29322942
2943+ def _get_media_items (self , messages : List [llama_types .ChatCompletionRequestMessage ]) -> List [Dict [str , str ]]:
2944+ """
2945+ Extracts all media payloads (images, audio) sequentially to maintain exact chronological order.
2946+ Strictly enforces capability checks, raising exceptions if unsupported media is passed.
2947+
2948+ Returns:
2949+ media_items: A list of dictionaries containing the media 'url' and its 'type' (image or audio).
2950+ """
2951+ media_items : List [Dict [str , str ]] = []
2952+ for message in messages :
2953+ if isinstance (message .get ("content" ), list ):
2954+ for content in message ["content" ]:
2955+ content_type = content .get ("type" , "" )
2956+
2957+ # 1. Vision Processing
2958+ if content_type == "image_url" :
2959+ if not self .is_support_vision :
2960+ raise ValueError (f"{ self .log_prefix } : This mmproj model instance does not support image inputs." )
2961+
2962+ url = content ["image_url" ] if isinstance (content ["image_url" ], str ) else content ["image_url" ]["url" ]
2963+ media_items .append ({"url" : url , "type" : "image" })
2964+
2965+ # 2. Audio Processing
2966+ elif content_type in ["audio_url" , "input_audio" ]:
2967+ if not self .is_support_audio :
2968+ raise ValueError (f"{ self .log_prefix } : This mmproj model instance does not support audio inputs." )
2969+
2970+ # Case A: Handle custom/forward-compatible audio_url format
2971+ if content == "audio_url" :
2972+ url = content ["audio_url" ] if isinstance (content ["audio_url" ], str ) else content ["audio_url" ]["url" ]
2973+ media_items .append ({"url" : url , "type" : "audio" })
2974+ # Case B: Handle OpenAI standard input_audio format
2975+ else :
2976+ input_audio = content .get ("input_audio" , {})
2977+ if isinstance (input_audio , dict ) and "data" in input_audio :
2978+ # It might just be raw base64 data, we can format it as a data URI to reuse load_audio logic
2979+ # input_audio: {
2980+ # data: audio.base64Data,
2981+ # format: audio.mimeType.includes('wav') ? 'wav' : 'mp3'
2982+ # }
2983+ audio_data = input_audio .get ("data" , "" )
2984+ audio_format = input_audio .get ("format" , "" )
2985+
2986+ # Strictly align with llama.cpp (require wav/mp3)
2987+ if audio_format not in ["wav" , "mp3" ]:
2988+ raise ValueError (f"{ self .log_prefix } : input_audio.format must be either 'wav' or 'mp3'" )
2989+
2990+ # Format as a Data URI to reuse the unified load_media logic
2991+ media_items .append ({
2992+ "url" : f"data:audio/{ audio_format } ;base64,{ audio_data } " ,
2993+ "type" : "audio"
2994+ })
2995+ else :
2996+ # Just a raw base64 data
2997+ url = input_audio if isinstance (input_audio , str ) else ""
2998+ if url :
2999+ media_items .append ({"url" : url , "type" : "audio" })
3000+
3001+ # 3. Text & Unknown Types
3002+ elif content_type == "text" :
3003+ continue
3004+ else :
3005+ if self .verbose :
3006+ print (f"{ self .log_prefix } : Ignored unknown content type '{ content_type } '." , file = sys .stderr )
3007+ return media_items
3008+
29333009 def _create_bitmap_from_bytes (self , media_bytes : bytes ):
29343010 """
29353011 Constructs an mtmd_bitmap structure from a raw byte buffer containing media data.
@@ -2992,7 +3068,7 @@ def _process_mtmd_prompt(
29923068 if system_prompt == "" and self .DEFAULT_SYSTEM_MESSAGE is not None :
29933069 messages = [{"role" : "system" , "content" : self .DEFAULT_SYSTEM_MESSAGE }] + messages
29943070
2995- image_urls = self .get_image_urls (messages )
3071+ media_items = self ._get_media_items (messages )
29963072 media_marker = self .media_marker
29973073
29983074 # 2. Render the chat template and replace actual URLs with C++ media markers
@@ -3004,31 +3080,31 @@ def _process_mtmd_prompt(
30043080 ** getattr (self , 'extra_template_arguments' , {})
30053081 )
30063082 # Replace image_url by media_marker in text
3007- for url in image_urls :
3008- text = text .replace (url , media_marker )
3083+ for item in media_items :
3084+ text = text .replace (item [ " url" ] , media_marker )
30093085
30103086 if self .verbose :
3011- print (f"{ self .log_prefix } (_process_mtmd_prompt): Rendered prompt length: { len (text )} chars, Image count: { len (image_urls )} ." , file = sys .stderr )
3087+ print (f"{ self .log_prefix } (_process_mtmd_prompt): Rendered prompt length: { len (text )} chars, Media count: { len (media_items )} ." , file = sys .stderr )
30123088 print (f"{ self .log_prefix } (_process_mtmd_prompt): Rendered prompt: { text } " , file = sys .stderr )
30133089
30143090 # 3. Pre-allocate bitmap array to guarantee chronological order during concurrent decoding
3015- bitmaps = [None ] * len (image_urls )
3091+ bitmaps = [None ] * len (media_items )
30163092 bitmap_cleanup = []
30173093 chunks = None
30183094
30193095 try :
30203096 # Concurrent Media Decoding
30213097 import concurrent .futures
3022- if image_urls :
3023- def _create_bitmap_func (idx : int , url : str ):
3024- media_bytes = self .load_image ( url )
3098+ if media_items :
3099+ def _create_bitmap_func (idx : int , item : str ):
3100+ media_bytes = self .load_media ( item [ " url" ], item [ "type" ] )
30253101 bitmap = self ._create_bitmap_from_bytes (media_bytes )
30263102 return idx , bitmap
3027- # This method uses multi-threaded parallel processing to convert images to bitmaps,
3103+ # This method uses multi-threaded parallel processing to convert images or audio to bitmaps,
30283104 # which can be used in the future to process large numbers of video frames.
3029- max_workers = min (llama .n_threads , len (image_urls ))
3105+ max_workers = min (llama .n_threads , len (media_items ))
30303106 with concurrent .futures .ThreadPoolExecutor (max_workers = max_workers ) as executor :
3031- futures = [executor .submit (_create_bitmap_func , i , url ) for i , url in enumerate (image_urls )]
3107+ futures = [executor .submit (_create_bitmap_func , i , item ) for i , item in enumerate (media_items )]
30323108
30333109 for future in concurrent .futures .as_completed (futures ):
30343110 idx , bitmap = future .result ()
@@ -3040,7 +3116,7 @@ def _create_bitmap_func(idx: int, url: str):
30403116 raise RuntimeError (f"{ self .log_prefix } (_create_bitmap_func): Failed to decode one or more media files." )
30413117 else :
30423118 if self .verbose :
3043- print (f"{ self .log_prefix } (_create_bitmap_func with { max_workers } threads): { len (image_urls )} bitmaps were successfully created." )
3119+ print (f"{ self .log_prefix } (_create_bitmap_func with { max_workers } threads): { len (media_items )} bitmaps were successfully created." )
30443120 else :
30453121 # If there are no images, set the bitmaps to empty.
30463122 bitmaps = []
@@ -3423,8 +3499,95 @@ def __call__(
34233499 )
34243500 return _convert_completion_to_chat (completion_or_chunks , stream = stream )
34253501
3426- def load_image (self , image_url : str ) -> bytes :
3427- return self ._load_image (image_url )
3502+ def load_media (self , media_url : str , media_type : str ) -> bytes :
3503+ """
3504+ Unified dispatcher for loading media payloads.
3505+ Routes the URL/URI to the specific image or audio processor based on the media_type.
3506+ """
3507+ if media_type == "image" :
3508+ return self ._load_image (media_url )
3509+ elif media_type == "audio" :
3510+ audio_bytes = self ._load_audio (media_url )
3511+ # Apply ironclad magic bytes validation before returning
3512+ try :
3513+ self .detect_audio_format (audio_bytes )
3514+ except ValueError as e :
3515+ raise ValueError (f"{ self .log_prefix } (load_media): { e } " )
3516+ return audio_bytes
3517+ else :
3518+ raise ValueError (f"{ self .log_prefix } (load_media): Unknown media type '{ media_type } '" )
3519+
3520+ @staticmethod
3521+ def detect_audio_format (audio_bytes : bytes ) -> str :
3522+ """
3523+ Pure utility function: Detects the audio format from magic bytes.
3524+ Strictly translated from llama.cpp's `is_audio_file` to ensure 100% compatibility
3525+ and avoid false positives (e.g., AVI files disguised as RIFF).
3526+ """
3527+ length = len (audio_bytes )
3528+
3529+ if length < 12 :
3530+ raise ValueError ("Audio data is corrupted or too small (less than 12 bytes)." )
3531+
3532+ # RIFF & WAVE magic bytes verification
3533+ is_wav = audio_bytes .startswith (b"RIFF" ) and audio_bytes [8 :12 ] == b"WAVE"
3534+
3535+ # ID3 metadata or MPEG sync word verification
3536+ is_mp3 = length >= 3 and (
3537+ audio_bytes .startswith (b"ID3" ) or
3538+ (audio_bytes [0 ] == 0xFF and (audio_bytes [1 ] & 0xE0 ) == 0xE0 )
3539+ )
3540+
3541+ # FLAC magic bytes verification
3542+ is_flac = audio_bytes .startswith (b"fLaC" )
3543+
3544+ if is_wav :
3545+ return "wav"
3546+ elif is_mp3 :
3547+ return "mp3"
3548+ elif is_flac :
3549+ return "flac"
3550+ else :
3551+ raise ValueError (
3552+ "Unsupported audio format detected via magic bytes. "
3553+ "The underlying C++ miniaudio backend ONLY supports WAV, MP3, and FLAC."
3554+ )
3555+
3556+ @staticmethod
3557+ def _load_audio (audio_url : str ) -> bytes :
3558+ """
3559+ Load audio from either a URL, local path, or a data URI and return raw bytes.
3560+ """
3561+
3562+ audio_bytes = b""
3563+
3564+ # 1. Handle data URI (base64)
3565+ if audio_url .strip ().startswith ("data:" ):
3566+ comma_pos = audio_url .find ("," )
3567+ if comma_pos == - 1 :
3568+ raise ValueError ("Invalid data URI: missing comma separator" )
3569+ base64_data = audio_url [comma_pos + 1 :]
3570+ audio_bytes = base64 .b64decode (base64_data )
3571+
3572+ # 2. Handle local file path
3573+ elif os .path .exists (audio_url ):
3574+ with open (audio_url , "rb" ) as f :
3575+ audio_bytes = f .read ()
3576+
3577+ # 3. Handle remote URL via HTTP/HTTPS
3578+ else :
3579+ headers = {"User-Agent" : "Mozilla/5.0" }
3580+ req = urllib .request .Request (audio_url , headers = headers )
3581+ try :
3582+ with urllib .request .urlopen (req , timeout = 15 ) as f :
3583+ audio_bytes = f .read ()
3584+ except (URLError , HTTPError ) as e :
3585+ raise ConnectionError (f"Failed to download audio from { audio_url } : { e } " )
3586+
3587+ if not audio_bytes :
3588+ raise ValueError ("Empty audio data received" )
3589+
3590+ return audio_bytes
34283591
34293592 @staticmethod
34303593 def _load_image (image_url : str ) -> bytes :
@@ -3444,7 +3607,6 @@ def _load_image(image_url: str) -> bytes:
34443607
34453608 # 1. Handle data URI (base64)
34463609 if image_url .strip ().startswith ("data:" ):
3447- import base64
34483610 # Split only once from the right to correctly handle mime types containing commas
34493611 comma_pos = image_url .find ("," )
34503612 if comma_pos == - 1 :
@@ -3454,9 +3616,6 @@ def _load_image(image_url: str) -> bytes:
34543616
34553617 # 2. Handle local/remote URL
34563618 else :
3457- import urllib .request
3458- from urllib .error import URLError , HTTPError
3459-
34603619 headers = {"User-Agent" : "Mozilla/5.0" }
34613620 req = urllib .request .Request (image_url , headers = headers )
34623621
@@ -3506,50 +3665,6 @@ def _load_image(image_url: str) -> bytes:
35063665 image .save (output , format = "JPEG" , quality = 95 , optimize = True , progressive = True )
35073666 return output .getvalue ()
35083667
3509- @staticmethod
3510- def get_image_urls (messages : List [llama_types .ChatCompletionRequestMessage ]):
3511- image_urls : List [str ] = []
3512- for message in messages :
3513- if message ["role" ] == "user" :
3514- if message ["content" ] is None :
3515- continue
3516- for content in message ["content" ]:
3517- if isinstance (content , dict ) and "type" in content :
3518- if content ["type" ] == "image_url" :
3519- if (
3520- isinstance (content ["image_url" ], dict )
3521- and "url" in content ["image_url" ]
3522- ):
3523- image_urls .append (content ["image_url" ]["url" ])
3524- else :
3525- image_urls .append (content ["image_url" ])
3526- return image_urls
3527-
3528- @staticmethod
3529- def split_text_on_image_urls (text : str , image_urls : List [str ]):
3530- """This method is no longer used in the new implementation."""
3531- def find_first (s : str , substrs : List [str ]):
3532- for i , substr in enumerate (substrs ):
3533- pos = s .find (substr )
3534- if pos != - 1 :
3535- return pos , i
3536- return None , None
3537-
3538- split_text : List [Tuple [Literal ["text" , "image_url" ], str ]] = []
3539- remaining = text
3540- while remaining :
3541- # Find first image_url
3542- pos , i = find_first (remaining , image_urls )
3543- if pos is not None and i is not None :
3544- if pos > 0 :
3545- split_text .append (("text" , remaining [:pos ]))
3546- split_text .append (("image_url" , image_urls [i ]))
3547- remaining = remaining [pos + len (image_urls [i ]) :]
3548- else :
3549- split_text .append (("text" , remaining ))
3550- remaining = ""
3551- return split_text
3552-
35533668 @classmethod
35543669 def from_pretrained (
35553670 cls ,
0 commit comments