Compare commits
2 Commits
derek
...
load_model
Author | SHA1 | Date |
---|---|---|
![]() |
22c07148e6 | |
![]() |
27b5d7c6b2 |
|
@ -10,7 +10,7 @@ from modules.api.models import *
|
||||||
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
|
||||||
from modules.sd_samplers import all_samplers
|
from modules.sd_samplers import all_samplers
|
||||||
from modules.extras import run_extras, run_pnginfo
|
from modules.extras import run_extras, run_pnginfo
|
||||||
from modules.sd_models import checkpoints_list
|
from modules.sd_models import checkpoints_list, get_closest_checkpoint_match, reload_model_weights
|
||||||
from modules.realesrgan_model import get_realesrgan_models
|
from modules.realesrgan_model import get_realesrgan_models
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
@ -57,6 +57,7 @@ class Api:
|
||||||
self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
|
self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
|
||||||
self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
|
self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
|
||||||
self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
|
self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
|
||||||
|
self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["POST"])
|
||||||
self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
|
||||||
self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
|
||||||
self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
|
||||||
|
@ -247,6 +248,21 @@ class Api:
|
||||||
def get_sd_models(self):
|
def get_sd_models(self):
|
||||||
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
|
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
|
||||||
|
|
||||||
|
def set_sd_models(self, req: LoadModelRequest):
|
||||||
|
name = req.name
|
||||||
|
|
||||||
|
info = get_closest_checkpoint_match(name)
|
||||||
|
if info is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Checkpoint not found")
|
||||||
|
|
||||||
|
shared.state.begin()
|
||||||
|
with self.queue_lock:
|
||||||
|
reload_model_weights(shared.sd_model, info)
|
||||||
|
|
||||||
|
shared.state.end()
|
||||||
|
|
||||||
|
return "OK"
|
||||||
|
|
||||||
def get_hypernetworks(self):
|
def get_hypernetworks(self):
|
||||||
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]
|
||||||
|
|
||||||
|
|
|
@ -167,6 +167,9 @@ class ProgressResponse(BaseModel):
|
||||||
state: dict = Field(title="State", description="The current state snapshot")
|
state: dict = Field(title="State", description="The current state snapshot")
|
||||||
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
|
||||||
|
|
||||||
|
class LoadModelRequest(BaseModel):
|
||||||
|
name: str = Field(title="Name", description="The name of the checkpoint")
|
||||||
|
|
||||||
fields = {}
|
fields = {}
|
||||||
for key, value in opts.data.items():
|
for key, value in opts.data.items():
|
||||||
metadata = opts.data_labels.get(key)
|
metadata = opts.data_labels.get(key)
|
||||||
|
|
|
@ -84,7 +84,7 @@ def list_models():
|
||||||
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
|
checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
|
||||||
|
|
||||||
|
|
||||||
def get_closet_checkpoint_match(searchString):
|
def get_closest_checkpoint_match(searchString):
|
||||||
applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
|
applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
|
||||||
if len(applicable) > 0:
|
if len(applicable) > 0:
|
||||||
return applicable[0]
|
return applicable[0]
|
||||||
|
|
|
@ -542,7 +542,7 @@ def apply_setting(key, value):
|
||||||
return gr.update()
|
return gr.update()
|
||||||
|
|
||||||
if key == "sd_model_checkpoint":
|
if key == "sd_model_checkpoint":
|
||||||
ckpt_info = sd_models.get_closet_checkpoint_match(value)
|
ckpt_info = sd_models.get_closest_checkpoint_match(value)
|
||||||
|
|
||||||
if ckpt_info is not None:
|
if ckpt_info is not None:
|
||||||
value = ckpt_info.title
|
value = ckpt_info.title
|
||||||
|
|
|
@ -85,7 +85,7 @@ def confirm_samplers(p, xs):
|
||||||
|
|
||||||
|
|
||||||
def apply_checkpoint(p, x, xs):
|
def apply_checkpoint(p, x, xs):
|
||||||
info = modules.sd_models.get_closet_checkpoint_match(x)
|
info = modules.sd_models.get_closest_checkpoint_match(x)
|
||||||
if info is None:
|
if info is None:
|
||||||
raise RuntimeError(f"Unknown checkpoint: {x}")
|
raise RuntimeError(f"Unknown checkpoint: {x}")
|
||||||
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
modules.sd_models.reload_model_weights(shared.sd_model, info)
|
||||||
|
@ -94,7 +94,7 @@ def apply_checkpoint(p, x, xs):
|
||||||
|
|
||||||
def confirm_checkpoints(p, xs):
|
def confirm_checkpoints(p, xs):
|
||||||
for x in xs:
|
for x in xs:
|
||||||
if modules.sd_models.get_closet_checkpoint_match(x) is None:
|
if modules.sd_models.get_closest_checkpoint_match(x) is None:
|
||||||
raise RuntimeError(f"Unknown checkpoint: {x}")
|
raise RuntimeError(f"Unknown checkpoint: {x}")
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue