Skip to content

Commit 3a2b9ea

Browse files
committed
Incorporate feedback from review
1 parent 770e3bf commit 3a2b9ea

2 files changed

Lines changed: 36 additions & 15 deletions

File tree

Sources/AnyLanguageModel/Models/MLXLanguageModel.swift

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ import Foundation
389389
func removeEntries(forModelKey modelKey: String) {
390390
lock.withLock {
391391
reapDeadSessionsLocked()
392-
for id in buckets.keys {
392+
for id in Array(buckets.keys) {
393393
guard var bucket = buckets[id] else {
394394
continue
395395
}
@@ -410,7 +410,10 @@ import Foundation
410410
}
411411

412412
private func reapDeadSessionsLocked() {
413-
for (id, bucket) in buckets where bucket.sessionReference.session == nil {
413+
let deadSessionIDs = buckets.compactMap { id, bucket in
414+
bucket.sessionReference.session == nil ? id : nil
415+
}
416+
for id in deadSessionIDs {
414417
buckets[id] = nil
415418
}
416419
}
@@ -442,7 +445,7 @@ import Foundation
442445
static let shared = GPUMemoryManager()
443446

444447
private let lock = NSLock()
445-
private var knownConfigs: Set<GPUMemoryConfiguration> = [.automatic]
448+
private var knownConfigs: Set<GPUMemoryConfiguration> = []
446449
private var activeScopes: [UUID: GPUMemoryConfiguration] = [:]
447450

448451
private init() {
@@ -500,7 +503,10 @@ import Foundation
500503
}
501504

502505
private func shouldClearOnEviction() -> Bool {
503-
knownConfigs.contains { $0.clearCacheOnEviction }
506+
if knownConfigs.isEmpty {
507+
return GPUMemoryConfiguration.automatic.clearCacheOnEviction
508+
}
509+
return knownConfigs.contains { $0.clearCacheOnEviction }
504510
}
505511
}
506512

@@ -852,17 +858,6 @@ import Foundation
852858
}
853859
previousToolCallSignature = signature
854860

855-
if !assistantText.isEmpty {
856-
allEntries.append(
857-
.response(
858-
Transcript.Response(
859-
assetIDs: [],
860-
segments: [.text(.init(content: assistantText))]
861-
)
862-
)
863-
)
864-
}
865-
866861
let resolution = try await resolveToolCalls(collectedToolCalls, session: session)
867862
switch resolution {
868863
case .stop(let calls):

Tests/AnyLanguageModelTests/MLXLanguageModelTests.swift

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,32 @@ import Testing
7878
#expect(!second.content.isEmpty)
7979
}
8080

81+
@Test func rejectsConcurrentRequestsForSameSession() async throws {
82+
let session = LanguageModelSession(model: model)
83+
let stream = session.streamResponse(
84+
to: "Count from 1 to 400 with one number per line.",
85+
options: .init(maximumResponseTokens: 256)
86+
)
87+
88+
do {
89+
_ = try await session.respond(to: "This concurrent request should fail.")
90+
Issue.record("Expected concurrent request to throw.")
91+
} catch let error as LanguageModelSession.GenerationError {
92+
switch error {
93+
case .concurrentRequests:
94+
break
95+
default:
96+
Issue.record("Expected .concurrentRequests, got \(error)")
97+
}
98+
} catch {
99+
Issue.record("Expected GenerationError.concurrentRequests, got \(error)")
100+
}
101+
102+
for try await _ in stream {
103+
break
104+
}
105+
}
106+
81107
@Test func withGenerationOptions() async throws {
82108
let session = LanguageModelSession(model: model)
83109

0 commit comments

Comments
 (0)