load model from API

This commit is contained in:
arcticfaded 2022-11-04 16:19:14 -04:00
parent 822210bae5
commit 27b5d7c6b2
2 changed files with 20 additions and 1 deletions

View File

@ -10,7 +10,7 @@ from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.sd_samplers import all_samplers
from modules.extras import run_extras, run_pnginfo
from modules.sd_models import checkpoints_list
from modules.sd_models import checkpoints_list, get_closet_checkpoint_match, reload_model_weights
from modules.realesrgan_model import get_realesrgan_models
from typing import List
@ -57,6 +57,7 @@ class Api:
self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem])
self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem])
self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem])
self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["POST"])
self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem])
self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem])
self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem])
@ -247,6 +248,21 @@ class Api:
def get_sd_models(self):
return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()]
def set_sd_models(self, req: LoadModelRequest):
name = req.name
info = get_closet_checkpoint_match(name)
if info is None:
raise HTTPException(status_code=404, detail="Checkpoint not found")
shared.state.begin()
with self.queue_lock:
reload_model_weights(shared.sd_model, info)
shared.state.end()
return "OK"
def get_hypernetworks(self):
return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks]

View File

@ -167,6 +167,9 @@ class ProgressResponse(BaseModel):
state: dict = Field(title="State", description="The current state snapshot")
current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.")
class LoadModelRequest(BaseModel):
name: str = Field(title="Name", description="The name of the checkpoint")
fields = {}
for key, value in opts.data.items():
metadata = opts.data_labels.get(key)