Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions huggingface_hub/src/api/files.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,9 @@ impl HfApi {
// Determine which files should be uploaded via xet (LFS) vs inline
// (regular). Files uploaded via xet are referenced by their SHA256 OID
// in the commit NDJSON.
let lfs_uploaded: HashMap<String, (String, u64)> =
self.preupload_and_upload_lfs_files(params, revision).await?;
let lfs_uploaded: HashMap<String, (String, u64)> = self
.preupload_and_upload_lfs_files(params, revision, params.progress_callback.as_ref())
.await?;

let mut ndjson_lines: Vec<Vec<u8>> = Vec::new();

Expand All @@ -187,6 +188,11 @@ impl HfApi {
ndjson_lines.push(serde_json::to_vec(&header_line)?);

for op in &params.operations {
let path_in_repo = match op {
CommitOperation::Add { path_in_repo, .. } => path_in_repo,
CommitOperation::Delete { path_in_repo } => path_in_repo,
};
let is_lfs = lfs_uploaded.contains_key(path_in_repo);
let line = match op {
CommitOperation::Add { path_in_repo, source } => {
if let Some((oid, size)) = lfs_uploaded.get(path_in_repo) {
Expand All @@ -211,6 +217,13 @@ impl HfApi {
},
};
ndjson_lines.push(serde_json::to_vec(&line)?);

// Call progress callback for non-LFS files (LFS files already triggered callback during upload)
if !is_lfs {
if let Some(ref callback) = params.progress_callback {
callback(path_in_repo);
}
}
}

let body: Vec<u8> = ndjson_lines
Expand Down Expand Up @@ -448,6 +461,7 @@ impl HfApi {
&self,
params: &CreateCommitParams,
revision: &str,
progress_callback: Option<&crate::types::CommitProgressCallback>,
) -> Result<HashMap<String, (String, u64)>> {
let add_ops: Vec<(&String, &AddSource)> = params
.operations
Expand Down Expand Up @@ -497,12 +511,13 @@ impl HfApi {
// LFS files require xet upload — fail if the feature is not enabled
#[cfg(not(feature = "xet"))]
{
let _ = lfs_files;
let _ = (lfs_files, progress_callback);
Err(HfError::XetNotEnabled)
}

#[cfg(feature = "xet")]
self.upload_lfs_files_via_xet(params, revision, &lfs_files).await
self.upload_lfs_files_via_xet(params, revision, &lfs_files, progress_callback)
.await
}

/// Call the Hub preupload endpoint to determine upload mode per file.
Expand Down Expand Up @@ -558,6 +573,7 @@ impl HfApi {
params: &CreateCommitParams,
revision: &str,
lfs_files: &[&(String, u64, Vec<u8>, &AddSource)],
progress_callback: Option<&crate::types::CommitProgressCallback>,
) -> Result<HashMap<String, (String, u64)>> {
// Step 4: Compute SHA256 for LFS files
let mut lfs_with_sha: Vec<(String, u64, String, &AddSource)> = Vec::new();
Expand All @@ -583,7 +599,8 @@ impl HfApi {
.map(|(path, _, _, source)| (path.clone(), (*source).clone()))
.collect();

crate::xet::xet_upload(self, &xet_files, &params.repo_id, params.repo_type, revision).await?;
crate::xet::xet_upload(self, &xet_files, &params.repo_id, params.repo_type, revision, progress_callback)
.await?;

let result: HashMap<String, (String, u64)> = lfs_with_sha
.into_iter()
Expand Down
7 changes: 7 additions & 0 deletions huggingface_hub/src/types/params.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use std::path::PathBuf;
use std::sync::Arc;

use typed_builder::TypedBuilder;

use super::commit::{AddSource, CommitOperation};
use super::repo::RepoType;

/// Callback function invoked after each operation is processed during a commit.
/// The argument is the path of the file that was just processed.
pub type CommitProgressCallback = Arc<dyn Fn(&str) + Send + Sync>;

#[derive(TypedBuilder)]
pub struct ModelInfoParams {
#[builder(setter(into))]
Expand Down Expand Up @@ -304,6 +309,8 @@ pub struct CreateCommitParams {
pub create_pr: Option<bool>,
#[builder(default, setter(into, strip_option))]
pub parent_commit: Option<String>,
#[builder(default, setter(strip_option))]
pub progress_callback: Option<CommitProgressCallback>,
}

#[derive(TypedBuilder)]
Expand Down
7 changes: 6 additions & 1 deletion huggingface_hub/src/xet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ pub(crate) async fn xet_upload(
repo_id: &str,
repo_type: Option<RepoType>,
revision: &str,
progress_callback: Option<&crate::types::CommitProgressCallback>,
) -> Result<Vec<XetFileInfo>> {
let session = api.get_or_init_xet_session("write", repo_id, repo_type, revision).await?;

Expand All @@ -158,7 +159,7 @@ pub(crate) async fn xet_upload(

let mut task_ids_in_order = Vec::with_capacity(files.len());

for (_path_in_repo, source) in files {
for (path_in_repo, source) in files {
let handle = match source {
AddSource::File(path) => commit
.upload_from_path(path.clone(), Sha256Policy::Compute)
Expand All @@ -170,6 +171,10 @@ pub(crate) async fn xet_upload(
.map_err(|e| HfError::Other(format!("Xet upload failed: {e}")))?,
};
task_ids_in_order.push(handle.task_id);

if let Some(callback) = progress_callback {
callback(path_in_repo);
}
}

let results = commit
Expand Down
Loading