diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f81816b..483cec6 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,9 +7,9 @@ on: branches: ["main"] jobs: - test: - name: Swift ${{ matrix.swift }} on Xcode ${{ matrix.xcode }} - runs-on: ${{ matrix.runs-on }} + test-macos: + name: Swift ${{ matrix.swift }} on macOS ${{ matrix.macos }} with Xcode ${{ matrix.xcode }} + runs-on: macos-${{ matrix.macos }} env: DEVELOPER_DIR: "/Applications/Xcode_${{ matrix.xcode }}.app/Contents/Developer" strategy: @@ -18,16 +18,53 @@ jobs: include: - swift: "6.0" xcode: "16.0" - runs-on: macos-15 + macos: "15" - swift: "6.1" xcode: "16.3" - runs-on: macos-15 + macos: "15" - swift: "6.2" xcode: "26.0" - runs-on: macos-26 + macos: "26" + timeout-minutes: 10 + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Cache Swift Package Manager dependencies + uses: actions/cache@v4 + with: + path: | + ~/.cache/org.swift.swiftpm + .build + key: ${{ runner.os }}-swift-${{ matrix.swift }}-spm-${{ hashFiles('**/Package.resolved') }} + restore-keys: | + ${{ runner.os }}-swift-${{ matrix.swift }}-spm- + + - name: Build + run: swift build -v + + - name: Test + run: swift test -v + test-linux: + name: Swift ${{ matrix.swift-version }} on Linux + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + swift-version: + - "6.0.3" + - "6.1.3" + - "6.2.3" + timeout-minutes: 10 steps: - - uses: actions/checkout@v4 + - name: Checkout code + uses: actions/checkout@v4 + + - name: Setup Swift + uses: vapor/swiftly-action@v0.2 + with: + toolchain: ${{ matrix.swift-version }} - name: Build run: swift build -v diff --git a/Package.resolved b/Package.resolved new file mode 100644 index 0000000..1e48ef0 --- /dev/null +++ b/Package.resolved @@ -0,0 +1,33 @@ +{ + "originHash" : "4309fa84e67265c5b44e488827f3ab8aad324591e808443f9979ed8618e54d87", + "pins" : [ + { + "identity" : "eventsource", + "kind" : "remoteSourceControl", + "location" : "https://github.com/mattt/EventSource.git", + "state" : { + "revision" : "ca2a9d90cbe49e09b92f4b6ebd922c03ebea51d0", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-asn1", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-asn1.git", + "state" : { + "revision" : "810496cf121e525d660cd0ea89a758740476b85f", + "version" : "1.5.1" + } + }, + { + "identity" : "swift-crypto", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-crypto.git", + "state" : { + "revision" : "6f70fa9eab24c1fd982af18c281c4525d05e3095", + "version" : "4.2.0" + } + } + ], + "version" : 3 +} diff --git a/Package.swift b/Package.swift index 90abb1d..74bdab8 100644 --- a/Package.swift +++ b/Package.swift @@ -20,13 +20,15 @@ let package = Package( ) ], dependencies: [ - .package(url: "https://github.com/mattt/EventSource.git", from: "1.0.0") + .package(url: "https://github.com/mattt/EventSource.git", from: "1.0.0"), + .package(url: "https://github.com/apple/swift-crypto.git", "1.0.0" ..< "5.0.0"), ], targets: [ .target( name: "HuggingFace", dependencies: [ - .product(name: "EventSource", package: "EventSource") + .product(name: "EventSource", package: "EventSource"), + .product(name: "Crypto", package: "swift-crypto"), ], path: "Sources/HuggingFace" ), diff --git a/Sources/HuggingFace/Hub/File.swift b/Sources/HuggingFace/Hub/File.swift index bf47c07..663d3ef 100644 --- a/Sources/HuggingFace/Hub/File.swift +++ b/Sources/HuggingFace/Hub/File.swift @@ -1,4 +1,3 @@ -import CryptoKit import Foundation /// Information about a file in a repository. diff --git a/Sources/HuggingFace/Hub/HubClient+Files.swift b/Sources/HuggingFace/Hub/HubClient+Files.swift index 0e26a51..deec7c4 100644 --- a/Sources/HuggingFace/Hub/HubClient+Files.swift +++ b/Sources/HuggingFace/Hub/HubClient+Files.swift @@ -1,5 +1,8 @@ import Foundation -import UniformTypeIdentifiers + +#if canImport(UniformTypeIdentifiers) + import UniformTypeIdentifiers +#endif #if canImport(FoundationNetworking) import FoundationNetworking @@ -285,10 +288,14 @@ public extension HubClient { var request = try await httpClient.createRequest(.get, url: url) request.cachePolicy = cachePolicy - let (tempURL, response) = try await session.download( - for: request, - delegate: progress.map { DownloadProgressDelegate(progress: $0) } - ) + #if canImport(FoundationNetworking) + let (tempURL, response) = try await session.asyncDownload(for: request, progress: progress) + #else + let (tempURL, response) = try await session.download( + for: request, + delegate: progress.map { DownloadProgressDelegate(progress: $0) } + ) + #endif _ = try httpClient.validateResponse(response, data: nil) // Store in cache before moving to destination @@ -321,29 +328,35 @@ public extension HubClient { return destination } - /// Download file with resume capability - /// - Parameters: - /// - resumeData: Resume data from a previous download attempt - /// - destination: Destination URL for downloaded file - /// - progress: Optional Progress object to track download progress - /// - Returns: Final destination URL - func resumeDownloadFile( - resumeData: Data, - to destination: URL, - progress: Progress? = nil - ) async throws -> URL { - let (tempURL, response) = try await session.download( - resumeFrom: resumeData, - delegate: progress.map { DownloadProgressDelegate(progress: $0) } - ) - _ = try httpClient.validateResponse(response, data: nil) + #if !canImport(FoundationNetworking) + /// Download file with resume capability + /// + /// - Note: This method is only available on Apple platforms. + /// On Linux, resume functionality is not supported. + /// + /// - Parameters: + /// - resumeData: Resume data from a previous download attempt + /// - destination: Destination URL for downloaded file + /// - progress: Optional Progress object to track download progress + /// - Returns: Final destination URL + func resumeDownloadFile( + resumeData: Data, + to destination: URL, + progress: Progress? = nil + ) async throws -> URL { + let (tempURL, response) = try await session.download( + resumeFrom: resumeData, + delegate: progress.map { DownloadProgressDelegate(progress: $0) } + ) + _ = try httpClient.validateResponse(response, data: nil) - // Move from temporary location to final destination - try? FileManager.default.removeItem(at: destination) - try FileManager.default.moveItem(at: tempURL, to: destination) + // Move from temporary location to final destination + try? FileManager.default.removeItem(at: destination) + try FileManager.default.moveItem(at: tempURL, to: destination) - return destination - } + return destination + } + #endif /// Download file to a destination URL (convenience method without progress tracking) /// - Parameters: @@ -379,32 +392,34 @@ public extension HubClient { // MARK: - Progress Delegate -private final class DownloadProgressDelegate: NSObject, URLSessionDownloadDelegate, @unchecked Sendable { - private let progress: Progress +#if !canImport(FoundationNetworking) + private final class DownloadProgressDelegate: NSObject, URLSessionDownloadDelegate, @unchecked Sendable { + private let progress: Progress - init(progress: Progress) { - self.progress = progress - } + init(progress: Progress) { + self.progress = progress + } - func urlSession( - _: URLSession, - downloadTask _: URLSessionDownloadTask, - didWriteData _: Int64, - totalBytesWritten: Int64, - totalBytesExpectedToWrite: Int64 - ) { - progress.totalUnitCount = totalBytesExpectedToWrite - progress.completedUnitCount = totalBytesWritten - } + func urlSession( + _: URLSession, + downloadTask _: URLSessionDownloadTask, + didWriteData _: Int64, + totalBytesWritten: Int64, + totalBytesExpectedToWrite: Int64 + ) { + progress.totalUnitCount = totalBytesExpectedToWrite + progress.completedUnitCount = totalBytesWritten + } - func urlSession( - _: URLSession, - downloadTask _: URLSessionDownloadTask, - didFinishDownloadingTo _: URL - ) { - // The actual file handling is done in the async/await layer + func urlSession( + _: URLSession, + downloadTask _: URLSessionDownloadTask, + didFinishDownloadingTo _: URL + ) { + // The actual file handling is done in the async/await layer + } } -} +#endif // MARK: - Delete Operations @@ -632,9 +647,107 @@ private struct UploadResponse: Codable { private extension URL { var mimeType: String? { - guard let uti = UTType(filenameExtension: pathExtension) else { - return nil - } - return uti.preferredMIMEType + #if canImport(UniformTypeIdentifiers) + guard let uti = UTType(filenameExtension: pathExtension) else { + return nil + } + return uti.preferredMIMEType + #else + // Fallback MIME type lookup for Linux + let ext = pathExtension.lowercased() + switch ext { + // MARK: - JSON + case "json": + return "application/json" + // MARK: - Text + case "txt": + return "text/plain" + case "md": + return "text/markdown" + case "csv": + return "text/csv" + case "tsv": + return "text/tab-separated-values" + // MARK: - HTML and Markup + case "html", "htm": + return "text/html" + case "xml": + return "application/xml" + case "svg": + return "image/svg+xml" + case "yaml", "yml": + return "application/x-yaml" + case "toml": + return "application/toml" + // MARK: - Code + case "js": + return "application/javascript" + case "py": + return "text/x-python" + case "swift": + return "text/x-swift" + case "css": + return "text/css" + case "ipynb": + return "application/x-ipynb+json" + // MARK: - Archives and Compressed + case "zip": + return "application/zip" + case "gz", "gzip": + return "application/gzip" + case "tar": + return "application/x-tar" + case "bz2": + return "application/x-bzip2" + case "7z": + return "application/x-7z-compressed" + // MARK: - PDF and Documents + case "pdf": + return "application/pdf" + // MARK: - Images + case "png": + return "image/png" + case "jpg", "jpeg": + return "image/jpeg" + case "gif": + return "image/gif" + case "webp": + return "image/webp" + case "bmp": + return "image/bmp" + case "tiff", "tif": + return "image/tiff" + // MARK: - Audio + case "m4a": + return "audio/mp4" + case "mp3": + return "audio/mpeg" + case "wav": + return "audio/wav" + case "flac": + return "audio/flac" + case "ogg": + return "audio/ogg" + // MARK: - Video + case "mp4": + return "video/mp4" + case "webm": + return "video/webm" + // MARK: - ML/Model/Raw Data + case "bin", "safetensors", "gguf", "ggml": + return "application/octet-stream" + case "pt", "pth": + return "application/octet-stream" + case "onnx": + return "application/octet-stream" + case "ckpt": + return "application/octet-stream" + case "npz": + return "application/octet-stream" + // MARK: - Default + default: + return "application/octet-stream" + } + #endif } } diff --git a/Sources/HuggingFace/Hub/Pagination.swift b/Sources/HuggingFace/Hub/Pagination.swift index daa591f..1dbd5f1 100644 --- a/Sources/HuggingFace/Hub/Pagination.swift +++ b/Sources/HuggingFace/Hub/Pagination.swift @@ -1,4 +1,7 @@ import Foundation +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif /// Sort direction for list queries. public enum SortDirection: Int, Hashable, Sendable { @@ -28,38 +31,48 @@ public struct PaginatedResponse: Sendable { } } -// MARK: - +// MARK: - Link Header Parsing -extension HTTPURLResponse { - /// Parses the Link header to extract the next page URL. - /// - /// The Link header format follows RFC 8288: `; rel="next"` - /// - /// - Returns: The URL for the next page, or `nil` if not found. - func nextPageURL() -> URL? { - guard let linkHeader = value(forHTTPHeaderField: "Link") else { - return nil - } +/// Parses the Link header from an HTTP response to extract the next page URL. +/// +/// The Link header format follows RFC 8288: `; rel="next"` +/// +/// - Parameter response: The HTTP response to parse. +/// - Returns: The URL for the next page, or `nil` if not found. +func parseNextPageURL(from response: HTTPURLResponse) -> URL? { + guard let linkHeader = response.value(forHTTPHeaderField: "Link") else { + return nil + } + return parseNextPageURL(from: linkHeader) +} - // Parse Link header format: ; rel="next" - let links = linkHeader.components(separatedBy: ",") - for link in links { - let components = link.components(separatedBy: ";") - guard components.count >= 2 else { continue } +/// Parses a Link header string to extract the next page URL. +/// +/// - Parameter linkHeader: The Link header value. +/// - Returns: The URL for the next page, or `nil` if not found. +func parseNextPageURL(from linkHeader: String) -> URL? { + // Parse Link header format: ; rel="next" + let links = linkHeader.components(separatedBy: ",") + for link in links { + let components = link.components(separatedBy: ";") + guard components.count >= 2 else { continue } - let urlPart = components[0].trimmingCharacters(in: .whitespaces) - let relPart = components[1].trimmingCharacters(in: .whitespaces) + let urlPart = components[0].trimmingCharacters(in: .whitespaces) + let relPart = components[1].trimmingCharacters(in: .whitespaces) - // Check if this is the "next" link - if relPart.contains("rel=\"next\"") || relPart.contains("rel='next'") { - // Extract URL from angle brackets - let urlString = urlPart.trimmingCharacters(in: CharacterSet(charactersIn: "<>")) - if let url = URL(string: urlString) { - return url - } + // Check if this is the "next" link + if relPart.contains("rel=\"next\"") || relPart.contains("rel='next'") { + // Extract URL from angle brackets + let urlString = urlPart.trimmingCharacters(in: CharacterSet(charactersIn: "<>")) + + // Check for empty URL string to ensure consistent behavior across platforms + guard !urlString.isEmpty, let url = URL(string: urlString) else { + continue } - } - return nil + return url + } } + + return nil } diff --git a/Sources/HuggingFace/OAuth/HuggingFaceAuthenticationManager.swift b/Sources/HuggingFace/OAuth/HuggingFaceAuthenticationManager.swift index e194651..c61f83e 100644 --- a/Sources/HuggingFace/OAuth/HuggingFaceAuthenticationManager.swift +++ b/Sources/HuggingFace/OAuth/HuggingFaceAuthenticationManager.swift @@ -1,3 +1,5 @@ +import Foundation + #if canImport(AuthenticationServices) import AuthenticationServices import Observation @@ -486,13 +488,15 @@ // MARK: - -private extension URL { - /// Extracts the OAuth authorization code from a callback URL. - /// - Returns: The authorization code if found, nil otherwise. - var oauthCode: String? { - URLComponents(string: absoluteString)? - .queryItems? - .first(where: { $0.name == "code" })? - .value +#if canImport(AuthenticationServices) + private extension URL { + /// Extracts the OAuth authorization code from a callback URL. + /// - Returns: The authorization code if found, nil otherwise. + var oauthCode: String? { + URLComponents(string: absoluteString)? + .queryItems? + .first(where: { $0.name == "code" })? + .value + } } -} +#endif diff --git a/Sources/HuggingFace/OAuth/OAuthClient.swift b/Sources/HuggingFace/OAuth/OAuthClient.swift index 1bc7d33..4a5348a 100644 --- a/Sources/HuggingFace/OAuth/OAuthClient.swift +++ b/Sources/HuggingFace/OAuth/OAuthClient.swift @@ -1,368 +1,369 @@ -#if canImport(CryptoKit) - import CryptoKit - import Foundation - - #if canImport(FoundationNetworking) - import FoundationNetworking - #endif // canImport(FoundationNetworking) - - /// An OAuth 2.0 client for handling authentication flows - /// with support for token caching, refresh, and secure code exchange - /// using PKCE (Proof Key for Code Exchange). - public actor OAuthClient: Sendable { - /// The OAuth client configuration. - public let configuration: OAuthClientConfiguration - - /// The URL session to use for network requests. - let urlSession: URLSession - - private var cachedToken: OAuthToken? - private var refreshTask: Task? - private var codeVerifier: String? - - /// Initializes a new OAuth client with the specified configuration. - /// - Parameters: - /// - configuration: The OAuth configuration containing client credentials and endpoints. - /// - session: The URL session to use for network requests. Defaults to `.shared`. - public init(configuration: OAuthClientConfiguration, session: URLSession = .shared) { - self.configuration = configuration - self.urlSession = session +import Crypto +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +/// An OAuth 2.0 client for handling authentication flows +/// with support for token caching, refresh, and secure code exchange +/// using PKCE (Proof Key for Code Exchange). +public actor OAuthClient: Sendable { + /// The OAuth client configuration. + public let configuration: OAuthClientConfiguration + + /// The URL session to use for network requests. + let urlSession: URLSession + + private var cachedToken: OAuthToken? + private var refreshTask: Task? + private var codeVerifier: String? + + /// Initializes a new OAuth client with the specified configuration. + /// - Parameters: + /// - configuration: The OAuth configuration containing client credentials and endpoints. + /// - session: The URL session to use for network requests. Defaults to `.shared`. + public init(configuration: OAuthClientConfiguration, session: URLSession = .shared) { + self.configuration = configuration + self.urlSession = session + } + + /// Retrieves a valid OAuth token, using cached token if available and valid. + /// + /// This method first checks for a valid cached token. If no valid token exists and a refresh + /// is already in progress, it waits for that refresh to complete. If no refresh is in progress, + /// it throws `OAuthError.authenticationRequired` to indicate that fresh authentication is needed. + /// + /// - Returns: A valid OAuth token. + /// - Throws: `OAuthError.authenticationRequired` if no valid token is available and no refresh is in progress. + public func getValidToken() async throws -> OAuthToken { + // Return cached token if valid + if let token = cachedToken, token.isValid { + return token } - /// Retrieves a valid OAuth token, using cached token if available and valid. - /// - /// This method first checks for a valid cached token. If no valid token exists and a refresh - /// is already in progress, it waits for that refresh to complete. If no refresh is in progress, - /// it throws `OAuthError.authenticationRequired` to indicate that fresh authentication is needed. - /// - /// - Returns: A valid OAuth token. - /// - Throws: `OAuthError.authenticationRequired` if no valid token is available and no refresh is in progress. - public func getValidToken() async throws -> OAuthToken { - // Return cached token if valid - if let token = cachedToken, token.isValid { - return token - } - - // If refresh already in progress, wait for it - if let task = refreshTask { - return try await task.value - } - - // No valid token and no refresh in progress - need fresh authentication - throw OAuthError.authenticationRequired + // If refresh already in progress, wait for it + if let task = refreshTask { + return try await task.value } - /// Initiates the OAuth authentication flow using PKCE (Proof Key for Code Exchange). - /// - /// This method generates PKCE values, constructs the authorization URL, and presents - /// a web authentication session to the user. The user will be redirected to the OAuth - /// provider's authorization page where they can grant permissions. - /// - /// - Parameter handler: A closure that handles the authentication session flow. - /// - Returns: The authorization code from the OAuth callback. - /// - Throws: `OAuthError.sessionFailedToStart` if the authentication session cannot be started. - /// - Throws: `OAuthError.invalidCallback` if the callback URL is invalid or doesn't contain an authorization code. - public func authenticate(handler: @escaping (URL, String) async throws -> String) - async throws -> String - { - // Generate PKCE values - let (verifier, challenge) = Self.generatePKCEValues() - self.codeVerifier = verifier - - // Build authorization URL - let authURL = configuration.baseURL.appendingPathComponent("oauth/authorize") - var components = URLComponents(url: authURL, resolvingAgainstBaseURL: false)! - components.queryItems = [ - .init(name: "client_id", value: configuration.clientID), - .init(name: "redirect_uri", value: configuration.redirectURL.absoluteString), - .init(name: "response_type", value: "code"), - .init(name: "scope", value: configuration.scope), - .init(name: "code_challenge", value: challenge), - .init(name: "code_challenge_method", value: "S256"), - .init(name: "state", value: UUID().uuidString), - ] - - guard let finalAuthURL = components.url, - let scheme = configuration.redirectURL.scheme - else { - throw OAuthError.sessionFailedToStart - } - - return try await handler(finalAuthURL, scheme) + // No valid token and no refresh in progress - need fresh authentication + throw OAuthError.authenticationRequired + } + + /// Initiates the OAuth authentication flow using PKCE (Proof Key for Code Exchange). + /// + /// This method generates PKCE values, constructs the authorization URL, and presents + /// a web authentication session to the user. The user will be redirected to the OAuth + /// provider's authorization page where they can grant permissions. + /// + /// - Parameter handler: A closure that handles the authentication session flow. + /// - Returns: The authorization code from the OAuth callback. + /// - Throws: `OAuthError.sessionFailedToStart` if the authentication session cannot be started. + /// - Throws: `OAuthError.invalidCallback` if the callback URL is invalid or doesn't contain an authorization code. + public func authenticate(handler: @escaping (URL, String) async throws -> String) + async throws -> String + { + // Generate PKCE values + let (verifier, challenge) = Self.generatePKCEValues() + self.codeVerifier = verifier + + // Build authorization URL + let authURL = configuration.baseURL.appendingPathComponent("oauth/authorize") + var components = URLComponents(url: authURL, resolvingAgainstBaseURL: false)! + components.queryItems = [ + .init(name: "client_id", value: configuration.clientID), + .init(name: "redirect_uri", value: configuration.redirectURL.absoluteString), + .init(name: "response_type", value: "code"), + .init(name: "scope", value: configuration.scope), + .init(name: "code_challenge", value: challenge), + .init(name: "code_challenge_method", value: "S256"), + .init(name: "state", value: UUID().uuidString), + ] + + guard let finalAuthURL = components.url, + let scheme = configuration.redirectURL.scheme + else { + throw OAuthError.sessionFailedToStart } - /// Exchanges an authorization code for an OAuth token using PKCE. - /// - /// This method takes the authorization code received from the OAuth callback and exchanges - /// it for an access token and refresh token. The code verifier generated during authentication - /// is used to complete the PKCE flow for security. - /// - /// - Parameter code: The authorization code from the OAuth callback. - /// - Returns: An OAuth token containing access and refresh tokens. - /// - Throws: `OAuthError.missingCodeVerifier` if no code verifier is available. - /// - Throws: `OAuthError.tokenExchangeFailed` if the token exchange request fails. - public func exchangeCode(_ code: String) async throws -> OAuthToken { - guard let verifier = codeVerifier else { - throw OAuthError.missingCodeVerifier - } - - let tokenURL = configuration.baseURL.appendingPathComponent("oauth/token") - var request = URLRequest(url: tokenURL) - request.httpMethod = "POST" - request.setValue("application/x-www-form-urlencoded", forHTTPHeaderField: "Content-Type") - - var components = URLComponents() - components.queryItems = [ - .init(name: "grant_type", value: "authorization_code"), - .init(name: "code", value: code), - .init(name: "redirect_uri", value: configuration.redirectURL.absoluteString), - .init(name: "client_id", value: configuration.clientID), - .init(name: "code_verifier", value: verifier), - ] - request.httpBody = components.percentEncodedQuery?.data(using: .utf8) - - let (data, response) = try await urlSession.data(for: request) - - guard let httpResponse = response as? HTTPURLResponse, - (200 ... 299).contains(httpResponse.statusCode) - else { - throw OAuthError.tokenExchangeFailed - } - - let tokenResponse = try await MainActor.run { - try JSONDecoder().decode(TokenResponse.self, from: data) - } - let token = OAuthToken( - accessToken: tokenResponse.accessToken, - refreshToken: tokenResponse.refreshToken, - expiresAt: Date().addingTimeInterval(TimeInterval(tokenResponse.expiresIn)) - ) - - self.cachedToken = token - self.codeVerifier = nil + return try await handler(finalAuthURL, scheme) + } - return token + /// Exchanges an authorization code for an OAuth token using PKCE. + /// + /// This method takes the authorization code received from the OAuth callback and exchanges + /// it for an access token and refresh token. The code verifier generated during authentication + /// is used to complete the PKCE flow for security. + /// + /// - Parameter code: The authorization code from the OAuth callback. + /// - Returns: An OAuth token containing access and refresh tokens. + /// - Throws: `OAuthError.missingCodeVerifier` if no code verifier is available. + /// - Throws: `OAuthError.tokenExchangeFailed` if the token exchange request fails. + public func exchangeCode(_ code: String) async throws -> OAuthToken { + guard let verifier = codeVerifier else { + throw OAuthError.missingCodeVerifier } - /// Refreshes an OAuth token using a refresh token. - /// - /// This method prevents multiple concurrent refresh operations by tracking an active refresh task. - /// If a refresh is already in progress, it waits for that refresh to complete rather than - /// starting a new one. - /// - /// - Parameter refreshToken: The refresh token to use for obtaining a new access token. - /// - Returns: A new OAuth token with updated access and refresh tokens. - /// - Throws: `OAuthError.tokenExchangeFailed` if the refresh request fails. - public func refreshToken(using refreshToken: String) async throws -> OAuthToken { - // Start refresh task if not already running - if let task = refreshTask { - return try await task.value - } - - let task = Task { - try await performRefresh(refreshToken: refreshToken) - } - refreshTask = task - - defer { - Task { clearRefreshTask() } - } + let tokenURL = configuration.baseURL.appendingPathComponent("oauth/token") + var request = URLRequest(url: tokenURL) + request.httpMethod = "POST" + request.setValue("application/x-www-form-urlencoded", forHTTPHeaderField: "Content-Type") + + var components = URLComponents() + components.queryItems = [ + .init(name: "grant_type", value: "authorization_code"), + .init(name: "code", value: code), + .init(name: "redirect_uri", value: configuration.redirectURL.absoluteString), + .init(name: "client_id", value: configuration.clientID), + .init(name: "code_verifier", value: verifier), + ] + request.httpBody = components.percentEncodedQuery?.data(using: .utf8) + + let (data, response) = try await urlSession.data(for: request) + + guard let httpResponse = response as? HTTPURLResponse, + (200 ... 299).contains(httpResponse.statusCode) + else { + throw OAuthError.tokenExchangeFailed + } + let tokenResponse = try JSONDecoder().decode(TokenResponse.self, from: data) + let token = OAuthToken( + accessToken: tokenResponse.accessToken, + refreshToken: tokenResponse.refreshToken, + expiresAt: Date().addingTimeInterval(TimeInterval(tokenResponse.expiresIn)) + ) + + self.cachedToken = token + self.codeVerifier = nil + + return token + } + + /// Refreshes an OAuth token using a refresh token. + /// + /// This method prevents multiple concurrent refresh operations by tracking an active refresh task. + /// If a refresh is already in progress, it waits for that refresh to complete rather than + /// starting a new one. + /// + /// - Parameter refreshToken: The refresh token to use for obtaining a new access token. + /// - Returns: A new OAuth token with updated access and refresh tokens. + /// - Throws: `OAuthError.tokenExchangeFailed` if the refresh request fails. + public func refreshToken(using refreshToken: String) async throws -> OAuthToken { + // Start refresh task if not already running + if let task = refreshTask { return try await task.value } - private func clearRefreshTask() { - refreshTask = nil + let task = Task { + try await performRefresh(refreshToken: refreshToken) } + refreshTask = task - private func performRefresh(refreshToken: String) async throws -> OAuthToken { - let tokenURL = configuration.baseURL.appendingPathComponent("oauth/token") - var request = URLRequest(url: tokenURL) - request.httpMethod = "POST" - request.setValue("application/x-www-form-urlencoded", forHTTPHeaderField: "Content-Type") - - var components = URLComponents() - components.queryItems = [ - .init(name: "grant_type", value: "refresh_token"), - .init(name: "refresh_token", value: refreshToken), - .init(name: "client_id", value: configuration.clientID), - ] - request.httpBody = components.percentEncodedQuery?.data(using: .utf8) - - let (data, response) = try await urlSession.data(for: request) - - guard let httpResponse = response as? HTTPURLResponse, - (200 ... 299).contains(httpResponse.statusCode) - else { - throw OAuthError.tokenExchangeFailed - } - - let tokenResponse = try await MainActor.run { - try JSONDecoder().decode(TokenResponse.self, from: data) - } - let token = OAuthToken( - accessToken: tokenResponse.accessToken, - refreshToken: tokenResponse.refreshToken ?? refreshToken, - expiresAt: Date().addingTimeInterval(TimeInterval(tokenResponse.expiresIn)) - ) - - self.cachedToken = token - return token + defer { + Task { clearRefreshTask() } } - /// Generates PKCE code verifier and challenge values as a tuple. - /// - Returns: A tuple containing the code verifier and its corresponding challenge. - private static func generatePKCEValues() -> (verifier: String, challenge: String) { - // Generate a cryptographically secure random code verifier - var buffer = [UInt8](repeating: 0, count: 32) - _ = SecRandomCopyBytes(kSecRandomDefault, buffer.count, &buffer) - let verifier = Data(buffer).urlSafeBase64EncodedString() - .trimmingCharacters(in: .whitespaces) + return try await task.value + } - // Generate SHA256 hash of the verifier for the challenge - let data = Data(verifier.utf8) - let hashed = SHA256.hash(data: data) - let challenge = Data(hashed).urlSafeBase64EncodedString() + private func clearRefreshTask() { + refreshTask = nil + } - return (verifier, challenge) + private func performRefresh(refreshToken: String) async throws -> OAuthToken { + let tokenURL = configuration.baseURL.appendingPathComponent("oauth/token") + var request = URLRequest(url: tokenURL) + request.httpMethod = "POST" + request.setValue("application/x-www-form-urlencoded", forHTTPHeaderField: "Content-Type") + + var components = URLComponents() + components.queryItems = [ + .init(name: "grant_type", value: "refresh_token"), + .init(name: "refresh_token", value: refreshToken), + .init(name: "client_id", value: configuration.clientID), + ] + request.httpBody = components.percentEncodedQuery?.data(using: .utf8) + + let (data, response) = try await urlSession.data(for: request) + + guard let httpResponse = response as? HTTPURLResponse, + (200 ... 299).contains(httpResponse.statusCode) + else { + throw OAuthError.tokenExchangeFailed } + + let tokenResponse = try JSONDecoder().decode(TokenResponse.self, from: data) + let token = OAuthToken( + accessToken: tokenResponse.accessToken, + refreshToken: tokenResponse.refreshToken ?? refreshToken, + expiresAt: Date().addingTimeInterval(TimeInterval(tokenResponse.expiresIn)) + ) + + self.cachedToken = token + return token } - // MARK: - - - /// Configuration for OAuth authentication client - public struct OAuthClientConfiguration: Sendable { - /// The base URL for OAuth endpoints - public let baseURL: URL - - /// The redirect URL for OAuth callbacks - public let redirectURL: URL - - /// The OAuth client ID - public let clientID: String - - /// The scopes for OAuth requests as a space-separated string - public let scope: String - - /// Initializes a new OAuth configuration with the specified parameters. - /// - Parameters: - /// - baseURL: The base URL for OAuth endpoints. - /// - redirectURL: The redirect URL for OAuth callbacks. - /// - clientID: The OAuth client ID. - /// - scope: The scopes for OAuth requests. - public init( - baseURL: URL, - redirectURL: URL, - clientID: String, - scope: String - ) { - self.baseURL = baseURL - self.redirectURL = redirectURL - self.clientID = clientID - self.scope = scope - } + /// Generates PKCE code verifier and challenge values as a tuple. + /// - Returns: A tuple containing the code verifier and its corresponding challenge. + private static func generatePKCEValues() -> (verifier: String, challenge: String) { + // Generate a cryptographically secure random code verifier + var buffer = [UInt8](repeating: 0, count: 32) + #if os(macOS) || os(iOS) || os(watchOS) || os(tvOS) + _ = SecRandomCopyBytes(kSecRandomDefault, buffer.count, &buffer) + #else + // This should be cryptographically secure, see: https://forums.swift.org/t/random-data-uint8-random-or-secrandomcopybytes/56165/9 + var generator = SystemRandomNumberGenerator() + buffer = buffer.map { _ in UInt8.random(in: 0 ... 255, using: &generator) } + #endif + + let verifier = Data(buffer).urlSafeBase64EncodedString() + .trimmingCharacters(in: .whitespaces) + + // Generate SHA256 hash of the verifier for the challenge + let data = Data(verifier.utf8) + let hashed = SHA256.hash(data: data) + let challenge = Data(hashed).urlSafeBase64EncodedString() + + return (verifier, challenge) } +} + +// MARK: - + +/// Configuration for OAuth authentication client +public struct OAuthClientConfiguration: Sendable { + /// The base URL for OAuth endpoints + public let baseURL: URL + + /// The redirect URL for OAuth callbacks + public let redirectURL: URL + + /// The OAuth client ID + public let clientID: String + + /// The scopes for OAuth requests as a space-separated string + public let scope: String + + /// Initializes a new OAuth configuration with the specified parameters. + /// - Parameters: + /// - baseURL: The base URL for OAuth endpoints. + /// - redirectURL: The redirect URL for OAuth callbacks. + /// - clientID: The OAuth client ID. + /// - scope: The scopes for OAuth requests. + public init( + baseURL: URL, + redirectURL: URL, + clientID: String, + scope: String + ) { + self.baseURL = baseURL + self.redirectURL = redirectURL + self.clientID = clientID + self.scope = scope + } +} - // MARK: - +// MARK: - - /// OAuth token containing access and refresh tokens - public struct OAuthToken: Sendable, Codable { - /// The access token - public let accessToken: String +/// OAuth token containing access and refresh tokens +public struct OAuthToken: Sendable, Codable { + /// The access token + public let accessToken: String - /// The refresh token - public let refreshToken: String? + /// The refresh token + public let refreshToken: String? - /// The expiration date of the token - public let expiresAt: Date + /// The expiration date of the token + public let expiresAt: Date - /// Whether the token is valid - public var isValid: Bool { - Date() < expiresAt.addingTimeInterval(-300) // 5 min buffer - } + /// Whether the token is valid + public var isValid: Bool { + Date() < expiresAt.addingTimeInterval(-300) // 5 min buffer + } - /// Initializes a new OAuth token with the specified parameters. - /// - Parameters: - /// - accessToken: The access token. - /// - refreshToken: The refresh token. - /// - expiresAt: The expiration date of the token. - public init(accessToken: String, refreshToken: String?, expiresAt: Date) { - self.accessToken = accessToken - self.refreshToken = refreshToken - self.expiresAt = expiresAt - } + /// Initializes a new OAuth token with the specified parameters. + /// - Parameters: + /// - accessToken: The access token. + /// - refreshToken: The refresh token. + /// - expiresAt: The expiration date of the token. + public init(accessToken: String, refreshToken: String?, expiresAt: Date) { + self.accessToken = accessToken + self.refreshToken = refreshToken + self.expiresAt = expiresAt + } - enum CodingKeys: String, CodingKey { - case accessToken = "access_token" - case refreshToken = "refresh_token" - case expiresAt = "expires_at" - } + enum CodingKeys: String, CodingKey { + case accessToken = "access_token" + case refreshToken = "refresh_token" + case expiresAt = "expires_at" } +} - /// OAuth error enum - public enum OAuthError: LocalizedError, Equatable, Sendable { - /// Authentication required - case authenticationRequired +/// OAuth error enum +public enum OAuthError: LocalizedError, Equatable, Sendable { + /// Authentication required + case authenticationRequired - /// Invalid callback - case invalidCallback + /// Invalid callback + case invalidCallback - /// Session failed to start - case sessionFailedToStart + /// Session failed to start + case sessionFailedToStart - /// Missing code verifier - case missingCodeVerifier + /// Missing code verifier + case missingCodeVerifier - /// Token exchange failed - case tokenExchangeFailed + /// Token exchange failed + case tokenExchangeFailed - /// Token storage error - case tokenStorageError(String) + /// Token storage error + case tokenStorageError(String) - /// Invalid configuration - case invalidConfiguration(String) + /// Invalid configuration + case invalidConfiguration(String) - /// The error description - public var errorDescription: String? { - switch self { - case .authenticationRequired: return "Authentication required" - case .invalidCallback: return "Invalid callback" - case .sessionFailedToStart: return "Session failed to start" - case .missingCodeVerifier: return "Missing code verifier" - case .tokenExchangeFailed: return "Token exchange failed" - case .tokenStorageError(let error): return "Token storage error: \(error)" - case .invalidConfiguration(let error): return "Invalid configuration: \(error)" - } + /// The error description + public var errorDescription: String? { + switch self { + case .authenticationRequired: return "Authentication required" + case .invalidCallback: return "Invalid callback" + case .sessionFailedToStart: return "Session failed to start" + case .missingCodeVerifier: return "Missing code verifier" + case .tokenExchangeFailed: return "Token exchange failed" + case .tokenStorageError(let error): return "Token storage error: \(error)" + case .invalidConfiguration(let error): return "Invalid configuration: \(error)" } } - - private struct TokenResponse: Sendable, Codable { - let accessToken: String - let refreshToken: String? - let expiresIn: Int - let tokenType: String - - enum CodingKeys: String, CodingKey { - case accessToken = "access_token" - case refreshToken = "refresh_token" - case expiresIn = "expires_in" - case tokenType = "token_type" - } +} + +private struct TokenResponse: Sendable, Codable { + let accessToken: String + let refreshToken: String? + let expiresIn: Int + let tokenType: String + + enum CodingKeys: String, CodingKey { + case accessToken = "access_token" + case refreshToken = "refresh_token" + case expiresIn = "expires_in" + case tokenType = "token_type" } - - // MARK: - - - private extension Data { - /// Returns a URL-safe Base64 encoded string suitable for use in URLs and OAuth flows. - /// - /// This method applies the standard Base64 encoding and then replaces characters - /// that are not URL-safe (+ becomes -, / becomes _, = padding is removed). - /// - Returns: A URL-safe Base64 encoded string. - func urlSafeBase64EncodedString() -> String { - base64EncodedString() - .replacingOccurrences(of: "+", with: "-") - .replacingOccurrences(of: "/", with: "_") - .replacingOccurrences(of: "=", with: "") - } +} + +// MARK: - + +private extension Data { + /// Returns a URL-safe Base64 encoded string suitable for use in URLs and OAuth flows. + /// + /// This method applies the standard Base64 encoding and then replaces characters + /// that are not URL-safe (+ becomes -, / becomes _, = padding is removed). + /// - Returns: A URL-safe Base64 encoded string. + func urlSafeBase64EncodedString() -> String { + base64EncodedString() + .replacingOccurrences(of: "+", with: "-") + .replacingOccurrences(of: "/", with: "_") + .replacingOccurrences(of: "=", with: "") } -#endif // canImport(CryptoKit) +} diff --git a/Sources/HuggingFace/OAuth/TokenStorage.swift b/Sources/HuggingFace/OAuth/TokenStorage.swift new file mode 100644 index 0000000..71fd3e7 --- /dev/null +++ b/Sources/HuggingFace/OAuth/TokenStorage.swift @@ -0,0 +1,194 @@ +import Foundation + +/// A cross-platform mechanism for storing and retrieving OAuth tokens. +/// +/// This provides a file-based storage implementation that works on all platforms, +/// including Linux. For Apple platforms, the `HuggingFaceAuthenticationManager` +/// provides keychain-based storage through its own `TokenStorage` type. +/// +/// Example usage: +/// ```swift +/// let storage = FileTokenStorage.default +/// try storage.store(token) +/// let retrieved = try storage.retrieve() +/// ``` +public struct FileTokenStorage: Sendable { + private let fileURL: URL + + /// Creates a new file-based token storage at the specified URL. + /// - Parameter fileURL: The URL where tokens will be stored. + public init(fileURL: URL) { + self.fileURL = fileURL + } + + /// The default token storage location. + /// + /// On Linux/Unix: `~/.cache/huggingface/token.json` + /// On macOS: `~/Library/Caches/huggingface/token.json` + public static var `default`: FileTokenStorage { + let cacheDir: URL + #if os(macOS) || os(iOS) || os(tvOS) || os(watchOS) || os(visionOS) + cacheDir = + FileManager.default.urls(for: .cachesDirectory, in: .userDomainMask).first + ?? FileManager.default.temporaryDirectory + #else + // Linux/Unix: Use XDG_CACHE_HOME or ~/.cache + if let xdgCache = ProcessInfo.processInfo.environment["XDG_CACHE_HOME"] { + cacheDir = URL(fileURLWithPath: xdgCache) + } else { + let home = + ProcessInfo.processInfo.environment["HOME"] + ?? NSHomeDirectory() + cacheDir = URL(fileURLWithPath: home).appendingPathComponent(".cache") + } + #endif + + let tokenDir = cacheDir.appendingPathComponent("huggingface") + let tokenFile = tokenDir.appendingPathComponent("token.json") + + return FileTokenStorage(fileURL: tokenFile) + } + + /// Stores an OAuth token to the file. + /// - Parameter token: The token to store. + /// - Throws: An error if the token cannot be encoded or written. + public func store(_ token: OAuthToken) throws { + // Create directory if needed + let directory = fileURL.deletingLastPathComponent() + try FileManager.default.createDirectory( + at: directory, + withIntermediateDirectories: true + ) + + // Encode and write + let encoder = JSONEncoder() + encoder.outputFormatting = .prettyPrinted + let data = try encoder.encode(token) + try data.write(to: fileURL, options: .atomic) + + // Set file permissions to owner-only (0600) on Unix systems + #if !os(Windows) + try FileManager.default.setAttributes( + [.posixPermissions: 0o600], + ofItemAtPath: fileURL.path + ) + #endif + } + + /// Retrieves the stored OAuth token. + /// - Returns: The stored token, or `nil` if no token is stored. + /// - Throws: An error if the token file exists but cannot be read or decoded. + public func retrieve() throws -> OAuthToken? { + guard FileManager.default.fileExists(atPath: fileURL.path) else { + return nil + } + + let data = try Data(contentsOf: fileURL) + let decoder = JSONDecoder() + return try decoder.decode(OAuthToken.self, from: data) + } + + /// Deletes the stored OAuth token. + /// - Throws: An error if the token file exists but cannot be deleted. + public func delete() throws { + guard FileManager.default.fileExists(atPath: fileURL.path) else { + return + } + try FileManager.default.removeItem(at: fileURL) + } + + /// Whether a token is currently stored. + public var hasStoredToken: Bool { + FileManager.default.fileExists(atPath: fileURL.path) + } +} + +// MARK: - Environment Token Storage + +/// A simple token storage that reads from an environment variable. +/// +/// This is useful for server-side applications and CI/CD environments +/// where tokens are provided via environment variables. +public struct EnvironmentTokenStorage: Sendable { + private let variableName: String + + /// Creates a new environment token storage. + /// - Parameter variableName: The name of the environment variable containing the token. + /// Defaults to `HF_TOKEN`. + public init(variableName: String = "HF_TOKEN") { + self.variableName = variableName + } + + /// Retrieves the token from the environment variable. + /// - Returns: An OAuth token with the access token from the environment, or `nil` if not set. + public func retrieve() -> OAuthToken? { + guard let token = ProcessInfo.processInfo.environment[variableName], + !token.isEmpty + else { + return nil + } + + // Environment tokens don't expire and don't have refresh tokens + return OAuthToken( + accessToken: token, + refreshToken: nil, + expiresAt: Date.distantFuture + ) + } +} + +// MARK: - Composite Token Storage + +/// A token storage that tries multiple storage backends in order. +/// +/// This is useful for applications that want to support multiple token sources, +/// such as checking environment variables first, then falling back to file storage. +public struct CompositeTokenStorage: Sendable { + private let storages: [@Sendable () throws -> OAuthToken?] + private let primaryStorage: FileTokenStorage? + + /// Creates a composite token storage with the specified backends. + /// - Parameters: + /// - environment: Whether to check environment variables first. + /// - file: The file storage to use, or `nil` to skip file storage. + public init( + environment: Bool = true, + file: FileTokenStorage? = .default + ) { + var storages: [@Sendable () throws -> OAuthToken?] = [] + + if environment { + let envStorage = EnvironmentTokenStorage() + storages.append { envStorage.retrieve() } + } + + if let file = file { + storages.append { try file.retrieve() } + } + + self.storages = storages + self.primaryStorage = file + } + + /// Retrieves a token from the first storage that has one. + /// - Returns: The first available token, or `nil` if none found. + public func retrieve() throws -> OAuthToken? { + for storage in storages { + if let token = try storage() { + return token + } + } + return nil + } + + /// Stores a token to the primary (file) storage. + /// - Parameter token: The token to store. + public func store(_ token: OAuthToken) throws { + try primaryStorage?.store(token) + } + + /// Deletes the token from the primary (file) storage. + public func delete() throws { + try primaryStorage?.delete() + } +} diff --git a/Sources/HuggingFace/Shared/HTTPClient.swift b/Sources/HuggingFace/Shared/HTTPClient.swift index 43754b2..2a3ca5e 100644 --- a/Sources/HuggingFace/Shared/HTTPClient.swift +++ b/Sources/HuggingFace/Shared/HTTPClient.swift @@ -72,6 +72,7 @@ final class HTTPClient: @unchecked Sendable { headers: [String: String]? = nil ) async throws -> T { let request = try await createRequest(method, path, params: params, headers: headers) + return try await performFetch(request: request) } @@ -118,7 +119,7 @@ final class HTTPClient: @unchecked Sendable { do { let items = try jsonDecoder.decode([T].self, from: data) - let nextURL = httpResponse.nextPageURL() + let nextURL = parseNextPageURL(from: httpResponse) return PaginatedResponse(items: items, nextURL: nextURL) } catch { throw HTTPClientError.decodingError( @@ -164,41 +165,87 @@ final class HTTPClient: @unchecked Sendable { let task = Task { do { let request = try await requestBuilder() - let (bytes, response) = try await session.bytes(for: request) - let httpResponse = try validateResponse(response) - guard (200 ..< 300).contains(httpResponse.statusCode) else { - var errorData = Data() - for try await byte in bytes { - errorData.append(byte) + #if canImport(FoundationNetworking) + // Linux: Use buffered approach since true streaming is not available + let (data, response) = try await session.data(for: request) + let httpResponse = try validateResponse(response, data: data) + + guard (200 ..< 300).contains(httpResponse.statusCode) else { + return } - // validateResponse will throw the appropriate error - _ = try validateResponse(response, data: errorData) - return // This line will never be reached, but satisfies the compiler - } - - for try await event in bytes.events { - // Check for [DONE] signal - if event.data.trimmingCharacters(in: .whitespacesAndNewlines) == "[DONE]" { + + // Parse SSE events from the buffered response + guard let responseString = String(data: data, encoding: .utf8) else { continuation.finish() return } - guard let jsonData = event.data.data(using: .utf8) else { - continue + for line in responseString.components(separatedBy: "\n") { + let trimmed = line.trimmingCharacters(in: .whitespaces) + guard trimmed.hasPrefix("data:") else { continue } + + let eventData = String(trimmed.dropFirst(5)).trimmingCharacters( + in: .whitespaces + ) + + // Check for [DONE] signal + if eventData == "[DONE]" { + continuation.finish() + return + } + + guard let jsonData = eventData.data(using: .utf8) else { + continue + } + + do { + let decoded = try jsonDecoder.decode(T.self, from: jsonData) + continuation.yield(decoded) + } catch { + print("Warning: Failed to decode streaming response chunk: \(error)") + } + } + + continuation.finish() + #else + // Apple platforms: Use native streaming APIs + let (bytes, response) = try await session.bytes(for: request) + let httpResponse = try validateResponse(response) + + guard (200 ..< 300).contains(httpResponse.statusCode) else { + var errorData = Data() + for try await byte in bytes { + errorData.append(byte) + } + // validateResponse will throw the appropriate error + _ = try validateResponse(response, data: errorData) + return // This line will never be reached, but satisfies the compiler } - do { - let decoded = try jsonDecoder.decode(T.self, from: jsonData) - continuation.yield(decoded) - } catch { - // Log decoding errors but don't fail the stream - // This allows the stream to continue even if individual chunks fail - print("Warning: Failed to decode streaming response chunk: \(error)") + for try await event in bytes.events { + // Check for [DONE] signal + if event.data.trimmingCharacters(in: .whitespacesAndNewlines) == "[DONE]" { + continuation.finish() + return + } + + guard let jsonData = event.data.data(using: .utf8) else { + continue + } + + do { + let decoded = try jsonDecoder.decode(T.self, from: jsonData) + continuation.yield(decoded) + } catch { + // Log decoding errors but don't fail the stream + // This allows the stream to continue even if individual chunks fail + print("Warning: Failed to decode streaming response chunk: \(error)") + } } - } - continuation.finish() + continuation.finish() + #endif } catch { continuation.finish(throwing: error) } diff --git a/Sources/HuggingFace/Shared/URLSession+Linux.swift b/Sources/HuggingFace/Shared/URLSession+Linux.swift new file mode 100644 index 0000000..3364494 --- /dev/null +++ b/Sources/HuggingFace/Shared/URLSession+Linux.swift @@ -0,0 +1,239 @@ +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking + + // MARK: - URLSession Async Extensions for Linux + + /// Provides async/await wrappers for URLSession APIs that are missing on Linux. + /// These extensions bridge the callback-based APIs to Swift's concurrency model. + extension URLSession { + /// Performs an HTTP request and returns the response data. + /// + /// This is a compatibility shim for Linux where the native async `data(for:)` may not be available. + /// + /// - Parameter request: The URL request to perform. + /// - Returns: A tuple containing the response data and URL response. + func data(for request: URLRequest) async throws -> (Data, URLResponse) { + try await withCheckedThrowingContinuation { continuation in + let task = self.dataTask(with: request) { data, response, error in + if let error = error { + continuation.resume(throwing: error) + return + } + guard let data = data, let response = response else { + continuation.resume( + throwing: URLError(.badServerResponse) + ) + return + } + continuation.resume(returning: (data, response)) + } + task.resume() + } + } + + /// Uploads data to a URL and returns the response. + /// + /// - Parameters: + /// - request: The URL request to perform. + /// - data: The data to upload. + /// - Returns: A tuple containing the response data and URL response. + func upload(for request: URLRequest, from data: Data) async throws -> (Data, URLResponse) { + try await withCheckedThrowingContinuation { continuation in + let task = self.uploadTask(with: request, from: data) { data, response, error in + if let error = error { + continuation.resume(throwing: error) + return + } + guard let data = data, let response = response else { + continuation.resume( + throwing: URLError(.badServerResponse) + ) + return + } + continuation.resume(returning: (data, response)) + } + task.resume() + } + } + + /// Uploads a file to a URL and returns the response. + /// + /// - Parameters: + /// - request: The URL request to perform. + /// - fileURL: The URL of the file to upload. + /// - Returns: A tuple containing the response data and URL response. + func upload(for request: URLRequest, fromFile fileURL: URL) async throws -> (Data, URLResponse) { + try await withCheckedThrowingContinuation { continuation in + let task = self.uploadTask(with: request, fromFile: fileURL) { data, response, error in + if let error = error { + continuation.resume(throwing: error) + return + } + guard let data = data, let response = response else { + continuation.resume( + throwing: URLError(.badServerResponse) + ) + return + } + continuation.resume(returning: (data, response)) + } + task.resume() + } + } + + /// Downloads a file from a URL to a temporary location. + /// + /// - Parameters: + /// - request: The URL request to perform. + /// - progress: Optional progress object to track download progress. + /// - Returns: A tuple containing the temporary file URL and URL response. + func asyncDownload( + for request: URLRequest, + progress: Progress? = nil + ) async throws -> (URL, URLResponse) { + try await withCheckedThrowingContinuation { continuation in + let delegate = progress.map { LinuxDownloadDelegate(progress: $0, continuation: continuation) } + + if let delegate = delegate { + // Use delegate-based download for progress tracking + let session = URLSession( + configuration: self.configuration, + delegate: delegate, + delegateQueue: nil + ) + let task = session.downloadTask(with: request) + delegate.task = task + task.resume() + } else { + // Simple download without progress + let task = self.downloadTask(with: request) { url, response, error in + if let error = error { + continuation.resume(throwing: error) + return + } + guard let tempURL = url, let response = response else { + continuation.resume(throwing: URLError(.badServerResponse)) + return + } + // Copy to a new temp location since the original will be deleted + let newTempURL = FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString) + do { + try FileManager.default.copyItem(at: tempURL, to: newTempURL) + continuation.resume(returning: (newTempURL, response)) + } catch { + continuation.resume(throwing: error) + } + } + task.resume() + } + } + } + + /// Streams bytes from a URL request. + /// + /// This provides a simplified streaming-like interface for Linux where `bytes(for:)` is not available. + /// + /// - Important: This implementation **buffers the entire response in memory** before streaming bytes. + /// It is **not** true streaming and is **not suitable for large responses or long‑lived streams**, + /// as it may cause excessive memory usage. + /// For true streaming on Linux, consider using a different HTTP client library. + /// + /// - Parameter request: The URL request to perform. + /// - Returns: A tuple containing the response bytes and URL response. + func asyncBytes(for request: URLRequest) async throws -> (LinuxAsyncBytes, URLResponse) { + let (data, response) = try await data(for: request) + return (LinuxAsyncBytes(data: data), response) + } + } + + // MARK: - Linux Download Delegate + + /// A delegate for tracking download progress on Linux. + private final class LinuxDownloadDelegate: NSObject, URLSessionDownloadDelegate, @unchecked Sendable { + let progress: Progress + let continuation: CheckedContinuation<(URL, URLResponse), Error> + var task: URLSessionDownloadTask? + private var hasResumed = false + + init(progress: Progress, continuation: CheckedContinuation<(URL, URLResponse), Error>) { + self.progress = progress + self.continuation = continuation + } + + func urlSession( + _ session: URLSession, + downloadTask: URLSessionDownloadTask, + didWriteData bytesWritten: Int64, + totalBytesWritten: Int64, + totalBytesExpectedToWrite: Int64 + ) { + progress.totalUnitCount = totalBytesExpectedToWrite + progress.completedUnitCount = totalBytesWritten + } + + func urlSession( + _ session: URLSession, + downloadTask: URLSessionDownloadTask, + didFinishDownloadingTo location: URL + ) { + guard !hasResumed else { return } + hasResumed = true + + guard let response = downloadTask.response else { + continuation.resume(throwing: URLError(.badServerResponse)) + return + } + + // Copy to a new temp location since the original will be deleted + let newTempURL = FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString) + do { + try FileManager.default.copyItem(at: location, to: newTempURL) + continuation.resume(returning: (newTempURL, response)) + } catch { + continuation.resume(throwing: error) + } + } + + func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { + guard !hasResumed else { return } + hasResumed = true + + if let error = error { + continuation.resume(throwing: error) + } + // If no error and not resumed, the download delegate method should have handled it + } + } + + // MARK: - Linux Async Bytes + + /// A simple async sequence wrapper for bytes on Linux. + /// This is a simplified implementation that works with pre-loaded data. + struct LinuxAsyncBytes: AsyncSequence, Sendable { + typealias Element = UInt8 + + let data: Data + + struct AsyncIterator: AsyncIteratorProtocol { + var index: Data.Index + let endIndex: Data.Index + let data: Data + + mutating func next() async -> UInt8? { + guard index < endIndex else { return nil } + let byte = data[index] + index = data.index(after: index) + return byte + } + } + + func makeAsyncIterator() -> AsyncIterator { + AsyncIterator(index: data.startIndex, endIndex: data.endIndex, data: data) + } + } + +#endif diff --git a/Tests/HuggingFaceTests/Helpers/MockURLProtocol.swift b/Tests/HuggingFaceTests/Helpers/MockURLProtocol.swift index 64afc1b..3e7b477 100644 --- a/Tests/HuggingFaceTests/Helpers/MockURLProtocol.swift +++ b/Tests/HuggingFaceTests/Helpers/MockURLProtocol.swift @@ -1,60 +1,61 @@ import Foundation import Testing +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + @testable import HuggingFace // MARK: - Request Handler Storage /// Stores and manages handlers for MockURLProtocol's request handling. -private actor RequestHandlerStorage { - private var requestHandler: (@Sendable (URLRequest) async throws -> (HTTPURLResponse, Data))? - private var isLocked = false +private final class RequestHandlerStorage: @unchecked Sendable { + private let lock = NSLock() + private var requestHandler: (@Sendable (URLRequest) throws -> (HTTPURLResponse, Data))? func setHandler( - _ handler: @Sendable @escaping (URLRequest) async throws -> (HTTPURLResponse, Data) - ) async { - // Wait for any existing handler to be released - while isLocked { - try? await Task.sleep(for: .milliseconds(10)) - } + _ handler: @Sendable @escaping (URLRequest) throws -> (HTTPURLResponse, Data) + ) { + lock.lock() requestHandler = handler - isLocked = true + lock.unlock() } - func clearHandler() async { + func clearHandler() { + lock.lock() requestHandler = nil - isLocked = false + lock.unlock() } - func executeHandler(for request: URLRequest) async throws -> (HTTPURLResponse, Data) { - guard let handler = requestHandler else { + func executeHandler(for request: URLRequest) throws -> (HTTPURLResponse, Data) { + lock.lock() + let handler = requestHandler + lock.unlock() + + guard let handler else { throw NSError( domain: "MockURLProtocolError", code: 0, userInfo: [NSLocalizedDescriptionKey: "No request handler set"] ) } - return try await handler(request) + return try handler(request) } } // MARK: - Mock URL Protocol /// Custom URLProtocol for testing network requests -final class MockURLProtocol: URLProtocol, @unchecked Sendable { +final class MockURLProtocol: URLProtocol { /// Storage for request handlers fileprivate static let requestHandlerStorage = RequestHandlerStorage() /// Set a handler to process mock requests static func setHandler( - _ handler: @Sendable @escaping (URLRequest) async throws -> (HTTPURLResponse, Data) + _ handler: @Sendable @escaping (URLRequest) throws -> (HTTPURLResponse, Data) ) async { - await requestHandlerStorage.setHandler(handler) - } - - /// Execute the stored handler for a request - func executeHandler(for request: URLRequest) async throws -> (HTTPURLResponse, Data) { - return try await Self.requestHandlerStorage.executeHandler(for: request) + requestHandlerStorage.setHandler(handler) } override class func canInit(with request: URLRequest) -> Bool { @@ -66,15 +67,17 @@ final class MockURLProtocol: URLProtocol, @unchecked Sendable { } override func startLoading() { - Task { - do { - let (response, data) = try await self.executeHandler(for: request) - client?.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) - client?.urlProtocol(self, didLoad: data) - client?.urlProtocolDidFinishLoading(self) - } catch { - client?.urlProtocol(self, didFailWithError: error) - } + do { + let (response, data) = try Self.requestHandlerStorage.executeHandler(for: self.request) + self.client?.urlProtocol( + self, + didReceive: response, + cacheStoragePolicy: .notAllowed + ) + self.client?.urlProtocol(self, didLoad: data) + self.client?.urlProtocolDidFinishLoading(self) + } catch { + self.client?.urlProtocol(self, didFailWithError: error) } } @@ -90,36 +93,41 @@ final class MockURLProtocol: URLProtocol, @unchecked Sendable { /// /// Provides mutual exclusion across async test execution to prevent /// interference between parallel test suites using shared mock handlers. - /// - /// Note: We can't use `NSLock` or `OSAllocatedUnfairLock` here because: - /// - They're synchronous locks designed for very short critical sections - /// - They block threads (bad for Swift concurrency's cooperative thread pool) - /// - They can't be held across suspension points (await calls) - /// - /// An actor-based lock is idiomatic for Swift's async/await model. private actor MockURLProtocolLock { static let shared = MockURLProtocolLock() + private var waiters: [CheckedContinuation] = [] private var isLocked = false private init() {} - func withLock(_ operation: @Sendable () async throws -> T) async rethrows -> T { - // Wait for lock to be available - while isLocked { - try? await Task.sleep(for: .milliseconds(10)) + func acquire() async { + if isLocked { + await withCheckedContinuation { continuation in + waiters.append(continuation) + } + } else { + isLocked = true } + } - // Acquire lock - isLocked = true + func release() { + if let next = waiters.first { + waiters.removeFirst() + next.resume() + } else { + isLocked = false + } + } - // Execute operation and ensure lock is released even on error + func withLock(_ operation: @Sendable () async throws -> T) async rethrows -> T { + await acquire() do { let result = try await operation() - isLocked = false + release() return result } catch { - isLocked = false + release() throw error } } @@ -135,13 +143,13 @@ final class MockURLProtocol: URLProtocol, @unchecked Sendable { // Serialize all MockURLProtocol tests to prevent interference try await MockURLProtocolLock.shared.withLock { // Clear handler before test - await MockURLProtocol.requestHandlerStorage.clearHandler() + MockURLProtocol.requestHandlerStorage.clearHandler() // Execute the test try await function() // Clear handler after test - await MockURLProtocol.requestHandlerStorage.clearHandler() + MockURLProtocol.requestHandlerStorage.clearHandler() } } } diff --git a/Tests/HuggingFaceTests/HubTests/CacheLocationProviderTests.swift b/Tests/HuggingFaceTests/HubTests/CacheLocationProviderTests.swift index 8277ad9..465f85d 100644 --- a/Tests/HuggingFaceTests/HubTests/CacheLocationProviderTests.swift +++ b/Tests/HuggingFaceTests/HubTests/CacheLocationProviderTests.swift @@ -1,4 +1,8 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace diff --git a/Tests/HuggingFaceTests/HubTests/CollectionTests.swift b/Tests/HuggingFaceTests/HubTests/CollectionTests.swift index 20b5bb4..9a76fc5 100644 --- a/Tests/HuggingFaceTests/HubTests/CollectionTests.swift +++ b/Tests/HuggingFaceTests/HubTests/CollectionTests.swift @@ -1,6 +1,14 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + @testable import HuggingFace #if swift(>=6.1) diff --git a/Tests/HuggingFaceTests/HubTests/DatasetTests.swift b/Tests/HuggingFaceTests/HubTests/DatasetTests.swift index 3a6c93c..5b0d505 100644 --- a/Tests/HuggingFaceTests/HubTests/DatasetTests.swift +++ b/Tests/HuggingFaceTests/HubTests/DatasetTests.swift @@ -1,4 +1,8 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace diff --git a/Tests/HuggingFaceTests/HubTests/DiscussionTests.swift b/Tests/HuggingFaceTests/HubTests/DiscussionTests.swift index 6c62ba6..c5819c2 100644 --- a/Tests/HuggingFaceTests/HubTests/DiscussionTests.swift +++ b/Tests/HuggingFaceTests/HubTests/DiscussionTests.swift @@ -1,6 +1,14 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + @testable import HuggingFace #if swift(>=6.1) diff --git a/Tests/HuggingFaceTests/HubTests/FileLockTests.swift b/Tests/HuggingFaceTests/HubTests/FileLockTests.swift index 29ac262..5af6cf7 100644 --- a/Tests/HuggingFaceTests/HubTests/FileLockTests.swift +++ b/Tests/HuggingFaceTests/HubTests/FileLockTests.swift @@ -1,4 +1,8 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace diff --git a/Tests/HuggingFaceTests/HubTests/FileOperationsTests.swift b/Tests/HuggingFaceTests/HubTests/FileOperationsTests.swift index d5d4150..81fc688 100644 --- a/Tests/HuggingFaceTests/HubTests/FileOperationsTests.swift +++ b/Tests/HuggingFaceTests/HubTests/FileOperationsTests.swift @@ -1,4 +1,8 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace diff --git a/Tests/HuggingFaceTests/HubTests/GatedModeTests.swift b/Tests/HuggingFaceTests/HubTests/GatedModeTests.swift index 5d6e412..94bd500 100644 --- a/Tests/HuggingFaceTests/HubTests/GatedModeTests.swift +++ b/Tests/HuggingFaceTests/HubTests/GatedModeTests.swift @@ -1,4 +1,8 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace diff --git a/Tests/HuggingFaceTests/HubTests/GitTests.swift b/Tests/HuggingFaceTests/HubTests/GitTests.swift index a273461..276b1a8 100644 --- a/Tests/HuggingFaceTests/HubTests/GitTests.swift +++ b/Tests/HuggingFaceTests/HubTests/GitTests.swift @@ -1,5 +1,9 @@ import Testing import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif @testable import HuggingFace @Suite("Git Tests") diff --git a/Tests/HuggingFaceTests/HubTests/HubCacheTests.swift b/Tests/HuggingFaceTests/HubTests/HubCacheTests.swift index a8aad4a..8380159 100644 --- a/Tests/HuggingFaceTests/HubTests/HubCacheTests.swift +++ b/Tests/HuggingFaceTests/HubTests/HubCacheTests.swift @@ -1,4 +1,8 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace diff --git a/Tests/HuggingFaceTests/HubTests/HubClientTests.swift b/Tests/HuggingFaceTests/HubTests/HubClientTests.swift index b691ac4..8d55afd 100644 --- a/Tests/HuggingFaceTests/HubTests/HubClientTests.swift +++ b/Tests/HuggingFaceTests/HubTests/HubClientTests.swift @@ -1,4 +1,8 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace diff --git a/Tests/HuggingFaceTests/HubTests/ModelTests.swift b/Tests/HuggingFaceTests/HubTests/ModelTests.swift index b047ef4..38c36d2 100644 --- a/Tests/HuggingFaceTests/HubTests/ModelTests.swift +++ b/Tests/HuggingFaceTests/HubTests/ModelTests.swift @@ -1,4 +1,8 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace diff --git a/Tests/HuggingFaceTests/HubTests/OrganizationTests.swift b/Tests/HuggingFaceTests/HubTests/OrganizationTests.swift index 873b865..6bf92d9 100644 --- a/Tests/HuggingFaceTests/HubTests/OrganizationTests.swift +++ b/Tests/HuggingFaceTests/HubTests/OrganizationTests.swift @@ -1,4 +1,8 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace diff --git a/Tests/HuggingFaceTests/HubTests/PaginationTests.swift b/Tests/HuggingFaceTests/HubTests/PaginationTests.swift index 206199b..ccb1bbc 100644 --- a/Tests/HuggingFaceTests/HubTests/PaginationTests.swift +++ b/Tests/HuggingFaceTests/HubTests/PaginationTests.swift @@ -37,7 +37,7 @@ struct PaginationTests { linkHeader: "; rel=\"next\"" ) - let nextURL = response.nextPageURL() + let nextURL = parseNextPageURL(from: response) #expect(nextURL != nil) #expect(nextURL?.absoluteString == "https://huggingface.co/api/models?limit=10&skip=10") @@ -49,7 +49,7 @@ struct PaginationTests { linkHeader: "; rel='next'" ) - let nextURL = response.nextPageURL() + let nextURL = parseNextPageURL(from: response) #expect(nextURL != nil) #expect(nextURL?.absoluteString == "https://huggingface.co/api/page2") @@ -62,7 +62,7 @@ struct PaginationTests { "; rel=\"prev\", ; rel=\"next\"" ) - let nextURL = response.nextPageURL() + let nextURL = parseNextPageURL(from: response) #expect(nextURL != nil) #expect(nextURL?.absoluteString == "https://huggingface.co/api/page3") @@ -74,7 +74,7 @@ struct PaginationTests { linkHeader: " ; rel=\"next\" " ) - let nextURL = response.nextPageURL() + let nextURL = parseNextPageURL(from: response) #expect(nextURL != nil) #expect(nextURL?.absoluteString == "https://huggingface.co/api/page2") @@ -87,7 +87,7 @@ struct PaginationTests { "; rel=\"next\"" ) - let nextURL = response.nextPageURL() + let nextURL = parseNextPageURL(from: response) #expect(nextURL != nil) #expect( @@ -100,7 +100,7 @@ struct PaginationTests { func testMissingLinkHeader() { let response = makeHTTPResponse(linkHeader: nil) - let nextURL = response.nextPageURL() + let nextURL = parseNextPageURL(from: response) #expect(nextURL == nil) } @@ -109,7 +109,7 @@ struct PaginationTests { func testEmptyLinkHeader() { let response = makeHTTPResponse(linkHeader: "") - let nextURL = response.nextPageURL() + let nextURL = parseNextPageURL(from: response) #expect(nextURL == nil) } @@ -120,7 +120,7 @@ struct PaginationTests { linkHeader: "; rel=\"prev\"" ) - let nextURL = response.nextPageURL() + let nextURL = parseNextPageURL(from: response) #expect(nextURL == nil) } @@ -131,7 +131,7 @@ struct PaginationTests { linkHeader: "https://huggingface.co/api/page2; rel=\"next\"" ) - let nextURL = response.nextPageURL() + let nextURL = parseNextPageURL(from: response) // Should still extract the URL even without proper angle brackets #expect(nextURL != nil) @@ -143,7 +143,7 @@ struct PaginationTests { linkHeader: "<>; rel=\"next\"" ) - let nextURL = response.nextPageURL() + let nextURL = parseNextPageURL(from: response) #expect(nextURL == nil) } @@ -154,7 +154,7 @@ struct PaginationTests { linkHeader: " rel=\"next\"" ) - let nextURL = response.nextPageURL() + let nextURL = parseNextPageURL(from: response) #expect(nextURL == nil) } @@ -165,7 +165,7 @@ struct PaginationTests { linkHeader: "; rel=\"next\"; title=\"Next Page\"" ) - let nextURL = response.nextPageURL() + let nextURL = parseNextPageURL(from: response) #expect(nextURL != nil) #expect(nextURL?.absoluteString == "https://huggingface.co/api/page2") @@ -178,7 +178,7 @@ struct PaginationTests { "; rel=\"next\", ; rel=\"next\"" ) - let nextURL = response.nextPageURL() + let nextURL = parseNextPageURL(from: response) #expect(nextURL != nil) // Should return the first "next" link found diff --git a/Tests/HuggingFaceTests/HubTests/PaperTests.swift b/Tests/HuggingFaceTests/HubTests/PaperTests.swift index 6d4aba7..e1145cf 100644 --- a/Tests/HuggingFaceTests/HubTests/PaperTests.swift +++ b/Tests/HuggingFaceTests/HubTests/PaperTests.swift @@ -1,4 +1,8 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace diff --git a/Tests/HuggingFaceTests/HubTests/RepoIDTests.swift b/Tests/HuggingFaceTests/HubTests/RepoIDTests.swift index b1395b0..62c45fb 100644 --- a/Tests/HuggingFaceTests/HubTests/RepoIDTests.swift +++ b/Tests/HuggingFaceTests/HubTests/RepoIDTests.swift @@ -1,4 +1,8 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace diff --git a/Tests/HuggingFaceTests/HubTests/RepoTests.swift b/Tests/HuggingFaceTests/HubTests/RepoTests.swift index 8f30e97..599ee82 100644 --- a/Tests/HuggingFaceTests/HubTests/RepoTests.swift +++ b/Tests/HuggingFaceTests/HubTests/RepoTests.swift @@ -1,4 +1,8 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace diff --git a/Tests/HuggingFaceTests/HubTests/SpaceTests.swift b/Tests/HuggingFaceTests/HubTests/SpaceTests.swift index 5124dba..a800347 100644 --- a/Tests/HuggingFaceTests/HubTests/SpaceTests.swift +++ b/Tests/HuggingFaceTests/HubTests/SpaceTests.swift @@ -1,6 +1,14 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + @testable import HuggingFace #if swift(>=6.1) diff --git a/Tests/HuggingFaceTests/HubTests/UserTests.swift b/Tests/HuggingFaceTests/HubTests/UserTests.swift index 8a63f4f..7217c85 100644 --- a/Tests/HuggingFaceTests/HubTests/UserTests.swift +++ b/Tests/HuggingFaceTests/HubTests/UserTests.swift @@ -1,6 +1,14 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + @testable import HuggingFace #if swift(>=6.1) diff --git a/Tests/HuggingFaceTests/InferenceProvidersTests/ChatCompletionTests.swift b/Tests/HuggingFaceTests/InferenceProvidersTests/ChatCompletionTests.swift index 92066b3..713ed78 100644 --- a/Tests/HuggingFaceTests/InferenceProvidersTests/ChatCompletionTests.swift +++ b/Tests/HuggingFaceTests/InferenceProvidersTests/ChatCompletionTests.swift @@ -1,4 +1,8 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace diff --git a/Tests/HuggingFaceTests/InferenceProvidersTests/FeatureExtractionTests.swift b/Tests/HuggingFaceTests/InferenceProvidersTests/FeatureExtractionTests.swift index 937c015..00bc60f 100644 --- a/Tests/HuggingFaceTests/InferenceProvidersTests/FeatureExtractionTests.swift +++ b/Tests/HuggingFaceTests/InferenceProvidersTests/FeatureExtractionTests.swift @@ -1,6 +1,14 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + @testable import HuggingFace #if swift(>=6.1) diff --git a/Tests/HuggingFaceTests/InferenceProvidersTests/InferenceClientTests.swift b/Tests/HuggingFaceTests/InferenceProvidersTests/InferenceClientTests.swift index d42fd69..a0d0dc0 100644 --- a/Tests/HuggingFaceTests/InferenceProvidersTests/InferenceClientTests.swift +++ b/Tests/HuggingFaceTests/InferenceProvidersTests/InferenceClientTests.swift @@ -1,6 +1,14 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + @testable import HuggingFace #if swift(>=6.1) diff --git a/Tests/HuggingFaceTests/InferenceProvidersTests/SpeechToTextTests.swift b/Tests/HuggingFaceTests/InferenceProvidersTests/SpeechToTextTests.swift index dab10d7..fa00576 100644 --- a/Tests/HuggingFaceTests/InferenceProvidersTests/SpeechToTextTests.swift +++ b/Tests/HuggingFaceTests/InferenceProvidersTests/SpeechToTextTests.swift @@ -1,4 +1,8 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace diff --git a/Tests/HuggingFaceTests/InferenceProvidersTests/TextToImageTests.swift b/Tests/HuggingFaceTests/InferenceProvidersTests/TextToImageTests.swift index 9cf856f..8ac5734 100644 --- a/Tests/HuggingFaceTests/InferenceProvidersTests/TextToImageTests.swift +++ b/Tests/HuggingFaceTests/InferenceProvidersTests/TextToImageTests.swift @@ -1,6 +1,10 @@ import Foundation import Testing +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + @testable import HuggingFace #if swift(>=6.1) @@ -97,7 +101,7 @@ import Testing if let body = request.httpBody { let json = try JSONSerialization.jsonObject(with: body) as! [String: Any] #expect(json["model"] as? String == "stabilityai/stable-diffusion-xl-base-1.0") - #expect(json["prompt"] as? String == "A beautiful sunset over mountains") + #expect(json["prompt"] as? String == "A futuristic city") } #expect(request.url?.path == "/v1/images/generations") @@ -167,7 +171,7 @@ import Testing if let body = request.httpBody { let json = try JSONSerialization.jsonObject(with: body) as! [String: Any] #expect(json["model"] as? String == "stabilityai/stable-diffusion-xl-base-1.0") - #expect(json["prompt"] as? String == "A beautiful sunset over mountains") + #expect(json["prompt"] as? String == "A stunning anime character") } #expect(request.url?.path == "/v1/images/generations") @@ -190,7 +194,7 @@ import Testing let result = try await client.textToImage( model: "stabilityai/stable-diffusion-xl-base-1.0", - prompt: "A beautiful anime character", + prompt: "A stunning anime character", loras: loras ) @@ -222,7 +226,7 @@ import Testing if let body = request.httpBody { let json = try JSONSerialization.jsonObject(with: body) as! [String: Any] #expect(json["model"] as? String == "stabilityai/stable-diffusion-xl-base-1.0") - #expect(json["prompt"] as? String == "A beautiful sunset over mountains") + #expect(json["prompt"] as? String == "A detailed architectural drawing") } #expect(request.url?.path == "/v1/images/generations") @@ -281,7 +285,7 @@ import Testing if let body = request.httpBody { let json = try JSONSerialization.jsonObject(with: body) as! [String: Any] #expect(json["model"] as? String == "stabilityai/stable-diffusion-xl-base-1.0") - #expect(json["prompt"] as? String == "A beautiful sunset over mountains") + #expect(json["prompt"] as? String == "A wide landscape view") } #expect(request.url?.path == "/v1/images/generations") #expect(request.httpMethod == "POST") @@ -386,7 +390,7 @@ import Testing if let body = request.httpBody { let json = try JSONSerialization.jsonObject(with: body) as! [String: Any] #expect(json["model"] as? String == "stabilityai/stable-diffusion-xl-base-1.0") - #expect(json["prompt"] as? String == "A beautiful sunset over mountains") + #expect(json["prompt"] as? String == "High resolution artwork") } #expect(request.url?.path == "/v1/images/generations") #expect(request.httpMethod == "POST") diff --git a/Tests/HuggingFaceTests/InferenceProvidersTests/TextToVideoTests.swift b/Tests/HuggingFaceTests/InferenceProvidersTests/TextToVideoTests.swift index 98c42e0..d19c2b1 100644 --- a/Tests/HuggingFaceTests/InferenceProvidersTests/TextToVideoTests.swift +++ b/Tests/HuggingFaceTests/InferenceProvidersTests/TextToVideoTests.swift @@ -1,4 +1,8 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace @@ -44,7 +48,7 @@ import Testing if let body = request.httpBody { let json = try JSONSerialization.jsonObject(with: body) as! [String: Any] #expect(json["model"] as? String == "zeroscope_v2_576w") - #expect(json["prompt"] as? String == "A beautiful sunset over mountains") + #expect(json["prompt"] as? String == "A cat playing with a ball") } let response = HTTPURLResponse( @@ -100,7 +104,7 @@ import Testing if let body = request.httpBody { let json = try JSONSerialization.jsonObject(with: body) as! [String: Any] #expect(json["model"] as? String == "zeroscope_v2_576w") - #expect(json["prompt"] as? String == "A beautiful sunset over mountains") + #expect(json["prompt"] as? String == "A dancing robot") } #expect(request.url?.path == "/v1/videos/generations") #expect(request.httpMethod == "POST") diff --git a/Tests/HuggingFaceTests/OAuthTests/HuggingFaceAuthenticationManagerTests.swift b/Tests/HuggingFaceTests/OAuthTests/HuggingFaceAuthenticationManagerTests.swift index 2db58ad..57e54d7 100644 --- a/Tests/HuggingFaceTests/OAuthTests/HuggingFaceAuthenticationManagerTests.swift +++ b/Tests/HuggingFaceTests/OAuthTests/HuggingFaceAuthenticationManagerTests.swift @@ -1,9 +1,13 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing @testable import HuggingFace -#if swift(>=6.1) +#if swift(>=6.1) && canImport(AuthenticationServices) @Suite("HuggingFace Authentication Manager Tests") struct HuggingFaceAuthenticationManagerTests { @Test("HuggingFaceAuthenticationManager can be initialized with valid parameters") @@ -185,4 +189,4 @@ import Testing #expect(customScope == .other("custom-scope")) } } -#endif // swift(>=6.1) +#endif // canImport(AuthenticationServices) && swift(>=6.1) diff --git a/Tests/HuggingFaceTests/OAuthTests/OAuthClientTests.swift b/Tests/HuggingFaceTests/OAuthTests/OAuthClientTests.swift index 131b792..15b73bb 100644 --- a/Tests/HuggingFaceTests/OAuthTests/OAuthClientTests.swift +++ b/Tests/HuggingFaceTests/OAuthTests/OAuthClientTests.swift @@ -1,5 +1,12 @@ import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif import Testing +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif @testable import HuggingFace