diff --git a/.gitignore b/.gitignore index 0023a53..144ca3b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ .DS_Store -/.build +.build /Packages xcuserdata/ DerivedData/ diff --git a/Example/Package.resolved b/Example/Package.resolved new file mode 100644 index 0000000..dcdc20f --- /dev/null +++ b/Example/Package.resolved @@ -0,0 +1,24 @@ +{ + "originHash" : "a0588c75481977c80f444ce6d823a9ad8ef45dee9fae5613afbe8bd488224283", + "pins" : [ + { + "identity" : "eventsource", + "kind" : "remoteSourceControl", + "location" : "https://github.com/mattt/EventSource.git", + "state" : { + "revision" : "ca2a9d90cbe49e09b92f4b6ebd922c03ebea51d0", + "version" : "1.3.0" + } + }, + { + "identity" : "swift-argument-parser", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-argument-parser", + "state" : { + "revision" : "cdd0ef3755280949551dc26dee5de9ddeda89f54", + "version" : "1.6.2" + } + } + ], + "version" : 3 +} diff --git a/Example/Package.swift b/Example/Package.swift new file mode 100644 index 0000000..a452baf --- /dev/null +++ b/Example/Package.swift @@ -0,0 +1,29 @@ +// swift-tools-version: 6.0 +import PackageDescription + +let package = Package( + name: "download-speed-test", + platforms: [ + .macOS(.v14), + .iOS(.v16), + ], + products: [ + .executable( + name: "download-speed-test", + targets: ["DownloadSpeedTest"] + ) + ], + dependencies: [ + .package(path: "../"), + .package(url: "https://github.com/apple/swift-argument-parser", from: "1.3.0"), + ], + targets: [ + .executableTarget( + name: "DownloadSpeedTest", + dependencies: [ + .product(name: "HuggingFace", package: "swift-huggingface"), + .product(name: "ArgumentParser", package: "swift-argument-parser"), + ] + ) + ] +) diff --git a/Example/README.md b/Example/README.md new file mode 100644 index 0000000..17b1a5c --- /dev/null +++ b/Example/README.md @@ -0,0 +1,163 @@ +# Download Speed Test Example + +This example demonstrates how to use the HuggingFace Swift package to download files from a repository and measure download performance. It's designed to compare download speeds with and without Xet support. + +## Usage + +### Running the Test + +From the `Example` directory: + +```bash +swift run download-speed-test +``` + +Use `--help` to see all arguments: + +```bash +swift run download-speed-test --help +``` + +### Command Line Options + +- `--repo ` or `-r `: Repository to benchmark (default: Qwen/Qwen3-0.6B) +- `--file ` or `-f `: Download a specific file (e.g., `model.safetensors`) +- `--min-size-mb `: Minimum file size in MB to test (default: 10, filters out small files) +- `--xet` / `--no-xet`: Enable or disable Xet acceleration + +### Testing with Xet Enabled/Disabled + +To compare performance with and without Xet: + +**Basic comparison (auto-selects large files):** +```bash +# With Xet +swift run download-speed-test + +# Without Xet (LFS) +swift run download-speed-test --no-xet +``` + +**Test a specific large file:** +```bash +# Download specific model file +swift run download-speed-test --file model.safetensors + +# Compare Xet vs LFS for the same file +swift run download-speed-test --file model.safetensors --no-xet +``` + +**Test different repository:** +```bash +swift run download-speed-test --repo meta-llama/Llama-3.2-1B +``` + +**Adjust minimum file size filter:** +```bash +# Only test files >= 100 MB (better for Xet benchmarking) +swift run download-speed-test --min-size-mb 100 + +# Include smaller files (>= 1 MB) +swift run download-speed-test --min-size-mb 1 +``` + +**Notes:** +- Xet excels at large files (10+ MB), so the benchmark filters out small files by default +- Small files (configs, JSONs) add overhead that doesn't showcase Xet's strengths +- Use `--file` to benchmark a specific large model file for accurate comparison + +### Performance Features + +Xet is optimized for high-performance downloads by default: + +- **256 concurrent range GET requests** per file (automatically set) +- **High-performance mode enabled** for maximum throughput +- **XetClient reuse** across downloads for HTTP/TLS connection pooling +- **JWT token caching** per repository/revision to avoid redundant API calls + +**No configuration required!** Xet should match or exceed LFS speeds out of the box. + +If you need to adjust settings: + +- **XET_NUM_CONCURRENT_RANGE_GETS**: Override default per-file concurrency + ```bash + XET_NUM_CONCURRENT_RANGE_GETS=128 swift run download-speed-test # Lower for slow networks + ``` + +- **XET_HIGH_PERFORMANCE**: Disable high-performance mode + ```bash + XET_HIGH_PERFORMANCE=0 swift run download-speed-test # Conservative mode + ``` + +## What It Does + +The test: +1. Connects to the Hugging Face Hub +2. Lists files in the specified repository (default: `Qwen/Qwen3-0.6B`) +3. Selects large files (default: >= 10 MB) that showcase Xet's performance: + - Model files (`.safetensors`, `.bin`, `.gguf`, `.pt`, `.pth`) + - Prioritizes the largest files for meaningful benchmarking +4. Downloads each file and measures: + - Download time + - File size + - Download speed (MB/s) +5. Provides a summary with total time, size, and average speed + +**Why filter small files?** +Xet is optimized for large files through: +- Content-addressable storage with chunking +- Parallel chunk downloads +- Deduplication across files + +Small files (<10 MB) don't benefit from these optimizations and add per-file overhead that skews results. + +## Output Example + +``` +🚀 Hugging Face Download Speed Test +Repository: Qwen/Qwen3-0.6B +============================================================ + +✅ Xet support: ENABLED + +📋 Listing files in repository... +📦 Selected 3 files for testing: + • model.safetensors (1.2 GB) + • model-00001-of-00002.safetensors (987 MB) + • model-00002-of-00002.safetensors (256 MB) + +⬇️ Starting download tests... + +✅ [1/3] model.safetensors + Time: 12.34s + Size: 1.2 GB + Speed: 99.2 MB/s + +✅ [2/3] model-00001-of-00002.safetensors + Time: 9.87s + Size: 987 MB + Speed: 100.1 MB/s + +✅ [3/3] model-00002-of-00002.safetensors + Time: 2.56s + Size: 256 MB + Speed: 100.0 MB/s + +============================================================ +📊 Summary +============================================================ +Total files: 3 +Total time: 24.77s +Total size: 2.4 GB +Average speed: 99.8 MB/s + +💡 Tip: toggle Xet via --xet / --no-xet to compare backends. +``` + +## Notes + +- The test uses a temporary directory that is automatically cleaned up +- Files are downloaded sequentially to get accurate timing +- The test automatically selects a mix of small and large files +- Progress is shown for each file download + diff --git a/Example/Sources/DownloadSpeedTest/main.swift b/Example/Sources/DownloadSpeedTest/main.swift new file mode 100644 index 0000000..903309f --- /dev/null +++ b/Example/Sources/DownloadSpeedTest/main.swift @@ -0,0 +1,244 @@ +import ArgumentParser +import Foundation +import HuggingFace + +#if canImport(Xet) + import Xet +#endif + +@main +struct DownloadSpeedTest: AsyncParsableCommand { + static let configuration = CommandConfiguration( + commandName: "download-speed-test", + abstract: "Benchmark download performance for Hugging Face repositories." + ) + + @Option( + name: [.short, .long], + help: "Repository identifier to benchmark (e.g. owner/name)." + ) + var repo: String = "Qwen/Qwen3-0.6B" + + @Option( + name: .long, + help: "Minimum file size in MB to test (filters out small files). Default: 10 MB." + ) + var minSizeMb: Int = 10 + + @Flag( + name: .long, + inversion: .prefixedNo, + help: "Enable Xet acceleration (use --no-xet to force classic LFS)." + ) + var xet: Bool = HubClient.isXetSupported + + @Option( + name: .long, + help: "Maximum concurrent range GET requests for Xet downloads. Default: 256 (high performance)." + ) + var maxConcurrentRangeGets: Int? + + @Flag( + name: .long, + inversion: .prefixedNo, + help: "Use high performance mode (256 concurrent requests). Use --no-high-performance for standard mode (48 concurrent requests)." + ) + var highPerformance: Bool = true + + func run() async throws { + guard let repoID = Repo.ID(rawValue: repo) else { + throw ValidationError("Invalid repository identifier: \(repo). Expected format is owner/name.") + } + + let xetConfiguration: XetConfiguration? + if xet { + if let maxConcurrent = maxConcurrentRangeGets { + xetConfiguration = XetConfiguration( + maxConcurrentRangeGets: maxConcurrent, + highPerformanceMode: highPerformance + ) + } else { + xetConfiguration = highPerformance ? .highPerformance() : .default() + } + } else { + xetConfiguration = nil + } + + let client = HubClient(xetConfiguration: xetConfiguration) + + print("🚀 Hugging Face Download Speed Test") + print("Repository: \(repoID)") + print("=" * 60) + print() + + if client.isXetEnabled { + print("✅ Xet support: ENABLED") + + if let config = xetConfiguration { + print(" Concurrent range GETs: \(config.maxConcurrentRangeGets)") + print(" High performance mode: \(config.highPerformanceMode ? "ON" : "OFF")") + } + } else { + print("❌ Xet support: DISABLED (using LFS)") + } + print() + + print("📋 Listing files in repository...") + do { + let testFiles: [Git.TreeEntry] + + // Auto-select large files + let files = try await client.listFiles( + in: repoID, + kind: .model, + revision: "main", + recursive: true + ) + + testFiles = Self.selectTestFiles(from: files, minSizeMB: minSizeMb) + + if testFiles.isEmpty { + print("⚠️ No suitable test files found (minimum size: \(minSizeMb) MB)") + print("💡 Try lowering --min-size-mb or specify a file with --file") + return + } + + print("📦 Selected \(testFiles.count) files for testing:") + for file in testFiles { + let size = file.size.map { Self.formatBytes(Int64($0)) } ?? "unknown size" + print(" • \(file.path) (\(size))") + } + print() + + let tempDir = FileManager.default.temporaryDirectory + .appendingPathComponent("hf-speed-test-\(UUID().uuidString)") + try FileManager.default.createDirectory( + at: tempDir, + withIntermediateDirectories: true + ) + defer { + try? FileManager.default.removeItem(at: tempDir) + } + + var totalTime: TimeInterval = 0 + var totalBytes: Int = 0 + + print("⬇️ Starting download tests...") + print() + + for (index, file) in testFiles.enumerated() { + let destination = tempDir.appendingPathComponent(file.path) + + try? FileManager.default.createDirectory( + at: destination.deletingLastPathComponent(), + withIntermediateDirectories: true + ) + + let startTime = Date() + + do { + _ = try await client.downloadFile( + at: file.path, + from: repoID, + to: destination, + kind: .model, + revision: "main" + ) + + let elapsed = Date().timeIntervalSince(startTime) + let fileSize = file.size ?? 0 + let speed = fileSize > 0 ? Double(fileSize) / elapsed : 0 + + totalTime += elapsed + totalBytes += fileSize + + print("✅ [\(index + 1)/\(testFiles.count)] \(file.path)") + print(" Time: \(String(format: "%.2f", elapsed))s") + print(" Size: \(Self.formatBytes(Int64(fileSize)))") + print(" Speed: \(Self.formatBytes(Int64(speed)))/s") + print() + } catch { + print("❌ [\(index + 1)/\(testFiles.count)] \(file.path)") + print(" Error: \(error.localizedDescription)") + print() + } + } + + print("=" * 60) + print("📊 Summary") + print("=" * 60) + print("Total files: \(testFiles.count)") + print("Total time: \(String(format: "%.2f", totalTime))s") + print("Total size: \(Self.formatBytes(Int64(totalBytes)))") + if totalTime > 0 { + let avgSpeed = Double(totalBytes) / totalTime + print("Average speed: \(Self.formatBytes(Int64(avgSpeed)))/s") + } + print() + print("💡 Tip: toggle Xet via --xet / --no-xet to compare backends.") + + } catch { + print("❌ Error: \(error.localizedDescription)") + throw ExitCode.failure + } + } + + static func selectTestFiles(from files: [Git.TreeEntry], minSizeMB: Int) -> [Git.TreeEntry] { + let minSizeBytes = minSizeMB * 1024 * 1024 + + // Filter files by minimum size first (Xet excels at large files) + let largeFiles = files.filter { file in + file.type == .file && (file.size ?? 0) >= minSizeBytes + } + + guard !largeFiles.isEmpty else { + return [] + } + + var selected: [Git.TreeEntry] = [] + + // Prioritize model files (safetensors, bin) as they're typically large + let priorities = [ + "*.safetensors", + "*.bin", + "*.gguf", + "*.pt", + "*.pth", + ] + + for priority in priorities { + let pattern = priority.replacingOccurrences(of: "*", with: "") + if let file = largeFiles.first(where: { $0.path.contains(pattern) }) { + if !selected.contains(where: { $0.path == file.path }) { + selected.append(file) + } + } + } + + // If we need more files, add the largest remaining ones + if selected.count < 3 { + let remaining = largeFiles.filter { file in + !selected.contains(where: { $0.path == file.path }) + } + + let sorted = remaining.sorted { ($0.size ?? 0) > ($1.size ?? 0) } + selected.append(contentsOf: sorted.prefix(3 - selected.count)) + } + + // Return up to 3 large files for benchmarking + return Array(selected.prefix(3)) + } + + static func formatBytes(_ bytes: Int64) -> String { + let formatter = ByteCountFormatter() + formatter.allowedUnits = [.useKB, .useMB, .useGB] + formatter.countStyle = .file + return formatter.string(fromByteCount: bytes) + } +} + +extension String { + static func * (lhs: String, rhs: Int) -> String { + return String(repeating: lhs, count: rhs) + } +} diff --git a/Package.resolved b/Package.resolved new file mode 100644 index 0000000..1b9229b --- /dev/null +++ b/Package.resolved @@ -0,0 +1,15 @@ +{ + "originHash" : "5bf87bf95304ccc81cb66bf2213ca3650163d512843ea900ebf7124b8935e25e", + "pins" : [ + { + "identity" : "eventsource", + "kind" : "remoteSourceControl", + "location" : "https://github.com/mattt/EventSource.git", + "state" : { + "revision" : "ca2a9d90cbe49e09b92f4b6ebd922c03ebea51d0", + "version" : "1.3.0" + } + } + ], + "version" : 3 +} diff --git a/Package.swift b/Package.swift index 27611a6..ae80a83 100644 --- a/Package.swift +++ b/Package.swift @@ -20,13 +20,15 @@ let package = Package( ) ], dependencies: [ - .package(url: "https://github.com/mattt/EventSource.git", from: "1.0.0") + .package(url: "https://github.com/mattt/EventSource.git", from: "1.0.0"), + .package(name: "Xet", path: "../swift-xet-ported"), ], targets: [ .target( name: "HuggingFace", dependencies: [ - .product(name: "EventSource", package: "EventSource") + .product(name: "EventSource", package: "EventSource"), + .product(name: "Xet", package: "Xet"), ], path: "Sources/HuggingFace" ), diff --git a/README.md b/README.md index 23ff38a..05d5467 100644 --- a/README.md +++ b/README.md @@ -1000,3 +1000,8 @@ let response = try await client.speechToText( print("Transcription: \(response.text)") ``` + +## License + +This project is available under the MIT license. +See the LICENSE file for more info. diff --git a/Sources/HuggingFace/Hub/HubClient+Files.swift b/Sources/HuggingFace/Hub/HubClient+Files.swift index 12412a0..a7b09e0 100644 --- a/Sources/HuggingFace/Hub/HubClient+Files.swift +++ b/Sources/HuggingFace/Hub/HubClient+Files.swift @@ -6,6 +6,8 @@ import UniformTypeIdentifiers import FoundationNetworking #endif +import Xet + // MARK: - Upload Operations public extension HubClient { @@ -183,6 +185,31 @@ public extension HubClient { useRaw: Bool = false, cachePolicy: URLRequest.CachePolicy = .useProtocolCachePolicy ) async throws -> Data { + #if canImport(Xet) + if isXetEnabled { + do { + let tempDirectory = FileManager.default.temporaryDirectory + .appendingPathComponent(UUID().uuidString, isDirectory: true) + let tempFile = tempDirectory.appendingPathComponent(UUID().uuidString) + try FileManager.default.createDirectory(at: tempDirectory, withIntermediateDirectories: true) + defer { try? FileManager.default.removeItem(at: tempDirectory) } + + if try await downloadFileWithXet( + repoPath: repoPath, + repo: repo, + revision: revision, + destination: tempFile, + progress: nil + ) != nil { + return try Data(contentsOf: tempFile) + } + } catch { + print("⚠️ Xet failed for \(repoPath): \(error), falling back to LFS") + } + } + #endif + + // Fallback to existing LFS download method let endpoint = useRaw ? "raw" : "resolve" let urlPath = "/\(repo)/\(endpoint)/\(revision)/\(repoPath)" var request = try httpClient.createRequest(.get, urlPath) @@ -215,6 +242,23 @@ public extension HubClient { cachePolicy: URLRequest.CachePolicy = .useProtocolCachePolicy, progress: Progress? = nil ) async throws -> URL { + if isXetEnabled { + do { + if let downloaded = try await downloadFileWithXet( + repoPath: repoPath, + repo: repo, + revision: revision, + destination: destination, + progress: progress + ) { + return downloaded + } + } catch { + // Silently fall back to LFS + } + } + + // Fallback to existing LFS download method let endpoint = useRaw ? "raw" : "resolve" let urlPath = "/\(repo)/\(endpoint)/\(revision)/\(repoPath)" var request = try httpClient.createRequest(.get, urlPath) @@ -542,6 +586,129 @@ public extension HubClient { } } +extension HubClient { + /// Downloads a file using Xet's content-addressable storage system. + /// + /// This method uses a cached XetHubClient instance and JWT tokens to maximize + /// download performance through connection reuse and reduced API overhead. + /// + /// Performance optimizations: + /// - Reuses a single XetHubClient across all downloads for HTTP/TLS connection pooling + /// - Caches CAS JWT tokens per refresh route to avoid redundant API calls + /// - Leverages Xet's parallel chunk downloading + @discardableResult + func downloadFileWithXet( + repoPath: String, + repo: Repo.ID, + revision: String, + destination: URL, + progress: Progress? + ) async throws -> URL? { + guard isXetEnabled else { + return nil + } + + let xetClient = try getXetClient() + let token = try? httpClient.tokenProvider.getToken() + + // Get file metadata (includes hash and refresh route) + guard + let metadata = try await xetClient.getFileMetadata( + repo: repo.rawValue, + path: repoPath, + revision: revision, + host: host, + token: token + ) + else { + return nil + } + + // Ensure destination directory exists + let destinationDirectory = destination.deletingLastPathComponent() + try FileManager.default.createDirectory( + at: destinationDirectory, + withIntermediateDirectories: true + ) + + // Setup progress tracking + actor BytesTracker { + var totalBytesWritten: Int64 = 0 + let progressObject: Progress? + + init(progressObject: Progress?) { + self.progressObject = progressObject + } + + func addBytes(_ bytes: Int64) { + totalBytesWritten += bytes + if let progressObject { + progressObject.completedUnitCount = totalBytesWritten + } + } + } + + let progressObject = progress + if let progressObject { + progressObject.totalUnitCount = Int64(metadata.size) + } + + let tracker = BytesTracker(progressObject: progressObject) + let progressHandler: XetHubClient.ProgressHandler = { bytesJustWritten in + Task { + await tracker.addBytes(Int64(bytesJustWritten)) + } + } + + if let fileHash = metadata.fileHash, let refreshRoute = metadata.refreshRoute { + // Xet file: Use CAS download + + // Get or fetch CAS JWT using refresh route + let jwt = try await getCachedJwt( + refreshRoute: refreshRoute, + token: token + ) + + let fileInfo = XetFileInfo( + hash: fileHash, + fileSize: metadata.size + ) + + _ = try await xetClient.downloadFromCas( + fileInfo: fileInfo, + jwt: jwt, + destination: destination, + progress: progressHandler + ) + + } else { + // Non-Xet file: Use parallel HTTP download + + guard let url = URL(string: metadata.downloadURL) else { + return nil + } + + var headers: [String: String] = [:] + if let token = token { + headers["Authorization"] = "Bearer \(token)" + } + + let request = HubDownloadRequest( + source: url, + destination: destination, + headers: headers + ) + + _ = try await xetClient.download( + request, + progress: progressHandler + ) + } + + return destination + } +} + // MARK: - Metadata Helpers extension HubClient { diff --git a/Sources/HuggingFace/Hub/HubClient.swift b/Sources/HuggingFace/Hub/HubClient.swift index 1781c36..ed240e5 100644 --- a/Sources/HuggingFace/Hub/HubClient.swift +++ b/Sources/HuggingFace/Hub/HubClient.swift @@ -4,6 +4,10 @@ import Foundation import FoundationNetworking #endif +#if canImport(Xet) + import Xet +#endif + /// A Hugging Face Hub API client. /// /// This client provides methods to interact with the Hugging Face Hub API, @@ -32,8 +36,16 @@ public final class HubClient: Sendable { /// environment variable (defaults to https://huggingface.co). public static let `default` = HubClient() + /// Indicates whether Xet acceleration is enabled for this client. + public let isXetEnabled: Bool + /// The underlying HTTP client. internal let httpClient: HTTPClient + + #if canImport(Xet) + /// Xet client instance for connection reuse (created once during initialization) + private let xetClient: XetHubClient? + #endif /// The host URL for requests made by the client. public var host: URL { @@ -65,15 +77,19 @@ public final class HubClient: Sendable { /// - Parameters: /// - session: The underlying client session. Defaults to `URLSession(configuration: .default)`. /// - userAgent: The value for the `User-Agent` header sent in requests, if any. Defaults to `nil`. + /// - xetConfiguration: Configuration for Xet downloads. Pass `nil` to disable Xet acceleration. + /// Defaults to `.highPerformance()` if Xet is supported, `nil` otherwise. public convenience init( session: URLSession = URLSession(configuration: .default), - userAgent: String? = nil + userAgent: String? = nil, + xetConfiguration: XetConfiguration? = HubClient.isXetSupported ? .highPerformance() : nil ) { self.init( session: session, host: Self.detectHost(), userAgent: userAgent, - tokenProvider: .environment + tokenProvider: .environment, + xetConfiguration: xetConfiguration ) } @@ -84,17 +100,21 @@ public final class HubClient: Sendable { /// - host: The host URL to use for requests. /// - userAgent: The value for the `User-Agent` header sent in requests, if any. Defaults to `nil`. /// - bearerToken: The Bearer token for authentication, if any. Defaults to `nil`. + /// - xetConfiguration: Configuration for Xet downloads. Pass `nil` to disable Xet acceleration. + /// Defaults to `.highPerformance()` if Xet is supported, `nil` otherwise. public convenience init( session: URLSession = URLSession(configuration: .default), host: URL, userAgent: String? = nil, - bearerToken: String? = nil + bearerToken: String? = nil, + xetConfiguration: XetConfiguration? = HubClient.isXetSupported ? .highPerformance() : nil ) { self.init( session: session, host: host, userAgent: userAgent, - tokenProvider: bearerToken.map { .fixed(token: $0) } ?? .none + tokenProvider: bearerToken.map { .fixed(token: $0) } ?? .none, + xetConfiguration: xetConfiguration ) } @@ -105,18 +125,34 @@ public final class HubClient: Sendable { /// - host: The host URL to use for requests. /// - userAgent: The value for the `User-Agent` header sent in requests, if any. Defaults to `nil`. /// - tokenProvider: The token provider for authentication. + /// - xetConfiguration: Configuration for Xet downloads. Pass `nil` to disable Xet acceleration. + /// Defaults to `.highPerformance()` if Xet is supported, `nil` otherwise. public init( session: URLSession = URLSession(configuration: .default), host: URL, userAgent: String? = nil, - tokenProvider: TokenProvider + tokenProvider: TokenProvider, + xetConfiguration: XetConfiguration? = HubClient.isXetSupported ? .highPerformance() : nil ) { + self.isXetEnabled = xetConfiguration != nil && HubClient.isXetSupported self.httpClient = HTTPClient( host: host, userAgent: userAgent, tokenProvider: tokenProvider, session: session ) + + #if canImport(Xet) + if let config = xetConfiguration, self.isXetEnabled { + // Create XetHubClient once during initialization + self.xetClient = XetHubClient( + sessionConfiguration: session.configuration, + configuration: config + ) + } else { + self.xetClient = nil + } + #endif } // MARK: - Auto-detection @@ -134,4 +170,45 @@ public final class HubClient: Sendable { } return defaultHost } -} + + public static var isXetSupported: Bool { + #if canImport(Xet) + return true + #else + return false + #endif + } + + // MARK: - Xet Client + + #if canImport(Xet) + /// Returns the Xet client for faster downloads. + /// + /// The client is created once during initialization and reused across downloads + /// to enable connection pooling and avoid reinitialization overhead. + /// + /// - Returns: A Xet client instance. + internal func getXetClient() throws -> XetHubClient { + guard isXetEnabled, let client = xetClient else { + throw HTTPClientError.requestError("Xet support is disabled for this client.") + } + return client + } + + /// Gets or fetches a CAS JWT for the given refresh route. + /// + /// JWTs are cached by the XetHubClient to avoid redundant API calls. + /// + /// - Parameters: + /// - refreshRoute: The refresh route URL for fetching the JWT + /// - token: Optional authentication token + /// - Returns: A CAS JWT info object + internal func getCachedJwt( + refreshRoute: String, + token: String? + ) async throws -> CasJwtInfo { + let xetClient = try getXetClient() + return try await xetClient.getCasJwt(refreshRoute: refreshRoute, token: token) + } + #endif + } diff --git a/Tests/HuggingFaceTests/DownloadSpeedTests.swift b/Tests/HuggingFaceTests/DownloadSpeedTests.swift new file mode 100644 index 0000000..6015a65 --- /dev/null +++ b/Tests/HuggingFaceTests/DownloadSpeedTests.swift @@ -0,0 +1,182 @@ +import Foundation +import Testing +@testable import HuggingFace + +#if canImport(Xet) + import Xet +#endif + +/// Integration tests that exercise the high-level download path (including Xet) against the live Hugging Face Hub. +/// +/// These tests perform large network transfers, so they are opt-in. Set the environment variable +/// `HF_RUN_SPEED_TEST=1` before running `swift test` to enable them. +@Suite struct DownloadSpeedTests { + @Test("Xet download speed") + func xetDownloadSpeed() async throws { + print("xetDownloadSpeed started") + guard ProcessInfo.processInfo.environment["HF_RUN_SPEED_TEST"] == "1" else { + Issue.record("Set HF_RUN_SPEED_TEST=1 to enable the download speed integration test.") + return + } + + guard HubClient.isXetSupported else { + Issue.record("Xet acceleration is not supported on this platform.") + return + } + + let env = ProcessInfo.processInfo.environment + let repoIdentifier = env["HF_SPEED_TEST_REPO"] ?? "Qwen/Qwen3-0.6B" + let minSizeMB = Int(env["HF_SPEED_TEST_MIN_SIZE_MB"] ?? "") ?? 10 + + guard let repoID = Repo.ID(rawValue: repoIdentifier) else { + Issue.record("Invalid repo identifier: \(repoIdentifier)") + return + } + + print("repoID: \(repoID)") + print("minSizeMB: \(minSizeMB)") + print("xetConfiguration: \(XetConfiguration.highPerformance())") + let tester = DownloadSpeedTester( + repo: repoID, + minSizeMB: minSizeMB, + xetConfiguration: .highPerformance() + ) + + let result: DownloadSpeedTester.Result + do { + print("tester.run() started") + result = try await tester.run() + } catch DownloadSpeedTester.RunError.noEligibleFiles { + Issue.record("No files of at least \(minSizeMB) MB were found in \(repoID).") + return + } + + #expect(result.totalBytes > 0, "Expected to download at least one file.") + #expect(result.averageSpeedBytesPerSecond > 0, "Average speed should be greater than zero.") + + Issue.record( + """ + Downloaded \(DownloadSpeedTester.formatBytes(Int64(result.totalBytes))) \ + in \(String(format: "%.2f", result.totalTime))s \ + (\(DownloadSpeedTester.formatBytes(Int64(result.averageSpeedBytesPerSecond)))/s average) + """ + ) + } +} + +private struct DownloadSpeedTester { + enum RunError: Error { + case noEligibleFiles + } + + struct Result { + let totalBytes: Int + let totalTime: TimeInterval + + var averageSpeedBytesPerSecond: Double { + guard totalTime > 0 else { return 0 } + return Double(totalBytes) / totalTime + } + } + + let repo: Repo.ID + let minSizeMB: Int + + let xetConfiguration: XetConfiguration? + + init(repo: Repo.ID, minSizeMB: Int, xetConfiguration: XetConfiguration?) { + self.repo = repo + self.minSizeMB = minSizeMB + self.xetConfiguration = xetConfiguration + } + + func run() async throws -> Result { + print("run() started") + let client = HubClient(xetConfiguration: xetConfiguration) + print("client: \(client)") + + let files = try await client.listFiles( + in: repo, + kind: .model, + revision: "main", + recursive: true + ) + + let testFiles = Self.selectTestFiles(from: files, minSizeMB: minSizeMB) + guard !testFiles.isEmpty else { throw RunError.noEligibleFiles } + + let tempDir = FileManager.default.temporaryDirectory + .appendingPathComponent("hf-speed-test-\(UUID().uuidString)") + print("tempDir: \(tempDir)") + try FileManager.default.createDirectory(at: tempDir, withIntermediateDirectories: true) + print("tempDir created") + defer { + print("tempDir cleanup") + try? FileManager.default.removeItem(at: tempDir) + } + + var totalBytes = 0 + var totalTime: TimeInterval = 0 + + for file in testFiles { + print("downloading file: \(file.path)") + let destination = tempDir.appendingPathComponent(file.path) + try? FileManager.default.createDirectory( + at: destination.deletingLastPathComponent(), + withIntermediateDirectories: true + ) + + let start = Date() + _ = try await client.downloadFile( + at: file.path, + from: repo, + to: destination, + kind: .model, + revision: "main" + ) + let duration = Date().timeIntervalSince(start) + + totalBytes += file.size ?? 0 + totalTime += duration + } + + print("run() completed") + return Result(totalBytes: totalBytes, totalTime: totalTime) + } + + static func selectTestFiles(from files: [Git.TreeEntry], minSizeMB: Int) -> [Git.TreeEntry] { + let minSizeBytes = minSizeMB * 1_024 * 1_024 + let largeFiles = files.filter { file in + file.type == .file && (file.size ?? 0) >= minSizeBytes + } + + guard !largeFiles.isEmpty else { return [] } + + var selected: [Git.TreeEntry] = [] + let priorities = ["safetensors", "bin", "gguf", "pt", "pth"] + + for priority in priorities { + if let match = largeFiles.first(where: { $0.path.contains(priority) }) { + if !selected.contains(where: { $0.path == match.path }) { + selected.append(match) + } + } + } + + if selected.count < 3 { + let remaining = largeFiles.filter { candidate in + !selected.contains(where: { $0.path == candidate.path }) + } + selected.append(contentsOf: remaining.sorted { ($0.size ?? 0) > ($1.size ?? 0) }) + } + + return Array(selected.prefix(3)) + } + + static func formatBytes(_ bytes: Int64) -> String { + let formatter = ByteCountFormatter() + formatter.allowedUnits = [.useKB, .useMB, .useGB] + formatter.countStyle = .file + return formatter.string(fromByteCount: bytes) + } +} diff --git a/Tests/HuggingFaceTests/HubTests/HubClientTests.swift b/Tests/HuggingFaceTests/HubTests/HubClientTests.swift index b691ac4..aa477a0 100644 --- a/Tests/HuggingFaceTests/HubTests/HubClientTests.swift +++ b/Tests/HuggingFaceTests/HubTests/HubClientTests.swift @@ -38,4 +38,17 @@ struct HubClientTests { #expect(client.host.path.hasSuffix("/")) } + + @Test("Xet configuration can be toggled per client") + func testXetConfigurationToggle() throws { + let host = URL(string: "https://huggingface.co")! + + try #require(HubClient.isXetSupported, "Xet is not supported on this platform") + + let disabledClient = HubClient(host: host, xetConfiguration: nil) + #expect(disabledClient.isXetEnabled == false) + + let enabledClient = HubClient(host: host, xetConfiguration: .highPerformance()) + #expect(enabledClient.isXetEnabled) + } }