@@ -183,6 +183,16 @@ import Foundation
183183 /// let model = MLXLanguageModel(modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit")
184184 /// ```
185185 public struct MLXLanguageModel : LanguageModel {
186+ /// Custom generation options for MLX models.
187+ public struct CustomGenerationOptions : AnyLanguageModel . CustomGenerationOptions {
188+ /// Additional key-value pairs injected into the chat template rendering context.
189+ public var additionalContext : [ String : MLXLMCommon . JSONValue ] ?
190+
191+ public init ( additionalContext: [ String : MLXLMCommon . JSONValue ] ? = nil ) {
192+ self . additionalContext = additionalContext
193+ }
194+ }
195+
186196 /// The reason the model is unavailable.
187197 public enum UnavailableReason : Sendable , Equatable , Hashable {
188198 /// The model has not been loaded into memory yet.
@@ -813,6 +823,11 @@ import Foundation
813823 // Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
814824 let generateParameters = toGenerateParameters ( options)
815825
826+ // Extract additional context from custom options
827+ let additionalContext : [ String : any Sendable ] ? = options [ custom: MLXLanguageModel . self]
828+ . flatMap { $0. additionalContext }
829+ . map { $0. mapValues { $0. toSendable ( ) } }
830+
816831 // Build chat history from full transcript
817832 var chat = convertTranscriptToMLXChat ( session: session, fallbackPrompt: prompt. description)
818833
@@ -828,7 +843,8 @@ import Foundation
828843 let userInput = MLXLMCommon . UserInput (
829844 chat: chat,
830845 processing: . init( resize: . init( width: 512 , height: 512 ) ) ,
831- tools: toolSpecs
846+ tools: toolSpecs,
847+ additionalContext: additionalContext,
832848 )
833849 let lmInput = try await context. processor. prepare ( input: userInput)
834850 let resolved = resolveCache (
@@ -991,10 +1007,17 @@ import Foundation
9911007
9921008 // Build chat inside task to avoid Sendable issues
9931009 let generateParameters = toGenerateParameters ( options)
994- let userInput = makeUserInput (
995- session: session,
996- fallbackPrompt: prompt. description,
997- tools: nil
1010+ let chat = convertTranscriptToMLXChat ( session: session, fallbackPrompt: prompt. description)
1011+
1012+ let additionalContext : [ String : any Sendable ] ? = options [ custom: MLXLanguageModel . self]
1013+ . flatMap { $0. additionalContext }
1014+ . map { $0. mapValues { $0. toSendable ( ) } }
1015+
1016+ let userInput = MLXLMCommon . UserInput (
1017+ chat: chat,
1018+ processing: . init( resize: . init( width: 512 , height: 512 ) ) ,
1019+ tools: nil ,
1020+ additionalContext: additionalContext
9981021 )
9991022 let lmInput = try await context. processor. prepare ( input: userInput)
10001023 let resolved = resolveCache (
@@ -1529,10 +1552,16 @@ import Foundation
15291552 let baseChat = convertTranscriptToMLXChat ( session: session, fallbackPrompt: prompt. description)
15301553 let schemaPrompt = includeSchemaInPrompt ? schemaPrompt ( for: schema) : nil
15311554 let chat = normalizeChatForStructuredGeneration ( baseChat, schemaPrompt: schemaPrompt)
1555+
1556+ let additionalContext : [ String : any Sendable ] ? = options [ custom: MLXLanguageModel . self]
1557+ . flatMap { $0. additionalContext }
1558+ . map { $0. mapValues { $0. toSendable ( ) } }
1559+
15321560 let userInput = MLXLMCommon . UserInput (
15331561 chat: chat,
15341562 processing: . init( resize: . init( width: 512 , height: 512 ) ) ,
1535- tools: nil
1563+ tools: nil ,
1564+ additionalContext: additionalContext,
15361565 )
15371566 let lmInput = try await context. processor. prepare ( input: userInput)
15381567
@@ -1773,4 +1802,18 @@ import Foundation
17731802 return sampledToken. item ( Int . self)
17741803 }
17751804 }
1805+ extension MLXLMCommon . JSONValue {
1806+ /// Recursively converts a `JSONValue` to its primitive Swift equivalent.
1807+ func toSendable( ) -> any Sendable {
1808+ switch self {
1809+ case . string( let s) : return s
1810+ case . int( let i) : return i
1811+ case . double( let d) : return d
1812+ case . bool( let b) : return b
1813+ case . null: return NSNull ( )
1814+ case . array( let arr) : return arr. map { $0. toSendable ( ) }
1815+ case . object( let obj) : return obj. mapValues { $0. toSendable ( ) }
1816+ }
1817+ }
1818+ }
17761819#endif // MLX
0 commit comments