diff --git a/modules/sd_models.py b/modules/sd_models.py index ae427a5c..63e07a12 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -163,11 +163,11 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) + if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"): + sd_vae.restore_base_vae(model) + checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy() - checkpoint_key = checkpoint_info - - if checkpoint_key not in checkpoints_loaded: + if checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) @@ -197,18 +197,15 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.first_stage_model.to(devices.dtype_vae) - if shared.opts.sd_checkpoint_cache > 0: - # if PR #4035 were to get merged, restore base VAE first before caching - checkpoints_loaded[checkpoint_key] = model.state_dict().copy() - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: - checkpoints_loaded.popitem(last=False) # LRU - else: vae_name = sd_vae.get_filename(vae_file) if vae_file else None vae_message = f" with {vae_name} VAE" if vae_name else "" print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") - checkpoints_loaded.move_to_end(checkpoint_key) - model.load_state_dict(checkpoints_loaded[checkpoint_key]) + model.load_state_dict(checkpoints_loaded[checkpoint_info]) + + if shared.opts.sd_checkpoint_cache > 0: + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) # LRU model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file