Skip to content

Commit 6cd9f93

Browse files
committed
feat: add tile size parameter for ESRGAN upscaling
1 parent 1798ec0 commit 6cd9f93

File tree

4 files changed

+26
-7
lines changed

4 files changed

+26
-7
lines changed

esrgan.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,11 @@ struct ESRGAN : public GGMLRunner {
156156

157157
ESRGAN(ggml_backend_t backend,
158158
bool offload_params_to_cpu,
159+
int tile_size = 128,
159160
const String2TensorStorage& tensor_storage_map = {})
160161
: GGMLRunner(backend, offload_params_to_cpu) {
161162
// rrdb_net will be created in load_from_file
163+
this->tile_size = tile_size;
162164
}
163165

164166
std::string get_desc() override {

examples/cli/main.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ struct SDParams {
130130
bool canny_preprocess = false;
131131
bool color = false;
132132
int upscale_repeats = 1;
133+
int upscale_tile_size = 128;
133134

134135
// Photo Maker
135136
std::string photo_maker_path;
@@ -227,6 +228,7 @@ void print_params(SDParams params) {
227228
printf(" vae_tiling: %s\n", params.vae_tiling_params.enabled ? "true" : "false");
228229
printf(" force_sdxl_vae_conv_scale: %s\n", params.force_sdxl_vae_conv_scale ? "true" : "false");
229230
printf(" upscale_repeats: %d\n", params.upscale_repeats);
231+
printf(" upscale_tile_size: %d\n", params.upscale_tile_size);
230232
printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false");
231233
printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false");
232234
printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad);
@@ -647,6 +649,10 @@ void parse_args(int argc, const char** argv, SDParams& params) {
647649
"--upscale-repeats",
648650
"Run the ESRGAN upscaler this many times (default: 1)",
649651
&params.upscale_repeats},
652+
{"",
653+
"--upscale-tile-size",
654+
"tile size for ESRGAN upscaling (default: 128)",
655+
&params.upscale_tile_size},
650656
{"-H",
651657
"--height",
652658
"image height, in pixel space (default: 512)",
@@ -1371,6 +1377,11 @@ void parse_args(int argc, const char** argv, SDParams& params) {
13711377
exit(1);
13721378
}
13731379

1380+
if (params.upscale_tile_size < 1) {
1381+
fprintf(stderr, "error: upscale tile size must be at least 1\n");
1382+
exit(1);
1383+
}
1384+
13741385
if (params.mode == UPSCALE) {
13751386
if (params.esrgan_path.length() == 0) {
13761387
fprintf(stderr, "error: upscale mode needs an upscaler model (--upscale-model)\n");
@@ -1995,7 +2006,8 @@ int main(int argc, const char* argv[]) {
19952006
upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(),
19962007
params.offload_params_to_cpu,
19972008
params.diffusion_conv_direct,
1998-
params.n_threads);
2009+
params.n_threads,
2010+
params.upscale_tile_size);
19992011

20002012
if (upscaler_ctx == nullptr) {
20012013
printf("new_upscaler_ctx failed\n");

stable-diffusion.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,8 @@ typedef struct upscaler_ctx_t upscaler_ctx_t;
332332
SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path,
333333
bool offload_params_to_cpu,
334334
bool direct,
335-
int n_threads);
335+
int n_threads,
336+
int tile_size);
336337
SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx);
337338

338339
SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx,

upscaler.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@ struct UpscalerGGML {
1010
std::string esrgan_path;
1111
int n_threads;
1212
bool direct = false;
13+
int tile_size = 128;
1314

1415
UpscalerGGML(int n_threads,
15-
bool direct = false)
16+
bool direct = false,
17+
int tile_size = 128)
1618
: n_threads(n_threads),
17-
direct(direct) {
19+
direct(direct),
20+
tile_size(tile_size) {
1821
}
1922

2023
bool load_from_file(const std::string& esrgan_path,
@@ -51,7 +54,7 @@ struct UpscalerGGML {
5154
backend = ggml_backend_cpu_init();
5255
}
5356
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
54-
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, model_loader.get_tensor_storage_map());
57+
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, tile_size, model_loader.get_tensor_storage_map());
5558
if (direct) {
5659
esrgan_upscaler->set_conv2d_direct_enabled(true);
5760
}
@@ -113,14 +116,15 @@ struct upscaler_ctx_t {
113116
upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str,
114117
bool offload_params_to_cpu,
115118
bool direct,
116-
int n_threads) {
119+
int n_threads,
120+
int tile_size) {
117121
upscaler_ctx_t* upscaler_ctx = (upscaler_ctx_t*)malloc(sizeof(upscaler_ctx_t));
118122
if (upscaler_ctx == nullptr) {
119123
return nullptr;
120124
}
121125
std::string esrgan_path(esrgan_path_c_str);
122126

123-
upscaler_ctx->upscaler = new UpscalerGGML(n_threads, direct);
127+
upscaler_ctx->upscaler = new UpscalerGGML(n_threads, direct, tile_size);
124128
if (upscaler_ctx->upscaler == nullptr) {
125129
return nullptr;
126130
}

0 commit comments

Comments
 (0)