Skip to content

Commit ac7ff7d

Browse files
authored
add utilities for updating diffusers pipeline metadata. (#7573)
* add utilities for updating diffusers pipeline metadata. * style * remove first empty line
1 parent a0cf607 commit ac7ff7d

File tree

2 files changed

+136
-0
lines changed

2 files changed

+136
-0
lines changed
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
name: Update Diffusers metadata
2+
3+
on:
4+
workflow_dispatch:
5+
push:
6+
branches:
7+
- main
8+
- update_diffusers_metadata*
9+
10+
jobs:
11+
update_metadata:
12+
runs-on: ubuntu-22.04
13+
defaults:
14+
run:
15+
shell: bash -l {0}
16+
17+
steps:
18+
- uses: actions/checkout@v3
19+
20+
- name: Setup environment
21+
run: |
22+
pip install --upgrade pip
23+
pip install datasets pandas
24+
pip install .[torch]
25+
26+
- name: Update metadata
27+
env:
28+
HUGGING_FACE_HUB_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
29+
run: |
30+
python utils/update_metadata.py --commit_sha ${{ github.sha }}

utils/update_metadata.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# coding=utf-8
2+
# Copyright 2024 The HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""
16+
Utility that updates the metadata of the Diffusers library in the repository `huggingface/diffusers-metadata`.
17+
18+
Usage for an update (as used by the GitHub action `update_metadata`):
19+
20+
```bash
21+
python utils/update_metadata.py
22+
```
23+
24+
Script modified from:
25+
https://github.com/huggingface/transformers/blob/main/utils/update_metadata.py
26+
"""
27+
import argparse
28+
import os
29+
import tempfile
30+
31+
import pandas as pd
32+
from datasets import Dataset
33+
from huggingface_hub import upload_folder
34+
35+
from diffusers.pipelines.auto_pipeline import (
36+
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
37+
AUTO_INPAINT_PIPELINES_MAPPING,
38+
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
39+
)
40+
41+
42+
def get_supported_pipeline_table() -> dict:
43+
"""
44+
Generates a dictionary containing the supported auto classes for each pipeline type,
45+
using the content of the auto modules.
46+
"""
47+
# All supported pipelines for automatic mapping.
48+
all_supported_pipeline_classes = [
49+
(class_name.__name__, "text-to-image", "AutoPipelineForText2Image")
50+
for _, class_name in AUTO_TEXT2IMAGE_PIPELINES_MAPPING.items()
51+
]
52+
all_supported_pipeline_classes += [
53+
(class_name.__name__, "image-to-image", "AutoPipelineForImage2Image")
54+
for _, class_name in AUTO_IMAGE2IMAGE_PIPELINES_MAPPING.items()
55+
]
56+
all_supported_pipeline_classes += [
57+
(class_name.__name__, "image-to-image", "AutoPipelineForInpainting")
58+
for _, class_name in AUTO_INPAINT_PIPELINES_MAPPING.items()
59+
]
60+
all_supported_pipeline_classes.sort(key=lambda x: x[0])
61+
all_supported_pipeline_classes = list(set(all_supported_pipeline_classes))
62+
63+
data = {}
64+
data["pipeline_class"] = [sample[0] for sample in all_supported_pipeline_classes]
65+
data["pipeline_tag"] = [sample[1] for sample in all_supported_pipeline_classes]
66+
data["auto_class"] = [sample[2] for sample in all_supported_pipeline_classes]
67+
68+
return data
69+
70+
71+
def update_metadata(commit_sha: str):
72+
"""
73+
Update the metadata for the Diffusers repo in `huggingface/diffusers-metadata`.
74+
75+
Args:
76+
commit_sha (`str`): The commit SHA on Diffusers corresponding to this update.
77+
"""
78+
pipelines_table = get_supported_pipeline_table()
79+
pipelines_table = pd.DataFrame(pipelines_table)
80+
pipelines_dataset = Dataset.from_pandas(pipelines_table)
81+
82+
with tempfile.TemporaryDirectory() as tmp_dir:
83+
pipelines_dataset.to_json(os.path.join(tmp_dir, "pipeline_tags.json"))
84+
85+
if commit_sha is not None:
86+
commit_message = (
87+
f"Update with commit {commit_sha}\n\nSee: "
88+
f"https://github.com/huggingface/diffusers/commit/{commit_sha}"
89+
)
90+
else:
91+
commit_message = "Update"
92+
93+
upload_folder(
94+
repo_id="huggingface/diffusers-metadata",
95+
folder_path=tmp_dir,
96+
repo_type="dataset",
97+
commit_message=commit_message,
98+
)
99+
100+
101+
if __name__ == "__main__":
102+
parser = argparse.ArgumentParser()
103+
parser.add_argument("--commit_sha", default=None, type=str, help="The sha of the commit going with this update.")
104+
args = parser.parse_args()
105+
106+
update_metadata(args.commit_sha)

0 commit comments

Comments
 (0)