From 27b5d7c6b2d6bd100c22cd59ebfa76789cef9a57 Mon Sep 17 00:00:00 2001 From: arcticfaded Date: Fri, 4 Nov 2022 16:19:14 -0400 Subject: [PATCH] load model from API --- modules/api/api.py | 18 +++++++++++++++++- modules/api/models.py | 3 +++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/modules/api/api.py b/modules/api/api.py index 8a7ab2f5..995cfc3d 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -10,7 +10,7 @@ from modules.api.models import * from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo -from modules.sd_models import checkpoints_list +from modules.sd_models import checkpoints_list, get_closet_checkpoint_match, reload_model_weights from modules.realesrgan_model import get_realesrgan_models from typing import List @@ -57,6 +57,7 @@ class Api: self.app.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[SamplerItem]) self.app.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[UpscalerItem]) self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[SDModelItem]) + self.app.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["POST"]) self.app.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[HypernetworkItem]) self.app.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[FaceRestorerItem]) self.app.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[RealesrganItem]) @@ -247,6 +248,21 @@ class Api: def get_sd_models(self): return [{"title":x.title, "model_name":x.model_name, "hash":x.hash, "filename": x.filename, "config": x.config} for x in checkpoints_list.values()] + def set_sd_models(self, req: LoadModelRequest): + name = req.name + + info = get_closet_checkpoint_match(name) + if info is None: + raise HTTPException(status_code=404, detail="Checkpoint not found") + + shared.state.begin() + with self.queue_lock: + reload_model_weights(shared.sd_model, info) + + shared.state.end() + + return "OK" + def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] diff --git a/modules/api/models.py b/modules/api/models.py index 2ae75f43..7f0f425e 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -167,6 +167,9 @@ class ProgressResponse(BaseModel): state: dict = Field(title="State", description="The current state snapshot") current_image: str = Field(default=None, title="Current image", description="The current image in base64 format. opts.show_progress_every_n_steps is required for this to work.") +class LoadModelRequest(BaseModel): + name: str = Field(title="Name", description="The name of the checkpoint") + fields = {} for key, value in opts.data.items(): metadata = opts.data_labels.get(key)