Skip to content

Commit d27310d

Browse files
noorbhatiamattt
authored andcommitted
Add additionalContext support to MLXLanguageModel
1 parent 884251f commit d27310d

2 files changed

Lines changed: 71 additions & 6 deletions

File tree

Sources/AnyLanguageModel/Models/MLXLanguageModel.swift

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,16 @@ import Foundation
183183
/// let model = MLXLanguageModel(modelId: "mlx-community/Llama-3.2-3B-Instruct-4bit")
184184
/// ```
185185
public struct MLXLanguageModel: LanguageModel {
186+
/// Custom generation options for MLX models.
187+
public struct CustomGenerationOptions: AnyLanguageModel.CustomGenerationOptions {
188+
/// Additional key-value pairs injected into the chat template rendering context.
189+
public var additionalContext: [String: MLXLMCommon.JSONValue]?
190+
191+
public init(additionalContext: [String: MLXLMCommon.JSONValue]? = nil) {
192+
self.additionalContext = additionalContext
193+
}
194+
}
195+
186196
/// The reason the model is unavailable.
187197
public enum UnavailableReason: Sendable, Equatable, Hashable {
188198
/// The model has not been loaded into memory yet.
@@ -813,6 +823,11 @@ import Foundation
813823
// Map AnyLanguageModel GenerationOptions to MLX GenerateParameters
814824
let generateParameters = toGenerateParameters(options)
815825

826+
// Extract additional context from custom options
827+
let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self]
828+
.flatMap { $0.additionalContext }
829+
.map { $0.mapValues { $0.toSendable() } }
830+
816831
// Build chat history from full transcript
817832
var chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)
818833

@@ -828,7 +843,8 @@ import Foundation
828843
let userInput = MLXLMCommon.UserInput(
829844
chat: chat,
830845
processing: .init(resize: .init(width: 512, height: 512)),
831-
tools: toolSpecs
846+
tools: toolSpecs,
847+
additionalContext: additionalContext,
832848
)
833849
let lmInput = try await context.processor.prepare(input: userInput)
834850
let resolved = resolveCache(
@@ -991,10 +1007,17 @@ import Foundation
9911007

9921008
// Build chat inside task to avoid Sendable issues
9931009
let generateParameters = toGenerateParameters(options)
994-
let userInput = makeUserInput(
995-
session: session,
996-
fallbackPrompt: prompt.description,
997-
tools: nil
1010+
let chat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)
1011+
1012+
let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self]
1013+
.flatMap { $0.additionalContext }
1014+
.map { $0.mapValues { $0.toSendable() } }
1015+
1016+
let userInput = MLXLMCommon.UserInput(
1017+
chat: chat,
1018+
processing: .init(resize: .init(width: 512, height: 512)),
1019+
tools: nil,
1020+
additionalContext: additionalContext
9981021
)
9991022
let lmInput = try await context.processor.prepare(input: userInput)
10001023
let resolved = resolveCache(
@@ -1529,10 +1552,16 @@ import Foundation
15291552
let baseChat = convertTranscriptToMLXChat(session: session, fallbackPrompt: prompt.description)
15301553
let schemaPrompt = includeSchemaInPrompt ? schemaPrompt(for: schema) : nil
15311554
let chat = normalizeChatForStructuredGeneration(baseChat, schemaPrompt: schemaPrompt)
1555+
1556+
let additionalContext: [String: any Sendable]? = options[custom: MLXLanguageModel.self]
1557+
.flatMap { $0.additionalContext }
1558+
.map { $0.mapValues { $0.toSendable() } }
1559+
15321560
let userInput = MLXLMCommon.UserInput(
15331561
chat: chat,
15341562
processing: .init(resize: .init(width: 512, height: 512)),
1535-
tools: nil
1563+
tools: nil,
1564+
additionalContext: additionalContext,
15361565
)
15371566
let lmInput = try await context.processor.prepare(input: userInput)
15381567

@@ -1773,4 +1802,18 @@ import Foundation
17731802
return sampledToken.item(Int.self)
17741803
}
17751804
}
1805+
extension MLXLMCommon.JSONValue {
1806+
/// Recursively converts a `JSONValue` to its primitive Swift equivalent.
1807+
func toSendable() -> any Sendable {
1808+
switch self {
1809+
case .string(let s): return s
1810+
case .int(let i): return i
1811+
case .double(let d): return d
1812+
case .bool(let b): return b
1813+
case .null: return NSNull()
1814+
case .array(let arr): return arr.map { $0.toSendable() }
1815+
case .object(let obj): return obj.mapValues { $0.toSendable() }
1816+
}
1817+
}
1818+
}
17761819
#endif // MLX

Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,28 @@ import Testing
255255
#expect([Priority.low, Priority.medium, Priority.high].contains(response.content))
256256
}
257257

258+
@Test func withAdditionalContext() async throws {
259+
let session = LanguageModelSession(model: model)
260+
261+
var options = GenerationOptions(
262+
temperature: 0.7,
263+
maximumResponseTokens: 32
264+
)
265+
options[custom: MLXLanguageModel.self] = .init(
266+
additionalContext: [
267+
"user_name": .string("Alice"),
268+
"turn_count": .int(3),
269+
"verbose": .bool(true),
270+
]
271+
)
272+
273+
let response = try await session.respond(
274+
to: "Say hello",
275+
options: options
276+
)
277+
#expect(!response.content.isEmpty)
278+
}
279+
258280
@Test func unavailableForNonexistentModel() async {
259281
let model = MLXLanguageModel(modelId: "mlx-community/does-not-exist-anylanguagemodel-test")
260282
await model.removeFromCache()

0 commit comments

Comments
 (0)