Skip to content

Commit 8823dc4

Browse files
authored
feat: align the spatial size to the corresponding multiple (#1073)
1 parent 1ac5a61 commit 8823dc4

File tree

2 files changed

+44
-17
lines changed

2 files changed

+44
-17
lines changed

ggml_extend.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@
6060
#define SD_UNUSED(x) (void)(x)
6161
#endif
6262

63+
__STATIC_INLINE__ int align_up_offset(int n, int multiple) {
64+
return (multiple - n % multiple) % multiple;
65+
}
66+
67+
__STATIC_INLINE__ int align_up(int n, int multiple) {
68+
return n + align_up_offset(n, multiple);
69+
}
70+
6371
__STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void*) {
6472
switch (level) {
6573
case GGML_LOG_LEVEL_DEBUG:

stable-diffusion.cpp

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1898,6 +1898,18 @@ class StableDiffusionGGML {
18981898
return vae_scale_factor;
18991899
}
19001900

1901+
int get_diffusion_model_down_factor() {
1902+
int down_factor = 8; // unet
1903+
if (sd_version_is_dit(version)) {
1904+
if (sd_version_is_wan(version)) {
1905+
down_factor = 2;
1906+
} else {
1907+
down_factor = 1;
1908+
}
1909+
}
1910+
return down_factor;
1911+
}
1912+
19011913
int get_latent_channel() {
19021914
int latent_channel = 4;
19031915
if (sd_version_is_dit(version)) {
@@ -3133,22 +3145,19 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
31333145
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
31343146
int width = sd_img_gen_params->width;
31353147
int height = sd_img_gen_params->height;
3136-
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
3137-
if (sd_version_is_dit(sd_ctx->sd->version)) {
3138-
if (width % 16 || height % 16) {
3139-
LOG_ERROR("Image dimensions must be must be a multiple of 16 on each axis for %s models. (Got %dx%d)",
3140-
model_version_to_str[sd_ctx->sd->version],
3141-
width,
3142-
height);
3143-
return nullptr;
3144-
}
3145-
} else if (width % 64 || height % 64) {
3146-
LOG_ERROR("Image dimensions must be must be a multiple of 64 on each axis for %s models. (Got %dx%d)",
3147-
model_version_to_str[sd_ctx->sd->version],
3148-
width,
3149-
height);
3150-
return nullptr;
3148+
3149+
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
3150+
int diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor();
3151+
int spatial_multiple = vae_scale_factor * diffusion_model_down_factor;
3152+
3153+
int width_offset = align_up_offset(width, spatial_multiple);
3154+
int height_offset = align_up_offset(height, spatial_multiple);
3155+
if (width_offset > 0 || height_offset > 0) {
3156+
width += width_offset;
3157+
height += height_offset;
3158+
LOG_WARN("align up %dx%d to %dx%d (multiple=%d)", sd_img_gen_params->width, sd_img_gen_params->height, width, height, spatial_multiple);
31513159
}
3160+
31523161
LOG_DEBUG("generate_image %dx%d", width, height);
31533162
if (sd_ctx == nullptr || sd_img_gen_params == nullptr) {
31543163
return nullptr;
@@ -3422,9 +3431,19 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
34223431
int frames = sd_vid_gen_params->video_frames;
34233432
frames = (frames - 1) / 4 * 4 + 1;
34243433
int sample_steps = sd_vid_gen_params->sample_params.sample_steps;
3425-
LOG_INFO("generate_video %dx%dx%d", width, height, frames);
34263434

3427-
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
3435+
int vae_scale_factor = sd_ctx->sd->get_vae_scale_factor();
3436+
int diffusion_model_down_factor = sd_ctx->sd->get_diffusion_model_down_factor();
3437+
int spatial_multiple = vae_scale_factor * diffusion_model_down_factor;
3438+
3439+
int width_offset = align_up_offset(width, spatial_multiple);
3440+
int height_offset = align_up_offset(height, spatial_multiple);
3441+
if (width_offset > 0 || height_offset > 0) {
3442+
width += width_offset;
3443+
height += height_offset;
3444+
LOG_WARN("align up %dx%d to %dx%d (multiple=%d)", sd_vid_gen_params->width, sd_vid_gen_params->height, width, height, spatial_multiple);
3445+
}
3446+
LOG_INFO("generate_video %dx%dx%d", width, height, frames);
34283447

34293448
enum sample_method_t sample_method = sd_vid_gen_params->sample_params.sample_method;
34303449
if (sample_method == SAMPLE_METHOD_COUNT) {

0 commit comments

Comments
 (0)