Skip to content

Commit 9a01667

Browse files
authored
Add support for reusing cache between Swift and Python Hub clients (#6)
* Add support for reusing cache between Swift and Python Hub clients * Add test coverage for HubCache and CacheLocationProvider * Create parent directories for nested paths * Add file locking when downloading files * Refactor FileLock implementation * Validate untrusted etag headers to prevent path traversal * Fix documentation comments for CacheLocationProvider * Validate filenames to prevent arbitrary path traversal
1 parent 21c7681 commit 9a01667

File tree

8 files changed

+2584
-179
lines changed

8 files changed

+2584
-179
lines changed

Sources/HuggingFace/Hub/HubCache.swift

Lines changed: 543 additions & 0 deletions
Large diffs are not rendered by default.

Sources/HuggingFace/Hub/HubClient+Files.swift

Lines changed: 84 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import CryptoKit
21
import Foundation
32
import UniformTypeIdentifiers
43

@@ -178,11 +177,23 @@ public extension HubClient {
178177
func downloadContentsOfFile(
179178
at repoPath: String,
180179
from repo: Repo.ID,
181-
kind _: Repo.Kind = .model,
180+
kind: Repo.Kind = .model,
182181
revision: String = "main",
183182
useRaw: Bool = false,
184183
cachePolicy: URLRequest.CachePolicy = .useProtocolCachePolicy
185184
) async throws -> Data {
185+
// Check cache first
186+
if let cache = cache,
187+
let cachedPath = cache.cachedFilePath(
188+
repo: repo,
189+
kind: kind,
190+
revision: revision,
191+
filename: repoPath
192+
)
193+
{
194+
return try Data(contentsOf: cachedPath)
195+
}
196+
186197
let endpoint = useRaw ? "raw" : "resolve"
187198
let urlPath = "/\(repo)/\(endpoint)/\(revision)/\(repoPath)"
188199
var request = try httpClient.createRequest(.get, urlPath)
@@ -191,6 +202,23 @@ public extension HubClient {
191202
let (data, response) = try await session.data(for: request)
192203
_ = try httpClient.validateResponse(response, data: data)
193204

205+
// Store in cache if we have etag and commit info
206+
if let cache = cache,
207+
let httpResponse = response as? HTTPURLResponse,
208+
let etag = httpResponse.value(forHTTPHeaderField: "ETag"),
209+
let commitHash = httpResponse.value(forHTTPHeaderField: "X-Repo-Commit")
210+
{
211+
try? cache.storeData(
212+
data,
213+
repo: repo,
214+
kind: kind,
215+
revision: commitHash,
216+
filename: repoPath,
217+
etag: etag,
218+
ref: revision != commitHash ? revision : nil
219+
)
220+
}
221+
194222
return data
195223
}
196224

@@ -209,12 +237,33 @@ public extension HubClient {
209237
at repoPath: String,
210238
from repo: Repo.ID,
211239
to destination: URL,
212-
kind _: Repo.Kind = .model,
240+
kind: Repo.Kind = .model,
213241
revision: String = "main",
214242
useRaw: Bool = false,
215243
cachePolicy: URLRequest.CachePolicy = .useProtocolCachePolicy,
216244
progress: Progress? = nil
217245
) async throws -> URL {
246+
// Check cache first
247+
if let cache = cache,
248+
let cachedPath = cache.cachedFilePath(
249+
repo: repo,
250+
kind: kind,
251+
revision: revision,
252+
filename: repoPath
253+
)
254+
{
255+
// Create parent directory if needed
256+
try FileManager.default.createDirectory(
257+
at: destination.deletingLastPathComponent(),
258+
withIntermediateDirectories: true
259+
)
260+
// Copy from cache to destination
261+
try? FileManager.default.removeItem(at: destination)
262+
try FileManager.default.copyItem(at: cachedPath, to: destination)
263+
progress?.completedUnitCount = progress?.totalUnitCount ?? 100
264+
return destination
265+
}
266+
218267
let endpoint = useRaw ? "raw" : "resolve"
219268
let urlPath = "/\(repo)/\(endpoint)/\(revision)/\(repoPath)"
220269
var request = try httpClient.createRequest(.get, urlPath)
@@ -226,6 +275,29 @@ public extension HubClient {
226275
)
227276
_ = try httpClient.validateResponse(response, data: nil)
228277

278+
// Store in cache before moving to destination
279+
if let cache = cache,
280+
let httpResponse = response as? HTTPURLResponse,
281+
let etag = httpResponse.value(forHTTPHeaderField: "ETag"),
282+
let commitHash = httpResponse.value(forHTTPHeaderField: "X-Repo-Commit")
283+
{
284+
try? cache.storeFile(
285+
at: tempURL,
286+
repo: repo,
287+
kind: kind,
288+
revision: commitHash,
289+
filename: repoPath,
290+
etag: etag,
291+
ref: revision != commitHash ? revision : nil
292+
)
293+
}
294+
295+
// Create parent directory if needed
296+
try FileManager.default.createDirectory(
297+
at: destination.deletingLastPathComponent(),
298+
withIntermediateDirectories: true
299+
)
300+
229301
// Move from temporary location to final destination
230302
try? FileManager.default.removeItem(at: destination)
231303
try FileManager.default.moveItem(at: tempURL, to: destination)
@@ -457,6 +529,11 @@ public extension HubClient {
457529

458530
public extension HubClient {
459531
/// Download a repository snapshot to a local directory.
532+
///
533+
/// This method downloads all files from a repository to the specified destination.
534+
/// Files are automatically cached in the Python-compatible cache directory,
535+
/// allowing cache reuse between Swift and Python Hugging Face clients.
536+
///
460537
/// - Parameters:
461538
/// - repo: Repository identifier
462539
/// - kind: Kind of repository
@@ -473,13 +550,6 @@ public extension HubClient {
473550
matching globs: [String] = [],
474551
progressHandler: ((Progress) -> Void)? = nil
475552
) async throws -> URL {
476-
let repoDestination = destination
477-
let repoMetadataDestination =
478-
repoDestination
479-
.appendingPathComponent(".cache")
480-
.appendingPathComponent("huggingface")
481-
.appendingPathComponent("download")
482-
483553
let filenames = try await listFiles(in: repo, kind: kind, revision: revision, recursive: true)
484554
.map(\.path)
485555
.filter { filename in
@@ -494,25 +564,9 @@ public extension HubClient {
494564

495565
for filename in filenames {
496566
let fileProgress = Progress(totalUnitCount: 100, parent: progress, pendingUnitCount: 1)
567+
let fileDestination = destination.appendingPathComponent(filename)
497568

498-
let fileDestination = repoDestination.appendingPathComponent(filename)
499-
let metadataDestination = repoMetadataDestination.appendingPathComponent(filename + ".metadata")
500-
501-
let localMetadata = readDownloadMetadata(at: metadataDestination)
502-
let remoteFile = try await getFile(at: filename, in: repo, kind: kind, revision: revision)
503-
504-
let localCommitHash = localMetadata?.commitHash ?? ""
505-
let remoteCommitHash = remoteFile.revision ?? ""
506-
507-
if isValidHash(remoteCommitHash, pattern: commitHashPattern),
508-
FileManager.default.fileExists(atPath: fileDestination.path),
509-
localMetadata != nil,
510-
localCommitHash == remoteCommitHash
511-
{
512-
fileProgress.completedUnitCount = 100
513-
continue
514-
}
515-
569+
// downloadFile handles cache lookup and storage automatically
516570
_ = try await downloadFile(
517571
at: filename,
518572
from: repo,
@@ -522,58 +576,15 @@ public extension HubClient {
522576
progress: fileProgress
523577
)
524578

525-
if let etag = remoteFile.etag, let revision = remoteFile.revision {
526-
try writeDownloadMetadata(
527-
commitHash: revision,
528-
etag: etag,
529-
to: metadataDestination
530-
)
531-
}
532-
533579
if Task.isCancelled {
534-
return repoDestination
580+
return destination
535581
}
536582

537583
fileProgress.completedUnitCount = 100
538584
}
539585

540586
progressHandler?(progress)
541-
return repoDestination
542-
}
543-
}
544-
545-
// MARK: - Metadata Helpers
546-
547-
extension HubClient {
548-
private var sha256Pattern: String { "^[0-9a-f]{64}$" }
549-
private var commitHashPattern: String { "^[0-9a-f]{40}$" }
550-
551-
/// Read metadata about a file in the local directory.
552-
func readDownloadMetadata(at metadataPath: URL) -> LocalDownloadFileMetadata? {
553-
FileManager.default.readDownloadMetadata(at: metadataPath)
554-
}
555-
556-
/// Write metadata about a downloaded file.
557-
func writeDownloadMetadata(commitHash: String, etag: String, to metadataPath: URL) throws {
558-
try FileManager.default.writeDownloadMetadata(
559-
commitHash: commitHash,
560-
etag: etag,
561-
to: metadataPath
562-
)
563-
}
564-
565-
/// Check if a hash matches the expected pattern.
566-
func isValidHash(_ hash: String, pattern: String) -> Bool {
567-
guard let regex = try? NSRegularExpression(pattern: pattern) else {
568-
return false
569-
}
570-
let range = NSRange(location: 0, length: hash.utf16.count)
571-
return regex.firstMatch(in: hash, options: [], range: range) != nil
572-
}
573-
574-
/// Compute SHA256 hash of a file.
575-
func computeFileHash(at url: URL) throws -> String {
576-
try FileManager.default.computeFileHash(at: url)
587+
return destination
577588
}
578589
}
579590

@@ -586,90 +597,6 @@ private struct UploadResponse: Codable {
586597

587598
// MARK: -
588599

589-
private extension FileManager {
590-
/// Read metadata about a file in the local directory.
591-
func readDownloadMetadata(at metadataPath: URL) -> LocalDownloadFileMetadata? {
592-
guard fileExists(atPath: metadataPath.path) else {
593-
return nil
594-
}
595-
596-
do {
597-
let contents = try String(contentsOf: metadataPath, encoding: .utf8)
598-
let lines = contents.components(separatedBy: .newlines)
599-
600-
guard lines.count >= 3 else {
601-
try? removeItem(at: metadataPath)
602-
return nil
603-
}
604-
605-
let commitHash = lines[0].trimmingCharacters(in: .whitespacesAndNewlines)
606-
let etag = lines[1].trimmingCharacters(in: .whitespacesAndNewlines)
607-
608-
guard let timestamp = Double(lines[2].trimmingCharacters(in: .whitespacesAndNewlines))
609-
else {
610-
try? removeItem(at: metadataPath)
611-
return nil
612-
}
613-
614-
let timestampDate = Date(timeIntervalSince1970: timestamp)
615-
let filename = metadataPath.lastPathComponent.replacingOccurrences(
616-
of: ".metadata",
617-
with: ""
618-
)
619-
620-
return LocalDownloadFileMetadata(
621-
commitHash: commitHash,
622-
etag: etag,
623-
filename: filename,
624-
timestamp: timestampDate
625-
)
626-
} catch {
627-
try? removeItem(at: metadataPath)
628-
return nil
629-
}
630-
}
631-
632-
/// Write metadata about a downloaded file.
633-
func writeDownloadMetadata(commitHash: String, etag: String, to metadataPath: URL) throws {
634-
let metadataContent = "\(commitHash)\n\(etag)\n\(Date().timeIntervalSince1970)\n"
635-
try createDirectory(
636-
at: metadataPath.deletingLastPathComponent(),
637-
withIntermediateDirectories: true
638-
)
639-
try metadataContent.write(to: metadataPath, atomically: true, encoding: .utf8)
640-
}
641-
642-
/// Compute SHA256 hash of a file.
643-
func computeFileHash(at url: URL) throws -> String {
644-
guard let fileHandle = try? FileHandle(forReadingFrom: url) else {
645-
throw HTTPClientError.unexpectedError("Unable to open file: \(url.path)")
646-
}
647-
648-
defer {
649-
try? fileHandle.close()
650-
}
651-
652-
var hasher = SHA256()
653-
let chunkSize = 1024 * 1024
654-
655-
while autoreleasepool(invoking: {
656-
guard let nextChunk = try? fileHandle.read(upToCount: chunkSize),
657-
!nextChunk.isEmpty
658-
else {
659-
return false
660-
}
661-
662-
hasher.update(data: nextChunk)
663-
return true
664-
}) {}
665-
666-
let digest = hasher.finalize()
667-
return digest.map { String(format: "%02x", $0) }.joined()
668-
}
669-
}
670-
671-
// MARK: -
672-
673600
private extension URL {
674601
var mimeType: String? {
675602
guard let uti = UTType(filenameExtension: pathExtension) else {

0 commit comments

Comments
 (0)