diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index c0ee1182d..94221d8e6 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -107,6 +107,7 @@ class StableDiffusionGGML { SDVersion version; bool vae_decode_only = false; + bool external_vae_is_invalid = false; bool free_params_immediately = false; std::shared_ptr rng = std::make_shared(); @@ -321,6 +322,7 @@ class StableDiffusionGGML { LOG_INFO("loading vae from '%s'", sd_ctx_params->vae_path); if (!model_loader.init_from_file(sd_ctx_params->vae_path, "vae.")) { LOG_WARN("loading vae from '%s' failed", sd_ctx_params->vae_path); + external_vae_is_invalid = true; } } @@ -619,10 +621,10 @@ class StableDiffusionGGML { first_stage_model->set_conv2d_direct_enabled(true); } if (sd_version_is_sdxl(version) && - (strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale)) { + (strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 || sd_ctx_params->force_sdxl_vae_conv_scale || external_vae_is_invalid)) { float vae_conv_2d_scale = 1.f / 32.f; LOG_WARN( - "No VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, " + "No valid VAE specified with --vae or --force-sdxl-vae-conv-scale flag set, " "using Conv2D scale %.3f", vae_conv_2d_scale); first_stage_model->set_conv2d_scale(vae_conv_2d_scale);