Files
hyungi_document_server/clients/ds-app/Tests/AITests/OnDeviceProviderTests.swift
T

202 lines
9.0 KiB
Swift

import XCTest
@testable import AIFabric
#if canImport(FoundationModels)
import FoundationModels
#endif
/// HW mock backend availability + generate .
struct MockOnDeviceBackend: OnDeviceModelBackend {
let avail: OnDeviceAvailability
let outcome: Result<String, OnDeviceGenerationError>
init(avail: OnDeviceAvailability = .available,
outcome: Result<String, OnDeviceGenerationError> = .success("on-device ok")) {
self.avail = avail
self.outcome = outcome
}
var availability: OnDeviceAvailability { avail }
func generate(prompt: String, systemPrompt: String?, maxTokens: Int?) async throws -> String {
switch outcome {
case .success(let s): return s
case .failure(let e): throw e
}
}
}
/// complete() provider .
actor CountingProvider: AIProvider {
nonisolated let id: AIProviderID
let available: Bool
private(set) var completeCalls = 0
init(id: AIProviderID, available: Bool) {
self.id = id
self.available = available
}
var isAvailable: Bool { get async { available } }
func complete(_ request: AICompletionRequest) async throws -> AICompletionResponse {
completeCalls += 1
return AICompletionResponse(text: "should-not-be-called", providerUsed: id)
}
}
final class LogSink: @unchecked Sendable {
private let lock = NSLock()
private var storage: [String] = []
func append(_ s: String) { lock.lock(); storage.append(s); lock.unlock() }
var lines: [String] { lock.lock(); defer { lock.unlock() }; return storage }
}
final class OnDeviceProviderTests: XCTestCase {
// MARK: + happy path ( backend)
func testAvailableReturnsText() async throws {
let p = OnDeviceProvider(backend: MockOnDeviceBackend(avail: .available, outcome: .success("요약 결과")))
let available = await p.isAvailable
XCTAssertTrue(available)
let resp = try await p.complete(AICompletionRequest(task: .quickSummarize, prompt: "p"))
XCTAssertEqual(resp.providerUsed, .onDevice)
XCTAssertEqual(resp.finishReason, .completed)
XCTAssertEqual(resp.text, "요약 결과")
XCTAssertNotNil(resp.latencyMs)
}
func testUnavailableReportsFalse() async throws {
let p = OnDeviceProvider(backend: MockOnDeviceBackend(avail: .unavailable(reason: "appleIntelligenceNotEnabled")))
let available = await p.isAvailable
XCTAssertFalse(available)
}
// MARK: GenerationError (S2-3c)
func testGuardrailAndRefusalReturnRefused() async throws {
for err in [OnDeviceGenerationError.guardrailViolation, .refusal] {
let p = OnDeviceProvider(backend: MockOnDeviceBackend(outcome: .failure(err)))
let resp = try await p.complete(AICompletionRequest(task: .quickSummarize, prompt: "p"))
XCTAssertEqual(resp.finishReason, .refused, "\(err) → .refused (답변의 일종, 폴백 X)")
XCTAssertEqual(resp.providerUsed, .onDevice)
}
}
func testRateLimitedThrowsUnavailableAndLoudLogs() async throws {
let sink = LogSink()
let p = OnDeviceProvider(backend: MockOnDeviceBackend(outcome: .failure(.rateLimited)),
log: { sink.append($0) })
do {
_ = try await p.complete(AICompletionRequest(task: .quickSummarize, prompt: "p"))
XCTFail("rateLimited → throw unavailable")
} catch let AIProviderError.unavailable(id) {
XCTAssertEqual(id, .onDevice)
}
XCTAssertTrue(sink.lines.contains { $0.contains("rateLimited") }, "stateless 위반은 loud log")
}
func testConcurrentRequestsThrowsUnavailableAndLoudLogs() async throws {
let sink = LogSink()
let p = OnDeviceProvider(backend: MockOnDeviceBackend(outcome: .failure(.concurrentRequests)),
log: { sink.append($0) })
do {
_ = try await p.complete(AICompletionRequest(task: .quickSummarize, prompt: "p"))
XCTFail("concurrentRequests → throw unavailable")
} catch let AIProviderError.unavailable(id) {
XCTAssertEqual(id, .onDevice)
}
XCTAssertTrue(sink.lines.contains { $0.contains("concurrentRequests") })
}
func testContextOverflowThrowsUnavailable() async throws {
let p = OnDeviceProvider(backend: MockOnDeviceBackend(outcome: .failure(.exceededContextWindowSize)))
do {
_ = try await p.complete(AICompletionRequest(task: .quickSummarize, prompt: "p"))
XCTFail("exceededContextWindowSize → throw unavailable (폴백 유도)")
} catch let AIProviderError.unavailable(id) {
XCTAssertEqual(id, .onDevice)
}
}
// MARK: (S2-3d)
func testRouterFallsBackOnDeviceOverflowToLocalMLX() async throws {
let router = AIRouter(providers: [
.onDevice: OnDeviceProvider(backend: MockOnDeviceBackend(outcome: .failure(.exceededContextWindowSize))),
.localMLX: MockAIProvider(id: .localMLX, available: true),
])
let resp = try await router.route(AICompletionRequest(task: .quickSummarize, prompt: "p"))
XCTAssertEqual(resp.providerUsed, .localMLX)
XCTAssertEqual(resp.routingNote, "fallback from onDevice → localMLX")
}
func testExplicitOnDeviceUnavailableNoFallback() async throws {
let counting = CountingProvider(id: .localMLX, available: true)
let router = AIRouter(providers: [
.onDevice: OnDeviceProvider(backend: MockOnDeviceBackend(avail: .unavailable(reason: "deviceNotEligible"))),
.localMLX: counting,
])
do {
_ = try await router.route(AICompletionRequest(task: .quickSummarize, prompt: "p", explicitProvider: .onDevice))
XCTFail("explicit onDevice unavailable → explicitProviderUnavailable, 자동 폴백 금지")
} catch let AIRoutingError.explicitProviderUnavailable(id) {
XCTAssertEqual(id, .onDevice)
}
let calls = await counting.completeCalls
XCTAssertEqual(calls, 0, "명시 불가 시 타 provider complete() 호출 0")
}
// MARK: SDK GenerationError lock ( )
#if canImport(FoundationModels)
func testTranslateGenerationErrorCases() {
let ctx = LanguageModelSession.GenerationError.Context(debugDescription: "test")
XCTAssertEqual(FoundationModelsBackend.translate(.exceededContextWindowSize(ctx)), .exceededContextWindowSize)
XCTAssertEqual(FoundationModelsBackend.translate(.guardrailViolation(ctx)), .guardrailViolation)
XCTAssertEqual(FoundationModelsBackend.translate(.rateLimited(ctx)), .rateLimited)
XCTAssertEqual(FoundationModelsBackend.translate(.concurrentRequests(ctx)), .concurrentRequests)
XCTAssertEqual(FoundationModelsBackend.translate(.unsupportedLanguageOrLocale(ctx)), .unsupportedLanguageOrLocale)
XCTAssertEqual(FoundationModelsBackend.translate(.assetsUnavailable(ctx)), .assetsUnavailable)
}
#endif
// MARK: (M5 Max -AI Mac skip)
func testLiveOnDeviceIntegration() async throws {
let p = OnDeviceProvider() // FoundationModels backend
guard await p.isAvailable else {
throw XCTSkip("FoundationModels not available on this machine — live test skipped")
}
let resp = try await p.complete(
AICompletionRequest(task: .quickSummarize,
prompt: "엘보 내경 가공의 핵심 관리 포인트를 한 문장으로 요약해줘.",
maxTokens: 120)
)
XCTAssertEqual(resp.providerUsed, .onDevice)
XCTAssertEqual(resp.finishReason, .completed)
XCTAssertFalse(resp.text.isEmpty, "라이브 응답은 비어있지 않아야")
}
func testLiveCancellationCooperative() async throws {
let p = OnDeviceProvider()
guard await p.isAvailable else {
throw XCTSkip("FoundationModels not available — cancellation live test skipped")
}
let started = Date()
let task = Task { () -> AIFinishReason in
let r = try await p.complete(
AICompletionRequest(task: .quickSummarize,
prompt: "대한민국 압력용기 산업과 ASME 표준 채택 역사를 아주 길고 자세하게 여러 단락으로 서술해줘.",
maxTokens: 4000)
)
return r.finishReason
}
try? await Task.sleep(nanoseconds: 500_000_000)
task.cancel()
do {
_ = try await task.value
// ( , ).
} catch is CancellationError {
let elapsed = Date().timeIntervalSince(started)
XCTAssertLessThan(elapsed, 8.0, "협조적 취소면 빠르게 중단(S2-3a: ~33ms 후)")
}
}
}