diff --git a/plugin.py b/plugin.py index 5f96941..5d1bcb4 100644 --- a/plugin.py +++ b/plugin.py @@ -37,21 +37,24 @@ def set_config(update: dict): sd_plugin.set_config(update) # TODO: Validate config dict are all valid keys return sd_plugin.get_config() -@app.on_event("startup") -async def startup_event(): +@app.get("/startup/{plugin_name}") +async def startup_event(plugin_name: str): print("Starting up") # A slight delay to ensure the app has started up. try: - set_model() + set_model(plugin_name) print("Successfully started up") + print(sd_plugin.plugin_name) sd_plugin.notify_main_system_of_startup("True") - except: + return {"status": "Success", "detail": "Plugin started successfully"} + + except Exception as e: sd_plugin.notify_main_system_of_startup("False") @app.get("/set_model/") -def set_model(): +def set_model(plugin_name): global sd_plugin - args = {"plugin": plugin, "config": config, "endpoints": endpoints} + args = {"plugin": plugin, "config": config, "endpoints": endpoints, "name": plugin_name} sd_plugin = SD(Namespace(**args)) # try: # sd_plugin.set_model(args["model_name"], dtype=args["model_dtype"]) @@ -106,7 +109,6 @@ class SD(Plugin): """ def __init__(self, arguments: "Namespace") -> None: super().__init__(arguments) - self.plugin_name = "Diffusers" self.set_model() def load_lora_weights(self, pipeline, checkpoint_path, multiplier=1):