Skip to content

Commit c2a59d0

Browse files
committed
Support Custom ESRGAN tile size
1 parent 8f6c5c2 commit c2a59d0

File tree

4 files changed

+26
-7
lines changed

4 files changed

+26
-7
lines changed

esrgan.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,11 @@ struct ESRGAN : public GGMLRunner {
156156

157157
ESRGAN(ggml_backend_t backend,
158158
bool offload_params_to_cpu,
159+
int tile_size = 128,
159160
const String2TensorStorage& tensor_storage_map = {})
160161
: GGMLRunner(backend, offload_params_to_cpu) {
161162
// rrdb_net will be created in load_from_file
163+
this->tile_size = tile_size;
162164
}
163165

164166
std::string get_desc() override {

examples/cli/main.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ struct SDParams {
118118
bool canny_preprocess = false;
119119
bool color = false;
120120
int upscale_repeats = 1;
121+
int upscale_tile_size = 128;
121122

122123
// Photo Maker
123124
std::string photo_maker_path;
@@ -204,6 +205,7 @@ void print_params(SDParams params) {
204205
printf(" vae_tiling: %s\n", params.vae_tiling_params.enabled ? "true" : "false");
205206
printf(" force_sdxl_vae_conv_scale: %s\n", params.force_sdxl_vae_conv_scale ? "true" : "false");
206207
printf(" upscale_repeats: %d\n", params.upscale_repeats);
208+
printf(" upscale_tile_size: %d\n", params.upscale_tile_size);
207209
printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false");
208210
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
209211
printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad);
@@ -605,6 +607,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
605607
"--upscale-repeats",
606608
"Run the ESRGAN upscaler this many times (default: 1)",
607609
&params.upscale_repeats},
610+
{"",
611+
"--upscale-tile-size",
612+
"tile size for ESRGAN upscaling (default: 128)",
613+
&params.upscale_tile_size},
608614
{"-H",
609615
"--height",
610616
"image height, in pixel space (default: 512)",
@@ -1187,6 +1193,11 @@ void parse_args(int argc, const char** argv, SDParams& params) {
11871193
exit(1);
11881194
}
11891195

1196+
if (params.upscale_tile_size < 1) {
1197+
fprintf(stderr, "error: upscale tile size must be at least 1\n");
1198+
exit(1);
1199+
}
1200+
11901201
if (params.mode == UPSCALE) {
11911202
if (params.esrgan_path.length() == 0) {
11921203
fprintf(stderr, "error: upscale mode needs an upscaler model (--upscale-model)\n");
@@ -1757,7 +1768,8 @@ int main(int argc, const char* argv[]) {
17571768
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
17581769
params.offload_params_to_cpu,
17591770
params.diffusion_conv_direct,
1760-
params.n_threads);
1771+
params.n_threads,
1772+
params.upscale_tile_size);
17611773

17621774
if (upscaler_ctx == nullptr) {
17631775
printf("new_upscaler_ctx failed\n");

stable-diffusion.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ typedef struct upscaler_ctx_t upscaler_ctx_t;
293293
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
294294
bool offload_params_to_cpu,
295295
bool direct,
296-
int n_threads);
296+
int n_threads,
297+
int tile_size);
297298
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
298299

299300
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx,

upscaler.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@ struct UpscalerGGML {
1010
std::string esrgan_path;
1111
int n_threads;
1212
bool direct = false;
13+
int tile_size = 128;
1314

1415
UpscalerGGML(int n_threads,
15-
bool direct = false)
16+
bool direct = false,
17+
int tile_size = 128)
1618
: n_threads(n_threads),
17-
direct(direct) {
19+
direct(direct),
20+
tile_size(tile_size) {
1821
}
1922

2023
bool load_from_file(const std::string& esrgan_path,
@@ -51,7 +54,7 @@ struct UpscalerGGML {
5154
backend = ggml_backend_cpu_init();
5255
}
5356
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
54-
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, model_loader.get_tensor_storage_map());
57+
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, tile_size, model_loader.get_tensor_storage_map());
5558
if (direct) {
5659
esrgan_upscaler->set_conv2d_direct_enabled(true);
5760
}
@@ -113,14 +116,15 @@ struct upscaler_ctx_t {
113116
upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
114117
bool offload_params_to_cpu,
115118
bool direct,
116-
int n_threads) {
119+
int n_threads,
120+
int tile_size) {
117121
upscaler_ctx_t* upscaler_ctx = (upscaler_ctx_t*)malloc(sizeof(upscaler_ctx_t));
118122
if (upscaler_ctx == nullptr) {
119123
return nullptr;
120124
}
121125
std::string esrgan_path(esrgan_path_c_str);
122126

123-
upscaler_ctx->upscaler = new UpscalerGGML(n_threads, direct);
127+
upscaler_ctx->upscaler = new UpscalerGGML(n_threads, direct, tile_size);
124128
if (upscaler_ctx->upscaler == nullptr) {
125129
return nullptr;
126130
}

0 commit comments

Comments
 (0)