Liu Song’s Projects


~/Projects/stable-diffusion-webui

git clone https://code.lsong.org/stable-diffusion-webui

Commit

Commit
ef1698fd6dbd6387341a1eeeded068ff1476ee50
Author
AUTOMATIC1111 <[email protected]>
Date
2023-08-05 08:01:38 +0300 +0300
Diffstat
 .github/workflows/run_tests.yaml | 1 
 CHANGELOG.md | 90 ++
 README.md | 4 
 extensions-builtin/Lora/extra_networks_lora.py | 36 
 extensions-builtin/Lora/lora.py | 506 ------------
 extensions-builtin/Lora/lyco_helpers.py | 21 
 extensions-builtin/Lora/network.py | 155 +++
 extensions-builtin/Lora/network_full.py | 22 
 extensions-builtin/Lora/network_hada.py | 55 +
 extensions-builtin/Lora/network_ia3.py | 30 
 extensions-builtin/Lora/network_lokr.py | 64 +
 extensions-builtin/Lora/network_lora.py | 86 ++
 extensions-builtin/Lora/networks.py | 468 +++++++++++
 extensions-builtin/Lora/preload.py | 1 
 extensions-builtin/Lora/scripts/lora_script.py | 90 +
 extensions-builtin/Lora/ui_edit_user_metadata.py | 30 
 extensions-builtin/Lora/ui_extra_networks_lora.py | 43 
 extensions-builtin/mobile/javascript/mobile.js | 26 
 html/extra-networks-card.html | 2 
 javascript/extraNetworks.js | 2 
 javascript/hints.js | 11 
 javascript/localization.js | 10 
 javascript/ui.js | 14 
 launch.py | 7 
 modules/api/api.py | 51 
 modules/api/models.py | 8 
 modules/cache.py | 31 
 modules/call_queue.py | 5 
 modules/cmd_args.py | 6 
 modules/devices.py | 86 +
 modules/errors.py | 53 +
 modules/extensions.py | 18 
 modules/extra_networks.py | 35 
 modules/extras.py | 38 
 modules/generation_parameters_copypaste.py | 3 
 modules/gradio_extensons.py | 60 +
 modules/hypernetworks/hypernetwork.py | 7 
 modules/images.py | 4 
 modules/img2img.py | 34 
 modules/launch_utils.py | 85 +
 modules/lowvram.py | 67 +
 modules/paths.py | 25 
 modules/processing.py | 229 +++-
 modules/prompt_parser.py | 131 ++
 modules/rng_philox.py | 102 ++
 modules/script_loading.py | 5 
 modules/scripts.py | 81 -
 modules/sd_disable_initialization.py | 102 ++
 modules/sd_hijack.py | 60 +
 modules/sd_hijack_clip.py | 37 
 modules/sd_hijack_inpainting.py | 2 
 modules/sd_hijack_open_clip.py | 34 
 modules/sd_hijack_optimizations.py | 57 +
 modules/sd_hijack_unet.py | 8 
 modules/sd_models.py | 276 ++++-
 modules/sd_models_config.py | 9 
 modules/sd_models_xl.py | 108 ++
 modules/sd_samplers.py | 3 
 modules/sd_samplers_common.py | 19 
 modules/sd_samplers_compvis.py | 6 
 modules/sd_samplers_extra.py | 74 +
 modules/sd_samplers_kdiffusion.py | 67 +
 modules/sd_vae.py | 16 
 modules/sd_vae_approx.py | 59 
 modules/sd_vae_taesd.py | 24 
 modules/shared.py | 89 +
 modules/styles.py | 5 
 modules/sysinfo.py | 7 
 modules/textual_inversion/textual_inversion.py | 24 
 modules/timer.py | 24 
 modules/txt2img.py | 3 
 modules/ui.py | 481 ++++-------
 modules/ui_checkpoint_merger.py | 124 ++
 modules/ui_common.py | 38 
 modules/ui_components.py | 2 
 modules/ui_extensions.py | 27 
 modules/ui_extra_networks.py | 18 
 modules/ui_extra_networks_checkpoints.py | 6 
 modules/ui_extra_networks_checkpoints_user_metadata.py | 60 +
 modules/ui_extra_networks_hypernets.py | 2 
 modules/ui_extra_networks_textual_inversion.py | 2 
 modules/ui_extra_networks_user_metadata.py | 8 
 modules/ui_postprocessing.py | 2 
 modules/ui_prompt_styles.py | 110 ++
 modules/ui_settings.py | 2 
 requirements.txt | 3 
 requirements_versions.txt | 17 
 scripts/xyz_grid.py | 40 
 style.css | 45 
 webui.py | 63 -
 webui.sh | 15 

Merge branch 'dev' into extra-networks-always-visible


diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml
index e9370cc0758f9fa40f72d0bfbc999f3acf935dd6..3dafaf8dcfcd14fd7a7ca3385806efad5550b871 100644
--- a/.github/workflows/run_tests.yaml
+++ b/.github/workflows/run_tests.yaml
@@ -41,6 +41,7 @@           launch.py
           --skip-prepare-environment
           --skip-torch-cuda-test
           --test-server
+          --do-not-download-clip
           --no-half
           --disable-opt-split-attention
           --use-cpu all




diff --git a/CHANGELOG.md b/CHANGELOG.md
index 925403a9138f294fe1b2534471cf509c717c3dfe..b18c6867348e0cdd3610f1f3c1f292e9347c902d 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,93 @@
+## 1.5.1

+

+### Minor:

+ * support parsing text encoder blocks in some new LoRAs

+ * delete scale checker script due to user demand

+

+### Extensions and API:

+ * add postprocess_batch_list script callback

+

+### Bug Fixes:

+ * fix TI training for SD1

+ * fix reload altclip model error

+ * prepend the pythonpath instead of overriding it

+ * fix typo in SD_WEBUI_RESTARTING

+ * if txt2img/img2img raises an exception, finally call state.end()

+ * fix composable diffusion weight parsing

+ * restyle Startup profile for black users

+ * fix webui not launching with --nowebui

+ * catch exception for non git extensions

+ * fix some options missing from /sdapi/v1/options

+ * fix for extension update status always saying "unknown"

+ * fix display of extra network cards that have `<>` in the name

+ * update lora extension to work with python 3.8

+

+

+## 1.5.0

+

+### Features:

+ * SD XL support

+ * user metadata system for custom networks

+ * extended Lora metadata editor: set activation text, default weight, view tags, training info

+ * Lora extension rework to include other types of networks (all that were previously handled by LyCORIS extension)

+ * show github stars for extenstions

+ * img2img batch mode can read extra stuff from png info

+ * img2img batch works with subdirectories

+ * hotkeys to move prompt elements: alt+left/right

+ * restyle time taken/VRAM display

+ * add textual inversion hashes to infotext

+ * optimization: cache git extension repo information

+ * move generate button next to the generated picture for mobile clients

+ * hide cards for networks of incompatible Stable Diffusion version in Lora extra networks interface

+ * skip installing packages with pip if they all are already installed - startup speedup of about 2 seconds

+

+### Minor:

+ * checkbox to check/uncheck all extensions in the Installed tab

+ * add gradio user to infotext and to filename patterns

+ * allow gif for extra network previews

+ * add options to change colors in grid

+ * use natural sort for items in extra networks

+ * Mac: use empty_cache() from torch 2 to clear VRAM

+ * added automatic support for installing the right libraries for Navi3 (AMD)

+ * add option SWIN_torch_compile to accelerate SwinIR upscale

+ * suppress printing TI embedding info at start to console by default

+ * speedup extra networks listing

+ * added `[none]` filename token.

+ * removed thumbs extra networks view mode (use settings tab to change width/height/scale to get thumbs)

+ * add always_discard_next_to_last_sigma option to XYZ plot

+ * automatically switch to 32-bit float VAE if the generated picture has NaNs without the need for `--no-half-vae` commandline flag.

+ 

+### Extensions and API:

+ * api endpoints: /sdapi/v1/server-kill, /sdapi/v1/server-restart, /sdapi/v1/server-stop

+ * allow Script to have custom metaclass

+ * add model exists status check /sdapi/v1/options

+ * rename --add-stop-route to --api-server-stop

+ * add `before_hr` script callback

+ * add callback `after_extra_networks_activate`

+ * disable rich exception output in console for API by default, use WEBUI_RICH_EXCEPTIONS env var to enable

+ * return http 404 when thumb file not found

+ * allow replacing extensions index with environment variable

+ 

+### Bug Fixes:

+ * fix for catch errors when retrieving extension index #11290

+ * fix very slow loading speed of .safetensors files when reading from network drives

+ * API cache cleanup

+ * fix UnicodeEncodeError when writing to file CLIP Interrogator batch mode

+ * fix warning of 'has_mps' deprecated from PyTorch

+ * fix problem with extra network saving images as previews losing generation info

+ * fix throwing exception when trying to resize image with I;16 mode

+ * fix for #11534: canvas zoom and pan extension hijacking shortcut keys

+ * fixed launch script to be runnable from any directory

+ * don't add "Seed Resize: -1x-1" to API image metadata

+ * correctly remove end parenthesis with ctrl+up/down

+ * fixing --subpath on newer gradio version

+ * fix: check fill size none zero when resize  (fixes #11425)

+ * use submit and blur for quick settings textbox

+ * save img2img batch with images.save_image()

+ * prevent running preload.py for disabled extensions

+ * fix: previously, model name was added together with directory name to infotext and to [model_name] filename pattern; directory name is now not included

+

+

 ## 1.4.1

 

 ### Bug Fixes:





diff --git a/README.md b/README.md
index e6d8e4bd423925b38f685b4c268b4751a8bd9bb0..2fd6e425de71d7e8eacba73672143e5f788a23ba 100644
--- a/README.md
+++ b/README.md
@@ -88,7 +88,7 @@ - [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions
 - [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions

 - Now without any bad letters!

 - Load checkpoints in safetensors format

-- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64

+- Eased resolution restriction: generated image's dimension must be a multiple of 8 rather than 64

 - Now with a license!

 - Reorder elements in the UI from settings screen

 

@@ -168,5 +168,7 @@ - Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix
 - Security advice - RyotaK

 - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC

 - TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd

+- LyCORIS - KohakuBlueleaf

+- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling

 - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.

 - (You)





diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py
index 66ee9c8563919fcda648f58d34278b85b4571f8d..ba2945c6fe1e77d87226f08fb20da0624959364b 100644
--- a/extensions-builtin/Lora/extra_networks_lora.py
+++ b/extensions-builtin/Lora/extra_networks_lora.py
@@ -1,5 +1,5 @@
 from modules import extra_networks, shared

-import lora

+import networks

 

 

 class ExtraNetworkLora(extra_networks.ExtraNetwork):

@@ -9,29 +9,44 @@ 
     def activate(self, p, params_list):

         additional = shared.opts.sd_lora

 

-        if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional):

+        if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):

             p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]

             params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))

 

         names = []

-from modules import extra_networks, shared

 

+    def __init__(self):

+        unet_multipliers = []

+        dyn_dims = []

         for params in params_list:

             assert params.items

 

+            names.append(params.positional[0])

+

+            te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0

+            te_multiplier = float(params.named.get("te", te_multiplier))

+

+class ExtraNetworkLora(extra_networks.ExtraNetwork):

 from modules import extra_networks, shared

+            unet_multiplier = float(params.named.get("unet", unet_multiplier))

+

+            dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None

+            dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim

+

+            te_multipliers.append(te_multiplier)

+class ExtraNetworkLora(extra_networks.ExtraNetwork):

         super().__init__('lora')

-from modules import extra_networks, shared

+class ExtraNetworkLora(extra_networks.ExtraNetwork):

     def activate(self, p, params_list):

 

-from modules import extra_networks, shared

+class ExtraNetworkLora(extra_networks.ExtraNetwork):

         additional = shared.opts.sd_lora

 

         if shared.opts.lora_add_hashes_to_infotext:

-import lora

+            network_hashes = []

+    def __init__(self):

-import lora

+    def __init__(self):

 from modules import extra_networks, shared

-                shorthash = item.lora_on_disk.shorthash

                 if not shorthash:

                     continue

 

@@ -41,11 +56,12 @@                     continue
 

                 alias = alias.replace(":", "").replace(",", "")

 

+    def __init__(self):

 import lora

-        additional = shared.opts.sd_lora

 

-            if lora_hashes:

+    def __init__(self):

 

+                p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)

 

     def deactivate(self, p):

         pass





diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 467ad65f200ef643df68494ee1a3354f68de7b14..9365aa74b4fe30afbb27fa08284ccd6bcb094fe7 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -1,526 +1,20 @@
-import os

-import re

 import torch

-from typing import Union

-

-from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes, cache

-

-metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}

-

-re_digits = re.compile(r"\d+")

-re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")

-re_compiled = {}

-

-suffix_conversion = {

-import os

 import re

-    "resnets": {

-        "conv1": "in_layers_2",

-        "conv2": "out_layers_3",

-        "time_emb_proj": "emb_layers_1",

-import os

 metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}

-    }

-}

 

-

-def convert_diffusers_name_to_compvis(key, is_sd2):

-    def match(match_list, regex_text):

-        regex = re_compiled.get(regex_text)

-import re

 import torch

-            regex = re.compile(regex_text)

-            re_compiled[regex_text] = regex

-

-        r = re.match(regex, key)

-        if not r:

             return False

-

-        match_list.clear()

-        match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])

-        return True

 

     m = []

-

-    if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):

-        suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])

-        return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"

-

-    if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):

-        suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])

-        return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"

-

-import torch

 re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")

 import torch

-from typing import Union

-        return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"

-

-    if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):

-        return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"

-

-from typing import Union

 import torch

-        return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"

-

-    if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):

-        if is_sd2:

-            if 'mlp_fc1' in m[1]:

-                return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"

-            elif 'mlp_fc2' in m[1]:

-

-            else:

-                return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"

-

-

 import torch

-

-    return key

-

-

-class LoraOnDisk:

-    def __init__(self, name, filename):

-        self.name = name

-        self.filename = filename

-        self.metadata = {}

-        self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"

-

-        def read_metadata():

-            metadata = sd_models.read_metadata_from_safetensors(filename)

-from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes, cache

 import torch

-

-            return metadata

-

-        if self.is_safetensors:

-            try:

-                self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)

-            except Exception as e:

-                errors.display(e, f"reading lora {filename}")

-

-        if self.metadata:

-metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}

 import os

-            for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):

-metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}

 import torch

-

-            self.metadata = m

-

-        self.alias = self.metadata.get('ss_output_name', self.name)

-

-        self.hash = None

-        self.shorthash = None

-        self.set_hash(

-            self.metadata.get('sshs_model_hash') or

-            hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or

-            ''

-        )

-

-re_digits = re.compile(r"\d+")

 import torch

-        self.hash = v

-        self.shorthash = self.hash[0:12]

-

-        if self.shorthash:

-            available_lora_hash_lookup[self.shorthash] = self

-

-    def read_hash(self):

-        if not self.hash:

-            self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')

-

-    def get_alias(self):

-re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")

 import re

-re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")

 import torch

-        else:

-            return self.alias

-

-

-class LoraModule:

-    def __init__(self, name, lora_on_disk: LoraOnDisk):

-        self.name = name

-        self.lora_on_disk = lora_on_disk

-        self.multiplier = 1.0

-        self.modules = {}

-        self.mtime = None

-

-        self.mentioned_name = None

-re_compiled = {}

 import torch

-

-

-class LoraUpDownModule:

-    def __init__(self):

-        self.up = None

-        self.down = None

-        self.alpha = None

-

-

-def assign_lora_names_to_compvis_modules(sd_model):

-    lora_layer_mapping = {}

-

-    for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():

-        lora_name = name.replace(".", "_")

-suffix_conversion = {

 import torch

-        module.lora_layer_name = lora_name

-

-    for name, module in shared.sd_model.model.named_modules():

-        lora_name = name.replace(".", "_")

-        lora_layer_mapping[lora_name] = module

-        module.lora_layer_name = lora_name

-

-    sd_model.lora_layer_mapping = lora_layer_mapping

-

-

-def load_lora(name, lora_on_disk):

-    lora = LoraModule(name, lora_on_disk)

-    lora.mtime = os.path.getmtime(lora_on_disk.filename)

-

-    sd = sd_models.read_state_dict(lora_on_disk.filename)

-

-    # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0

-    if not hasattr(shared.sd_model, 'lora_layer_mapping'):

-        assign_lora_names_to_compvis_modules(shared.sd_model)

-

-    keys_failed_to_match = {}

-    is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping

-

-    for key_diffusers, weight in sd.items():

-        key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)

-        key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)

-

-        sd_module = shared.sd_model.lora_layer_mapping.get(key, None)

-

-        if sd_module is None:

-            m = re_x_proj.match(key)

-            if m:

-                sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)

-

-        if sd_module is None:

-            keys_failed_to_match[key_diffusers] = key

-            continue

-

-        lora_module = lora.modules.get(key, None)

-        if lora_module is None:

-            lora_module = LoraUpDownModule()

-            lora.modules[key] = lora_module

-

-        if lora_key == "alpha":

-            lora_module.alpha = weight.item()

-            continue

-

-        if type(sd_module) == torch.nn.Linear:

-            module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)

-        elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:

-            module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)

-        elif type(sd_module) == torch.nn.MultiheadAttention:

-            module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)

-        elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):

-            module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)

-        elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):

-            module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)

-        else:

-            print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')

-            continue

-            raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}")

-

-        with torch.no_grad():

-            module.weight.copy_(weight)

-

-        module.to(device=devices.cpu, dtype=devices.dtype)

-

-        if lora_key == "lora_up.weight":

-            lora_module.up = module

-        elif lora_key == "lora_down.weight":

-            lora_module.down = module

-        else:

-            raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha")

-

-    if keys_failed_to_match:

-        print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}")

-

-    return lora

-

-

-def load_loras(names, multipliers=None):

-    already_loaded = {}

-

-    for lora in loaded_loras:

-        if lora.name in names:

-            already_loaded[lora.name] = lora

-

-    loaded_loras.clear()

-

-    loras_on_disk = [available_lora_aliases.get(name, None) for name in names]

-    if any(x is None for x in loras_on_disk):

-        list_available_loras()

-

-        loras_on_disk = [available_lora_aliases.get(name, None) for name in names]

-

-    failed_to_load_loras = []

-

-    for i, name in enumerate(names):

-        lora = already_loaded.get(name, None)

-

-        lora_on_disk = loras_on_disk[i]

-

-        if lora_on_disk is not None:

-            if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime:

-                try:

-                    lora = load_lora(name, lora_on_disk)

-                except Exception as e:

-                    errors.display(e, f"loading Lora {lora_on_disk.filename}")

-                    continue

-

-            lora.mentioned_name = name

-

-            lora_on_disk.read_hash()

-

-        if lora is None:

-            failed_to_load_loras.append(name)

-            print(f"Couldn't find Lora with name {name}")

-            continue

-

-        lora.multiplier = multipliers[i] if multipliers else 1.0

-        loaded_loras.append(lora)

-

-    if failed_to_load_loras:

-        sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras))

-

-

-def lora_calc_updown(lora, module, target):

-    with torch.no_grad():

-        up = module.up.weight.to(target.device, dtype=target.dtype)

-        down = module.down.weight.to(target.device, dtype=target.dtype)

-

-        if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):

-            updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)

-        elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):

-            updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)

-        else:

-            updown = up @ down

-

-        updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)

-

-        return updown

-

-

-def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):

-    weights_backup = getattr(self, "lora_weights_backup", None)

-

-    if weights_backup is None:

-        return

-

-    if isinstance(self, torch.nn.MultiheadAttention):

-        self.in_proj_weight.copy_(weights_backup[0])

-        self.out_proj.weight.copy_(weights_backup[1])

-    else:

-        self.weight.copy_(weights_backup)

-

-

-def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):

-    """

-    Applies the currently selected set of Loras to the weights of torch layer self.

-    If weights already have this particular set of loras applied, does nothing.

-    If not, restores orginal weights from backup and alters weights according to loras.

-    """

-

-    lora_layer_name = getattr(self, 'lora_layer_name', None)

-    if lora_layer_name is None:

-        return

-

-    current_names = getattr(self, "lora_current_names", ())

-    wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)

-

-    weights_backup = getattr(self, "lora_weights_backup", None)

-    if weights_backup is None:

-        if isinstance(self, torch.nn.MultiheadAttention):

-            weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))

-        else:

-            weights_backup = self.weight.to(devices.cpu, copy=True)

-

-        self.lora_weights_backup = weights_backup

-

-    if current_names != wanted_names:

-        lora_restore_weights_from_backup(self)

-

-        for lora in loaded_loras:

-            module = lora.modules.get(lora_layer_name, None)

-            if module is not None and hasattr(self, 'weight'):

-                self.weight += lora_calc_updown(lora, module, self.weight)

-                continue

-

-            module_q = lora.modules.get(lora_layer_name + "_q_proj", None)

-            module_k = lora.modules.get(lora_layer_name + "_k_proj", None)

-            module_v = lora.modules.get(lora_layer_name + "_v_proj", None)

-            module_out = lora.modules.get(lora_layer_name + "_out_proj", None)

-

-            if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:

-                updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight)

-                updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight)

-                updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight)

-                updown_qkv = torch.vstack([updown_q, updown_k, updown_v])

-

-                self.in_proj_weight += updown_qkv

-                self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight)

-                continue

-

-            if module is None:

-                continue

-

-            print(f'failed to calculate lora weights for layer {lora_layer_name}')

-

-        self.lora_current_names = wanted_names

-

-

-def lora_forward(module, input, original_forward):

-    """

-    Old way of applying Lora by executing operations during layer's forward.

-    Stacking many loras this way results in big performance degradation.

-    """

-

-    if len(loaded_loras) == 0:

-        return original_forward(module, input)

-

-    input = devices.cond_cast_unet(input)

-

-    lora_restore_weights_from_backup(module)

-    lora_reset_cached_weight(module)

-

-    res = original_forward(module, input)

-

-    lora_layer_name = getattr(module, 'lora_layer_name', None)

-    for lora in loaded_loras:

-        module = lora.modules.get(lora_layer_name, None)

-        if module is None:

-            continue

-

-        module.up.to(device=devices.device)

-        module.down.to(device=devices.device)

-

-        res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)

-

-    return res

-

-

-def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):

-    self.lora_current_names = ()

-    self.lora_weights_backup = None

-

-

-def lora_Linear_forward(self, input):

-    if shared.opts.lora_functional:

-        return lora_forward(self, input, torch.nn.Linear_forward_before_lora)

-

-    lora_apply_weights(self)

-

-    return torch.nn.Linear_forward_before_lora(self, input)

-

-

-def lora_Linear_load_state_dict(self, *args, **kwargs):

-    lora_reset_cached_weight(self)

-

-    return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)

-

-

-def lora_Conv2d_forward(self, input):

-    if shared.opts.lora_functional:

-        return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)

-

-    lora_apply_weights(self)

-

-    return torch.nn.Conv2d_forward_before_lora(self, input)

-

-

-def lora_Conv2d_load_state_dict(self, *args, **kwargs):

-    lora_reset_cached_weight(self)

-

-    return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)

-

-

-def lora_MultiheadAttention_forward(self, *args, **kwargs):

-    lora_apply_weights(self)

-

-    return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs)

-

-

-def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):

-    lora_reset_cached_weight(self)

-

-    return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)

-

-

-def list_available_loras():

-    available_loras.clear()

-    available_lora_aliases.clear()

-    forbidden_lora_aliases.clear()

-    available_lora_hash_lookup.clear()

-    forbidden_lora_aliases.update({"none": 1, "Addams": 1})

-

-    os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)

-

-    candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))

-    for filename in candidates:

-        if os.path.isdir(filename):

-            continue

-

-        name = os.path.splitext(os.path.basename(filename))[0]

-        try:

-            entry = LoraOnDisk(name, filename)

-        except OSError:  # should catch FileNotFoundError and PermissionError etc.

-            errors.report(f"Failed to load LoRA {name} from {filename}", exc_info=True)

-            continue

-

-        available_loras[name] = entry

-

-        if entry.alias in available_lora_aliases:

-            forbidden_lora_aliases[entry.alias.lower()] = 1

-

-        available_lora_aliases[name] = entry

-        available_lora_aliases[entry.alias] = entry

-

-

-re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")

-

-

-def infotext_pasted(infotext, params):

-    if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:

-        return  # if the other extension is active, it will handle those fields, no need to do anything

-

-    added = []

-

-    for k in params:

-        if not k.startswith("AddNet Model "):

-            continue

-

-        num = k[13:]

-

-        if params.get("AddNet Module " + num) != "LoRA":

-            continue

-

-        name = params.get("AddNet Model " + num)

-        if name is None:

-            continue

-

-        m = re_lora_name.match(name)

-        if m:

-            name = m.group(1)

-

-        multiplier = params.get("AddNet Weight A " + num, "1.0")

-

-        added.append(f"<lora:{name}:{multiplier}>")

-

-    if added:

-        params["Prompt"] += "\n" + "".join(added)

-

-

-available_loras = {}

-available_lora_aliases = {}

-available_lora_hash_lookup = {}

-forbidden_lora_aliases = {}

-loaded_loras = []

-

-list_available_loras()





diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..279b34bc928bcb52979fd67068be0c2ca35b847b
--- /dev/null
+++ b/extensions-builtin/Lora/lyco_helpers.py
@@ -0,0 +1,21 @@
+import torch

+

+

+def make_weight_cp(t, wa, wb):

+    temp = torch.einsum('i j k l, j r -> i r k l', t, wb)

+    return torch.einsum('i j k l, i r -> r j k l', temp, wa)

+

+

+def rebuild_conventional(up, down, shape, dyn_dim=None):

+    up = up.reshape(up.size(0), -1)

+    down = down.reshape(down.size(0), -1)

+    if dyn_dim is not None:

+        up = up[:, :dyn_dim]

+        down = down[:dyn_dim, :]

+    return (up @ down).reshape(shape)

+

+

+def rebuild_cp_decomposition(up, down, mid):

+    up = up.reshape(up.size(0), -1)

+    down = down.reshape(down.size(0), -1)

+    return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down)





diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a18d69eb26412d54c3b36e7801b5557691e2b68
--- /dev/null
+++ b/extensions-builtin/Lora/network.py
@@ -0,0 +1,155 @@
+from __future__ import annotations

+import os

+from collections import namedtuple

+import enum

+

+from modules import sd_models, cache, errors, hashes, shared

+

+NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module'])

+

+metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}

+

+

+class SdVersion(enum.Enum):

+    Unknown = 1

+    SD1 = 2

+    SD2 = 3

+    SDXL = 4

+

+

+class NetworkOnDisk:

+    def __init__(self, name, filename):

+        self.name = name

+        self.filename = filename

+        self.metadata = {}

+        self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"

+

+        def read_metadata():

+            metadata = sd_models.read_metadata_from_safetensors(filename)

+            metadata.pop('ssmd_cover_images', None)  # those are cover images, and they are too big to display in UI as text

+

+            return metadata

+

+        if self.is_safetensors:

+            try:

+                self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata)

+            except Exception as e:

+                errors.display(e, f"reading lora {filename}")

+

+        if self.metadata:

+            m = {}

+            for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)):

+                m[k] = v

+

+            self.metadata = m

+

+        self.alias = self.metadata.get('ss_output_name', self.name)

+

+        self.hash = None

+        self.shorthash = None

+        self.set_hash(

+            self.metadata.get('sshs_model_hash') or

+            hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or

+            ''

+        )

+

+        self.sd_version = self.detect_version()

+

+    def detect_version(self):

+        if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"):

+            return SdVersion.SDXL

+        elif str(self.metadata.get('ss_v2', "")) == "True":

+            return SdVersion.SD2

+        elif len(self.metadata):

+            return SdVersion.SD1

+

+        return SdVersion.Unknown

+

+    def set_hash(self, v):

+        self.hash = v

+        self.shorthash = self.hash[0:12]

+

+        if self.shorthash:

+            import networks

+            networks.available_network_hash_lookup[self.shorthash] = self

+

+    def read_hash(self):

+        if not self.hash:

+            self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '')

+

+    def get_alias(self):

+        import networks

+        if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases:

+            return self.name

+        else:

+            return self.alias

+

+

+class Network:  # LoraModule

+    def __init__(self, name, network_on_disk: NetworkOnDisk):

+        self.name = name

+        self.network_on_disk = network_on_disk

+        self.te_multiplier = 1.0

+        self.unet_multiplier = 1.0

+        self.dyn_dim = None

+        self.modules = {}

+        self.mtime = None

+

+        self.mentioned_name = None

+        """the text that was used to add the network to prompt - can be either name or an alias"""

+

+

+class ModuleType:

+    def create_module(self, net: Network, weights: NetworkWeights) -> Network | None:

+        return None

+

+

+class NetworkModule:

+    def __init__(self, net: Network, weights: NetworkWeights):

+        self.network = net

+        self.network_key = weights.network_key

+        self.sd_key = weights.sd_key

+        self.sd_module = weights.sd_module

+

+        if hasattr(self.sd_module, 'weight'):

+            self.shape = self.sd_module.weight.shape

+

+        self.dim = None

+        self.bias = weights.w.get("bias")

+        self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None

+        self.scale = weights.w["scale"].item() if "scale" in weights.w else None

+

+    def multiplier(self):

+        if 'transformer' in self.sd_key[:20]:

+            return self.network.te_multiplier

+        else:

+            return self.network.unet_multiplier

+

+    def calc_scale(self):

+        if self.scale is not None:

+            return self.scale

+        if self.dim is not None and self.alpha is not None:

+            return self.alpha / self.dim

+

+        return 1.0

+

+    def finalize_updown(self, updown, orig_weight, output_shape):

+        if self.bias is not None:

+            updown = updown.reshape(self.bias.shape)

+            updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype)

+            updown = updown.reshape(output_shape)

+

+        if len(output_shape) == 4:

+            updown = updown.reshape(output_shape)

+

+        if orig_weight.size().numel() == updown.size().numel():

+            updown = updown.reshape(orig_weight.shape)

+

+        return updown * self.calc_scale() * self.multiplier()

+

+    def calc_updown(self, target):

+        raise NotImplementedError()

+

+    def forward(self, x, y):

+        raise NotImplementedError()

+





diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py
new file mode 100644
index 0000000000000000000000000000000000000000..109b4c2c594e5079067d55331271ebafcf6c9fe4
--- /dev/null
+++ b/extensions-builtin/Lora/network_full.py
@@ -0,0 +1,22 @@
+import network

+

+

+class ModuleTypeFull(network.ModuleType):

+    def create_module(self, net: network.Network, weights: network.NetworkWeights):

+        if all(x in weights.w for x in ["diff"]):

+            return NetworkModuleFull(net, weights)

+

+        return None

+

+

+class NetworkModuleFull(network.NetworkModule):

+    def __init__(self,  net: network.Network, weights: network.NetworkWeights):

+        super().__init__(net, weights)

+

+        self.weight = weights.w.get("diff")

+

+    def calc_updown(self, orig_weight):

+        output_shape = self.weight.shape

+        updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype)

+

+        return self.finalize_updown(updown, orig_weight, output_shape)





diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fcb0695fbb36f01a1dbdeaba87ea288a3d00eb2
--- /dev/null
+++ b/extensions-builtin/Lora/network_hada.py
@@ -0,0 +1,55 @@
+import lyco_helpers

+import network

+

+

+class ModuleTypeHada(network.ModuleType):

+    def create_module(self, net: network.Network, weights: network.NetworkWeights):

+        if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]):

+            return NetworkModuleHada(net, weights)

+

+        return None

+

+

+class NetworkModuleHada(network.NetworkModule):

+    def __init__(self,  net: network.Network, weights: network.NetworkWeights):

+        super().__init__(net, weights)

+

+        if hasattr(self.sd_module, 'weight'):

+            self.shape = self.sd_module.weight.shape

+

+        self.w1a = weights.w["hada_w1_a"]

+        self.w1b = weights.w["hada_w1_b"]

+        self.dim = self.w1b.shape[0]

+        self.w2a = weights.w["hada_w2_a"]

+        self.w2b = weights.w["hada_w2_b"]

+

+        self.t1 = weights.w.get("hada_t1")

+        self.t2 = weights.w.get("hada_t2")

+

+    def calc_updown(self, orig_weight):

+        w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)

+        w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)

+        w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)

+        w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)

+

+        output_shape = [w1a.size(0), w1b.size(1)]

+

+        if self.t1 is not None:

+            output_shape = [w1a.size(1), w1b.size(1)]

+            t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype)

+            updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b)

+            output_shape += t1.shape[2:]

+        else:

+            if len(w1b.shape) == 4:

+                output_shape += w1b.shape[2:]

+            updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape)

+

+        if self.t2 is not None:

+            t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)

+            updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)

+        else:

+            updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape)

+

+        updown = updown1 * updown2

+

+        return self.finalize_updown(updown, orig_weight, output_shape)





diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py
new file mode 100644
index 0000000000000000000000000000000000000000..7edc4249791c0c58528ef4eb3f7df04c81d56020
--- /dev/null
+++ b/extensions-builtin/Lora/network_ia3.py
@@ -0,0 +1,30 @@
+import network

+

+

+class ModuleTypeIa3(network.ModuleType):

+    def create_module(self, net: network.Network, weights: network.NetworkWeights):

+        if all(x in weights.w for x in ["weight"]):

+            return NetworkModuleIa3(net, weights)

+

+        return None

+

+

+class NetworkModuleIa3(network.NetworkModule):

+    def __init__(self,  net: network.Network, weights: network.NetworkWeights):

+        super().__init__(net, weights)

+

+        self.w = weights.w["weight"]

+        self.on_input = weights.w["on_input"].item()

+

+    def calc_updown(self, orig_weight):

+        w = self.w.to(orig_weight.device, dtype=orig_weight.dtype)

+

+        output_shape = [w.size(0), orig_weight.size(1)]

+        if self.on_input:

+            output_shape.reverse()

+        else:

+            w = w.reshape(-1, 1)

+

+        updown = orig_weight * w

+

+        return self.finalize_updown(updown, orig_weight, output_shape)





diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py
new file mode 100644
index 0000000000000000000000000000000000000000..340acdab3d0c7d2275db9d2684c8e70a0ea6d889
--- /dev/null
+++ b/extensions-builtin/Lora/network_lokr.py
@@ -0,0 +1,64 @@
+import torch

+

+import lyco_helpers

+import network

+

+

+class ModuleTypeLokr(network.ModuleType):

+    def create_module(self, net: network.Network, weights: network.NetworkWeights):

+        has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w)

+        has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w)

+        if has_1 and has_2:

+            return NetworkModuleLokr(net, weights)

+

+        return None

+

+

+def make_kron(orig_shape, w1, w2):

+    if len(w2.shape) == 4:

+        w1 = w1.unsqueeze(2).unsqueeze(2)

+    w2 = w2.contiguous()

+    return torch.kron(w1, w2).reshape(orig_shape)

+

+

+class NetworkModuleLokr(network.NetworkModule):

+    def __init__(self,  net: network.Network, weights: network.NetworkWeights):

+        super().__init__(net, weights)

+

+        self.w1 = weights.w.get("lokr_w1")

+        self.w1a = weights.w.get("lokr_w1_a")

+        self.w1b = weights.w.get("lokr_w1_b")

+        self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim

+        self.w2 = weights.w.get("lokr_w2")

+        self.w2a = weights.w.get("lokr_w2_a")

+        self.w2b = weights.w.get("lokr_w2_b")

+        self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim

+        self.t2 = weights.w.get("lokr_t2")

+

+    def calc_updown(self, orig_weight):

+        if self.w1 is not None:

+            w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype)

+        else:

+            w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype)

+            w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype)

+            w1 = w1a @ w1b

+

+        if self.w2 is not None:

+            w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype)

+        elif self.t2 is None:

+            w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)

+            w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)

+            w2 = w2a @ w2b

+        else:

+            t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype)

+            w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype)

+            w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype)

+            w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b)

+

+        output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)]

+        if len(orig_weight.shape) == 4:

+            output_shape = orig_weight.shape

+

+        updown = make_kron(output_shape, w1, w2)

+

+        return self.finalize_updown(updown, orig_weight, output_shape)





diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py
new file mode 100644
index 0000000000000000000000000000000000000000..26c0a72c237f36f11ac6bf39d0ddc3d4d286bda2
--- /dev/null
+++ b/extensions-builtin/Lora/network_lora.py
@@ -0,0 +1,86 @@
+import torch

+

+import lyco_helpers

+import network

+from modules import devices

+

+

+class ModuleTypeLora(network.ModuleType):

+    def create_module(self, net: network.Network, weights: network.NetworkWeights):

+        if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]):

+            return NetworkModuleLora(net, weights)

+

+        return None

+

+

+class NetworkModuleLora(network.NetworkModule):

+    def __init__(self,  net: network.Network, weights: network.NetworkWeights):

+        super().__init__(net, weights)

+

+        self.up_model = self.create_module(weights.w, "lora_up.weight")

+        self.down_model = self.create_module(weights.w, "lora_down.weight")

+        self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True)

+

+        self.dim = weights.w["lora_down.weight"].shape[0]

+

+    def create_module(self, weights, key, none_ok=False):

+        weight = weights.get(key)

+

+        if weight is None and none_ok:

+            return None

+

+        is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention]

+        is_conv = type(self.sd_module) in [torch.nn.Conv2d]

+

+        if is_linear:

+            weight = weight.reshape(weight.shape[0], -1)

+            module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)

+        elif is_conv and key == "lora_down.weight" or key == "dyn_up":

+            if len(weight.shape) == 2:

+                weight = weight.reshape(weight.shape[0], -1, 1, 1)

+

+            if weight.shape[2] != 1 or weight.shape[3] != 1:

+                module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)

+            else:

+                module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)

+        elif is_conv and key == "lora_mid.weight":

+            module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False)

+        elif is_conv and key == "lora_up.weight" or key == "dyn_down":

+            module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)

+        else:

+            raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}')

+

+        with torch.no_grad():

+            if weight.shape != module.weight.shape:

+                weight = weight.reshape(module.weight.shape)

+            module.weight.copy_(weight)

+

+        module.to(device=devices.cpu, dtype=devices.dtype)

+        module.weight.requires_grad_(False)

+

+        return module

+

+    def calc_updown(self, orig_weight):

+        up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)

+        down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)

+

+        output_shape = [up.size(0), down.size(1)]

+        if self.mid_model is not None:

+            # cp-decomposition

+            mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype)

+            updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid)

+            output_shape += mid.shape[2:]

+        else:

+            if len(down.shape) == 4:

+                output_shape += down.shape[2:]

+            updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim)

+

+        return self.finalize_updown(updown, orig_weight, output_shape)

+

+    def forward(self, x, y):

+        self.up_model.to(device=devices.device)

+        self.down_model.to(device=devices.device)

+

+        return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale()

+

+





diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..17cbe1bb7fe383ff1213b17d06d0665f7765a392
--- /dev/null
+++ b/extensions-builtin/Lora/networks.py
@@ -0,0 +1,468 @@
+import os

+import re

+

+import network

+import network_lora

+import network_hada

+import network_ia3

+import network_lokr

+import network_full

+

+import torch

+from typing import Union

+

+from modules import shared, devices, sd_models, errors, scripts, sd_hijack

+

+module_types = [

+    network_lora.ModuleTypeLora(),

+    network_hada.ModuleTypeHada(),

+    network_ia3.ModuleTypeIa3(),

+    network_lokr.ModuleTypeLokr(),

+    network_full.ModuleTypeFull(),

+]

+

+

+re_digits = re.compile(r"\d+")

+re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")

+re_compiled = {}

+

+suffix_conversion = {

+    "attentions": {},

+    "resnets": {

+        "conv1": "in_layers_2",

+        "conv2": "out_layers_3",

+        "time_emb_proj": "emb_layers_1",

+        "conv_shortcut": "skip_connection",

+    }

+}

+

+

+def convert_diffusers_name_to_compvis(key, is_sd2):

+    def match(match_list, regex_text):

+        regex = re_compiled.get(regex_text)

+        if regex is None:

+            regex = re.compile(regex_text)

+            re_compiled[regex_text] = regex

+

+        r = re.match(regex, key)

+        if not r:

+            return False

+

+        match_list.clear()

+        match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()])

+        return True

+

+    m = []

+

+    if match(m, r"lora_unet_conv_in(.*)"):

+        return f'diffusion_model_input_blocks_0_0{m[0]}'

+

+    if match(m, r"lora_unet_conv_out(.*)"):

+        return f'diffusion_model_out_2{m[0]}'

+

+    if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"):

+        return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}"

+

+    if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):

+        suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])

+        return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"

+

+    if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):

+        suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])

+        return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"

+

+    if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):

+        suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])

+        return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"

+

+    if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):

+        return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"

+

+    if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):

+        return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"

+

+    if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):

+        if is_sd2:

+            if 'mlp_fc1' in m[1]:

+                return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"

+            elif 'mlp_fc2' in m[1]:

+                return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"

+            else:

+                return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"

+

+        return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"

+

+    if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"):

+        if 'mlp_fc1' in m[1]:

+            return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"

+        elif 'mlp_fc2' in m[1]:

+            return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"

+        else:

+            return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"

+

+    return key

+

+

+def assign_network_names_to_compvis_modules(sd_model):

+    network_layer_mapping = {}

+

+    if shared.sd_model.is_sdxl:

+        for i, embedder in enumerate(shared.sd_model.conditioner.embedders):

+            if not hasattr(embedder, 'wrapped'):

+                continue

+

+            for name, module in embedder.wrapped.named_modules():

+                network_name = f'{i}_{name.replace(".", "_")}'

+                network_layer_mapping[network_name] = module

+                module.network_layer_name = network_name

+    else:

+        for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules():

+            network_name = name.replace(".", "_")

+            network_layer_mapping[network_name] = module

+            module.network_layer_name = network_name

+

+    for name, module in shared.sd_model.model.named_modules():

+        network_name = name.replace(".", "_")

+        network_layer_mapping[network_name] = module

+        module.network_layer_name = network_name

+

+    sd_model.network_layer_mapping = network_layer_mapping

+

+

+def load_network(name, network_on_disk):

+    net = network.Network(name, network_on_disk)

+    net.mtime = os.path.getmtime(network_on_disk.filename)

+

+    sd = sd_models.read_state_dict(network_on_disk.filename)

+

+    # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0

+    if not hasattr(shared.sd_model, 'network_layer_mapping'):

+        assign_network_names_to_compvis_modules(shared.sd_model)

+

+    keys_failed_to_match = {}

+    is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping

+

+    matched_networks = {}

+

+    for key_network, weight in sd.items():

+        key_network_without_network_parts, network_part = key_network.split(".", 1)

+

+        key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)

+        sd_module = shared.sd_model.network_layer_mapping.get(key, None)

+

+        if sd_module is None:

+            m = re_x_proj.match(key)

+            if m:

+                sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None)

+

+        # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model"

+        if sd_module is None and "lora_unet" in key_network_without_network_parts:

+            key = key_network_without_network_parts.replace("lora_unet", "diffusion_model")

+            sd_module = shared.sd_model.network_layer_mapping.get(key, None)

+        elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts:

+            key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")

+            sd_module = shared.sd_model.network_layer_mapping.get(key, None)

+

+            # some SD1 Loras also have correct compvis keys

+            if sd_module is None:

+                key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")

+                sd_module = shared.sd_model.network_layer_mapping.get(key, None)

+

+        if sd_module is None:

+            keys_failed_to_match[key_network] = key

+            continue

+

+        if key not in matched_networks:

+            matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module)

+

+        matched_networks[key].w[network_part] = weight

+

+    for key, weights in matched_networks.items():

+        net_module = None

+        for nettype in module_types:

+            net_module = nettype.create_module(net, weights)

+            if net_module is not None:

+                break

+

+        if net_module is None:

+            raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}")

+

+        net.modules[key] = net_module

+

+    if keys_failed_to_match:

+        print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}")

+

+    return net

+

+

+def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):

+    already_loaded = {}

+

+    for net in loaded_networks:

+        if net.name in names:

+            already_loaded[net.name] = net

+

+    loaded_networks.clear()

+

+    networks_on_disk = [available_network_aliases.get(name, None) for name in names]

+    if any(x is None for x in networks_on_disk):

+        list_available_networks()

+

+        networks_on_disk = [available_network_aliases.get(name, None) for name in names]

+

+    failed_to_load_networks = []

+

+    for i, name in enumerate(names):

+        net = already_loaded.get(name, None)

+

+        network_on_disk = networks_on_disk[i]

+

+        if network_on_disk is not None:

+            if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:

+                try:

+                    net = load_network(name, network_on_disk)

+                except Exception as e:

+                    errors.display(e, f"loading network {network_on_disk.filename}")

+                    continue

+

+            net.mentioned_name = name

+

+            network_on_disk.read_hash()

+

+        if net is None:

+            failed_to_load_networks.append(name)

+            print(f"Couldn't find network with name {name}")

+            continue

+

+        net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0

+        net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0

+        net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0

+        loaded_networks.append(net)

+

+    if failed_to_load_networks:

+        sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))

+

+

+def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):

+    weights_backup = getattr(self, "network_weights_backup", None)

+

+    if weights_backup is None:

+        return

+

+    if isinstance(self, torch.nn.MultiheadAttention):

+        self.in_proj_weight.copy_(weights_backup[0])

+        self.out_proj.weight.copy_(weights_backup[1])

+    else:

+        self.weight.copy_(weights_backup)

+

+

+def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):

+    """

+    Applies the currently selected set of networks to the weights of torch layer self.

+    If weights already have this particular set of networks applied, does nothing.

+    If not, restores orginal weights from backup and alters weights according to networks.

+    """

+

+    network_layer_name = getattr(self, 'network_layer_name', None)

+    if network_layer_name is None:

+        return

+

+    current_names = getattr(self, "network_current_names", ())

+    wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks)

+

+    weights_backup = getattr(self, "network_weights_backup", None)

+    if weights_backup is None:

+        if isinstance(self, torch.nn.MultiheadAttention):

+            weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))

+        else:

+            weights_backup = self.weight.to(devices.cpu, copy=True)

+

+        self.network_weights_backup = weights_backup

+

+    if current_names != wanted_names:

+        network_restore_weights_from_backup(self)

+

+        for net in loaded_networks:

+            module = net.modules.get(network_layer_name, None)

+            if module is not None and hasattr(self, 'weight'):

+                with torch.no_grad():

+                    updown = module.calc_updown(self.weight)

+

+                    if len(self.weight.shape) == 4 and self.weight.shape[1] == 9:

+                        # inpainting model. zero pad updown to make channel[1]  4 to 9

+                        updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5))

+

+                    self.weight += updown

+                    continue

+

+            module_q = net.modules.get(network_layer_name + "_q_proj", None)

+            module_k = net.modules.get(network_layer_name + "_k_proj", None)

+            module_v = net.modules.get(network_layer_name + "_v_proj", None)

+            module_out = net.modules.get(network_layer_name + "_out_proj", None)

+

+            if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:

+                with torch.no_grad():

+                    updown_q = module_q.calc_updown(self.in_proj_weight)

+                    updown_k = module_k.calc_updown(self.in_proj_weight)

+                    updown_v = module_v.calc_updown(self.in_proj_weight)

+                    updown_qkv = torch.vstack([updown_q, updown_k, updown_v])

+                    updown_out = module_out.calc_updown(self.out_proj.weight)

+

+                    self.in_proj_weight += updown_qkv

+                    self.out_proj.weight += updown_out

+                    continue

+

+            if module is None:

+                continue

+

+            print(f'failed to calculate network weights for layer {network_layer_name}')

+

+        self.network_current_names = wanted_names

+

+

+def network_forward(module, input, original_forward):

+    """

+    Old way of applying Lora by executing operations during layer's forward.

+    Stacking many loras this way results in big performance degradation.

+    """

+

+    if len(loaded_networks) == 0:

+        return original_forward(module, input)

+

+    input = devices.cond_cast_unet(input)

+

+    network_restore_weights_from_backup(module)

+    network_reset_cached_weight(module)

+

+    y = original_forward(module, input)

+

+    network_layer_name = getattr(module, 'network_layer_name', None)

+    for lora in loaded_networks:

+        module = lora.modules.get(network_layer_name, None)

+        if module is None:

+            continue

+

+        y = module.forward(y, input)

+

+    return y

+

+

+def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):

+    self.network_current_names = ()

+    self.network_weights_backup = None

+

+

+def network_Linear_forward(self, input):

+    if shared.opts.lora_functional:

+        return network_forward(self, input, torch.nn.Linear_forward_before_network)

+

+    network_apply_weights(self)

+

+    return torch.nn.Linear_forward_before_network(self, input)

+

+

+def network_Linear_load_state_dict(self, *args, **kwargs):

+    network_reset_cached_weight(self)

+

+    return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)

+

+

+def network_Conv2d_forward(self, input):

+    if shared.opts.lora_functional:

+        return network_forward(self, input, torch.nn.Conv2d_forward_before_network)

+

+    network_apply_weights(self)

+

+    return torch.nn.Conv2d_forward_before_network(self, input)

+

+

+def network_Conv2d_load_state_dict(self, *args, **kwargs):

+    network_reset_cached_weight(self)

+

+    return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)

+

+

+def network_MultiheadAttention_forward(self, *args, **kwargs):

+    network_apply_weights(self)

+

+    return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)

+

+

+def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):

+    network_reset_cached_weight(self)

+

+    return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)

+

+

+def list_available_networks():

+    available_networks.clear()

+    available_network_aliases.clear()

+    forbidden_network_aliases.clear()

+    available_network_hash_lookup.clear()

+    forbidden_network_aliases.update({"none": 1, "Addams": 1})

+

+    os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True)

+

+    candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))

+    candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"]))

+    for filename in candidates:

+        if os.path.isdir(filename):

+            continue

+

+        name = os.path.splitext(os.path.basename(filename))[0]

+        try:

+            entry = network.NetworkOnDisk(name, filename)

+        except OSError:  # should catch FileNotFoundError and PermissionError etc.

+            errors.report(f"Failed to load network {name} from {filename}", exc_info=True)

+            continue

+

+        available_networks[name] = entry

+

+        if entry.alias in available_network_aliases:

+            forbidden_network_aliases[entry.alias.lower()] = 1

+

+        available_network_aliases[name] = entry

+        available_network_aliases[entry.alias] = entry

+

+

+re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)")

+

+

+def infotext_pasted(infotext, params):

+    if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]:

+        return  # if the other extension is active, it will handle those fields, no need to do anything

+

+    added = []

+

+    for k in params:

+        if not k.startswith("AddNet Model "):

+            continue

+

+        num = k[13:]

+

+        if params.get("AddNet Module " + num) != "LoRA":

+            continue

+

+        name = params.get("AddNet Model " + num)

+        if name is None:

+            continue

+

+        m = re_network_name.match(name)

+        if m:

+            name = m.group(1)

+

+        multiplier = params.get("AddNet Weight A " + num, "1.0")

+

+        added.append(f"<lora:{name}:{multiplier}>")

+

+    if added:

+        params["Prompt"] += "\n" + "".join(added)

+

+

+available_networks = {}

+available_network_aliases = {}

+loaded_networks = []

+available_network_hash_lookup = {}

+forbidden_network_aliases = {}

+

+list_available_networks()





diff --git a/extensions-builtin/Lora/preload.py b/extensions-builtin/Lora/preload.py
index 863dc5c0b510a20e03ed68adec0108ebf1e03474..50961be33d7e29214b00b6e69daf0c3e5e8f108c 100644
--- a/extensions-builtin/Lora/preload.py
+++ b/extensions-builtin/Lora/preload.py
@@ -4,3 +4,4 @@ 
 

 def preload(parser):

     parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora'))

+    parser.add_argument("--lyco-dir-backcompat", type=str, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS'))





diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index e650f469f3663cee190818444e5c18f67e3ef6e4..cd28afc92e7ae82d9df4329febcc28f40a254abe 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -4,87 +4,93 @@ import torch
 import gradio as gr

 from fastapi import FastAPI

 

-import lora

+import network

+import networks

+import lora  # noqa:F401

 import extra_networks_lora

 import ui_extra_networks_lora

 from modules import script_callbacks, ui_extra_networks, extra_networks, shared

 

 def unload():

-    torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora

+    torch.nn.Linear.forward = torch.nn.Linear_forward_before_network

-    torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora

+    torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network

-    torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora

+    torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network

-    torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora

+    torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network

-    torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora

+    torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network

-    torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora

+    torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network

 

 

 def before_ui():

     ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora())

-    extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora())

-

 

-

+from modules import script_callbacks, ui_extra_networks, extra_networks, shared

-

+from modules import script_callbacks, ui_extra_networks, extra_networks, shared

 import re

+from modules import script_callbacks, ui_extra_networks, extra_networks, shared

 

 

 

-

+from modules import script_callbacks, ui_extra_networks, extra_networks, shared

 import torch

-

+    torch.nn.Linear_forward_before_network = torch.nn.Linear.forward

 

-import gradio as gr

-

+from modules import script_callbacks, ui_extra_networks, extra_networks, shared

 from fastapi import FastAPI

-

-

+from modules import script_callbacks, ui_extra_networks, extra_networks, shared

 import lora

 

+from modules import script_callbacks, ui_extra_networks, extra_networks, shared

 import extra_networks_lora

-

+    torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward

 

-import ui_extra_networks_lora

-

+from modules import script_callbacks, ui_extra_networks, extra_networks, shared

 from modules import script_callbacks, ui_extra_networks, extra_networks, shared

+    torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict

 

-import torch

+import re

+import re

-import torch

 import re

 

-import torch

 

-import torch

+def unload():

 import torch

-import torch

+def unload():

 import gradio as gr

-import torch

+

+def unload():

 from fastapi import FastAPI

-import torch

+def unload():

 import lora

-import torch

+def unload():

 import extra_networks_lora

+torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict

+torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward

+torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict

 

-script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)

+script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)

 script_callbacks.on_script_unloaded(unload)

 script_callbacks.on_before_ui(before_ui)

-import gradio as gr

+import re

 import re

+

 

 

 shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), {

-import gradio as gr

+    torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora

 import torch

     "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}),

     "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),

+    "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),

+    "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),

 }))

 

 

 shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), {

-    "lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),

+    "lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"),

 }))

 

 

-def create_lora_json(obj: lora.LoraOnDisk):

+def create_lora_json(obj: network.NetworkOnDisk):

     return {

         "name": obj.name,

         "alias": obj.alias,

@@ -92,19 +97,19 @@         "metadata": obj.metadata,
     }

 

 

-def api_loras(_: gr.Blocks, app: FastAPI):

+def api_networks(_: gr.Blocks, app: FastAPI):

     @app.get("/sdapi/v1/loras")

     async def get_loras():

-from fastapi import FastAPI

+    torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora

 from modules import script_callbacks, ui_extra_networks, extra_networks, shared

 

     @app.post("/sdapi/v1/refresh-loras")

     async def refresh_loras():

-import lora

+import re

 

 

 

-script_callbacks.on_app_started(api_loras)

+script_callbacks.on_app_started(api_networks)

 

 re_lora = re.compile("<lora:([^:]+):")

 

@@ -117,20 +121,20 @@ 
     hashes = [x.strip().split(':', 1) for x in hashes.split(",")]

     hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes}

 

-import extra_networks_lora

 import re

+if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'):

         alias = m.group(1)

         shorthash = hashes.get(alias)

         if shorthash is None:

             return m.group(0)

 

-        lora_on_disk = lora.available_lora_hash_lookup.get(shorthash)

+        network_on_disk = networks.available_network_hash_lookup.get(shorthash)

-        if lora_on_disk is None:

+        if network_on_disk is None:

             return m.group(0)

 

-        return f'<lora:{lora_on_disk.get_alias()}:'

+        return f'<lora:{network_on_disk.get_alias()}:'

 

-    d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"])

+    d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"])

 

 

 script_callbacks.on_infotext_pasted(infotext_pasted)





diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py
index 354a1d686c089107c326d651f1677e876ae63141..390d9dde3fbc0a848185809f39ea694a7bd00ac2 100644
--- a/extensions-builtin/Lora/ui_edit_user_metadata.py
+++ b/extensions-builtin/Lora/ui_edit_user_metadata.py
@@ -1,3 +1,4 @@
+import datetime

 import html

 import random

 

@@ -46,15 +47,19 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
     def __init__(self, ui, tabname, page):

         super().__init__(ui, tabname, page)

 

+        self.select_sd_version = None

+

         self.taginfo = None

         self.edit_activation_text = None

         self.slider_preferred_weight = None

         self.edit_notes = None

 

+import html

 

-import re

+def is_non_comma_tagset(tags):

         user_metadata = self.get_user_metadata(name)

         user_metadata["description"] = desc

+        user_metadata["sd version"] = sd_version

         user_metadata["activation text"] = activation_text

         user_metadata["preferred weight"] = preferred_weight

         user_metadata["notes"] = notes

@@ -69,6 +74,7 @@ 
         keys = {

             'ss_sd_model_name': "Model:",

             'ss_clip_skip': "Clip skip:",

+            'ss_network_module': "Kohya module:",

         }

 

         for key, label in keys.items():

@@ -76,6 +82,10 @@             value = metadata.get(key, None)
             if value is not None and str(value) != "None":

                 table.append((label, html.escape(value)))

 

+        ss_training_started_at = metadata.get('ss_training_started_at')

+        if ss_training_started_at:

+            table.append(("Date trained:", datetime.datetime.utcfromtimestamp(float(ss_training_started_at)).strftime('%Y-%m-%d %H:%M')))

+

         ss_bucket_info = metadata.get("ss_bucket_info")

         if ss_bucket_info and "buckets" in ss_bucket_info:

             resolutions = {}

@@ -113,13 +123,12 @@         tags = build_tags(metadata)
         gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]]

 

         return [

-            *values[0:4],

+            *values[0:5],

+            item.get("sd_version", "Unknown"),

             gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False),

             user_metadata.get('activation text', ''),

             float(user_metadata.get('preferred weight', 0.0)),

     average_tag_length = sum(len(x) for x in tags.keys()) / len(tags)

-import random

-    average_tag_length = sum(len(x) for x in tags.keys()) / len(tags)

 

             gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False),

         ]

@@ -144,10 +153,16 @@                 res.append(tag)
 

         return ", ".join(sorted(res))

 

+    def create_extra_default_items_in_left_column(self):

+

+        # this would be a lot better as gr.Radio but I can't make it work

+        self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True)

+

     def create_editor(self):

         self.create_default_editor_elems()

 

 import html

+            'ss_clip_skip': "Clip skip:",

         self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora")

         self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01)

 

@@ -156,7 +172,7 @@                 random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
 

             with gr.Column(scale=1, min_width=120):

 import html

-def is_non_comma_tagset(tags):

+        }

 

         self.edit_notes = gr.TextArea(label='Notes', lines=4)

 

@@ -182,10 +198,12 @@             self.edit_description,
             self.html_filedata,

             self.html_preview,

 def build_tags(metadata):

+def is_non_comma_tagset(tags):

+            self.select_sd_version,

+def build_tags(metadata):

 import gradio as gr

             self.edit_activation_text,

             self.slider_preferred_weight,

-            self.edit_notes,

             row_random_prompt,

             random_prompt,

         ]

@@ -196,6 +214,7 @@             .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])
 

         edited_components = [

             self.edit_description,

+            self.select_sd_version,

             self.edit_activation_text,

             self.slider_preferred_weight,

             self.edit_notes,





diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py
index b2bc18102703f0ae831df1112f7f4d0171ead4c3..3629e5c0cf227192d5618fb0800a2be75f84ccf8 100644
--- a/extensions-builtin/Lora/ui_extra_networks_lora.py
+++ b/extensions-builtin/Lora/ui_extra_networks_lora.py
@@ -1,5 +1,8 @@
 import os

+

+from modules import shared, ui_extra_networks

 import lora

+import networks

 

 from modules import shared, ui_extra_networks

 from modules.ui_extra_networks import quote_js

@@ -11,18 +14,16 @@     def __init__(self):
         super().__init__('Lora')

 

     def refresh(self):

-        lora.list_available_loras()

+        networks.list_available_networks()

 

-    def create_item(self, name, index=None):

+    def create_item(self, name, index=None, enable_filter=True):

-        lora_on_disk = lora.available_loras.get(name)

+        lora_on_disk = networks.available_networks.get(name)

 

         path, ext = os.path.splitext(lora_on_disk.filename)

 

         alias = lora_on_disk.get_alias()

 

 import os

-from ui_edit_user_metadata import LoraUserMetadataEditor

-import os

 class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):

             "name": name,

             "filename": lora_on_disk.filename,

@@ -32,6 +33,7 @@             "search_term": self.search_terms_from_path(lora_on_disk.filename),
             "local_preview": f"{path}.{shared.opts.samples_format}",

             "metadata": lora_on_disk.metadata,

             "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)},

+            "sd_version": lora_on_disk.sd_version.name,

         }

 

         self.read_user_metadata(item)

@@ -42,22 +44,43 @@ 
         if activation_text:

             item["prompt"] += " + " + quote_js(" " + activation_text)

 

+        sd_version = item["user_metadata"].get("sd version")

+        if sd_version in network.SdVersion.__members__:

+            item["sd_version"] = sd_version

+            sd_version = network.SdVersion[sd_version]

+        else:

+from modules.ui_extra_networks import quote_js

 

 

+        if shared.opts.lora_show_all or not enable_filter:

+            pass

+        elif sd_version == network.SdVersion.Unknown:

+            model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1

+            if model_version.name in shared.opts.lora_hide_unknown_for_versions:

+                return None

+        elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL:

+            return None

+        elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2:

+            return None

+from ui_edit_user_metadata import LoraUserMetadataEditor

 

+            return None

 

-from modules import shared, ui_extra_networks

 

-from modules.ui_extra_networks import quote_js

+

 

+    def list_items(self):

 from ui_edit_user_metadata import LoraUserMetadataEditor

+from modules import shared, ui_extra_networks

 

-class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):

+from ui_edit_user_metadata import LoraUserMetadataEditor

 

+            if item is not None:

+                yield item

 

-    def __init__(self):

 

-        super().__init__('Lora')

+    def __init__(self):

+        return [shared.cmd_opts.lora_dir, shared.cmd_opts.lyco_dir_backcompat]

 

     def create_user_metadata_editor(self, ui, tabname):

         return LoraUserMetadataEditor(ui, tabname, self)





diff --git a/extensions-builtin/mobile/javascript/mobile.js b/extensions-builtin/mobile/javascript/mobile.js
new file mode 100644
index 0000000000000000000000000000000000000000..12cae4b75764779f7da3e424a959f966c06a8648
--- /dev/null
+++ b/extensions-builtin/mobile/javascript/mobile.js
@@ -0,0 +1,26 @@
+var isSetupForMobile = false;
+
+function isMobile() {
+    for (var tab of ["txt2img", "img2img"]) {
+        var imageTab = gradioApp().getElementById(tab + '_results');
+        if (imageTab && imageTab.offsetParent && imageTab.offsetLeft == 0) {
+            return true;
+        }
+    }
+
+    return false;
+}
+
+function reportWindowSize() {
+    var currentlyMobile = isMobile();
+    if (currentlyMobile == isSetupForMobile) return;
+    isSetupForMobile = currentlyMobile;
+
+    for (var tab of ["txt2img", "img2img"]) {
+        var button = gradioApp().getElementById(tab + '_generate_box');
+        var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column');
+        target.insertBefore(button, target.firstElementChild);
+    }
+}
+
+window.addEventListener("resize", reportWindowSize);




diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html
index eb8b1a6726295cf9af40652f6925c4bdca1c79a5..39674666f1e336d9bf61d2a6986721cf8591eeee 100644
--- a/html/extra-networks-card.html
+++ b/html/extra-networks-card.html
@@ -1,8 +1,8 @@
 <div class='card' style={style} onclick={card_clicked} data-name="{name}" {sort_keys}>
 	{background_image}
 	<div class="button-row">
-		{edit_button}
 		{metadata_button}
+		{edit_button}
 	</div>
 	<div class='actions'>
 		<div class='additional'>




diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js
index 2361144a9abe4df20b0f66d0e86ec408e874d5b7..44d02349a9d9f75efcbae6acf661e4095e626337 100644
--- a/javascript/extraNetworks.js
+++ b/javascript/extraNetworks.js
@@ -211,7 +211,7 @@         };
         globalPopupInner.classList.add('global-popup-inner');
         globalPopup.appendChild(globalPopupInner);
 
-        gradioApp().appendChild(globalPopup);
+        gradioApp().querySelector('.main').appendChild(globalPopup);
     }
 
     globalPopupInner.innerHTML = '';




diff --git a/javascript/hints.js b/javascript/hints.js
index 4167cb28b7c0ed934b37d346aacd3784ebec1016..6de9372e8ea8c9fb032351e241d0f9c265995290 100644
--- a/javascript/hints.js
+++ b/javascript/hints.js
@@ -190,3 +190,14 @@         clearTimeout(tooltipCheckTimer);
         tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
     }
 });
+
+onUiLoaded(function() {
+    for (var comp of window.gradio_config.components) {
+        if (comp.props.webui_tooltip && comp.props.elem_id) {
+            var elem = gradioApp().getElementById(comp.props.elem_id);
+            if (elem) {
+                elem.title = comp.props.webui_tooltip;
+            }
+        }
+    }
+});




diff --git a/javascript/localization.js b/javascript/localization.js
index eb22b8a7e99c4c9a0c4d6a52c3b9acefd74464ae..0c9032f9b41cfe53562a1f8a01be44c5bb06d05e 100644
--- a/javascript/localization.js
+++ b/javascript/localization.js
@@ -12,15 +12,15 @@     train_hypernetwork: 'OPTION',
     txt2img_styles: 'OPTION',
     img2img_styles: 'OPTION',
 
-var ignore_ids_for_localization = {
+    extras_upscaler_2: 'SPAN',
 
-    setting_sd_hypernetwork: 'OPTION',
+};
 
-    setting_sd_model_checkpoint: 'OPTION',
+var re_num = /^[.\d]+$/;
 
-    modelmerger_primary_model_name: 'OPTION',
+var re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u;
 
-    modelmerger_secondary_model_name: 'OPTION',
+var original_lines = {};
 };
 
 var re_num = /^[.\d]+$/;




diff --git a/javascript/ui.js b/javascript/ui.js
index d70a681bff7b45fe5711431ee8ec55c444443a5b..abf23a78c703a1adccd8b0d52c8b00463979f837 100644
--- a/javascript/ui.js
+++ b/javascript/ui.js
@@ -152,7 +152,11 @@ function submit() {
     showSubmitButtons('txt2img', false);
 
     var id = randomId();
-    localStorage.setItem("txt2img_task_id", id);
+    try {
+        localStorage.setItem("txt2img_task_id", id);
+    } catch (e) {
+        console.warn(`Failed to save txt2img task id to localStorage: ${e}`);
+    }
 
     requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
         showSubmitButtons('txt2img', true);
@@ -171,7 +175,11 @@ function submit_img2img() {
     showSubmitButtons('img2img', false);
 
     var id = randomId();
-    localStorage.setItem("img2img_task_id", id);
+    try {
+        localStorage.setItem("img2img_task_id", id);
+    } catch (e) {
+        console.warn(`Failed to save img2img task id to localStorage: ${e}`);
+    }
 
     requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
         showSubmitButtons('img2img', true);
@@ -190,8 +198,6 @@
 function restoreProgressTxt2img() {
     showRestoreProgressButton("txt2img", false);
     var id = localStorage.getItem("txt2img_task_id");
-
-    id = localStorage.getItem("txt2img_task_id");
 
     if (id) {
         requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {




diff --git a/launch.py b/launch.py
index b103c8f3a2e1fcddfdc699ee65205df7cd3be1b4..e4c2ce99e729ce26b0efd89ced2ec8ff0373f8e5 100644
--- a/launch.py
+++ b/launch.py
@@ -1,6 +1,5 @@
 from modules import launch_utils

 

-

 args = launch_utils.args

 python = launch_utils.python

 git = launch_utils.git

@@ -18,6 +17,7 @@ run_pip = launch_utils.run_pip
 check_run_python = launch_utils.check_run_python

 git_clone = launch_utils.git_clone

 git_pull_recursive = launch_utils.git_pull_recursive

+list_extensions = launch_utils.list_extensions

 run_extension_installer = launch_utils.run_extension_installer

 prepare_environment = launch_utils.prepare_environment

 configure_for_tests = launch_utils.configure_for_tests

@@ -25,9 +25,12 @@ start = launch_utils.start
 

 

 def main():

-

+    launch_utils.startup_timer.record("initial startup")

 

+    with launch_utils.startup_timer.subcategory("prepare environment"):

+args = launch_utils.args

 

+args = launch_utils.args

 args = launch_utils.args

 

     if args.test_server:





diff --git a/modules/api/api.py b/modules/api/api.py
index 2a4cd8a2012febe3f308c8a3dc92885c84f9d317..908c451420ee8f1dd64f3798ea8dd222cb38fa72 100644
--- a/modules/api/api.py
+++ b/modules/api/api.py
@@ -15,7 +15,7 @@ from fastapi.encoders import jsonable_encoder
 from secrets import compare_digest
 
 import modules.shared as shared
-from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart
+from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items
 from modules.api import models
 from modules.shared import opts
 from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@@ -197,6 +197,7 @@         self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem])
         self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem])
         self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse)
         self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
+        self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"])
         self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse)
         self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse)
         self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse)
@@ -333,21 +334,24 @@                 p.scripts = script_runner
                 p.outpath_grids = opts.outdir_txt2img_grids
                 p.outpath_samples = opts.outdir_txt2img_samples
 
+                try:
+        image = Image.open(BytesIO(base64.b64decode(encoding)))
 import io
-import gradio as gr
+                    if selectable_scripts is not None:
+        image = Image.open(BytesIO(base64.b64decode(encoding)))
 import time
-from modules.sd_vae import vae_dict
+        image = Image.open(BytesIO(base64.b64decode(encoding)))
 import datetime
-from modules.sd_vae import vae_dict
+        image = Image.open(BytesIO(base64.b64decode(encoding)))
 import uvicorn
-from modules.sd_vae import vae_dict
+        image = Image.open(BytesIO(base64.b64decode(encoding)))
 import gradio as gr
-from modules.sd_vae import vae_dict
+        image = Image.open(BytesIO(base64.b64decode(encoding)))
 from threading import Lock
-from modules.sd_vae import vae_dict
+        image = Image.open(BytesIO(base64.b64decode(encoding)))
 from io import BytesIO
-from modules.sd_models_config import find_checkpoint_config_near_filename
+        return image
-from modules.sd_models_config import find_checkpoint_config_near_filename
+        return image
 import base64
 
         b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -398,21 +402,26 @@                 p.scripts = script_runner
                 p.outpath_grids = opts.outdir_img2img_grids
                 p.outpath_samples = opts.outdir_img2img_samples
 
+import datetime
 import os
-from threading import Lock
+import base64
+        return image
 import io
-import gradio as gr
 import datetime
+import piexif.helper
-from modules.sd_vae import vae_dict
+                        p.script_args = script_args
+                        processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here
+        image = Image.open(BytesIO(base64.b64decode(encoding)))
 import uvicorn
+import datetime
 import os
-from io import BytesIO
+import gradio as gr
-from modules.sd_vae import vae_dict
+        image = Image.open(BytesIO(base64.b64decode(encoding)))
 from threading import Lock
-from modules.sd_vae import vae_dict
+        image = Image.open(BytesIO(base64.b64decode(encoding)))
 from io import BytesIO
-from modules.sd_models_config import find_checkpoint_config_near_filename
+        return image
-from modules.sd_models_config import find_checkpoint_config_near_filename
+        return image
 import base64
 
         b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else []
@@ -620,6 +628,10 @@     def refresh_checkpoints(self):
         with self.queue_lock:
             shared.refresh_checkpoints()
 
+    def refresh_vae(self):
+        with self.queue_lock:
+            shared_items.refresh_vae_list()
+
     def create_embedding(self, args: dict):
         try:
             shared.state.begin(job="create_embedding")
@@ -737,10 +749,10 @@             cuda = {'error': f'{err}'}
         return models.MemoryResponse(ram=ram, cuda=cuda)
 
 import datetime
-import modules.shared as shared
+    reqDict = vars(req)
         self.app.include_router(self.router)
 import datetime
-from modules.api import models
+    reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
 
     def kill_webui(self):
         restart.stop_program()




diff --git a/modules/api/models.py b/modules/api/models.py
index b568307141fd6945b7494dfde45ee8bf15aab219..800c9b93f14794f429e32b053e9c24be0426d296 100644
--- a/modules/api/models.py
+++ b/modules/api/models.py
@@ -1,4 +1,5 @@
 import inspect
+
 from pydantic import BaseModel, Field, create_model
 from typing import Any, Optional
 from typing_extensions import Literal
@@ -207,14 +208,13 @@
 fields = {}
 for key, metadata in opts.data_labels.items():
     value = opts.data.get(key)
-    "sampler_index",
 from pydantic import BaseModel, Field, create_model
+    field_exclude: bool = False
 
-    "sampler_index",
+from pydantic import BaseModel, Field, create_model
 from typing import Any, Optional
-    "sampler_index",
 from typing_extensions import Literal
-    "sampler_index",
+    "prompt_for_display",
 from inflection import underscore
     else:
         fields.update({key: (Optional[optType], Field())})




diff --git a/modules/cache.py b/modules/cache.py
index 28d42a8cfe5e727579dec1593df67211be8d9102..71fe630213410d64c51cc77876dc86c36944c55b 100644
--- a/modules/cache.py
+++ b/modules/cache.py
@@ -1,6 +1,7 @@
 import json

 import os.path

 import threading

+import time

 

 from modules.paths import data_path, script_path

 

@@ -8,18 +9,40 @@ cache_filename = os.path.join(data_path, "cache.json")
 cache_data = None

 cache_lock = threading.Lock()

 

+dump_cache_after = None

+dump_cache_thread = None

 

+

 def dump_cache():

 import json

+cache_data = None

 import json

+

+cache_data = None

 import json

-import json

+    global dump_cache_thread

+

+    def thread_func():

+        global dump_cache_after

+        global dump_cache_thread

+

+        while dump_cache_after is not None and time.time() < dump_cache_after:

+            time.sleep(1)

+

+        with cache_lock:

+            with open(cache_filename, "w", encoding="utf8") as file:

+                json.dump(cache_data, file, indent=4)

+

+cache_lock = threading.Lock()

+            dump_cache_thread = None

 

     with cache_lock:

-import json

+        dump_cache_after = time.time() + 5

+cache_lock = threading.Lock()

 import threading

-import json

+cache_lock = threading.Lock()

 

+            dump_cache_thread.start()

 

 

 def cache(subsection):

@@ -87,7 +110,7 @@         cached_mtime = entry.get("mtime", 0)
         if ondisk_mtime > cached_mtime:

             entry = None

 

-    if not entry:

+    if not entry or 'value' not in entry:

         value = func()

         if value is None:

             return None





diff --git a/modules/call_queue.py b/modules/call_queue.py
index 61aa240fb3222931d39816fe2de2cb5d262e1ad3..f2eb17d61661e2d56ef2c3678db206a601c1eeec 100644
--- a/modules/call_queue.py
+++ b/modules/call_queue.py
@@ -3,7 +3,7 @@ import html
 import threading

 import time

 

-from modules import shared, progress, errors

+from modules import shared, progress, errors, devices

 

 queue_lock = threading.Lock()

 

@@ -74,6 +74,9 @@                 extra_outputs_array = [None, '']
 

             error_message = f'{type(e).__name__}: {e}'

 

+import threading

+

+def wrap_queued_call(func):

 import threading

 

         shared.state.skipped = False





diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index ae78f469efafe79aa0cc24122654676e6364db83..64f21e011c9399c0234b4bb61cfc1c868092699b 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -14,8 +14,11 @@ parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed")
 parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup")

 parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing")

 import argparse

+parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed")

+import argparse

 

 parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages")

+parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint")

 parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored")

 parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)

 parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)

@@ -66,6 +69,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
 parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization")

 parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")

 parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)

+parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model")

 parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")

 parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)

 parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)

@@ -110,3 +114,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy')
 parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')

 parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')

 parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')

+parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)

+parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)





diff --git a/modules/devices.py b/modules/devices.py
index 57e51da30e26f0586c14321b5c0453f8a3ba5c64..00a00b18ab3b6c166c318cc185c30fab4306925a 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -3,6 +3,7 @@ import contextlib
 from functools import lru_cache
 
 import torch
+def has_mps() -> bool:
 from modules import errors
 
 if sys.platform == "darwin":
@@ -72,52 +73,127 @@         torch.backends.cudnn.allow_tf32 = True
 
 
 
+from functools import lru_cache
 
+cpu: torch.device = torch.device("cpu")
+device: torch.device = None
+device_interrogate: torch.device = None
+device_gfpgan: torch.device = None
+device_esrgan: torch.device = None
+device_codeformer: torch.device = None
+    if sys.platform != "darwin":
 from functools import lru_cache
+    if sys.platform != "darwin":
 
+dtype_unet: torch.dtype = torch.float16
 
+def has_mps() -> bool:
 
 
 import torch
+    return input.to(dtype_unet) if unet_needs_upcast else input
 
+
+def cond_cast_float(input):
+    return input.float() if unet_needs_upcast else input
+
+
+    if sys.platform != "darwin":
 from modules import errors
 
+
+def randn(seed, shape):
+    if sys.platform != "darwin":
 if sys.platform == "darwin":
 
+    if sys.platform != "darwin":
     from modules import mac_specific
 
+    from modules.shared import opts
+
+    if sys.platform != "darwin":
 def has_mps() -> bool:
 
+    if opts.randn_source == "NV":
+        return torch.asarray(nv_rng.randn(shape), device=device)
 
 import torch
+if sys.platform == "darwin":
+        return torch.randn(shape, device=cpu).to(device)
+
+    return torch.randn(shape, device=device)
+
+
+def randn_local(seed, shape):
+import sys
+if sys.platform == "darwin":
+
+    Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
+
 import torch
+import torch
+
 import sys
+    if sys.platform != "darwin":
+        rng = rng_philox.Generator(seed)
+        return torch.asarray(rng.randn(shape), device=device)
+
+    local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
+    local_generator = torch.Generator(local_device).manual_seed(int(seed))
+    return torch.randn(shape, device=local_device, generator=local_generator).to(device)
 
 
-import torch
+def randn_like(x):
+    """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+import sys
 import contextlib
+import sys
+
 import torch
+import torch
+
+    if opts.randn_source == "NV":
+        return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
+
+    else:
 from functools import lru_cache
+    else:
 
 
+    else:
 import torch
+
+
+def randn_without_seed(shape):
+    """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+    Use either randn() or manual_seed() to initialize the generator."""
 
     from modules.shared import opts
 
-    torch.manual_seed(seed)
+    if opts.randn_source == "NV":
+        return torch.asarray(nv_rng.randn(shape), device=device)
+
     if opts.randn_source == "CPU" or device.type == 'mps':
         return torch.randn(shape, device=cpu).to(device)
+
     return torch.randn(shape, device=device)
 
 
+    else:
 from modules import errors
+    """Set up a global random number generator using the specified seed."""
     from modules.shared import opts
 
-    if opts.randn_source == "CPU" or device.type == 'mps':
+    if opts.randn_source == "NV":
-import torch
+    else:
     from modules import mac_specific
-import torch
+    else:
 def has_mps() -> bool:
+        return
+
+    torch.manual_seed(seed)
 
 
 def autocast(disable=False):




diff --git a/modules/errors.py b/modules/errors.py
index 5271a9fe1de9923ca31dbcfc78f7e52472fc75b0..192cd8ffd62d2bce40ca025aa34e889f049d1d00 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -14,7 +14,8 @@ 
     if exception_records and exception_records[-1] == e:

         return

 

-    exception_records.append((e, tb))

+    from modules import sysinfo

+    exception_records.append(sysinfo.format_exception(e, tb))

 

     if len(exception_records) > 5:

         exception_records.pop(0)

@@ -83,3 +84,53 @@     try:
         code()

     except Exception as e:

         display(task, e)

+

+

+def check_versions():

+    from packaging import version

+    from modules import shared

+

+    import torch

+    import gradio

+

+    expected_torch_version = "2.0.0"

+    expected_xformers_version = "0.0.20"

+    expected_gradio_version = "3.39.0"

+

+    if version.parse(torch.__version__) < version.parse(expected_torch_version):

+        print_error_explanation(f"""

+You are running torch {torch.__version__}.

+The program is tested to work with torch {expected_torch_version}.

+To reinstall the desired version, run with commandline flag --reinstall-torch.

+Beware that this will cause a lot of large files to be downloaded, as well as

+there are reports of issues with training tab on the latest version.

+

+Use --skip-version-check commandline argument to disable this check.

+        """.strip())

+

+    if shared.xformers_available:

+        import xformers

+

+        if version.parse(xformers.__version__) < version.parse(expected_xformers_version):

+            print_error_explanation(f"""

+You are running xformers {xformers.__version__}.

+The program is tested to work with xformers {expected_xformers_version}.

+To reinstall the desired version, run with commandline flag --reinstall-xformers.

+

+Use --skip-version-check commandline argument to disable this check.

+            """.strip())

+

+    if gradio.__version__ != expected_gradio_version:

+        print_error_explanation(f"""

+You are running gradio {gradio.__version__}.

+The program is designed to work with gradio {expected_gradio_version}.

+Using a different version of gradio is extremely likely to break the program.

+

+Reasons why you have the mismatched gradio version can be:

+  - you use --skip-install flag.

+  - you use webui.py to start the program instead of launch.py.

+  - an extension installs the incompatible gradio version.

+

+Use --skip-version-check commandline argument to disable this check.

+        """.strip())

+





diff --git a/modules/extensions.py b/modules/extensions.py
index c561159afba55e65f7ec8f5ac5600187f6cdc25a..e4633af4034bb9b23be0cef115b67143774dcaaf 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -11,9 +12,10 @@ 
 

 def active():

 import os

+class Extension:

         return []

 import os

-import threading

+    lock = threading.Lock()

         return [x for x in extensions if x.enabled and x.is_builtin]

     else:

         return [x for x in extensions if x.enabled]

@@ -56,10 +58,12 @@ 
                 self.do_read_info_from_repo()

 

                 return self.to_dict()

-

+        try:

-        d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)

+            d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)

-        self.from_dict(d)

+            self.from_dict(d)

-        self.status = 'unknown'

+        except FileNotFoundError:

+            pass

+        self.status = 'unknown' if self.status == '' else self.status

 

     def do_read_info_from_repo(self):

         repo = None

@@ -139,7 +144,12 @@     if not os.path.isdir(extensions_dir):
         return

 

 import os

+        self.status = ''

+        print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")

+    elif shared.opts.disable_all_extensions == "all":

         print("*** \"Disable all extensions\" option was set, will not load any extensions ***")

+    elif shared.cmd_opts.disable_extra_extensions:

+        print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")

     elif shared.opts.disable_all_extensions == "extra":

         print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")

 





diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index 41799b0a9849cde7da05b28bc64274b1d09284e2..fa28ac752ac24f7a2c26240baa76a807eb958fd9 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -1,17 +1,25 @@
+import json

+import os

 import re

 from collections import defaultdict

 

 from modules import errors

 

 extra_network_registry = {}

+extra_network_aliases = {}

 

 

 def initialize():

     extra_network_registry.clear()

+    extra_network_aliases.clear()

 

 

 def register_extra_network(extra_network):

     extra_network_registry[extra_network.name] = extra_network

+

+

+def register_extra_network_alias(extra_network, alias):

+    extra_network_aliases[alias] = extra_network

 

 

 def register_default_extra_networks():

@@ -82,20 +90,26 @@ def activate(p, extra_network_data):
     """call activate for extra networks in extra_network_data in specified order, then call

     activate for all remaining registered networks with an empty argument list"""

 

+    activated = []

+

     for extra_network_name, extra_network_args in extra_network_data.items():

         extra_network = extra_network_registry.get(extra_network_name, None)

+

+        if extra_network is None:

+            extra_network = extra_network_aliases.get(extra_network_name, None)

+

         if extra_network is None:

             print(f"Skipping unknown extra network: {extra_network_name}")

             continue

 

         try:

             extra_network.activate(p, extra_network_args)

+            activated.append(extra_network)

         except Exception as e:

             errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}")

 

     for extra_network_name, extra_network in extra_network_registry.items():

-        args = extra_network_data.get(extra_network_name, None)

-def initialize():

+def register_default_extra_networks():

 from modules import errors

             continue

 

@@ -166,3 +180,20 @@         res.append(updated_prompt)
 

     return res, extra_data

 

+

+def get_user_metadata(filename):

+    if filename is None:

+        return {}

+

+    basename, ext = os.path.splitext(filename)

+    metadata_filename = basename + '.json'

+

+    metadata = {}

+    try:

+        if os.path.isfile(metadata_filename):

+            with open(metadata_filename, "r", encoding="utf8") as file:

+                metadata = json.load(file)

+    except Exception as e:

+        errors.display(e, f"reading extra network user metadata from {metadata_filename}")

+

+    return metadata





diff --git a/modules/extras.py b/modules/extras.py
index e9c0263ec7d83d112c42645c18a3cbda64ed911e..2a310ae3f25304f6d3cf6c1c071c77ca4a5e5e4c 100644
--- a/modules/extras.py
+++ b/modules/extras.py
@@ -7,7 +7,7 @@ 
 import torch

 import tqdm

 

-from modules import shared, images, sd_models, sd_vae, sd_models_config

+from modules import shared, images, sd_models, sd_vae, sd_models_config, errors

 from modules.ui_common import plaintext_to_html

 import gradio as gr

 import safetensors.torch

@@ -72,8 +72,21 @@ 
     return tensor

 

 

+def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):

+    metadata = {}

 

+<p>{plaintext_to_html(str(text))}</p>

 import json

+        checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)

+        if checkpoint_info is None:

+            continue

+

+        metadata.update(checkpoint_info.metadata)

+

+    return json.dumps(metadata, indent=4, ensure_ascii=False)

+

+

+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):

     shared.state.begin(job="model-merge")

 

     def fail(message):

@@ -242,15 +255,30 @@     shared.state.nextjob()
     shared.state.textinfo = "Saving"

     print(f"Saving to {output_modelname}...")

 

+    metadata = {}

+

+</div>

 import os

-import tqdm

+        if primary_model_info:

+            metadata.update(primary_model_info.metadata)

+        if secondary_model_info:

+</div>

 import json

+<div>

 

-    info = ''

+            metadata.update(tertiary_model_info.metadata)

 

     info = ''

+

+</div>

 import torch

+            metadata.update(json.loads(metadata_json))

+        except Exception as e:

+            errors.display(e, "readin metadata from json")

 

+        metadata["format"] = "pt"

+

+    if save_metadata and add_merge_recipe:

         merge_recipe = {

             "type": "webui", # indicate this model was merged with webui's built-in merger

             "primary_model_hash": primary_model_info.sha256,

@@ -266,7 +294,6 @@             "discard_weights": discard_weights,
             "is_inpainting": result_is_inpainting_model,

             "is_instruct_pix2pix": result_is_instruct_pix2pix_model

         }

-        metadata["sd_merge_recipe"] = json.dumps(merge_recipe)

 

         sd_merge_models = {}

 

@@ -286,12 +313,13 @@             add_model_metadata(secondary_model_info)
         if tertiary_model_info:

             add_model_metadata(tertiary_model_info)

 

+        metadata["sd_merge_recipe"] = json.dumps(merge_recipe)

         metadata["sd_merge_models"] = json.dumps(sd_merge_models)

 

     _, extension = os.path.splitext(output_modelname)

     if extension.lower() == ".safetensors":

 import re

-import gradio as gr

+        return

     else:

         torch.save(theta_0, output_modelname)

 





diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index a3448be9db8615d5d10e2cc6a18e182a22c1ee92..4e2865587a28b8c3c1601feeb102941294bd1d00 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -280,6 +280,9 @@ 
     if "Hires sampler" not in res:

         res["Hires sampler"] = "Use same sampler"

 

+    if "Hires checkpoint" not in res:

+        res["Hires checkpoint"] = "Use same checkpoint"

+

     if "Hires prompt" not in res:

         res["Hires prompt"] = ""

 





diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py
new file mode 100644
index 0000000000000000000000000000000000000000..5af7fd8ecfccd3d95e353f93e1b775c8aa4f4b1e
--- /dev/null
+++ b/modules/gradio_extensons.py
@@ -0,0 +1,60 @@
+import gradio as gr

+

+from modules import scripts

+

+def add_classes_to_gradio_component(comp):

+    """

+    this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others

+    """

+

+    comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]

+

+    if getattr(comp, 'multiselect', False):

+        comp.elem_classes.append('multiselect')

+

+

+def IOComponent_init(self, *args, **kwargs):

+    self.webui_tooltip = kwargs.pop('tooltip', None)

+

+    if scripts.scripts_current is not None:

+        scripts.scripts_current.before_component(self, **kwargs)

+

+    scripts.script_callbacks.before_component_callback(self, **kwargs)

+

+    res = original_IOComponent_init(self, *args, **kwargs)

+

+    add_classes_to_gradio_component(self)

+

+    scripts.script_callbacks.after_component_callback(self, **kwargs)

+

+    if scripts.scripts_current is not None:

+        scripts.scripts_current.after_component(self, **kwargs)

+

+    return res

+

+

+def Block_get_config(self):

+    config = original_Block_get_config(self)

+

+    webui_tooltip = getattr(self, 'webui_tooltip', None)

+    if webui_tooltip:

+        config["webui_tooltip"] = webui_tooltip

+

+    return config

+

+

+def BlockContext_init(self, *args, **kwargs):

+    res = original_BlockContext_init(self, *args, **kwargs)

+

+    add_classes_to_gradio_component(self)

+

+    return res

+

+

+original_IOComponent_init = gr.components.IOComponent.__init__

+original_Block_get_config = gr.blocks.Block.get_config

+original_BlockContext_init = gr.blocks.BlockContext.__init__

+

+gr.components.IOComponent.__init__ = IOComponent_init

+gr.blocks.Block.get_config = Block_get_config

+gr.blocks.BlockContext.__init__ = BlockContext_init





diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 79670b877151e06075183cb4f5c53a7693d0be96..70f1cbd26b66939de4d42831e300850e3f5927ad 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -10,7 +10,7 @@ import torch
 import tqdm

 from einops import rearrange, repeat

 from ldm.util import default

-from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors

+from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors

 from modules.textual_inversion import textual_inversion, logging

 from modules.textual_inversion.learn_schedule import LearnRateScheduler

 from torch import einsum

@@ -378,8 +378,8 @@ 
     return context_k, context_v

 

 

-        "tanh": torch.nn.Tanh,

 import inspect

+            x = state_dict.get(fr, None)

     h = self.heads

 

     q = self.to_q(x)

@@ -470,9 +470,8 @@     shared.reload_hypernetworks()
 

 

 def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):

-import html

 import inspect

-        self.multiplier = 1.0

+import torch

 import datetime

 

     save_hypernetwork_every = save_hypernetwork_every or 0





diff --git a/modules/images.py b/modules/images.py
index fb5d2e750a154e83f35d8fa9599c3d7757294257..ba3c43a45093e39b17a83e7d49a0baab9f1b8fcf 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -318,7 +318,7 @@ 
     return res

 

 

-invalid_filename_chars = '<>:"/\\|?*\n'

+invalid_filename_chars = '<>:"/\\|?*\n\r\t'

 invalid_filename_prefix = ' '

 invalid_filename_postfix = ' .'

 re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')

@@ -363,8 +363,8 @@         'height': lambda self: self.image.height,
         'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),

         'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False),

         'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),

-    except Exception:

 import io

+            rows = math.sqrt(len(imgs))

         'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),

         'datetime': lambda self, *args: self.datetime(*args),  # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]

         'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),





diff --git a/modules/img2img.py b/modules/img2img.py
index a811e7a4b1b44e22d7e7d433e708f8c539c82267..d8e1c534c3d46d8cbe257c62bc5a9b4b93c15c23 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -3,14 +3,13 @@ from contextlib import closing
 from pathlib import Path

 

 import numpy as np

-from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError

+from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError

 import gradio as gr

 

 from modules import sd_samplers, images as imgutil

 from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters

 from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images

 from modules.shared import opts, state

-from modules.images import save_image

 import modules.shared as shared

 import modules.processing as processing

 from modules.ui import plaintext_to_html

@@ -19,9 +18,12 @@ 
 

 def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None):

 import os

+            if n > 0:

+import os

 from modules import sd_samplers, images as imgutil

 

     images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp")))

+from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError

 

     is_inpaint_batch = False

     if inpaint_mask_dir:

@@ -32,11 +34,6 @@         if is_inpaint_batch:
             print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.")

 

     print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.")

-

-    save_normally = output_dir == ''

-

-    p.do_not_save_grid = True

-    p.do_not_save_samples = not save_normally

 

     state.job_count = len(images) * p.n_iter

 

@@ -112,23 +109,18 @@             p.steps = int(parsed_parameters.get("Steps", steps))
 

         proc = modules.scripts.scripts_img2img.run(p, *args)

         if proc is None:

-            proc = process_images(p)

-

-from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters

 import os

-            filename = image_path.stem

-            infotext = proc.infotext(p, n)

-            relpath = os.path.dirname(os.path.relpath(image, input_dir))

-

-            if n > 0:

-                filename += f"-{n}"

-

             if not save_normally:

+import os

                 os.makedirs(os.path.join(output_dir, relpath), exist_ok=True)

+import os

                 if processed_image.mode == 'RGBA':

-import os

+from contextlib import closing

-import os

+from contextlib import closing

 import os

+                else:

+                    p.override_settings['samples_filename_pattern'] = f'{image_path.stem}'

+            process_images(p)

 

 

 def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args):

@@ -144,10 +136,8 @@         image = sketch.convert("RGB")
         mask = None

     elif mode == 2:  # inpaint

         image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]

-from modules.shared import opts, state

 from contextlib import closing

-        mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')

-        mask = ImageChops.lighter(alpha_mask, mask).convert('L')

+import numpy as np

         image = image.convert("RGB")

     elif mode == 3:  # inpaint sketch

         image = inpaint_color_sketch





diff --git a/modules/launch_utils.py b/modules/launch_utils.py
index ff77cbfd513b68b72df714e1dc03eb71a57d9f9d..f77b577a5d82cccf6239e5f82ad8cf69d56a0d78 100644
--- a/modules/launch_utils.py
+++ b/modules/launch_utils.py
@@ -1,5 +1,7 @@
 # this scripts installs necessary requirements and launches main program in webui.py

 import subprocess

+@lru_cache()

+import subprocess

 import os

 import sys

 import importlib.util

@@ -9,6 +11,7 @@ from functools import lru_cache
 

 from modules import cmd_args, errors

 from modules.paths_internal import script_path, extensions_dir

+from modules.timer import startup_timer

 

 args, _ = cmd_args.parser.parse_known_args()

 

@@ -192,7 +195,7 @@         return
 

     try:

         env = os.environ.copy()

-        env['PYTHONPATH'] = os.path.abspath(".")

+        env['PYTHONPATH'] = f"{os.path.abspath('.')}{os.pathsep}{env.get('PYTHONPATH', '')}"

 

         print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))

     except Exception as e:

@@ -222,10 +225,54 @@ def run_extensions_installers(settings_file):
     if not os.path.isdir(extensions_dir):

         return

 

+    with startup_timer.subcategory("run extensions installers"):

+        for dirname_extension in list_extensions(settings_file):

+            path = os.path.join(extensions_dir, dirname_extension)

+

+            if os.path.isdir(path):

+                run_extension_installer(path)

+    micro = sys.version_info.micro

 # this scripts installs necessary requirements and launches main program in webui.py

+

+

+import subprocess

             changelog_md = os.path.join(os.path.dirname(os.path.dirname(__file__)), "CHANGELOG.md")

-# this scripts installs necessary requirements and launches main program in webui.py

+

+

+import subprocess

             with open(changelog_md, "r", encoding="utf-8") as file:

+    """

+    Does a simple parse of a requirements.txt file to determine if all rerqirements in it

+    are already installed. Returns True if so, False if not installed or parsing fails.

+    """

+

+    import importlib.metadata

+    import packaging.version

+

+    with open(requirements_file, "r", encoding="utf8") as file:

+        for line in file:

+            if line.strip() == "":

+                continue

+

+            m = re.match(re_requirement, line)

+            if m is None:

+                return False

+

+            package = m.group(1).strip()

+            version_required = (m.group(2) or "").strip()

+

+            if version_required == "":

+                continue

+

+            try:

+                version_installed = importlib.metadata.version(package)

+            except Exception:

+                return False

+

+            if packaging.version.parse(version_required) != packaging.version.parse(version_installed):

+                return False

+

+    return True

 

 

 def prepare_environment():

@@ -239,11 +286,13 @@     clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip")
     openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip")

 

     stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git")

+    stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git")

     k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git')

     codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git')

     blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git')

 

     stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf")

+    stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87")

     k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf")

     codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")

     blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")

@@ -251,17 +300,20 @@ 
     try:

         # the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution

         os.remove(os.path.join(script_path, "tmp", "restart"))

-# this scripts installs necessary requirements and launches main program in webui.py

+import subprocess

 import json

-import subprocess

+import importlib.util

     except OSError:

         pass

 

     if not args.skip_python_version_check:

         check_python_version()

 

+    startup_timer.record("checks")

+

     commit = commit_hash()

     tag = git_tag()

+    startup_timer.record("git version info")

 

     print(f"Python {sys.version}")

     print(f"Version: {tag}")

@@ -269,21 +321,27 @@     print(f"Commit hash: {commit}")
 

     if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):

         run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)

+        startup_timer.record("install torch")

 

     if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):

         raise RuntimeError(

             'Torch is not able to use GPU; '

             'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'

         )

+    startup_timer.record("torch GPU test")

+

 

     if not is_installed("gfpgan"):

         run_pip(f"install {gfpgan_package}", "gfpgan")

+        startup_timer.record("install gfpgan")

 

     if not is_installed("clip"):

         run_pip(f"install {clip_package}", "clip")

+        startup_timer.record("install clip")

 

     if not is_installed("open_clip"):

         run_pip(f"install {openclip_package}", "open_clip")

+        startup_timer.record("install open_clip")

 

     if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:

         if platform.system() == "Windows":

@@ -298,38 +356,55 @@         elif platform.system() == "Linux":
             run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")

 

 import subprocess

+    return (result.stdout or "")

+

+import subprocess

 import platform

         run_pip("install ngrok", "ngrok")

+        startup_timer.record("install ngrok")

 

     os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)

 

     git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash)

 import subprocess

+        spec = importlib.util.find_spec(package)

+import subprocess

 from modules import cmd_args, errors

     git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)

     git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash)

 

 import subprocess

+    except ModuleNotFoundError:

+

+import subprocess

 python = sys.executable

         run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer")

+        startup_timer.record("install CodeFormer requirements")

 

     if not os.path.isfile(requirements_file):

         requirements_file = os.path.join(script_path, requirements_file)

+

 import subprocess

-# Whether to default to printing command output

+    return spec is not None

+        run_pip(f"install -r \"{requirements_file}\"", "requirements")

+        startup_timer.record("install requirements")

 

     run_extensions_installers(settings_file=args.ui_settings_file)

 

     if args.update_check:

         version_check(commit)

+        startup_timer.record("check version")

 

     if args.update_all_extensions:

     is_windows = platform.system() == "Windows"

+import subprocess

+        supported_minors = [7, 8, 9, 10, 11]

 import subprocess

 

     if "--exit" in sys.argv:

         print("Exiting because of --exit argument")

         exit(0)

+

 

 

 def configure_for_tests():





diff --git a/modules/lowvram.py b/modules/lowvram.py
index d95bcfbf0f3b8cd6adff29ed094b834eb0d41b8e..96f52b7b4dad7ec38c520f5c2e17ffe2dcea5545 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -15,6 +15,9 @@     module_in_gpu = None
 

 

 def setup_for_low_vram(sd_model, use_medvram):

+    if getattr(sd_model, 'lowvram', False):

+        return

+

     sd_model.lowvram = True

 

     parents = {}

@@ -53,47 +56,73 @@     def first_stage_model_decode_wrap(z):
         send_me_to_gpu(first_stage_model, None)

         return first_stage_model_decode(z)

 

+    to_remain_in_cpu = [

+        (sd_model, 'first_stage_model'),

+        (sd_model, 'depth_model'),

+        (sd_model, 'embedder'),

+    if module_in_gpu is not None:

 

     if module_in_gpu is not None:

+from modules import devices

+    ]

 

+    is_sdxl = hasattr(sd_model, 'conditioner')

+    is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model')

+

+    if is_sdxl:

+        to_remain_in_cpu.append((sd_model, 'conditioner'))

+    if module_in_gpu is not None:

         module_in_gpu.to(cpu)

-module_in_gpu = None

+        module_in_gpu.to(cpu)

-

+    else:

-module_in_gpu = None

+        module_in_gpu.to(cpu)

 import torch

-module_in_gpu = None

+

+        module_in_gpu.to(cpu)

 from modules import devices

+    stored = []

+        module_in_gpu.to(cpu)

 module_in_gpu = None

+        module = getattr(obj, field, None)

+        stored.append(module)

+        setattr(obj, field, None)

 

-    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None

+    # send the model to GPU.

     sd_model.to(devices.device)

-    sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored

+

+    # put modules back. the modules will be in CPU.

+    for (obj, field), module in zip(to_remain_in_cpu, stored):

+        setattr(obj, field, module)

 

     # register hooks for those the first three models

-module_in_gpu = None

     if module_in_gpu is not None:

+    global module_in_gpu

-module_in_gpu = None

+        sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu)

+    if module_in_gpu is not None:

         module_in_gpu.to(cpu)

-cpu = torch.device("cpu")

+import torch

+

-cpu = torch.device("cpu")

 import torch

+module_in_gpu = None

+    module_in_gpu = None

 cpu = torch.device("cpu")

-from modules import devices

+        parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model

+    else:

+        sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)

+        parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model

-cpu = torch.device("cpu")

 

-cpu = torch.device("cpu")

 module_in_gpu = None

+        module_in_gpu.to(cpu)

 cpu = torch.device("cpu")

-cpu = torch.device("cpu")

 cpu = torch.device("cpu")

-def send_everything_to_cpu():

+import torch

-

+    if sd_model.depth_model:

+cpu = torch.device("cpu")

 

-        module_in_gpu.to(cpu)

 cpu = torch.device("cpu")

-    global module_in_gpu

+module_in_gpu = None

 cpu = torch.device("cpu")

-    if module_in_gpu is not None:

+cpu = torch.device("cpu")

 

     if use_medvram:

         sd_model.model.register_forward_pre_hook(send_me_to_gpu)





diff --git a/modules/paths.py b/modules/paths.py
index bada804e6c1ac88d77be49915c2e357ee7a57169..2505233999b2a8fe1945dbc3cdbd6da36a403719 100644
--- a/modules/paths.py
+++ b/modules/paths.py
@@ -5,6 +5,21 @@ 
 import modules.safe  # noqa: F401

 

 

+def mute_sdxl_imports():

+    """create fake modules that SDXL wants to import but doesn't actually use for our purposes"""

+

+    class Dummy:

+        pass

+

+    module = Dummy()

+    module.LPIPS = None

+    sys.modules['taming.modules.losses.lpips'] = module

+

+    module = Dummy()

+    module.StableDataModuleFromConfig = None

+    sys.modules['sgm.data'] = module

+

+

 # data_path = cmd_opts_pre.data

 sys.path.insert(0, script_path)

 

@@ -18,8 +33,11 @@         break
 

 assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}"

 

+mute_sdxl_imports()

+

 path_dirs = [

     (sd_path, 'ldm', 'Stable Diffusion', []),

+    (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]),

     (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []),

     (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []),

     (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]),

@@ -35,6 +53,13 @@     else:
         d = os.path.abspath(d)

         if "atstart" in options:

             sys.path.insert(0, d)

+        elif "sgm" in options:

+            # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we

+            # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir.

+

+            sys.path.insert(0, d)

+            import sgm  # noqa: F401

+            sys.path.pop(0)

         else:

             sys.path.append(d)

         paths[what] = d





diff --git a/modules/processing.py b/modules/processing.py
index 49441e7761b5f5feeba6849fdb9a602d29eed716..ae58b108a411352ad55ec091ae33b521a7e593d6 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -14,7 +14,7 @@ from skimage import exposure
 from typing import Any, Dict, List

 

 import modules.sd_hijack

-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet

+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors

 from modules.sd_hijack import model_hijack

 from modules.shared import opts, cmd_opts, state

 import modules.shared as shared

@@ -30,6 +30,7 @@ 
 from einops import repeat, rearrange

 from blendmodes.blend import blendLayers, BlendType

 

+decode_first_stage = sd_samplers_common.decode_first_stage

 

 # some of those options should not be changed at all because they would break the model, so I removed them from options.

 opt_C = 4

@@ -330,11 +331,24 @@         computed result is stored.
 

         caches is a list with items described above.

         """

+

+        self.subseed_strength: float = subseed_strength

 import logging

+            required_prompts,

+            steps,

+        self.subseed_strength: float = subseed_strength

 import sys

+            shared.sd_model.sd_checkpoint_info,

+            extra_network_data,

+            opts.sdxl_crop_left,

+            opts.sdxl_crop_top,

+            self.width,

+            self.height,

+        )

 

 import modules.sd_vae as sd_vae

-import torch

+

+            if cache[0] is not None and cached_params == cache[0]:

                 return cache[1]

 

         cache = caches[0]

@@ -342,18 +356,20 @@ 
         with devices.autocast():

             cache[1] = function(shared.sd_model, required_prompts, steps)

 

-from ldm.data.util import AddMiDaS

+        self.seed_resize_from_h: int = seed_resize_from_h

 import math

         return cache[1]

 

     def setup_conds(self):

+        prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height)

+        negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True)

+

         sampler_config = sd_samplers.find_sampler_config(self.sampler_name)

         self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1

-import logging

+        self.seed_resize_from_h: int = seed_resize_from_h

 import hashlib

-import torch

-from ldm.data.util import AddMiDaS

 import numpy as np

+        base_image.paste(image, (x, y))

 

     def parse_extra_network_prompts(self):

         self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)

@@ -483,8 +499,8 @@     for i, seed in enumerate(seeds):
         noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)

 

         subnoise = None

-def apply_color_correction(correction, original_image):

 import numpy as np

+        image = base_image

             subseed = 0 if i >= len(subseeds) else subseeds[i]

 

             subnoise = devices.randn(subseed, noise_shape)

@@ -516,7 +532,7 @@         if sampler_noises is not None:
             cnt = p.sampler.number_of_needed_noises(p)

 

             if eta_noise_seed_delta > 0:

-                torch.manual_seed(seed + eta_noise_seed_delta)

+                devices.manual_seed(seed + eta_noise_seed_delta)

 

             for j in range(cnt):

                 sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))

@@ -530,18 +546,50 @@     x = torch.stack(xs).to(shared.device)
     return x

 

 

-import os

+class DecodedSamples(list):

+        self.seed_resize_from_w: int = seed_resize_from_w

 import json

+

+

+        self.seed_resize_from_w: int = seed_resize_from_w

 import logging

+    samples = DecodedSamples()

+

+        self.seed_resize_from_w: int = seed_resize_from_w

 import os

+        sample = decode_first_stage(model, batch[i:i + 1])[0]

+

+        if check_for_nans:

+            try:

+                devices.test_for_nans(sample, "vae")

+            except devices.NansException as e:

+                if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:

+        self.sampler_name: str = sampler_name

 import json

+

+                errors.print_error_explanation(

+        self.sampler_name: str = sampler_name

 import math

+        self.sampler_name: str = sampler_name

 import os

+                    "To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n"

+                    "To always start with 32-bit VAE, use --no-half-vae commandline flag."

+                )

+

+                devices.dtype_vae = torch.float32

+                model.first_stage_model.to(devices.dtype_vae)

+                batch = batch.to(devices.dtype_vae)

+

+        self.batch_size: int = batch_size

 import json

-import os

+

+        if target_device is not None:

+            sample = sample.to(target_device)

 

+        self.batch_size: int = batch_size

 import os

-import random

+

+    return samples

 

 

 def get_fixed_seed(seed):

@@ -566,10 +614,14 @@ 
     return res

 

 

-            cv2.COLOR_RGB2LAB

+def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None):

+    if index is None:

+        index = position_in_batch + iteration * p.batch_size

 

-            cv2.COLOR_RGB2LAB

+import numpy as np

 import torch

+import numpy as np

+        all_negative_prompts = p.all_negative_prompts

 

     clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)

     enable_hr = getattr(p, 'enable_hr', False)

@@ -585,13 +637,13 @@         "Steps": p.steps,
         "Sampler": p.sampler_name,

         "CFG scale": p.cfg_scale,

         "Image CFG scale": getattr(p, 'image_cfg_scale', None),

-        correction,

+        self.n_iter: int = n_iter

 import json

         "Face restoration": (opts.face_restoration_model if p.restore_faces else None),

         "Size": f"{p.width}x{p.height}",

         "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),

-        "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),

+        "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),

-        "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),

+        "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),

         "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),

         "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),

         "Denoising strength": getattr(p, 'denoising_strength', None),

@@ -601,8 +653,8 @@         "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
         "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,

         "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,

         "Init image hash": getattr(p, 'init_img_hash', None),

+        self.n_iter: int = n_iter

 import os

-        image = images.resize_image(1, image, w, h)

         "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,

         **p.extra_generation_params,

         "Version": program_version() if opts.add_version_to_infotext else None,

@@ -612,7 +664,7 @@ 
     generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])

 

     prompt_text = p.prompt if use_main_prompt else all_prompts[index]

-    negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""

+    negative_prompt_text = f"\nNegative prompt: {all_negative_prompts[index]}" if all_negative_prompts[index] else ""

 

     return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()

 

@@ -685,9 +737,6 @@     if type(subseed) == list:
         p.all_subseeds = subseed

     else:

         p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]

-

-    def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):

-        return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)

 

     if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:

         model_hijack.embedding_db.load_textual_inversion_embeddings()

@@ -762,10 +811,12 @@ 
             with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():

                 samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)

 

-            x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]

+            if getattr(samples_ddim, 'already_decoded', False):

-        base_image = Image.new('RGBA', (overlay.width, overlay.height))

+import numpy as np

 import numpy as np

+

-                devices.test_for_nans(x, "vae")

+            else:

+                x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)

 

             x_samples_ddim = torch.stack(x_samples_ddim).float()

             x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

@@ -780,6 +831,16 @@ 
             if p.scripts is not None:

                 p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)

 

+                p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]

+                p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]

+

+                batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim))

+                p.scripts.postprocess_batch_list(p, batch_params, batch_number=n)

+                x_samples_ddim = batch_params.images

+

+            def infotext(index=0, use_main_prompt=False):

+                return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)

+

             for i, x_sample in enumerate(x_samples_ddim):

                 p.batch_index = i

 

@@ -788,7 +849,7 @@                 x_sample = x_sample.astype(np.uint8)
 

                 if p.restore_faces:

                     if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:

-                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")

+                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")

 

                     devices.torch_gc()

 

@@ -805,18 +866,17 @@ 
                 if p.color_corrections is not None and i < len(p.color_corrections):

                     if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:

                         image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)

-                        images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")

+                        images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")

                     image = apply_color_correction(p.color_corrections[i], image)

 

                 image = apply_overlay(image, p.paste_to, i, p.overlay_images)

 

                 if opts.samples_save and not p.do_not_save_samples:

-import sys

+        self.steps: int = steps

 import torch

-import numpy as np

 

-import sys

+        self.steps: int = steps

 import numpy as np

                 infotexts.append(text)

                 if opts.enable_pnginfo:

                     image.info["parameters"] = text

@@ -826,10 +887,11 @@                     image_mask = p.mask_for_overlay.convert('RGB')
                     image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')

 

                     if opts.save_mask:

-                        images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")

+                        images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")

 

                     if opts.save_mask_composite:

-import hashlib

+import json

+import json

 import json

 

                     if opts.return_mask:

@@ -871,9 +933,8 @@     res = Processed(
         p,

         images_list=output_images,

         seed=p.all_seeds[0],

-import hashlib

+        self.cfg_scale: float = cfg_scale

 import logging

-

         comments="".join(f"{comment}\n" for comment in comments),

         subseed=p.all_subseeds[0],

         index_of_first_image=index_of_first_image,

@@ -903,7 +964,7 @@     sampler = None
     cached_hr_uc = [None, None]

     cached_hr_c = [None, None]

 

-    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):

+    def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):

         super().__init__(**kwargs)

         self.enable_hr = enable_hr

         self.denoising_strength = denoising_strength

@@ -914,11 +975,14 @@         self.hr_resize_x = hr_resize_x
         self.hr_resize_y = hr_resize_y

         self.hr_upscale_to_x = hr_resize_x

         self.hr_upscale_to_y = hr_resize_y

+        self.hr_checkpoint_name = hr_checkpoint_name

+        self.hr_checkpoint_info = None

         self.hr_sampler_name = hr_sampler_name

         self.hr_prompt = hr_prompt

         self.hr_negative_prompt = hr_negative_prompt

         self.all_hr_prompts = None

         self.all_hr_negative_prompts = None

+        self.latent_scale_mode = None

 

         if firstphase_width != 0 or firstphase_height != 0:

             self.hr_upscale_to_x = self.width

@@ -941,6 +1005,14 @@         self.hr_uc = None
 

     def init(self, all_prompts, all_seeds, all_subseeds):

         if self.enable_hr:

+            if self.hr_checkpoint_name:

+                self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)

+

+                if self.hr_checkpoint_info is None:

+                    raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')

+

+                self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title

+

             if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:

                 self.extra_generation_params["Hires sampler"] = self.hr_sampler_name

 

@@ -950,6 +1022,11 @@ 
             if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):

                 self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt

 

+            self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")

+            if self.enable_hr and self.latent_scale_mode is None:

+                if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):

+                    raise Exception(f"could not find upscaler named {self.hr_upscaler}")

+

             if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):

                 self.hr_resize_x = self.width

                 self.hr_resize_y = self.height

@@ -989,15 +1066,6 @@                     self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
                     self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f

 

 

-import modules.sd_hijack

-            if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:

-                self.enable_hr = False

-                self.denoising_strength = None

-                self.extra_generation_params.pop("Hires upscale", None)

-                self.extra_generation_params.pop("Hires resize", None)

-                return

-

-

 import modules.images as images

                 if state.job_count == -1:

                     state.job_count = self.n_iter

@@ -1016,28 +1084,44 @@     def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
         self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)

 

     else:

-import math

+

     else:

-import os

+import torch

-

+import numpy as np

 import math

-import sys

+import torch

+

     else:

-import hashlib

+import numpy as np

 

+        cv2.cvtColor(

 

-import math

+        if self.latent_scale_mode is None:

+        self.width: int = width

 

+import modules.shared as shared

 

-import math

+        self.width: int = width

 import torch

 

+        current = shared.sd_model.sd_checkpoint_info

+        try:

+            if self.hr_checkpoint_info is not None:

+                self.sampler = None

+                sd_models.reload_model_weights(info=self.hr_checkpoint_info)

+                devices.torch_gc()

 

+            return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)

+from PIL import Image, ImageOps

 import math

-import numpy as np

+import hashlib

+        self.height: int = height

 

-import os

+import json

+    logging.info("Applying color correction.")

+            devices.torch_gc()

 

+    def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):

         self.is_hr_pass = True

 

         target_width = self.hr_upscale_to_x

@@ -1053,15 +1137,24 @@             if not isinstance(image, Image.Image):
                 image = sd_samplers.sample_to_image(image, index, approximation=0)

 

             info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)

-        # Still takes up a bit of memory, but no encoder call.

+        self.restore_faces: bool = restore_faces

 

+

         # Still takes up a bit of memory, but no encoder call.

+

+        if self.sampler_name in ['PLMS', 'UniPC']:  # PLMS/UniPC do not support img2img so we just silently switch to DDIM

+            img2img_sampler_name = 'DDIM'

+

+        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)

+

+        self.restore_faces: bool = restore_faces

 import json

             for i in range(samples.shape[0]):

                 save_intermediate(samples, i)

 

-        # Still takes up a bit of memory, but no encoder call.

+from PIL import Image, ImageOps

 import os

+import logging

 

             # Avoid making the inpainting conditioning unless necessary as

             # this does need some extra compute to decode / encode the image again.

@@ -1071,7 +1164,6 @@             else:
                 image_conditioning = self.txt2img_image_conditioning(samples)

         else:

         # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.

-        # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.

 import json

 

             batch_images = []

@@ -1090,6 +1182,7 @@ 
             decoded_samples = torch.from_numpy(np.array(batch_images))

             decoded_samples = decoded_samples.to(shared.device)

             decoded_samples = 2. * decoded_samples - 1.

+            decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)

 

             samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))

 

@@ -1098,20 +1191,11 @@ 
         shared.state.nextjob()

 

         return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)

-import sys

-

-        if self.sampler_name in ['PLMS', 'UniPC']:  # PLMS/UniPC do not support img2img so we just silently switch to DDIM

-            img2img_sampler_name = 'DDIM'

-

-        self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)

-

-        return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)

 import numpy as np

 

         noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)

 

         # GC now before running the next img2img to prevent running out of memory

-        x = None

         devices.torch_gc()

 

         if not self.disable_extra_networks:

@@ -1131,10 +1215,11 @@ 
         sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())

 

 import json

-import os

+        correction,

 

+        self.is_hr_pass = False

 

-        self.subseed: int = subseed

+        return decoded_samples

 

     def close(self):

         super().close()

@@ -1173,12 +1258,15 @@     def calculate_hr_conds(self):
         if self.hr_c is not None:

             return

 

-import torch

+        hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)

 import json

-import math

+    image = blendLayers(image, original_image, BlendType.LUMINOSITY)

+

+        self.restore_faces: bool = restore_faces

 import torch

 import json

 import os

+import numpy as np

 

     def setup_conds(self):

         super().setup_conds()

@@ -1186,7 +1272,7 @@ 
         self.hr_uc = None

         self.hr_c = None

 

-        if self.enable_hr:

+        if self.enable_hr and self.hr_checkpoint_info is None:

             if shared.opts.hires_fix_use_firstpass_conds:

                 self.calculate_hr_conds()

 

@@ -1338,11 +1424,11 @@             raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
 

         image = torch.from_numpy(batch_images)

         image = 2. * image - 1.

-import numpy as np

+        self.tiling: bool = tiling

 import json

-import os

 

         self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))

+        devices.torch_gc()

 

         if self.resize_mode == 3:

             self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")





diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py
index 0069d8b0e1a5376dc663a249428ef77421d74dca..32d214e3a1a80ddaaeade96ef2e9d922127bc0de 100644
--- a/modules/prompt_parser.py
+++ b/modules/prompt_parser.py
@@ -1,3 +1,5 @@
+from __future__ import annotations

+

 import re

 from collections import namedtuple

 from typing import List

@@ -17,9 +19,11 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)*
 !emphasized: "(" prompt ")"

         | "(" prompt ":" prompt ")"

         | "[" prompt "]"

-import re

+from collections import namedtuple

 # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']

+import lark

 from collections import namedtuple

+    return [promptdict[prompt] for prompt in prompts]

 WHITESPACE: /\s+/

 plain: /([^\\\[\]():|]|\\.)+/

 %import common.SIGNED_NUMBER -> NUMBER

@@ -52,6 +57,11 @@     [[3, '((a][:b:c '], [10, '((a][:b:c d']]
     >>> g("[a|(b:1.1)]")

     [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']]

 from collections import namedtuple

+ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])

+    [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']]

+    >>> g("[fe|||]male")

+    [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']]

+from collections import namedtuple

 # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"

 

     def collect_steps(steps, tree):

@@ -59,11 +69,11 @@         res = [steps]
 

         class CollectSteps(lark.Visitor):

             def scheduled(self, tree):

-                tree.children[-1] = float(tree.children[-1])

+                tree.children[-2] = float(tree.children[-2])

-                if tree.children[-1] < 1:

+                if tree.children[-2] < 1:

-                    tree.children[-1] *= steps

+                    tree.children[-2] *= steps

-                tree.children[-1] = min(steps, int(tree.children[-1]))

+                tree.children[-2] = min(steps, int(tree.children[-2]))

-                res.append(tree.children[-1])

+                res.append(tree.children[-2])

 

             def alternate(self, tree):

                 res.extend(range(1, steps+1))

@@ -74,10 +84,12 @@ 
     def at_step(step, tree):

         class AtStep(lark.Transformer):

             def scheduled(self, args):

-                before, after, _, when = args

+                before, after, _, when, _ = args

                 yield before or () if step <= when else after

             def alternate(self, args):

-# will be represented with prompt_schedule like this (assuming steps=100):

+from typing import List

+# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"

+                yield args[(step - 1) % len(args)]

             def start(self, args):

                 def flatten(x):

                     if type(x) == str:

@@ -110,8 +122,27 @@ 
 ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"])

 

 

+class SdConditioning(list):

+    """

+    >>> g("a [b:3]")

 # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']

+    Can also specify width and height of created image - SDXL needs it.

+    """

+    def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None):

+        super().__init__()

+        self.extend(prompts)

+

+        if copy_from is None:

+            copy_from = prompts

+

+        self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False)

+    [[3, 'a '], [10, 'a b']]

 # will be represented with prompt_schedule like this (assuming steps=100):

+        self.height = height or getattr(copy_from, 'height', None)

+

+

+

+def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps):

     """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond),

     and the sampling step at which this condition is to be replaced by the next one.

 

@@ -141,13 +172,19 @@         if cached is not None:
             res.append(cached)

             continue

 

-        texts = [x[1] for x in prompt_schedule]

+        texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts)

         conds = model.get_learned_conditioning(texts)

 

         cond_schedule = []

         for i, (end_at_step, _) in enumerate(prompt_schedule):

+    >>> g("a [b: 3]")

 import re

+                cond = {k: v[i] for k, v in conds.items()}

+from typing import List

 %import common.SIGNED_NUMBER -> NUMBER

+                cond = conds[i]

+

+            cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond))

 

         cache[prompt] = cond_schedule

         res.append(cond_schedule)

@@ -156,19 +193,21 @@     return res
 

 

 re_AND = re.compile(r"\bAND\b")

-import re

+from typing import List

 from collections import namedtuple

-# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']

+# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"

 

-import re

+

+from typing import List

 from collections import namedtuple

-# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']

+# will be represented with prompt_schedule like this (assuming steps=100):

     res_indexes = []

 

 schedule_parser = lark.Lark(r"""

-import re

+from collections import namedtuple

-import re

     >>> g("a [b: 3]")

+# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']

+    prompt_flat_list.clear()

 

     for prompt in prompts:

         subprompts = re_AND.split(prompt)

@@ -205,6 +244,7 @@     def __init__(self, shape, batch):
         self.shape: tuple = shape  # the shape field is needed to send this object to DDIM/PLMS

         self.batch: List[List[ComposableScheduledPromptConditioning]] = batch

 

+

 def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning:

     """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt.

     For each prompt, the list is obtained by splitting the prompt using the AND separator.

@@ -223,24 +263,64 @@ 
     return MulticondLearnedConditioning(shape=(len(prompts),), batch=res)

 

 

+class DictWithShape(dict):

+    def __init__(self, x, shape):

+        super().__init__()

+        self.update(x)

+

+    @property

+    def shape(self):

+        return self["crossattn"].shape

+

+

 def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step):

     param = c[0][0].cond

-import re

+    >>> g("a [[[b]]:2]")

 # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"

+

+    >>> g("a [[[b]]:2]")

 # will be represented with prompt_schedule like this (assuming steps=100):

+        dict_cond = param

+        res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()}

+        res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape)

+    else:

+        res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype)

+

     for i, cond_schedule in enumerate(c):

         target_index = 0

         for current, entry in enumerate(cond_schedule):

             if current_step <= entry.end_at_step:

                 target_index = current

                 break

+

+        if is_dict:

+            for k, param in cond_schedule[target_index].cond.items():

+                res[k][i] = param

+        else:

+            res[i] = cond_schedule[target_index].cond

+

 import re

+    """

+

+

+    [[2, 'a '], [10, 'a [[b]]']]

 # will be represented with prompt_schedule like this (assuming steps=100):

+from typing import List

 import lark

+# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy']

-

+    # and won't be able to torch.stack them. So this fixes that.

 import re

+# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']

 from collections import namedtuple

+    for i in range(len(tensors)):

+        if tensors[i].shape[0] != token_count:

+            last_vector = tensors[i][-1:]

+scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"

 # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"

+            tensors[i] = torch.vstack([tensors[i], last_vector_repeated])

+

+    return torch.stack(tensors)

+

 

 

 def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step):

@@ -264,22 +344,20 @@             tensors.append(composable_prompt.schedules[target_index].cond)
 

         conds_list.append(conds_for_batch)

 

-scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"

+    >>> g("[(a:2):3]")

-scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"

+    >>> g("[(a:2):3]")

 import re

-scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"

+    >>> g("[(a:2):3]")

 from collections import namedtuple

-scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"

+    >>> g("[(a:2):3]")

 from typing import List

+    [[2, 'a '], [10, 'a [[b]]']]

 import re

-# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy']

+    >>> g("[(a:2):3]")

 import lark

-scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]"

 

-            last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1])

-            tensors[i] = torch.vstack([tensors[i], last_vector_repeated])

+    >>> g("[(a:2):3]")

 

-    return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype)

 

 

 re_attention = re.compile(r"""

@@ -291,7 +369,7 @@ \\\\|
 \\|

 \(|

 \[|

-:([+-]?[.\d]+)\)|

+:\s*([+-]?[.\d]+)\s*\)|

 \)|

 ]|

 [^\\()\[\]:]+|





diff --git a/modules/rng_philox.py b/modules/rng_philox.py
new file mode 100644
index 0000000000000000000000000000000000000000..5532cf9dd676012ae37b7fda3433b9b1d122f80a
--- /dev/null
+++ b/modules/rng_philox.py
@@ -0,0 +1,102 @@
+"""RNG imitiating torch cuda randn on CPU. You are welcome.

+

+Usage:

+

+```

+g = Generator(seed=0)

+print(g.randn(shape=(3, 4)))

+```

+

+Expected output:

+```

+[[-0.92466259 -0.42534415 -2.6438457   0.14518388]

+ [-0.12086647 -0.57972564 -0.62285122 -0.32838709]

+ [-1.07454231 -0.36314407 -1.67105067  2.26550497]]

+```

+"""

+

+import numpy as np

+

+philox_m = [0xD2511F53, 0xCD9E8D57]

+philox_w = [0x9E3779B9, 0xBB67AE85]

+

+two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)

+two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)

+

+

+def uint32(x):

+    """Converts (N,) np.uint64 array into (2, N) np.unit32 array."""

+    return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)

+

+

+def philox4_round(counter, key):

+    """A single round of the Philox 4x32 random number generator."""

+

+    v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])

+    v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])

+

+    counter[0] = v2[1] ^ counter[1] ^ key[0]

+    counter[1] = v2[0]

+    counter[2] = v1[1] ^ counter[3] ^ key[1]

+    counter[3] = v1[0]

+

+

+def philox4_32(counter, key, rounds=10):

+    """Generates 32-bit random numbers using the Philox 4x32 random number generator.

+

+    Parameters:

+        counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).

+        key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).

+        rounds (int): The number of rounds to perform.

+

+    Returns:

+        numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.

+    """

+

+    for _ in range(rounds - 1):

+        philox4_round(counter, key)

+

+        key[0] = key[0] + philox_w[0]

+        key[1] = key[1] + philox_w[1]

+

+    philox4_round(counter, key)

+    return counter

+

+

+def box_muller(x, y):

+    """Returns just the first out of two numbers generated by Box–Muller transform algorithm."""

+    u = x * two_pow32_inv + two_pow32_inv / 2

+    v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2

+

+    s = np.sqrt(-2.0 * np.log(u))

+

+    r1 = s * np.sin(v)

+    return r1.astype(np.float32)

+

+

+class Generator:

+    """RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""

+

+    def __init__(self, seed):

+        self.seed = seed

+        self.offset = 0

+

+    def randn(self, shape):

+        """Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""

+

+        n = 1

+        for x in shape:

+            n *= x

+

+        counter = np.zeros((4, n), dtype=np.uint32)

+        counter[0] = self.offset

+        counter[2] = np.arange(n, dtype=np.uint32)  # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]

+        self.offset += 1

+

+        key = np.empty(n, dtype=np.uint64)

+        key.fill(self.seed)

+        key = uint32(key)

+

+        g = philox4_32(counter, key)

+

+        return box_muller(g[0], g[1]).reshape(shape)  # discard g[2] and g[3]





diff --git a/modules/script_loading.py b/modules/script_loading.py
index 306a1f35f778ee06b12910b92a38c0cc2be75a31..0d55f1932ee20122bff1cef60551548f90d6a634 100644
--- a/modules/script_loading.py
+++ b/modules/script_loading.py
@@ -12,11 +12,12 @@ 
     return module

 

 

-def preload_extensions(extensions_dir, parser):

+def preload_extensions(extensions_dir, parser, extension_list=None):

     if not os.path.isdir(extensions_dir):

         return

 

-    for dirname in sorted(os.listdir(extensions_dir)):

+    extensions = extension_list if extension_list is not None else os.listdir(extensions_dir)

+    for dirname in sorted(extensions):

         preload_script = os.path.join(extensions_dir, dirname, "preload.py")

         if not os.path.isfile(preload_script):

             continue





diff --git a/modules/scripts.py b/modules/scripts.py
index 7d9dd59f2ad40d567254df4a90d9fc6a9d8d9dcd..f7d060aa59cbe5691fb7128064079db12388f1ef 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -16,6 +16,11 @@     def __init__(self, image):
         self.image = image

 

 

+class PostprocessBatchListArgs:

+    def __init__(self, images):

+        self.images = images

+

+

 class Script:

     name = None

     """script's internal name derived from title"""

@@ -119,7 +124,7 @@         pass
 

     def after_extra_networks_activate(self, p, *args, **kwargs):

         """

-        Calledafter extra networks activation, before conds calculation

+        Called after extra networks activation, before conds calculation

         allow modification of the network after extra networks activation been applied

         won't be call if p.disable_extra_networks

 

@@ -152,6 +157,25 @@ 
         **kwargs will have same items as process_batch, and also:

           - batch_number - index of current batch, from 0 to number of batches-1

           - images - torch tensor with all generated images, with values ranging from 0 to 1;

+        """

+

+        pass

+

+    def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):

+        """

+        Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.

+        This is useful when you want to update the entire batch instead of individual images.

+

+        You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.

+        If the number of images is different from the batch size when returning,

+        then the script has the responsibility to also update the following attributes in the processing object (p):

+          - p.prompts

+          - p.negative_prompts

+          - p.seeds

+          - p.subseeds

+

+        **kwargs will have same items as process_batch, and also:

+          - batch_number - index of current batch, from 0 to number of batches-1

         """

 

         pass

@@ -537,6 +561,15 @@             except Exception:
                 errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)

 

 import sys

+        Same as process_batch(), but called for every batch after it has been generated.

+        for script in self.alwayson_scripts:

+            try:

+                script_args = p.script_args[script.args_from:script.args_to]

+                script.postprocess_batch_list(p, pp, *script_args, **kwargs)

+            except Exception:

+                errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)

+

+import sys

     def __init__(self, image):

         for script in self.alwayson_scripts:

             try:

@@ -600,49 +633,3 @@     scripts_img2img.reload_sources(cache)
 

 

 reload_scripts = load_scripts  # compatibility alias

-

-

-def add_classes_to_gradio_component(comp):

-    """

-    this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others

-    """

-

-    comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]

-

-    if getattr(comp, 'multiselect', False):

-        comp.elem_classes.append('multiselect')

-

-

-

-def IOComponent_init(self, *args, **kwargs):

-    if scripts_current is not None:

-        scripts_current.before_component(self, **kwargs)

-

-    script_callbacks.before_component_callback(self, **kwargs)

-

-    res = original_IOComponent_init(self, *args, **kwargs)

-

-    add_classes_to_gradio_component(self)

-

-    script_callbacks.after_component_callback(self, **kwargs)

-

-    if scripts_current is not None:

-        scripts_current.after_component(self, **kwargs)

-

-    return res

-

-

-original_IOComponent_init = gr.components.IOComponent.__init__

-gr.components.IOComponent.__init__ = IOComponent_init

-

-

-def BlockContext_init(self, *args, **kwargs):

-    res = original_BlockContext_init(self, *args, **kwargs)

-

-    add_classes_to_gradio_component(self)

-

-    return res

-

-

-original_BlockContext_init = gr.blocks.BlockContext.__init__

-gr.blocks.BlockContext.__init__ = BlockContext_init





diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py
index 9fc89dc6a75985e68bb39f564dd55ce928ec6b29..695c573626947078aeffa550b40a5d7a38c6467c 100644
--- a/modules/sd_disable_initialization.py
+++ b/modules/sd_disable_initialization.py
@@ -3,8 +3,32 @@ import open_clip
 import torch

 import transformers.utils.hub

 

+from modules import shared

 

+

+class ReplaceHelper:

+    def __init__(self):

+        self.replaced = []

+

+    def replace(self, obj, field, func):

+        original = getattr(obj, field, None)

+        if original is None:

+            return None

+

+        self.replaced.append((obj, field, original))

+        setattr(obj, field, func)

+

+import open_clip

 class DisableInitialization:

+

+    def restore(self):

+        for obj, field, original in self.replaced:

+            setattr(obj, field, original)

+

+        self.replaced.clear()

+

+

+class DisableInitialization(ReplaceHelper):

     """

     When an object of this class enters a `with` block, it starts:

     - preventing torch's layer initialization functions from working

@@ -21,7 +45,7 @@     ```
     """

 

     def __init__(self, disable_clip=True):

-        self.replaced = []

+        super().__init__()

         self.disable_clip = disable_clip

 

     def replace(self, obj, field, func):

@@ -86,11 +110,87 @@             self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file)
             self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache)

 

     def __exit__(self, exc_type, exc_val, exc_tb):

+    """

 class DisableInitialization:

+

+

+    """

     """

+    """

+    Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device,

+    which results in those parameters having no values and taking no memory. model.to() will be broken and

+    will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict.

+

+    Usage:

+    ```

+    with sd_disable_initialization.InitializeOnMeta():

+        sd_model = instantiate_from_config(sd_config.model)

+    ```

+    """

+

+    def __enter__(self):

+        if shared.cmd_opts.disable_model_loading_ram_optimization:

+            return

+

+    When an object of this class enters a `with` block, it starts:

 class DisableInitialization:

+            x["device"] = "meta"

+    When an object of this class enters a `with` block, it starts:

     When an object of this class enters a `with` block, it starts:

 

+        linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs)))

+        conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs)))

+        mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs)))

+        self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None)

+

 class DisableInitialization:

+class DisableInitialization:

+        self.restore()

+

+

+class LoadStateDictOnMeta(ReplaceHelper):

+    """

+    Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device.

+    As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory.

+    Meant to be used together with InitializeOnMeta above.

+

+    Usage:

+    ```

+    with sd_disable_initialization.LoadStateDictOnMeta(state_dict):

+        model.load_state_dict(state_dict, strict=False)

+    ```

+    """

+

+    def __init__(self, state_dict, device):

+        super().__init__()

+        self.state_dict = state_dict

+        self.device = device

+

+    def __enter__(self):

+        if shared.cmd_opts.disable_model_loading_ram_optimization:

+            return

+

+        sd = self.state_dict

+        device = self.device

+

+        def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs):

+            params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta]

+

+            for name, param in params:

+                if param.is_meta:

+                    self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad)

+

+    - changes CLIP and OpenCLIP to not download model weights

     - preventing torch's layer initialization functions from working

 

+            for name, _ in params:

+                key = prefix + name

+                if key in sd:

+                    del sd[key]

+

+        linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs))

+        conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs))

+        mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs))

+

+    def __exit__(self, exc_type, exc_val, exc_tb):

+        self.restore()





diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index 6b5aae4b5f8b89dcb106d4b92f5db235915cafaf..9ad98199818bdea18f278168a5d66813d1bad66d 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -2,11 +2,10 @@ import torch
 from torch.nn.functional import silu

 from types import MethodType

 

-import modules.textual_inversion.textual_inversion

 from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet

 from modules.hypernetworks import hypernetwork

 from modules.shared import cmd_opts

-from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr

+from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting

 

 import ldm.modules.attention

 import ldm.modules.diffusionmodules.model

@@ -15,6 +14,11 @@ import ldm.models.diffusion.ddim
 import ldm.models.diffusion.plms

 import ldm.modules.encoders.modules

 

+import sgm.modules.attention

+import sgm.modules.diffusionmodules.model

+import sgm.modules.diffusionmodules.openaimodel

+import sgm.modules.encoders.modules

+

 attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward

 diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity

 diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward

@@ -26,9 +30,13 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
 

 # silence new console spam from SD2

 from torch.nn.functional import silu

+diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity

+ldm.modules.diffusionmodules.model.print = shared.ldm_print

+ldm.util.print = shared.ldm_print

+ldm.models.diffusion.ddpm.print = shared.ldm_print

 

 from torch.nn.functional import silu

-import modules.textual_inversion.textual_inversion

+ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention

 

 optimizers = []

 current_optimizer: sd_hijack_optimizations.SdOptimization = None

@@ -58,6 +66,9 @@ 
     ldm.modules.diffusionmodules.model.nonlinearity = silu

     ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th

 

+    sgm.modules.diffusionmodules.model.nonlinearity = silu

+    sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th

+

     if current_optimizer is not None:

         current_optimizer.undo()

         current_optimizer = None

@@ -90,6 +101,10 @@ def undo_optimizations():
     ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity

     ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward

     ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward

+

+    sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity

+    sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward

+    sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward

 

 

 def fix_checkpoint():

@@ -155,13 +170,13 @@     clip = None
     optimization_method = None

 

 import ldm.modules.diffusionmodules.model

-from types import MethodType

 

-import ldm.modules.diffusionmodules.model

+        import modules.textual_inversion.textual_inversion

 

         self.extra_generation_params = {}

         self.comments = []

 

+        self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()

         self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)

 

     def apply_optimizations(self, option=None):

@@ -172,6 +187,32 @@             errors.display(e, "applying cross attention optimization")
             undo_optimizations()

 

     def hijack(self, m):

+        conditioner = getattr(m, 'conditioner', None)

+        if conditioner:

+            text_cond_models = []

+

+            for i in range(len(conditioner.embedders)):

+                embedder = conditioner.embedders[i]

+                typename = type(embedder).__name__

+                if typename == 'FrozenOpenCLIPEmbedder':

+                    embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)

+                    conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self)

+                    text_cond_models.append(conditioner.embedders[i])

+                if typename == 'FrozenCLIPEmbedder':

+                    model_embeddings = embedder.transformer.text_model.embeddings

+                    model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)

+                    conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)

+                    text_cond_models.append(conditioner.embedders[i])

+                if typename == 'FrozenOpenCLIPEmbedder2':

+                    embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')

+                    conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)

+                    text_cond_models.append(conditioner.embedders[i])

+

+            if len(text_cond_models) == 1:

+                m.cond_stage_model = text_cond_models[0]

+            else:

+                m.cond_stage_model = conditioner

+

         if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:

             model_embeddings = m.cond_stage_model.roberta.embeddings

             model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)

@@ -209,9 +250,8 @@ 
         ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward

 

     def undo_hijack(self, m):

-import torch

 from torch.nn.functional import silu

-

+    if selection == "None":

             m.cond_stage_model = m.cond_stage_model.wrapped

 

         elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:

@@ -260,11 +300,12 @@         self.hijack(m)
 

 

 class EmbeddingsWithFixes(torch.nn.Module):

+ldm.modules.diffusionmodules.model.print = lambda *args: None

 import torch

-def weighted_forward(sd_model, x, c, w, *args, **kwargs):

         super().__init__()

         self.wrapped = wrapped

         self.embeddings = embeddings

+        self.textual_inversion_key = textual_inversion_key

 

     def forward(self, input_ids):

         batch_fixes = self.embeddings.fixes

@@ -278,8 +319,9 @@ 
         vecs = []

         for fixes, tensor in zip(batch_fixes, inputs_embeds):

             for offset, embedding in fixes:

-# new memory efficient cross attention blocks do not support hypernets and we already

 from torch.nn.functional import silu

+    elif matching_optimizer is None:

+                emb = devices.cond_cast_unet(vec)

                 emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])

                 tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])

 





diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py
index c1d780a3317634cff6440baf461126709d7bd033..8f29057a9cfcfafcf18e5c5c3eb095a8a1649198 100644
--- a/modules/sd_hijack_clip.py
+++ b/modules/sd_hijack_clip.py
@@ -42,6 +42,10 @@ 
         self.hijack: sd_hijack.StableDiffusionModelHijack = hijack

         self.chunk_length = 75

 

+        self.is_trainable = getattr(wrapped, 'is_trainable', False)

+        self.input_key = getattr(wrapped, 'input_key', 'txt')

+        self.legacy_ucg_val = None

+

     def empty_chunk(self):

         """creates an empty PromptChunk and returns it"""

 

@@ -157,7 +161,7 @@                     chunk.multipliers.append(weight)
                     position += 1

                     continue

 

-                emb_len = int(embedding.vec.shape[0])

+                emb_len = int(embedding.vectors)

                 if len(chunk.tokens) + emb_len > self.chunk_length:

                     next_chunk()

 

@@ -199,10 +203,10 @@     def forward(self, texts):
         """

         Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts.

         Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will

-import math

+    """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to

 

-from modules.shared import opts

         An example shape returned by this function can be: (2, 77, 768).

+        For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values.

         Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet

         is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream"

         """

@@ -242,9 +246,14 @@                 name = name.replace(":", "").replace(",", "")
                 hashes.append(f"{name}: {shorthash}")

 

             if hashes:

+                if self.hijack.extra_generation_params.get("TI hashes"):

+                    hashes.append(self.hijack.extra_generation_params.get("TI hashes"))

                 self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes)

 

-        return torch.hstack(zs)

+        if getattr(self.wrapped, 'return_pooled', False):

+            return torch.hstack(zs), zs[0].pooled

+        else:

+            return torch.hstack(zs)

 

     def process_tokens(self, remade_batch_tokens, batch_multipliers):

         """

@@ -264,12 +273,17 @@                 tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad
 

         z = self.encode_with_transformers(tokens)

 

+        pooled = getattr(z, 'pooled', None)

+

         # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise

         batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)

         original_mean = z.mean()

         z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)

         new_mean = z.mean()

         z = z * (original_mean / new_mean)

+

+        if pooled is not None:

+            z.pooled = pooled

 

         return z

 

@@ -326,3 +340,18 @@         ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
         embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0)

 

         return embedded

+

+

+class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords):

+    def __init__(self, wrapped, hijack):

+        super().__init__(wrapped, hijack)

+

+    def encode_with_transformers(self, tokens):

+        outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden")

+

+        if self.wrapped.layer == "last":

+            z = outputs.last_hidden_state

+        else:

+            z = outputs.hidden_states[self.wrapped.layer_idx]

+

+        return z





diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index c1977b194190858fbb9726eb02256cdeafc654a0..2d44b856668a1a6a361a672a68beaf513ad717b8 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -92,6 +92,4 @@     return x_prev, pred_x0, e_t
 
 
 def do_inpainting_hijack():
-    # p_sample_plms is needed because PLMS can't work with dicts as conditionings
-
     ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms




diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py
index f733e8529fb6cd68d97b2f255bc705d0cd949fbc..25c5e9831a3527b9555971b5d9add68519a2208e 100644
--- a/modules/sd_hijack_open_clip.py
+++ b/modules/sd_hijack_open_clip.py
@@ -35,3 +35,37 @@         ids = torch.asarray([ids], device=devices.device, dtype=torch.int)
         embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0)

 

         return embedded

+

+

+class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase):

+    def __init__(self, wrapped, hijack):

+        super().__init__(wrapped, hijack)

+

+        self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0]

+        self.id_start = tokenizer.encoder["<start_of_text>"]

+        self.id_end = tokenizer.encoder["<end_of_text>"]

+        self.id_pad = 0

+

+    def tokenize(self, texts):

+        assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip'

+

+        tokenized = [tokenizer.encode(text) for text in texts]

+

+        return tokenized

+

+    def encode_with_transformers(self, tokens):

+        d = self.wrapped.encode_with_transformer(tokens)

+        z = d[self.wrapped.layer]

+

+        pooled = d.get("pooled")

+        if pooled is not None:

+            z.pooled = pooled

+

+        return z

+

+    def encode_embedding_init_text(self, init_text, nvpt):

+        ids = tokenizer.encode(init_text)

+        ids = torch.asarray([ids], device=devices.device, dtype=torch.int)

+        embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0)

+

+        return embedded





diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 53e27adea2caaf39f3d58a9a015bb36eb0395bba..0e810eec8a9a01f28ca96007595eb9d00e08eff9 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -14,7 +14,11 @@ 
 import ldm.modules.attention

 import ldm.modules.diffusionmodules.model

 

+import sgm.modules.attention

+import sgm.modules.diffusionmodules.model

+

 diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward

+sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward

 

 

 class SdOptimization:

@@ -39,6 +43,9 @@     def undo(self):
         ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward

         ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward

 

+        sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward

+        sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward

+

 

 class SdOptimizationXformers(SdOptimization):

     name = "xformers"

@@ -51,6 +58,8 @@ 
     def apply(self):

         ldm.modules.attention.CrossAttention.forward = xformers_attention_forward

         ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward

+        sgm.modules.attention.CrossAttention.forward = xformers_attention_forward

+        sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward

 

 

 class SdOptimizationSdpNoMem(SdOptimization):

@@ -65,6 +74,8 @@ 
     def apply(self):

         ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward

         ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward

+        sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward

+        sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward

 

 

 class SdOptimizationSdp(SdOptimizationSdpNoMem):

@@ -76,6 +87,8 @@ 
     def apply(self):

         ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward

         ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward

+        sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward

+        sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward

 

 

 class SdOptimizationSubQuad(SdOptimization):

@@ -86,6 +99,8 @@ 
     def apply(self):

         ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward

         ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward

+        sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward

+        sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward

 

 

 class SdOptimizationV1(SdOptimization):

@@ -94,9 +109,10 @@     label = "original v1"
     cmd_opt = "opt_split_attention_v1"

     priority = 10

 

-

     def apply(self):

 from torch import einsum

+import math

+    label = "scaled dot product without memory efficient attention"

 import math

 

 

@@ -110,6 +126,7 @@         return 1000 if not torch.cuda.is_available() else 10
 

     def apply(self):

         ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI

+        sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI

 

 

 class SdOptimizationDoggettx(SdOptimization):

@@ -120,6 +137,9 @@ 
     def apply(self):

         ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward

 from ldm.util import default

+import torch

+        sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward

+    label = "scaled dot product without memory efficient attention"

 import torch

 

 

@@ -157,7 +177,7 @@         return psutil.virtual_memory().available
 

 

 # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion

-def split_cross_attention_forward_v1(self, x, context=None, mask=None):

+def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs):

     h = self.heads

 

     q_in = self.to_q(x)

@@ -198,8 +218,8 @@     return self.to_out(r2)
 

 

 # taken from https://github.com/Doggettx/stable-diffusion and modified

-diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward

 import psutil

+        mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())

     h = self.heads

 

     q_in = self.to_q(x)

@@ -241,9 +261,9 @@             max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
             raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '

                                f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free')

 

-        slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]

+        slice_size = q.shape[1] // steps

         for i in range(0, q.shape[1], slice_size):

-            end = i + slice_size

+            end = min(i + slice_size, q.shape[1])

             s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k)

 

             s2 = s1.softmax(dim=-1, dtype=q.dtype)

@@ -265,17 +285,20 @@ 
 # -- Taken from https://github.com/invoke-ai/InvokeAI and modified --

 mem_total_gb = psutil.virtual_memory().total // (1 << 30)

 

+

 def einsum_op_compvis(q, k, v):

     s = einsum('b i d, b j d -> b i j', q, k)

     s = s.softmax(dim=-1, dtype=s.dtype)

     return einsum('b i j, b j d -> b i d', s, v)

 

+

 def einsum_op_slice_0(q, k, v, slice_size):

     r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)

     for i in range(0, q.shape[0], slice_size):

         end = i + slice_size

         r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end])

     return r

+

 

 def einsum_op_slice_1(q, k, v, slice_size):

     r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)

@@ -283,6 +306,7 @@     for i in range(0, q.shape[1], slice_size):
         end = i + slice_size

         r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v)

     return r

+

 

 def einsum_op_mps_v1(q, k, v):

     if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096

@@ -293,11 +317,13 @@         if slice_size % 4096 == 0:
             slice_size -= 1

         return einsum_op_slice_1(q, k, v, slice_size)

 

+

 def einsum_op_mps_v2(q, k, v):

     if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:

         return einsum_op_compvis(q, k, v)

     else:

     priority: int = 0

+

 

 

 def einsum_op_tensor_mem(q, k, v, max_tensor_mb):

@@ -309,6 +335,7 @@     if div <= q.shape[0]:
         return einsum_op_slice_0(q, k, v, q.shape[0] // div)

     return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))

 

+

 def einsum_op_cuda(q, k, v):

     stats = torch.cuda.memory_stats(q.device)

     mem_active = stats['active_bytes.all.current']

@@ -318,6 +345,7 @@     mem_free_torch = mem_reserved - mem_active
     mem_free_total = mem_free_cuda + mem_free_torch

     # Divide factor of safety as there's copying and fragmentation

     return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))

+

 

 def einsum_op(q, k, v):

     if q.device.type == 'cuda':

@@ -332,7 +360,8 @@     # Smaller slices are faster due to L2/L3/SLC caches.
     # Tested on i7 with 8MB L3 cache.

     return einsum_op_tensor_mem(q, k, v, 32)

 

-def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):

+

+def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs):

     h = self.heads

 

     q = self.to_q(x)

@@ -360,8 +389,8 @@ 
 

 # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1

 # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface

-        return f"{self.name} - {self.label}"

 

+from __future__ import annotations

     assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor."

 

     h = self.heads

@@ -396,6 +425,7 @@     x = out_proj(x)
     x = dropout(x)

 

     return x

+

 

 def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True):

     bytes_per_token = torch.finfo(q.dtype).bits//8

@@ -447,8 +477,8 @@ 
     return None

 

 

+    cmd_opt = "opt_sdp_no_mem_attention"

 import math

-def list_optimizers(res):

     h = self.heads

     q_in = self.to_q(x)

     context = default(context, x)

@@ -470,11 +500,12 @@     out = out.to(dtype)
 

     out = rearrange(out, 'b n h d -> b n (h d)', h=h)

     return self.to_out(out)

+

 

 # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py

 # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface

-        ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward

 

+import psutil

     batch_size, sequence_length, inner_dim = x.shape

 

     if mask is not None:

@@ -514,11 +545,13 @@     # dropout
     hidden_states = self.to_out[1](hidden_states)

     return hidden_states

 

-import psutil

+

+

 

     with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):

         return scaled_dot_product_attention_forward(self, x, context, mask)

 

+

 def cross_attention_attnblock_forward(self, x):

         h_ = x

         h_ = self.norm(h_)

@@ -577,6 +610,7 @@         h3 += x
 

         return h3

 

+

 def xformers_attnblock_forward(self, x):

     try:

         h_ = x

@@ -599,6 +633,7 @@         out = self.proj_out(out)
         return x + out

     except NotImplementedError:

         return cross_attention_attnblock_forward(self, x)

+

 

 def sdp_attnblock_forward(self, x):

     h_ = x

@@ -620,9 +655,11 @@     out = rearrange(out, 'b (h w) c -> b c h w', h=h)
     out = self.proj_out(out)

     return x + out

 

+

 def sdp_no_mem_attnblock_forward(self, x):

     with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False):

         return sdp_attnblock_forward(self, x)

+

 

 def sub_quad_attnblock_forward(self, x):

     h_ = x





diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py
index ca1daf45f3c625fc6947b9547f4130ca7ed80ab2..2101f1a04152bd53214934d61b06e0d22af71ca7 100644
--- a/modules/sd_hijack_unet.py
+++ b/modules/sd_hijack_unet.py
@@ -39,8 +39,11 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):
 

     if isinstance(cond, dict):

         for y in cond.keys():

-from packaging import version

+from modules.sd_hijack_utils import CondFunc

     This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;

+                cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]

+            else:

+                cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y]

 

     with devices.autocast():

         return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()

@@ -78,3 +81,6 @@ first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs)
 CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond)

 CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond)

 CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond)

+

+CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast)

+CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast)





diff --git a/modules/sd_models.py b/modules/sd_models.py
index 060e0007c10aacd6e7517d476d0c3fad43c7e593..f60516046324fecad4df2f5baf42c2bea5fa3e42 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -14,9 +14,9 @@ import ldm.modules.midas as midas
 

 from ldm.util import instantiate_from_config

 

-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet

+import gc

 import collections

-

+import os.path

 from modules.timer import Timer

 import tomesd

 

@@ -34,6 +34,8 @@     def __init__(self, filename):
         self.filename = filename

         abspath = os.path.abspath(filename)

 

+        self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"

+

         if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):

             name = abspath.replace(shared.cmd_opts.ckpt_dir, '')

         elif abspath.startswith(model_path):

@@ -44,35 +46,43 @@ 
         if name.startswith("\\") or name.startswith("/"):

             name = name[1:]

 

-        self.name = name

-        self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]

-        self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]

 import gc

-

+from ldm.util import instantiate_from_config

         self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")

+import threading

 import gc

-import os.path

+from modules.sd_hijack_inpainting import do_inpainting_hijack

 

 import gc

-import sys

+from modules.timer import Timer

 

 import gc

+import threading

 import gc

+import tomesd

-

+            try:

 import gc

-import threading

+model_dir = "Stable-diffusion"

-

+            except Exception as e:

 import gc

+model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))

 

-import gc

+import sys

 import torch

-import gc

+import sys

 import re

-import gc

+import sys

 import safetensors.torch

-import threading

+import gc

-import threading

+

+import gc

 import collections

+        self.shorthash = self.sha256[0:10] if self.sha256 else None

+

+        self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'

+        self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'

+

+        self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])

 

     def register(self):

         checkpoints_list[self.title] = self

@@ -89,9 +99,10 @@ 
         if self.shorthash not in self.ids:

             self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']

 

-

+        self.shorthash = self.sha256[0:10] if self.sha256 else None

 import sys

         self.title = f'{self.name} [{self.shorthash}]'

+        self.short_title = f'{self.name_for_extra} [{self.shorthash}]'

         self.register()

 

         return self.shorthash

@@ -112,14 +123,9 @@ 
     enable_midas_autodownload()

 

 

-def checkpoint_tiles():

-    def convert(name):

-        return int(name) if name.isdigit() else name.lower()

-

-    def alphanumeric_key(key):

+def checkpoint_tiles(use_short=False):

-        return [convert(c) for c in re.split('([0-9]+)', key)]

+        self.shorthash = self.sha256[0:10] if self.sha256 else None

 

-    return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)

 

 

 def list_models():

@@ -142,18 +148,26 @@         shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
     elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:

         print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)

 

-import safetensors.torch

+        self.shorthash = self.sha256[0:10] if self.sha256 else None

 import torch

         checkpoint_info = CheckpointInfo(filename)

         checkpoint_info.register()

 

 

+re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")

+

+

 def get_closet_checkpoint_match(search_string):

     checkpoint_info = checkpoint_aliases.get(search_string, None)

     if checkpoint_info is not None:

         return checkpoint_info

 

     found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))

+    if found:

+        return found[0]

+

+    search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)

+    found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))

     if found:

         return found[0]

 

@@ -301,16 +315,24 @@ 
     if state_dict is None:

         state_dict = get_checkpoint_state_dict(checkpoint_info, timer)

 

+    model.is_sdxl = hasattr(model, 'conditioner')

+        self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'

 import os.path

-import os.path

+    model.is_sd1 = not model.is_sdxl and not model.is_sd2

+

+    if model.is_sdxl:

+        sd_models_xl.extend_sdxl(model)

+

 model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))

-import sys

+import os.path

     timer.record("apply weights to model")

 

     if shared.opts.sd_checkpoint_cache > 0:

         # cache newly loaded model

+        checkpoints_loaded[checkpoint_info] = state_dict

+

 model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))

-import torch

+import sys

 

     if shared.cmd_opts.opt_channelslast:

         model.to(memory_format=torch.channels_last)

@@ -334,7 +356,7 @@             model.depth_model = depth_model
 

         timer.record("apply half()")

 

-    devices.dtype_unet = model.model.diffusion_model.dtype

+    devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype

     devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16

 

     model.first_stage_model.to(devices.dtype_vae)

@@ -349,9 +371,10 @@     model.sd_model_checkpoint = checkpoint_info.filename
     model.sd_checkpoint_info = checkpoint_info

     shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256

 

-import os.path

+import gc

 import sys

-

+import re

+        model.logvar = model.logvar.to(devices.device)  # fix for training

 

     sd_vae.delete_base_vae()

     sd_vae.clear_loaded_vae()

@@ -408,11 +431,12 @@ 
     if not hasattr(sd_config.model.params, "use_ema"):

         sd_config.model.params.use_ema = False

 

-    if shared.cmd_opts.no_half:

+    if hasattr(sd_config.model.params, 'unet_config'):

-        sd_config.model.params.unet_config.params.use_fp16 = False

+        if shared.cmd_opts.no_half:

+        self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])

 import os.path

-        return int(name) if name.isdigit() else name.lower()

+        elif shared.cmd_opts.upcast_sampling:

-        sd_config.model.params.unet_config.params.use_fp16 = True

+            sd_config.model.params.unet_config.params.use_fp16 = True

 

     if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:

         sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"

@@ -425,11 +449,15 @@ 
 

 sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'

 sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'

+sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'

+        self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])

 

 

+

 class SdModelData:

     def __init__(self):

         self.sd_model = None

+        self.loaded_sd_models = []

         self.was_loaded_at_least_once = False

         self.lock = threading.Lock()

 

@@ -444,6 +472,7 @@                     return self.sd_model
 

                 try:

                     load_model()

+

                 except Exception as e:

                     errors.display(e, "loading stable diffusion model", full_traceback=True)

                     print("", file=sys.stderr)

@@ -455,34 +484,79 @@ 
     def set_sd_model(self, v):

         self.sd_model = v

 

+        try:

+            self.loaded_sd_models.remove(v)

+        except ValueError:

+            pass

+

+        if v is not None:

+            self.loaded_sd_models.insert(0, v)

+

 

 model_data = SdModelData()

 

 

-        elif abspath.startswith(model_path):

+def get_empty_cond(sd_model):

+    from modules import extra_networks, processing

+

+    p = processing.StableDiffusionProcessingTxt2Img()

+    extra_networks.activate(p, {})

+

+    if hasattr(sd_model, 'conditioner'):

+        d = sd_model.get_learned_conditioning([""])

+        _, ext = os.path.splitext(self.filename)

-    from modules import lowvram, sd_hijack

+    else:

-import sys

+        _, ext = os.path.splitext(self.filename)

 import collections

+

+

+        _, ext = os.path.splitext(self.filename)

 import os.path

+    from modules import lowvram

 

 import sys

-import collections

+        self.hash = model_hash(filename)

+        lowvram.send_everything_to_cpu()

+    else:

+        m.to(devices.cpu)

+

+    devices.torch_gc()

+

+

+def send_model_to_device(m):

+        _, ext = os.path.splitext(self.filename)

 import sys

+

 import sys

+        self.hash = model_hash(filename)

+        lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)

+    else:

+        m.to(shared.device)

+

+

+def send_model_to_trash(m):

+        if ext.lower() == ".safetensors":

 import collections

 import gc

+        return self.shorthash

+

+

         elif abspath.startswith(model_path):

-import threading

+    from modules import sd_hijack

         elif abspath.startswith(model_path):

+import os.path

 

         elif abspath.startswith(model_path):

-import torch

+import safetensors.torch

 

         elif abspath.startswith(model_path):

-import re

+import sys

-

+        send_model_to_trash(model_data.sd_model)

+        model_data.sd_model = None

         elif abspath.startswith(model_path):

-import safetensors.torch

+import torch

+

+    timer.record("unload existing model")

 

     if already_loaded_state_dict is not None:

         state_dict = already_loaded_state_dict

@@ -490,7 +563,7 @@     else:
         state_dict = get_checkpoint_state_dict(checkpoint_info, timer)

 

     checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)

-    clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict

+    clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)

 

     timer.record("find config")

 

@@ -503,32 +576,32 @@     print(f"Creating model from config: {checkpoint_config}")
 

     sd_model = None

     try:

-        with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):

+        with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip):

-            sd_model = instantiate_from_config(sd_config.model)

+            with sd_disable_initialization.InitializeOnMeta():

-    except Exception:

+                sd_model = instantiate_from_config(sd_config.model)

-        pass

+

+    except Exception as e:

+        errors.display(e, "creating model quickly", full_traceback=True)

 

     if sd_model is None:

         print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)

-        else:

 

+        with sd_disable_initialization.InitializeOnMeta():

+            sd_model = instantiate_from_config(sd_config.model)

 

     sd_model.used_config = checkpoint_config

 

     timer.record("create model")

 

-    load_model_weights(sd_model, checkpoint_info, state_dict, timer)

-

-import sys

 import gc

+def list_models():

 import sys

+    def convert(name):

 import gc

-import collections

 import re

-import re

 import sys

-        self.shorthash = self.sha256[0:10] if self.sha256 else None

 

+    send_model_to_device(sd_model)

     timer.record("move model to device")

 

     sd_hijack.model_hijack.hijack(sd_model)

@@ -535,9 +609,8 @@ 
     timer.record("hijack")

 

     sd_model.eval()

-import sys

 import gc

-import torch

+    cmd_ckpt = shared.cmd_opts.ckpt

     model_data.was_loaded_at_least_once = True

 

     sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model

@@ -549,8 +622,8 @@ 
     timer.record("scripts callbacks")

 

     with devices.autocast(), torch.no_grad():

-        if name.startswith("\\") or name.startswith("/"):

 import gc

+    if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):

 

     timer.record("calculate empty prompt")

 

@@ -559,55 +632,107 @@ 
     return sd_model

 

 

+def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):

+checkpoints_loaded = collections.OrderedDict()

 import sys

-import threading

+    Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.

+import gc

 import re

+import safetensors.torch

+    If not, returns the model that can be used to load weights from checkpoint_info's file.

+    If no such model exists, returns None.

+    Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).

+checkpoints_loaded = collections.OrderedDict()

 import sys

+

+    already_loaded = None

+    for i in reversed(range(len(model_data.loaded_sd_models))):

+                self.metadata = read_metadata_from_safetensors(filename)

 import threading

+import gc

 import safetensors.torch

+

+            already_loaded = loaded_model

+            continue

+

+        if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:

+            print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")

+            model_data.loaded_sd_models.pop()

+            send_model_to_trash(loaded_model)

+            except Exception as e:

 import sys

 

+import threading

+import gc

+            send_model_to_cpu(sd_model)

+            timer.record("send model to cpu")

 

-import sys

+    if already_loaded is not None:

+        send_model_to_device(already_loaded)

+        timer.record("send model to device")

 

+        model_data.set_sd_model(already_loaded)

+                errors.display(e, f"reading checkpoint metadata: {filename}")

 import collections

 import sys

+import torch

 

+                errors.display(e, f"reading checkpoint metadata: {filename}")

 import os.path

+        print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")

 

 import sys

+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet

+        load_model(checkpoint_info)

+        self.name = name

 

-import sys

+    elif len(model_data.loaded_sd_models) > 0:

+        sd_model = model_data.loaded_sd_models.pop()

+        model_data.sd_model = sd_model

-import sys

 

-import gc

+        print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")

+        return sd_model

     else:

-import sys

+        return None

+

 

+import sys

 import threading

+import re

+    from modules import devices, sd_hijack

             name = name[1:]

 

-import threading

+        elif abspath.startswith(model_path):

 import safetensors.torch

 

             name = name[1:]

-import torch

+import collections

+        sd_model = model_data.sd_model

 

             name = name[1:]

+import sys

+        current_checkpoint_info = None

+import re

 import re

             name = name[1:]

-import safetensors.torch

+import threading

 import sys

+        return self.shorthash

+            return sd_model

+

+    sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)

+    def register(self):

 import sys

-            sd_model.to(devices.cpu)

+        return sd_model

 

+    if sd_model is not None:

 import sys

+

 import torch

-import collections

-

+        send_model_to_cpu(sd_model)

 import sys

+import torch

 import collections

-import safetensors.torch

 

     state_dict = get_checkpoint_state_dict(checkpoint_info, timer)

 

@@ -616,9 +739,10 @@ 
     timer.record("find config")

 

     if sd_model is None or checkpoint_config != sd_model.used_config:

-import sys

+        if sd_model is not None:

+    def register(self):

 import torch

-import gc

+

         load_model(checkpoint_info, already_loaded_state_dict=state_dict)

         return model_data.sd_model

 

@@ -640,6 +764,8 @@             sd_model.to(devices.device)
             timer.record("move model to device")

 

     print(f"Weights loaded in {timer.summary()}.")

+

+    model_data.set_sd_model(sd_model)

 

     return sd_model

 





diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py
index 9bfe1237d1ded6d8ade2e565fe33622b444b8b02..8266fa39797b2044ddd0abd5e921bca5c6f87bea 100644
--- a/modules/sd_models_config.py
+++ b/modules/sd_models_config.py
@@ -6,12 +6,15 @@ from modules import shared, paths, sd_disable_initialization
 

 sd_configs_path = shared.sd_configs_path

 sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion")

+sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference")

 

 

 config_default = shared.sd_default_config

 config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml")

 config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml")

 config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml")

+config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml")

+config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml")

 config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml")

 config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml")

 config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml")

@@ -68,7 +71,11 @@     sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None)
     diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None)

     sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None)

 

-    if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:

+    if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None:

+        return config_sdxl

+    if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None:

+        return config_sdxl_refiner

+    elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None:

         return config_depth_model

     elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768:

         return config_unclip





diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py
new file mode 100644
index 0000000000000000000000000000000000000000..0112332161fa21807e705cc1763681eefd7456e5
--- /dev/null
+++ b/modules/sd_models_xl.py
@@ -0,0 +1,108 @@
+from __future__ import annotations

+

+import torch

+

+import sgm.models.diffusion

+import sgm.modules.diffusionmodules.denoiser_scaling

+import sgm.modules.diffusionmodules.discretizer

+from modules import devices, shared, prompt_parser

+

+

+def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]):

+    for embedder in self.conditioner.embedders:

+        embedder.ucg_rate = 0.0

+

+    width = getattr(batch, 'width', 1024)

+    height = getattr(batch, 'height', 1024)

+    is_negative_prompt = getattr(batch, 'is_negative_prompt', False)

+    aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score

+

+    devices_args = dict(device=devices.device, dtype=devices.dtype)

+

+    sdxl_conds = {

+        "txt": batch,

+        "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),

+        "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1),

+        "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1),

+        "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1),

+    }

+

+    force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch)

+    c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else [])

+

+    return c

+

+

+def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond):

+    return self.model(x, t, cond)

+

+

+def get_first_stage_encoding(self, x):  # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility

+    return x

+

+

+sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning

+sgm.models.diffusion.DiffusionEngine.apply_model = apply_model

+sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding

+

+

+def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt):

+    res = []

+

+    for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]:

+        encoded = embedder.encode_embedding_init_text(init_text, nvpt)

+        res.append(encoded)

+

+    return torch.cat(res, dim=1)

+

+

+def tokenize(self: sgm.modules.GeneralConditioner, texts):

+    for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:

+        return embedder.tokenize(texts)

+

+    raise AssertionError('no tokenizer available')

+

+

+

+def process_texts(self, texts):

+    for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:

+        return embedder.process_texts(texts)

+

+

+def get_target_prompt_token_count(self, token_count):

+    for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]:

+        return embedder.get_target_prompt_token_count(token_count)

+

+

+# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist

+sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text

+sgm.modules.GeneralConditioner.tokenize = tokenize

+sgm.modules.GeneralConditioner.process_texts = process_texts

+sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count

+

+

+def extend_sdxl(model):

+    """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase."""

+

+    dtype = next(model.model.diffusion_model.parameters()).dtype

+    model.model.diffusion_model.dtype = dtype

+    model.model.conditioning_key = 'crossattn'

+    model.cond_stage_key = 'txt'

+    # model.cond_stage_model will be set in sd_hijack

+

+    model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps"

+

+    discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()

+    model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)

+

+    model.conditioner.wrapped = torch.nn.Module()

+

+

+sgm.modules.attention.print = shared.ldm_print

+sgm.modules.diffusionmodules.model.print = shared.ldm_print

+sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print

+sgm.modules.encoders.modules.print = shared.ldm_print

+

+# this gets the code to load the vanilla attention that we override

+sgm.modules.attention.SDP_IS_AVAILABLE = True

+sgm.modules.attention.XFORMERS_IS_AVAILABLE = False





diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index f22aad8f2165c5eb350d147b673b1accacbcd818..bea2684c4db6171075a38427bb9c01b34d9e688a 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -28,6 +28,9 @@     config = find_sampler_config(name)
 

     assert config is not None, f'bad sampler name: {name}'

 

+    if model.is_sdxl and config.options.get("no_sdxl", False):

+        raise Exception(f"Sampler {config.name} is not supported for SDXL")

+

     sampler = config.constructor(model)

     sampler.config = config

 





diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py
index 763829f1ca34be1878cdcafd5d7e594488f6f2f8..b3d344e777b61da442c0fd8d8d2b7e7b7bb0ffc0 100644
--- a/modules/sd_samplers_common.py
+++ b/modules/sd_samplers_common.py
@@ -2,10 +2,9 @@ from collections import namedtuple
 import numpy as np

 import torch

 from PIL import Image

-from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd

 

+import torch

 from modules.shared import opts, state

-import modules.shared as shared

 

 SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])

 

@@ -37,7 +36,7 @@     elif approximation == 3:
         x_sample = sample * 1.5

         x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()

     else:

-        x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5

+        x_sample = decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5

 

     x_sample = torch.clamp(x_sample, min=0.0, max=1.0)

     x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)

@@ -46,6 +45,12 @@ 
     return Image.fromarray(x_sample)

 

 

+def decode_first_stage(model, x):

+    x = model.decode_first_stage(x.to(devices.dtype_vae))

+

+    return x

+

+

 def sample_to_image(samples, index=0, approximation=None):

     return single_sample_to_image(samples[index], approximation)

 

@@ -85,13 +90,15 @@ class InterruptedException(BaseException):
     pass

 

 

-if opts.randn_source == "CPU":

+def replace_torchsde_browinan():

     import torchsde._brownian.brownian_interval

 

     def torchsde_randn(size, dtype, device, seed):

 

+SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])

 

-from collections import namedtuple

 

+import numpy as np

 

-import numpy as np

+

+replace_torchsde_browinan()





diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
index bdae8b404b50505530b35400825f95752b143691..4a8396f97ec3fd75b2eabdc9fbf97ac34af27e6a 100644
--- a/modules/sd_samplers_compvis.py
+++ b/modules/sd_samplers_compvis.py
@@ -12,11 +12,11 @@ 
 

 samplers_data_compvis = [

 import math

-import math

+        if isinstance(cond, dict):

 import math

-import ldm.models.diffusion.ddim

+            if self.conditioning_key == "crossattn-adm":

 import math

-import ldm.models.diffusion.plms

+                image_conditioning = cond["c_adm"]

 ]

 

 





diff --git a/modules/sd_samplers_extra.py b/modules/sd_samplers_extra.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b981ca80c355cbb6a92915d422cfa19e51b21c4
--- /dev/null
+++ b/modules/sd_samplers_extra.py
@@ -0,0 +1,74 @@
+import torch

+import tqdm

+import k_diffusion.sampling

+

+

[email protected]_grad()

+def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None):

+    """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)

+    Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}

+    If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list

+    """

+    extra_args = {} if extra_args is None else extra_args

+    s_in = x.new_ones([x.shape[0]])

+    step_id = 0

+    from k_diffusion.sampling import to_d, get_sigmas_karras

+

+    def heun_step(x, old_sigma, new_sigma, second_order=True):

+        nonlocal step_id

+        denoised = model(x, old_sigma * s_in, **extra_args)

+        d = to_d(x, old_sigma, denoised)

+        if callback is not None:

+            callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised})

+        dt = new_sigma - old_sigma

+        if new_sigma == 0 or not second_order:

+            # Euler method

+            x = x + d * dt

+        else:

+            # Heun's method

+            x_2 = x + d * dt

+            denoised_2 = model(x_2, new_sigma * s_in, **extra_args)

+            d_2 = to_d(x_2, new_sigma, denoised_2)

+            d_prime = (d + d_2) / 2

+            x = x + d_prime * dt

+        step_id += 1

+        return x

+

+    steps = sigmas.shape[0] - 1

+    if restart_list is None:

+        if steps >= 20:

+            restart_steps = 9

+            restart_times = 1

+            if steps >= 36:

+                restart_steps = steps // 4

+                restart_times = 2

+            sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device)

+            restart_list = {0.1: [restart_steps + 1, restart_times, 2]}

+        else:

+            restart_list = {}

+

+    restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()}

+

+    step_list = []

+    for i in range(len(sigmas) - 1):

+        step_list.append((sigmas[i], sigmas[i + 1]))

+        if i + 1 in restart_list:

+            restart_steps, restart_times, restart_max = restart_list[i + 1]

+            min_idx = i + 1

+            max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))

+            if max_idx < min_idx:

+                sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1]

+                while restart_times > 0:

+                    restart_times -= 1

+                    step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])])

+

+    last_sigma = None

+    for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable):

+        if last_sigma is None:

+            last_sigma = old_sigma

+        elif last_sigma < old_sigma:

+            x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5

+        x = heun_step(x, old_sigma, new_sigma)

+        last_sigma = new_sigma

+

+    return x





diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index 71581b763593c2cf1da347f2e6ea3898a891d6b5..8bb639f57a9f1b7c90a9d15e02845d52991758cc 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -2,7 +2,7 @@ from collections import deque
 import torch

 import inspect

 import k_diffusion.sampling

-from modules import prompt_parser, devices, sd_samplers_common

+from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra

 

 from modules.shared import opts, state

 import modules.shared as shared

@@ -31,12 +31,16 @@     ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}),
     ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}),

     ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),

 ]

+from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback

+    ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}),

+]

+

 

 samplers_data_k_diffusion = [

     sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)

     for label, funcname, aliases, options in samplers_k_diffusion

 import inspect

-

+    ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}),

 ]

 

 sampler_extra_params = {

@@ -52,6 +56,28 @@     'karras': k_diffusion.sampling.get_sigmas_karras,
     'exponential': k_diffusion.sampling.get_sigmas_exponential,

     'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential

 }

+

+

+def catenate_conds(conds):

+    if not isinstance(conds[0], dict):

+        return torch.cat(conds)

+

+    return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}

+

+

+def subscript_cond(cond, a, b):

+    if not isinstance(cond, dict):

+        return cond[a:b]

+

+    return {key: vec[a:b] for key, vec in cond.items()}

+

+

+def pad_cond(tensor, repeats, empty):

+    if not isinstance(tensor, dict):

+        return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1)

+

+    tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)

+    return tensor

 

 

 class CFGDenoiser(torch.nn.Module):

@@ -106,11 +132,14 @@         repeats = [len(conds_list[i]) for i in range(batch_size)]
 

         if shared.sd_model.model.conditioning_key == "crossattn-adm":

             image_uncond = torch.zeros_like(image_cond)

-            make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm}

+            make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm}

         else:

             image_uncond = image_cond

+            if isinstance(uncond, dict):

+                make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]}

+            else:

+    sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options)

 import modules.shared as shared

-from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback

 

         if not is_edit_model:

             x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])

@@ -142,34 +171,33 @@             empty = shared.sd_model.cond_stage_model_empty_prompt
             num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1]

 

             if num_repeats < 0:

-from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback

+import inspect

 import k_diffusion.sampling

+from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback

                 self.padded_cond_uncond = True

             elif num_repeats > 0:

-                uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1)

+                uncond = pad_cond(uncond, num_repeats, empty)

                 self.padded_cond_uncond = True

 

         if tensor.shape[1] == uncond.shape[1] or skip_uncond:

             if is_edit_model:

+    for label, funcname, aliases, options in samplers_k_diffusion

 from collections import deque

-    ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}),

             elif skip_uncond:

                 cond_in = tensor

             else:

-from collections import deque

+    for label, funcname, aliases, options in samplers_k_diffusion

 import torch

-import k_diffusion.sampling

 

             if shared.batch_cond_uncond:

-                x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in))

+                x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in))

             else:

                 x_out = torch.zeros_like(x_in)

                 for batch_offset in range(0, x_out.shape[0], batch_size):

                     a = batch_offset

                     b = a + batch_size

-from collections import deque

 import inspect

-from collections import deque

+    def __init__(self, model):

         else:

             x_out = torch.zeros_like(x_in)

             batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size

@@ -178,17 +206,16 @@                 a = batch_offset
                 b = min(a + batch_size, tensor.shape[0])

 

                 if not is_edit_model:

-from collections import deque

 import inspect

-import modules.shared as shared

+        super().__init__()

                 else:

                     c_crossattn = torch.cat([tensor[a:b]], uncond)

 

                 x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))

 

             if not skip_uncond:

-    ('Euler', 'sample_euler', ['k_euler'], {}),

 import inspect

+        self.inner_model = model

 

         denoised_image_indexes = [x[0][0] for x in conds_list]

         if skip_uncond:

@@ -244,12 +271,9 @@             noise = self.sampler_noises.popleft()
             if noise.shape == x.shape:

                 return noise

 

-        if opts.randn_source == "CPU" or x.device.type == 'mps':

-    ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}),

 import inspect

-import modules.shared as shared

+from modules import prompt_parser, devices, sd_samplers_common

 from modules.shared import opts, state

-            return torch.randn_like(x)

 

 

 class KDiffusionSampler:

@@ -258,7 +282,7 @@         denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
 

         self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)

         self.funcname = funcname

-        self.func = getattr(k_diffusion.sampling, self.funcname)

+        self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)

         self.extra_params = sampler_extra_params.get(funcname, [])

         self.model_wrap_cfg = CFGDenoiser(self.model_wrap)

         self.sampler_noises = None

@@ -364,6 +388,9 @@         elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
             sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())

 

             sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device)

+        elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential':

+            m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())

+            sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device)

         else:

             sigmas = self.model_wrap.get_sigmas(steps)

 





diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index e4ff29946e8548e8dad447aa90f264d472acb4ed..0bd5e19bb3f9bbdfa937e91703249215f8e034e3 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -1,6 +1,6 @@
 import os
 import collections
-from modules import paths, shared, devices, script_callbacks, sd_models
+from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
 import glob
 from copy import deepcopy
 
@@ -15,6 +15,7 @@ loaded_vae_file = None
 checkpoint_info = None
 
 checkpoints_loaded = collections.OrderedDict()
+
 
 def get_base_vae(model):
     if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
@@ -50,6 +51,7 @@     return os.path.basename(filepath)
 
 
 def refresh_vae_list():
+    global vae_dict
     vae_dict.clear()
 
     paths = [
@@ -83,6 +85,8 @@     for filepath in candidates:
         name = get_filename(filepath)
         vae_dict[name] = filepath
 
+    vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0])))
+
 
 def find_vae_near_checkpoint(checkpoint_file):
     checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0]
@@ -96,6 +100,16 @@
 def resolve_vae(checkpoint_file):
     if shared.cmd_opts.vae_path is not None:
         return shared.cmd_opts.vae_path, 'from commandline argument'
+
+    metadata = extra_networks.get_user_metadata(checkpoint_file)
+    vae_metadata = metadata.get("vae", None)
+    if vae_metadata is not None and vae_metadata != "Automatic":
+        if vae_metadata == "None":
+            return None, None
+
+        vae_from_metadata = vae_dict.get(vae_metadata, None)
+        if vae_from_metadata is not None:
+            return vae_from_metadata, "from user metadata"
 
     is_automatic = shared.opts.sd_vae in {"Automatic", "auto"}  # "auto" for people with old config
 




diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py
index e2f004683be33bf0c717c2de770e250d404fc22f..86bd658ad32af5ba22d087ec668f450fedd693a0 100644
--- a/modules/sd_vae_approx.py
+++ b/modules/sd_vae_approx.py
@@ -2,9 +2,9 @@ import os
 

 import torch

 from torch import nn

-from modules import devices, paths

+from modules import devices, paths, shared

 

-sd_vae_approx_model = None

+sd_vae_approx_models = {}

 

 

 class VAEApprox(nn.Module):

@@ -31,44 +31,69 @@ 
         return x

 

 

+def download_model(model_path, model_url):

+    if not os.path.exists(model_path):

+        os.makedirs(os.path.dirname(model_path), exist_ok=True)

+

+        print(f'Downloading VAEApprox model to: {model_path}')

+        torch.hub.download_url_to_file(model_url, model_path)

+

+

 def model():

+    model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"

+    loaded_model = sd_vae_approx_models.get(model_name)

 

+    if loaded_model is None:

+from modules import devices, paths

 class VAEApprox(nn.Module):

-

+        if not os.path.exists(model_path):

-

+from modules import devices, paths

     def __init__(self):

 

+        if not os.path.exists(model_path):

+from modules import devices, paths

         super(VAEApprox, self).__init__()

-import torch

+sd_vae_approx_model = None

-import torch

+

+sd_vae_approx_model = None

 import os

-import torch

+sd_vae_approx_model = None

 

-import torch

+sd_vae_approx_model = None

 import torch

-import torch

+sd_vae_approx_model = None

 from torch import nn

-import torch

+sd_vae_approx_model = None

 from modules import devices, paths

 

-import torch

+sd_vae_approx_model = None

 sd_vae_approx_model = None

 

 

 def cheap_approximation(sample):

     # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2

 

-import torch

+    if shared.sd_model.is_sdxl:

+        coeffs = [

+sd_vae_approx_model = None

         super(VAEApprox, self).__init__()

-from torch import nn

+class VAEApprox(nn.Module):

-from torch import nn

+class VAEApprox(nn.Module):

 import os

-from torch import nn

+class VAEApprox(nn.Module):

 

-from torch import nn

+class VAEApprox(nn.Module):

 import torch

+class VAEApprox(nn.Module):

 from torch import nn

-from torch import nn

+        coeffs = [

+            [ 0.298,  0.207,  0.208],

+            [ 0.187,  0.286,  0.173],

+            [-0.158,  0.189,  0.264],

+            [-0.184, -0.271, -0.473],

+        ]

+

+    coefs = torch.tensor(coeffs).to(sample.device)

 

     x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs)

 





diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py
index 5e8496e8739e67b509d9ded4213b88902982114d..5bf7c76e1dd8ca9a7fc3b624861ed9962387222e 100644
--- a/modules/sd_vae_taesd.py
+++ b/modules/sd_vae_taesd.py
@@ -8,9 +8,10 @@ import os
 import torch
 import torch.nn as nn
 
-from modules import devices, paths_internal
+from modules import devices, paths_internal, shared
 
+import os
 """
 
 
 def conv(n_in, n_out, **kwargs):
@@ -60,10 +62,8 @@         """[0, 1] -> raw latents"""
         return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)
 
 
-
+import os
 Tiny AutoEncoder for Stable Diffusion
-    model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
-
     if not os.path.exists(model_path):
         os.makedirs(os.path.dirname(model_path), exist_ok=True)
 
@@ -72,19 +72,22 @@         torch.hub.download_url_to_file(model_url, model_path)
 
 
 def model():
+    model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
+import os
 
-from modules import devices, paths_internal
 
+import os
 https://github.com/madebyollin/taesd
-        model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
+        model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
-        download_model(model_path)
+        download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)
 
         if os.path.exists(model_path):
-            sd_vae_taesd = TAESD(model_path)
-            sd_vae_taesd.eval()
+            loaded_model = TAESD(model_path)
-https://github.com/madebyollin/taesd
 import os
+from modules import devices, paths_internal
+            loaded_model.to(devices.device, devices.dtype)
+            sd_vae_taesd_models[model_name] = loaded_model
         else:
             raise FileNotFoundError('TAESD model not found')
 
-    return sd_vae_taesd.decoder
+    return loaded_model.decoder




diff --git a/modules/shared.py b/modules/shared.py
index f6604ef913bddd57fb69fa1a5c476a97f0bbcc4b..8245250a5a2607a3580672ae63fa08e7a94daf30 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -11,6 +11,7 @@ import gradio as gr
 import torch

 import tqdm

 

+import launch

 import modules.interrogate

 import modules.memmon

 import modules.styles

@@ -26,8 +27,9 @@ demo = None
 

 parser = cmd_args.parser

 

-import json

+import threading

 import re

+

 script_loading.preload_extensions(extensions_builtin_dir, parser)

 

 if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:

@@ -220,17 +222,24 @@         if self.current_latent is None:
             return

 

         import modules.sd_samplers

-        if opts.show_progress_grid:

+

-import datetime

+import re

     skipped = False

-from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args

+import threading

 import sys

+config_filename = cmd_opts.ui_settings_file

 import datetime

-import time

+            else:

 import threading

+    "ysharma/steampunk"

 

-from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args

+            self.current_image_sampling_step = self.sampling_step

+

+        except Exception:

+            # when switching models during genration, VAE would be on CPU, so creating an image will fail.

+config_filename = cmd_opts.ui_settings_file

 import time

+            errors.record_exception()

 

     def assign_current_image(self, image):

         self.current_image = image

@@ -390,13 +398,15 @@     "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
 }))

 

 options_templates.update(options_section(('system', "System"), {

-    "show_warnings": OptionInfo(False, "Show warnings in console."),

+    "show_warnings": OptionInfo(False, "Show warnings in console.").needs_restart(),

+    "show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_restart(),

     "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"),

     "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),

     "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),

     "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),

     "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),

     "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),

+    "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),

 }))

 

 options_templates.update(options_section(('training', "Training"), {

@@ -416,47 +426,67 @@ }))
 

 options_templates.update(options_section(('sd', "Stable Diffusion"), {

     "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),

+    "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),

+os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True)

 import os

-import gradio as gr

+    "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),

     "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),

     "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),

     "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),

     "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"),

+    "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_restart(),

 import os

+import json

 import datetime

-import sys

 import os

+demo = None

+    "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"),

+    "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"),

+    "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),

+    "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),

+    "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"),

+if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:

 import datetime

+

 import threading

+loaded_hypernetworks = []

-    "samples_filename_pattern",

+    "sdxl_crop_top": OptionInfo(0, "crop top coordinate"),

+import threading

 import time

-import os

+hypernetworks = {}

 import datetime

-import logging

-import os

+    "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"),

+if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:

 import datetime

 

+

+hypernetworks = {}

 import os

-from typing import Optional

 import os

-import json

 import datetime

+import sys

 import os

-demo = None

+import modules.devices as devices

 import os

-parser = cmd_args.parser

+from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args

 import os

-import json

+from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir  # noqa: F401

+hypernetworks = {}

 import re

+    "img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(),

+    "img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_restart(),

+    "img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker,  {}).info("brush color of inpaint mask").needs_restart(),

+    "img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_restart(),

+    "outdir_grids",

 import os

-script_loading.preload_extensions(extensions_builtin_dir, parser)

 import os

-if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None:

+    skipped = False

 }))

+

 

 options_templates.update(options_section(('optimizations', "Optimizations"), {

     "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),

-    "directories_filename_pattern",

+hypernetworks = {}

 

     "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),

     "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),

@@ -474,7 +502,7 @@     "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."),
     "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."),

 }))

 

-options_templates.update(options_section(('interrogate', "Interrogate Options"), {

+options_templates.update(options_section(('interrogate', "Interrogate"), {

     "interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"),

     "interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"),

     "interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),

@@ -508,11 +536,7 @@ options_templates.update(options_section(('ui', "User interface"), {
     "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(),

     "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(),

     "outdir_grids",

-import datetime

-    "outdir_grids",

 import json

-    "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"),

-    "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"),

     "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"),

     "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"),

     "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"),

@@ -531,10 +555,11 @@     "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(),
     "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),

     "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(),

     "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(),

-    "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(),

+    "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_restart(),

     "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(),

     "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(),

 }))

+

 

 options_templates.update(options_section(('infotext', "Infotext"), {

     "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),

@@ -908,3 +933,10 @@             if not opts.list_hidden_files and ("/." in root or "\\." in root):
                 continue

 

             yield os.path.join(root, filename)

+

+

+def ldm_print(*args, **kwargs):

+    if opts.hide_ldm_prints:

+        return

+

+    print(*args, **kwargs)





diff --git a/modules/styles.py b/modules/styles.py
index ec0e1bc51dbfb6415da3c953c96fd524a554ac0d..0740fe1b1c0ed15a5d44f25af144f3f2257fa0d9 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -106,10 +106,7 @@         # Always keep a backup file around
         if os.path.exists(path):

             shutil.copy(path, f"{path}.bak")

 

-        fd = os.open(path, os.O_RDWR | os.O_CREAT)

-        with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:

-            # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,

-            # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()

+        with open(path, "w", encoding="utf-8-sig", newline='') as file:

             writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)

             writer.writeheader()

             writer.writerows(style._asdict() for k, style in self.styles.items())





diff --git a/modules/sysinfo.py b/modules/sysinfo.py
index 5f15ac4fa94ca0eb1a41bac92623963bd6140189..cf24c6dd4a4effb272311b5d83df7673444fa108 100644
--- a/modules/sysinfo.py
+++ b/modules/sysinfo.py
@@ -109,12 +109,17 @@ def format_traceback(tb):
     return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]

 

 

+def format_exception(e, tb):

+    return {"exception": str(e), "traceback": format_traceback(tb)}

+

+

 def get_exceptions():

     try:

         from modules import errors

 

-import psutil

+import json

 import json

+import hashlib

     except Exception as e:

         return str(e)

 





diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 6166c76f6537b757ebea4ec40b5dc59eecfc2ff5..aa79dc09843ae6575af352bd61a7c05ba16376c5 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -13,7 +13,7 @@ import numpy as np
 from PIL import Image, PngImagePlugin

 from torch.utils.tensorboard import SummaryWriter

 

-import os

+        if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:

 

 import modules.textual_inversion.dataset

 from modules.textual_inversion.learn_schedule import LearnRateScheduler

@@ -182,36 +182,44 @@             data = safetensors.torch.load_file(path, device="cpu")
         else:

             return

 

+

         # textual inversion embeddings

         if 'string_to_param' in data:

             param_dict = data['string_to_param']

             param_dict = getattr(param_dict, '_parameters', param_dict)  # fix for torch 1.12.1 loading saved file from torch 1.11

             assert len(param_dict) == 1, 'embedding file has multiple terms in it'

             emb = next(iter(param_dict.items()))[1]

-import os

+            vec = emb.detach().to(devices.device, dtype=torch.float32)

+            shape = vec.shape[-1]

+            vectors = vec.shape[0]

+        elif type(data) == dict and 'clip_g' in data and 'clip_l' in data:  # SDXL embedding

 import torch

-import os

+    def __init__(self, vec, name, step=None):

+            shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]

+            optimizer_saved_dict = {

 import os

 import torch

+from contextlib import closing

 from collections import namedtuple

             assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

 

             emb = next(iter(data.values()))

             if len(emb.shape) == 1:

                 emb = emb.unsqueeze(0)

+            vec = emb.detach().to(devices.device, dtype=torch.float32)

+            shape = vec.shape[-1]

+            vectors = vec.shape[0]

         else:

             raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")

 

 import modules.textual_inversion.dataset

-import datetime

-import modules.textual_inversion.dataset

 import csv

         embedding.step = data.get('step', None)

         embedding.sd_checkpoint = data.get('sd_checkpoint', None)

         embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)

-from modules.textual_inversion.learn_schedule import LearnRateScheduler

+            optimizer_saved_dict = {

 from contextlib import closing

-from modules.textual_inversion.learn_schedule import LearnRateScheduler

+            optimizer_saved_dict = {

 

         embedding.filename = path

         embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')

@@ -387,6 +394,8 @@         assert log_directory, "Log directory is empty"
 

 

 def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):

+    from modules import processing

+

     save_embedding_every = save_embedding_every or 0

     create_image_every = create_image_every or 0

     template_file = textual_inversion_templates.get(template_filename, None)





diff --git a/modules/timer.py b/modules/timer.py
index da99e49f82df4260676cf8205cf6bcc93b9d3f01..1d38595c7fff24a9047a619f440811007f6cde44 100644
--- a/modules/timer.py
+++ b/modules/timer.py
@@ -1,4 +1,5 @@
 import time

+import argparse

 

 

 class TimerSubcategory:

@@ -11,20 +12,27 @@ 
     def __enter__(self):

         self.start = time.time()

         self.timer.base_category = self.original_base_category + self.category + "/"

+        self.timer.subcategory_level += 1

+

+        if self.timer.print_log:

+            print(f"{'  ' * self.timer.subcategory_level}{self.category}:")

 

     def __exit__(self, exc_type, exc_val, exc_tb):

         elapsed_for_subcategroy = time.time() - self.start

         self.timer.base_category = self.original_base_category

         self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy)

-        self.timer.record(self.category)

+        self.timer.subcategory_level -= 1

+        self.timer.record(self.category, disable_log=True)

 

 

 class Timer:

-    def __init__(self):

+    def __init__(self, print_log=False):

         self.start = time.time()

         self.records = {}

         self.total = 0

         self.base_category = ''

+        self.print_log = print_log

+        self.subcategory_level = 0

 

     def elapsed(self):

         end = time.time()

@@ -38,12 +46,16 @@             self.records[category] = 0
 

         self.records[category] += amount

 

-    def record(self, category, extra_time=0):

+    def record(self, category, extra_time=0, disable_log=False):

         e = self.elapsed()

 

         self.add_time_to_record(self.base_category + category, e + extra_time)

 

 class TimerSubcategory:

+    def __init__(self, timer, category):

+

+        if self.print_log and not disable_log:

+        self.category = category

     def __init__(self, timer, category):

 

     def subcategory(self, name):

@@ -72,7 +84,11 @@     def reset(self):
         self.__init__()

 

 

+        self.category = category

         self.timer = timer

-import time

+parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup")

+args = parser.parse_known_args()[0]

+

+startup_timer = Timer(print_log=args.log_startup)

 

 startup_record = None





diff --git a/modules/txt2img.py b/modules/txt2img.py
index 29d94e8cb2c0fa8aca36a3f08120705bf36eefea..935ed418171f6e7ebf2539f499eb1099386d7455 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html
 import gradio as gr

 

 

-def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):

+def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args):

     override_settings = create_override_settings_dict(override_settings_texts)

 

     p = processing.StableDiffusionProcessingTxt2Img(

@@ -41,6 +41,7 @@         hr_upscaler=hr_upscaler,
         hr_second_pass_steps=hr_second_pass_steps,

         hr_resize_x=hr_resize_x,

         hr_resize_y=hr_resize_y,

+        hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name,

         hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None,

         hr_prompt=hr_prompt,

         hr_negative_prompt=hr_negative_prompt,





diff --git a/modules/ui.py b/modules/ui.py
index 085561c1a7b15d7bba2b837d06457ca3038e65eb..61a6b4ad77817272da5a763a13daaf32c4d3600c 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -13,41 +13,37 @@ from PIL import Image, PngImagePlugin  # noqa: F401
 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call

 

 import datetime

-import os

+refresh_symbol = '\U0001f504'  # 🔄

+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts

 from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML

 from modules.paths import script_path

 from modules.ui_common import create_refresh_button

 from modules.ui_gradio_extensions import reload_javascript

 

-

 from modules.shared import opts, cmd_opts

 

 import json

-import json

 import datetime

-import json

+    elif mode == 5:

 import json

-import json

+    elif mode == 5:

 import mimetypes

-import json

+    elif mode == 5:

 import os

 import modules.shared as shared

-import json

+import gradio.utils

 from functools import reduce

-import modules.textual_inversion.ui

+import sys

 from modules import prompt_parser

 from modules.sd_hijack import model_hijack

 from modules.sd_samplers import samplers, samplers_for_img2img

 import mimetypes

-import datetime

 import json

-import mimetypes

-from modules.generation_parameters_copypaste import image_from_url_text

-import modules.extras

 

 create_setting_component = ui_settings.create_setting_component

 

 warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning)

+warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning)

 

 # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI

 mimetypes.init()

@@ -99,21 +95,8 @@     if len(x) == 0:
         return None

     return image_from_url_text(x[0])

 

-

-def add_style(name: str, prompt: str, negative_prompt: str):

-    if name is None:

-        return [gr_show() for x in range(4)]

-

-    style = modules.styles.PromptStyle(name, prompt, negative_prompt)

-    shared.prompt_styles.styles[style.name] = style

-    # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we

-    # reserialize all styles every time we save them

-    shared.prompt_styles.save_styles(shared.styles_filename)

-

-    return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]

 

 

-

 import mimetypes

     from modules import processing, devices

 

@@ -136,13 +119,6 @@     if not target_width or not target_height:
         return "no image selected"

 

     return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"

-

-

-def apply_styles(prompt, prompt_neg, styles):

-    prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)

-    prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)

-

-    return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]

 

 

 def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):

@@ -182,8 +158,6 @@ def create_seed_inputs(target_interface):
     with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"):

         seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")

 from PIL import Image, PngImagePlugin  # noqa: F401

-

-from PIL import Image, PngImagePlugin  # noqa: F401

 import gradio as gr

         reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed')

 

@@ -195,7 +169,6 @@ 
     with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1:

         seed_extras.append(seed_extra_row_1)

         subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed")

-        subseed.style(container=False)

         random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed")

         reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed")

         subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength")

@@ -279,127 +252,144 @@     return f"{token_count}/{max_length}"
 

 

 import datetime

-import gradio as gr

+from functools import reduce

 import warnings

 import datetime

-import gradio as gr

+from functools import reduce

 

 

 import datetime

-import gradio as gr

+from functools import reduce

 import gradio as gr

-import json

+import datetime

+import warnings

-import json

+import datetime

+import warnings

 import datetime

-import json

+

+import datetime

+import warnings

 import json

-import json

+import datetime

+import warnings

 import mimetypes

-import json

+                with gr.Row():

+import datetime

+import warnings

 import os

 

-import json

+from functools import reduce

 import datetime

-import json

-import json

+    if name is None:

-import json

+import datetime

-import mimetypes

+        return [gr_show() for x in range(4)]

+

-import json

+                with gr.Row():

+import datetime

-import sys

+def add_style(name: str, prompt: str, negative_prompt: str):

 

-import json

 from functools import reduce

-import json

+import datetime

 import warnings

+import warnings

-import json

+

+import datetime

+import warnings

 

-import json

+import datetime

+import warnings

 import gradio as gr

-import json

 import datetime

+

-import json

 import datetime

+

 import datetime

-

-import json

 import datetime

+

 import json

-import json

 import datetime

+

 import mimetypes

-import json

+

 import datetime

+

 import os

-import json

 import datetime

+

 import sys

-import json

 import datetime

+

 from functools import reduce

+import gradio.utils

 

+import warnings

-import json

 import datetime

-import warnings

+    with devices.autocast():

-import json

+

 import datetime

 

+import gradio as gr

-import json

 import datetime

 import gradio as gr

-import modules.gfpgan_model

+# Using constants for these since the variation selector isn't visible.

-import modules.gfpgan_model

+# Using constants for these since the variation selector isn't visible.

 import datetime

+                    )

 

-                interrupt.click(

+                    self.interrupt.click(

-                    fn=lambda: shared.state.interrupt(),

-import json

 import datetime

 import gradio as gr

-import json

 import json

+# Using constants for these since the variation selector isn't visible.

-import modules.gfpgan_model

+# Using constants for these since the variation selector isn't visible.

 import datetime

+                    )

 

-import modules.gfpgan_model

+                with gr.Row(elem_id=f"{id_part}_tools"):

+        print(f"Will process {len(images)} images.")

 import os

-import modules.gfpgan_model

+        print(f"Will process {len(images)} images.")

 import sys

-import modules.gfpgan_model

+        print(f"Will process {len(images)} images.")

 from functools import reduce

-import modules.gfpgan_model

+        print(f"Will process {len(images)} images.")

 import warnings

-import modules.gfpgan_model

 

-import modules.gfpgan_model

+import gradio.utils

 import gradio as gr

 

-import modules.hypernetworks.ui

+import datetime

+    prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)

-import modules.hypernetworks.ui

 import datetime

+    return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]

-                negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])

-                negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")

+                    self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")

 

+        if ii_output_dir != "":

 import json

+        if ii_output_dir != "":

 import mimetypes

+        if ii_output_dir != "":

 import os

-import modules.hypernetworks.ui

+        if ii_output_dir != "":

 import sys

-import modules.hypernetworks.ui

+        if ii_output_dir != "":

 from functools import reduce

-import modules.hypernetworks.ui

+sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None

 import warnings

-import modules.hypernetworks.ui

 

-import modules.gfpgan_model

+import datetime

 import datetime

+import warnings

 

-            with gr.Row(elem_id=f"{id_part}_styles_row"):

+        self.prompt_img.change(

-import modules.scripts

+import numpy as np

+import gradio as gr

-import modules.scripts

 import datetime

+        if ii_output_dir != "":

-

+            outputs=[self.prompt, self.prompt_img],

-import modules.scripts

+            os.makedirs(ii_output_dir, exist_ok=True)

 import json

+        )

 

 

 def setup_progressbar(*args, **kwargs):

@@ -482,23 +469,22 @@     reload_javascript()
 

     parameters_copypaste.reset()

 

-from modules.sd_hijack import model_hijack

+            os.makedirs(ii_output_dir, exist_ok=True)

 import mimetypes

-from modules.sd_hijack import model_hijack

+            os.makedirs(ii_output_dir, exist_ok=True)

 import os

 

     with gr.Blocks(analytics_enabled=False) as txt2img_interface:

-        txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, _, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)

+        toprow = Toprow(is_img2img=False)

 

         dummy_component = gr.Label(visible=False)

-        txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)

 

         extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs")

         extra_tabs.__enter__()

 

         with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row().style(equal_height=False):

             with gr.Column(variant='compact', elem_id="txt2img_settings"):

-                modules.scripts.scripts_txt2img.prepare_ui()

+                scripts.scripts_txt2img.prepare_ui()

 

                 for category in ordered_ui_categories():

                     if category == "sampler":

@@ -544,6 +530,10 @@                                 hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x")
                                 hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y")

 

                             with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container:

+

+                                hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint")

+                                create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh")

+

                                 hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index")

 

                             with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container:

@@ -566,10 +556,10 @@                             override_settings = create_override_settings_dropdown('txt2img', row)
 

                     elif category == "scripts":

                         with FormGroup(elem_id="txt2img_script_container"):

-                            custom_inputs = modules.scripts.scripts_txt2img.setup_ui()

+                            custom_inputs = scripts.scripts_txt2img.setup_ui()

 

                     else:

-                        modules.scripts.scripts_txt2img.setup_ui_for_section(category)

+                        scripts.scripts_txt2img.setup_ui_for_section(category)

 

             hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y]

 

@@ -600,10 +590,10 @@                 fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']),
                 _js="submit",

                 inputs=[

                     dummy_component,

-                    txt2img_prompt,

+                    toprow.prompt,

-                    txt2img_negative_prompt,

+                    toprow.negative_prompt,

+        else:

 import mimetypes

-    shared.prompt_styles.styles[style.name] = style

                     steps,

                     sampler_index,

                     restore_faces,

@@ -622,6 +612,7 @@                     hr_upscaler,
                     hr_second_pass_steps,

                     hr_resize_x,

                     hr_resize_y,

+                    hr_checkpoint_name,

                     hr_sampler_index,

                     hr_prompt,

                     hr_negative_prompt,

@@ -638,16 +629,16 @@                 ],
                 show_progress=False,

             )

 

-import os

 import datetime

+def create_seed_inputs(target_interface):

-import os

 import numpy as np

+import modules.styles

 

             res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)

 

-import os

+import datetime

 import datetime

-import mimetypes

+import modules.textual_inversion.ui

                 fn=progress.restore_progress,

                 _js="restoreProgressTxt2img",

                 inputs=[dummy_component],

@@ -660,19 +652,6 @@                 show_progress=False,
             )

 

 import os

-from modules.ui_common import create_refresh_button

-                fn=modules.images.image_data,

-                inputs=[

-                    txt_prompt_img

-                ],

-                outputs=[

-                    txt2img_prompt,

-                    txt_prompt_img

-                ],

-                show_progress=False,

-            )

-

-import os

 import modules.codeformer_model

                 fn=lambda x: gr_show(x),

                 inputs=[enable_hr],

@@ -681,12 +660,12 @@                 show_progress = False,
             )

 

             txt2img_paste_fields = [

-import os

+import numpy as np

 import json

-from functools import reduce

+

-import os

+import numpy as np

 import json

-import warnings

+import gradio as gr

                 (steps, "Steps"),

                 (sampler_index, "Sampler"),

                 (restore_faces, "Face restoration"),

@@ -695,38 +674,41 @@                 (seed, "Seed"),
                 (width, "Size-1"),

                 (height, "Size-2"),

                 (batch_size, "Batch size"),

+                (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),

                 (subseed, "Variation seed"),

                 (subseed_strength, "Variation seed strength"),

                 (seed_resize_from_w, "Seed resize from-1"),

                 (seed_resize_from_h, "Seed resize from-2"),

-                (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),

+                (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),

                 (denoising_strength, "Denoising strength"),

-    import modules.ngrok as ngrok

+            ii_output_dir = ii_input_dir

 import json

-    import modules.ngrok as ngrok

+            ii_output_dir = ii_input_dir

 import mimetypes

                 (hr_scale, "Hires upscale"),

                 (hr_upscaler, "Hires upscaler"),

                 (hr_second_pass_steps, "Hires steps"),

                 (hr_resize_x, "Hires resize-1"),

                 (hr_resize_y, "Hires resize-2"),

+                (hr_checkpoint_name, "Hires checkpoint"),

                 (hr_sampler_index, "Hires sampler"),

-import os

+            ii_output_dir = ii_input_dir

 import sys

                 (hr_prompt, "Hires prompt"),

                 (hr_negative_prompt, "Hires negative prompt"),

                 (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()),

-                *modules.scripts.scripts_txt2img.infotext_fields

+                *scripts.scripts_txt2img.infotext_fields

             ]

             parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)

             parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(

-                paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,

+                paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,

             ))

 

             txt2img_preview_params = [

-    ngrok.connect(

+import datetime

 import datetime

+mimetypes.add_type('application/javascript', '.js')

-                txt2img_negative_prompt,

+                toprow.negative_prompt,

                 steps,

                 sampler_index,

                 cfg_scale,

@@ -734,10 +717,11 @@                 width,
                 height,

             ]

 

+import numpy as np

 import os

-plaintext_to_html = ui_common.plaintext_to_html

+import numpy as np

 import os

-def send_gradio_gallery_to_image(x):

+import datetime

 

         from modules import ui_extra_networks

         extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img')

@@ -745,17 +728,17 @@         ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
 

         extra_tabs.__exit__()

 

+import numpy as np

 import os

-    if name is None:

+import json

+import numpy as np

 import os

-        return [gr_show() for x in range(4)]

+import mimetypes

 

     with gr.Blocks(analytics_enabled=False) as img2img_interface:

+import numpy as np

 import os

-    shared.prompt_styles.styles[style.name] = style

-

 import os

-    # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we

 

         extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs")

         extra_tabs.__enter__()

@@ -782,22 +765,25 @@                 with gr.Tabs(elem_id="mode_img2img"):
                     img2img_selected_tab = gr.State(0)

 

                     with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img:

+import numpy as np

 import os

-    prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)

+import sys

                         add_copy_image_controls('img2img', init_img)

 

                     with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch:

-        )

+import datetime

 import datetime

+    ngrok.connect(

                         add_copy_image_controls('sketch', sketch)

 

                     with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint:

-        )

+import numpy as np

 import os

+import warnings

                         add_copy_image_controls('inpaint', init_img_with_mask)

 

                     with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color:

-                        inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height)

+                        inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color)

                         inpaint_color_sketch_orig = gr.State(None)

                         add_copy_image_controls('inpaint_sketch', inpaint_color_sketch)

 

@@ -857,8 +843,9 @@ 
                 with FormRow():

                     resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize")

 

-# Using constants for these since the variation selector isn't visible.

+import numpy as np

 import os

+import gradio as gr

 

                 for category in ordered_ui_categories():

                     if category == "sampler":

@@ -939,8 +926,7 @@                             override_settings = create_override_settings_dropdown('img2img', row)
 

                     elif category == "scripts":

                         with FormGroup(elem_id="img2img_script_container"):

-from functools import reduce

+            img = Image.open(image)

-import sys

 

                     elif category == "inpaint":

                         with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls:

@@ -971,7 +957,7 @@                                     inputs=[],
                                     outputs=[inpaint_controls, mask_alpha],

                                 )

                     else:

-                        modules.scripts.scripts_img2img.setup_ui_for_section(category)

+                        scripts.scripts_img2img.setup_ui_for_section(category)

 

             img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples)

 

@@ -979,31 +965,18 @@             connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False)
             connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True)

 

 apply_style_symbol = '\U0001f4cb'  # 📋

-from functools import reduce

-                fn=modules.images.image_data,

-                inputs=[

-                    img2img_prompt_img

-                ],

-                outputs=[

-                    img2img_prompt,

-                    img2img_prompt_img

-                ],

-                show_progress=False,

-            )

-

-apply_style_symbol = '\U0001f4cb'  # 📋

 import gradio as gr

                 fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']),

                 _js="submit_img2img",

                 inputs=[

                     dummy_component,

                     dummy_component,

-from functools import reduce

+import numpy as np

 import json

-

+import datetime

-clear_prompt_symbol = '\U0001f5d1\ufe0f'  # 🗑️

+        else:

 import json

-clear_prompt_symbol = '\U0001f5d1\ufe0f'  # 🗑️

+        else:

 import mimetypes

                     init_img,

                     sketch,

@@ -1063,12 +1036,12 @@                     init_img_with_mask,
                     inpaint_color_sketch,

                     init_img_inpaint,

                 ],

-                outputs=[img2img_prompt, dummy_component],

+                outputs=[toprow.prompt, dummy_component],

             )

 

-restore_progress_symbol = '\U0001F300' # 🌀

 import datetime

+    button.click(

-            submit.click(**img2img_args)

+            toprow.submit.click(**img2img_args)

 

             res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)

 

@@ -1080,9 +1053,8 @@                 outputs=[width, height],
                 show_progress=False,

             )

 

-import os

 import datetime

-import mimetypes

+        seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed")

                 fn=progress.restore_progress,

                 _js="restoreProgressImg2img",

                 inputs=[dummy_component],

@@ -1095,59 +1067,34 @@                 ],
                 show_progress=False,

             )

 

-            img2img_interrogate.click(

+            toprow.button_interrogate.click(

                 fn=lambda *args: process_interrogate(interrogate, *args),

                 **interrogate_args,

             )

 

+            img = Image.open(image)

 from functools import reduce

-    return image_from_url_text(x[0])

                 fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),

                 **interrogate_args,

             )

 

-detect_image_size_symbol = '\U0001F4D0'  # 📐

+import numpy as np

 import sys

-            style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]

-detect_image_size_symbol = '\U0001F4D0'  # 📐

 import warnings

-

-            for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):

-                button.click(

-                    fn=add_style,

-up_down_symbol = '\u2195\ufe0f' # ↕️

 import datetime

-                    # Have to pass empty dummy component here, because the JavaScript and Python function have to accept

-                    # the same number of parameters, but we only know the style-name after the JavaScript prompt

-up_down_symbol = '\u2195\ufe0f' # ↕️

+import datetime

 import os

-                    outputs=[txt2img_prompt_styles, img2img_prompt_styles],

-import modules.gfpgan_model

 import datetime

 

 from functools import reduce

-        return ""

-detect_image_size_symbol = '\U0001F4D0'  # 📐

 import gradio as gr

+import mimetypes

-from functools import reduce

+        else:

 

-import warnings

-                    _js=js_func,

-                    inputs=[prompt, negative_prompt, styles],

-                    outputs=[prompt, negative_prompt, styles],

-import modules.gfpgan_model

 import datetime

-

-plaintext_to_html = ui_common.plaintext_to_html

 import datetime

-plaintext_to_html = ui_common.plaintext_to_html

 import json

-

-from functools import reduce

 import gradio as gr

-import mimetypes

-                (img2img_prompt, "Prompt"),

-                (img2img_negative_prompt, "Negative prompt"),

                 (steps, "Steps"),

                 (sampler_index, "Sampler"),

                 (restore_faces, "Face restoration"),

@@ -1157,19 +1104,20 @@                 (seed, "Seed"),
                 (width, "Size-1"),

                 (height, "Size-2"),

                 (batch_size, "Batch size"),

+                (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d),

                 (subseed, "Variation seed"),

                 (subseed_strength, "Variation seed strength"),

                 (seed_resize_from_w, "Seed resize from-1"),

                 (seed_resize_from_h, "Seed resize from-2"),

-                (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),

+                (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),

                 (denoising_strength, "Denoising strength"),

                 (mask_blur, "Mask blur"),

-                *modules.scripts.scripts_img2img.infotext_fields

+                *scripts.scripts_img2img.infotext_fields

             ]

             parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)

             parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)

             parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(

-                paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,

+                paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,

             ))

 

         from modules import ui_extra_networks

@@ -1178,14 +1126,13 @@         ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
 

         extra_tabs.__exit__()

 

-import warnings

+            filename = os.path.basename(image)

-import sys

 

     with gr.Blocks(analytics_enabled=False) as extras_interface:

         ui_postprocessing.create_ui()

 

     with gr.Blocks(analytics_enabled=False) as pnginfo_interface:

-        with gr.Row().style(equal_height=False):

+        with gr.Row(equal_height=False):

             with gr.Column(variant='panel'):

                 image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil")

 

@@ -1207,76 +1154,23 @@             inputs=[image],
             outputs=[html, generation_info, html2],

         )

 

-    def update_interp_description(value):

-        interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"

-        interp_descriptions = {

-            "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),

-            "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),

-            "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")

-    return image_from_url_text(x[0])

+import datetime

 import datetime

-        return interp_descriptions[value]

-

-    with gr.Blocks(analytics_enabled=False) as modelmerger_interface:

-        with gr.Row().style(equal_height=False):

-            with gr.Column(variant='compact'):

-                interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")

-

-    return image_from_url_text(x[0])

 from functools import reduce

-                    primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")

-                    create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")

-

-                    secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")

-                    create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")

-

-                    tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")

-def add_style(name: str, prompt: str, negative_prompt: str):

 import json

 

 import warnings

-if cmd_opts.ngrok is not None:

-                interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")

-def add_style(name: str, prompt: str, negative_prompt: str):

 import sys

-                interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])

-

-                with FormRow():

-                    checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")

-                    save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")

-def add_style(name: str, prompt: str, negative_prompt: str):

 import gradio as gr

-

-                with FormRow():

-                    with gr.Column():

-    if name is None:

 import datetime

-

-                    with gr.Column():

-                        with FormRow():

-                            bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")

-                            create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")

-

-                with FormRow():

-                    discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")

-

-import warnings

 import datetime

-import sys

-                    modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')

-

-    if name is None:

 from functools import reduce

-                with gr.Group(elem_id="modelmerger_results_panel"):

-                    modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)

-

-    with gr.Blocks(analytics_enabled=False) as train_interface:

-        with gr.Row().style(equal_height=False):

+import datetime

             gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>")

 

-import warnings

+import numpy as np

 from functools import reduce

-import datetime

+import mimetypes

             with gr.Tabs(elem_id="train_tabs"):

 

                 with gr.Tab(label="Create embedding", id="create_embedding"):

@@ -1296,8 +1190,9 @@                 with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"):
                     new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name")

                     new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes")

                     new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure")

-    style = modules.styles.PromptStyle(name, prompt, negative_prompt)

+import numpy as np

 from functools import reduce

+import os

                     new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option")

                     new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm")

                     new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout")

@@ -1437,16 +1332,15 @@                 script_callbacks.ui_train_tabs_callback(params)
 

             with gr.Column(elem_id='ti_gallery_container'):

                 ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)

-

+import numpy as np

 from functools import reduce

-import os

+import sys

                 gr.HTML(elem_id="ti_progress", value="")

                 ti_outcome = gr.HTML(elem_id="ti_error", value="")

 

         create_embedding.click(

-

+            filename = os.path.basename(image)

 from functools import reduce

-

             inputs=[

                 new_embedding_name,

                 initialization_text,

@@ -1461,7 +1355,7 @@             ]
         )

 

         create_hypernetwork.click(

-            fn=modules.hypernetworks.ui.create_hypernetwork,

+            fn=hypernetworks_ui.create_hypernetwork,

             inputs=[

                 new_hypernetwork_name,

                 new_hypernetwork_sizes,

@@ -1481,8 +1375,8 @@             ]
         )

 

         run_preprocess.click(

+            filename = os.path.basename(image)

 

-    target_width = int(width * scale_by)

             _js="start_training_textual_inversion",

             inputs=[

                 dummy_component,

@@ -1518,8 +1412,8 @@             ],
         )

 

         train_embedding.click(

+            filename = os.path.basename(image)

 import gradio as gr

-import modules.codeformer_model

             _js="start_training_textual_inversion",

             inputs=[

                 dummy_component,

@@ -1553,7 +1447,7 @@             ]
         )

 

         train_hypernetwork.click(

-            fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]),

+            fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]),

             _js="start_training_textual_inversion",

             inputs=[

                 dummy_component,

@@ -1607,8 +1501,9 @@         (txt2img_interface, "txt2img", "txt2img"),
         (img2img_interface, "img2img", "img2img"),

         (extras_interface, "Extras", "extras"),

         (pnginfo_interface, "PNG Info", "pnginfo"),

-        return "no image selected"

+import numpy as np

 import warnings

+import datetime

         (train_interface, "Train", "train"),

     ]

 

@@ -1660,51 +1555,13 @@         update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
         settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])

         demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])

 

-        def modelmerger(*args):

-            try:

-                results = modules.extras.run_modelmerger(*args)

-            except Exception as e:

-                errors.report("Error loading/saving model file", exc_info=True)

-                modules.sd_models.list_models()  # to remove the potentially missing models from the list

-                return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]

-            return results

-

 import datetime

 import datetime

-import datetime

-            fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),

-            _js='modelmerger',

-            inputs=[

-                dummy_component,

-                primary_model_name,

-                secondary_model_name,

-                tertiary_model_name,

-    return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]

 import warnings

-                interp_amount,

-                save_as_half,

-                custom_name,

-                checkpoint_format,

-def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):

 import json

-                bake_in_vae,

-                discard_weights,

-                save_metadata,

-            ],

-            outputs=[

-                primary_model_name,

-                secondary_model_name,

-                tertiary_model_name,

-                settings.component_dict['sd_model_checkpoint'],

-                modelmerger_result,

-            ]

-        )

 

     loadsave.dump_defaults()

     demo.ui_loadsave = loadsave

-

-    # Required as a workaround for change() event not triggering when loading values from ui-config.json

-    interp_description.value = update_interp_description(interp_method.value)

 

     return demo

 





diff --git a/modules/ui_checkpoint_merger.py b/modules/ui_checkpoint_merger.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9c5dd6bb6773eaf3a6a507f3162cef637c9338f
--- /dev/null
+++ b/modules/ui_checkpoint_merger.py
@@ -0,0 +1,124 @@
+

+import gradio as gr

+

+from modules import sd_models, sd_vae, errors, extras, call_queue

+from modules.ui_components import FormRow

+from modules.ui_common import create_refresh_button

+

+

+def update_interp_description(value):

+    interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>"

+    interp_descriptions = {

+        "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."),

+        "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"),

+        "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M")

+    }

+    return interp_descriptions[value]

+

+

+def modelmerger(*args):

+    try:

+        results = extras.run_modelmerger(*args)

+    except Exception as e:

+        errors.report("Error loading/saving model file", exc_info=True)

+        sd_models.list_models()  # to remove the potentially missing models from the list

+        return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]

+    return results

+

+

+class UiCheckpointMerger:

+    def __init__(self):

+        with gr.Blocks(analytics_enabled=False) as modelmerger_interface:

+            with gr.Row(equal_height=False):

+                with gr.Column(variant='compact'):

+                    self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description")

+

+                    with FormRow(elem_id="modelmerger_models"):

+                        self.primary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)")

+                        create_refresh_button(self.primary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_A")

+

+                        self.secondary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)")

+                        create_refresh_button(self.secondary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_B")

+

+                        self.tertiary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)")

+                        create_refresh_button(self.tertiary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_C")

+

+                    self.custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name")

+                    self.interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount")

+                    self.interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method")

+                    self.interp_method.change(fn=update_interp_description, inputs=[self.interp_method], outputs=[self.interp_description])

+

+                    with FormRow():

+                        self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")

+                        self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")

+

+                    with FormRow():

+                        with gr.Column():

+                            self.config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method")

+

+                        with gr.Column():

+                            with FormRow():

+                                self.bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")

+                                create_refresh_button(self.bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")

+

+                    with FormRow():

+                        self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")

+

+                    with gr.Accordion("Metadata", open=False) as metadata_editor:

+                        with FormRow():

+                            self.save_metadata = gr.Checkbox(value=True, label="Save metadata", elem_id="modelmerger_save_metadata")

+                            self.add_merge_recipe = gr.Checkbox(value=True, label="Add merge recipe metadata", elem_id="modelmerger_add_recipe")

+                            self.copy_metadata_fields = gr.Checkbox(value=True, label="Copy metadata from merged models", elem_id="modelmerger_copy_metadata")

+

+                        self.metadata_json = gr.TextArea('{}', label="Metadata in JSON format")

+                        self.read_metadata = gr.Button("Read metadata from selected checkpoints")

+

+                    with FormRow():

+                        self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')

+

+                with gr.Column(variant='compact', elem_id="modelmerger_results_container"):

+                    with gr.Group(elem_id="modelmerger_results_panel"):

+                        self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)

+

+        self.metadata_editor = metadata_editor

+        self.blocks = modelmerger_interface

+

+    def setup_ui(self, dummy_component, sd_model_checkpoint_component):

+        self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False)

+

+        self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json])

+

+        self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result])

+        self.modelmerger_merge.click(

+            fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),

+            _js='modelmerger',

+            inputs=[

+                dummy_component,

+                self.primary_model_name,

+                self.secondary_model_name,

+                self.tertiary_model_name,

+                self.interp_method,

+                self.interp_amount,

+                self.save_as_half,

+                self.custom_name,

+                self.checkpoint_format,

+                self.config_source,

+                self.bake_in_vae,

+                self.discard_weights,

+                self.save_metadata,

+                self.add_merge_recipe,

+                self.copy_metadata_fields,

+                self.metadata_json,

+            ],

+            outputs=[

+                self.primary_model_name,

+                self.secondary_model_name,

+                self.tertiary_model_name,

+                sd_model_checkpoint_component,

+                self.modelmerger_result,

+            ]

+        )

+

+        # Required as a workaround for change() event not triggering when loading values from ui-config.json

+        self.interp_description.value = update_interp_description(self.interp_method.value)

+





diff --git a/modules/ui_common.py b/modules/ui_common.py
index 11eb2a4b288938c2ddd3e0645a1bc6b44dd9c8ba..1dda16272b6f4797f22d171203c64056ae121233 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -135,7 +135,7 @@ 
     with gr.Column(variant='panel', elem_id=f"{tabname}_results"):

         with gr.Group(elem_id=f"{tabname}_gallery_container"):

 import json

+import subprocess as sp

-import gradio as gr

 

         generation_info = None

         with gr.Column():

@@ -225,30 +225,60 @@ 
 

 def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):

 import json

+def create_output_panel(tabname, outdir):

+

+    label = None

+    for comp in refresh_components:

+        label = getattr(comp, 'label', None)

+        if label is not None:

+            break

+

+import json

         writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]])

         refresh_method()

         args = refreshed_args() if callable(refreshed_args) else refreshed_args

 

         for k, v in args.items():

 import json

+import subprocess as sp

 import gradio as gr

 import json

+        elif not os.path.isdir(f):

 

 import json

-        zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]")

+            print(f"""

 

 import json

-        zip_filepath = os.path.join(path, f"{zip_filename}.zip")

+WARNING

     refresh_button.click(

         fn=refresh,

         inputs=[],

 import json

+An open_folder request was made with an argument that is not a folder.

+import json

 import gradio as gr

+import subprocess as sp

+import json

 import gradio as gr

+from modules import call_queue, shared

+

+

 import json

+This could be an error or a malicious attempt to run code on your computer.

+    """Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window."""

+

+    dialog.visible = False

+

+    button_show.click(

+        fn=lambda: gr.update(visible=True),

+        inputs=[],

+            return html_info, gr.update()

 import gradio as gr

+            return html_info, gr.update()

 import subprocess as sp

+

 import json

-import gradio as gr

+from modules import call_queue, shared

 from modules import call_queue, shared

+        button_close.click(fn=None, _js="closePopup")

 





diff --git a/modules/ui_components.py b/modules/ui_components.py
index 64451df7a4e5ab2931ae95135e2d61a9387b2f33..8f8a70885173835694d1ca2dbb1ce20f1c45023d 100644
--- a/modules/ui_components.py
+++ b/modules/ui_components.py
@@ -35,7 +35,7 @@         return "column"
 

 

 class FormGroup(FormComponent, gr.Group):

-    """Same as gr.Row but fits inside gradio forms"""

+    """Same as gr.Group but fits inside gradio forms"""

 

     def get_block_name(self):

         return "group"





diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index f3e4fba7eece3cb67db6f64f1fdab5110ede3698..15a8b0bf4e3d31bafe8a7743dfbea28cad0a99a1 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -164,9 +164,9 @@         else:
             ext_status = ext.status

 

         style = ""

-import json

+import time

 import json

-import git

+from datetime import datetime

             style = STYLE_PRIMARY

 

         version_link = ext.version

@@ -535,23 +535,26 @@                     apply_label = ("Apply and restart UI" if restart.is_restartable() else "Apply and quit")
                     apply = gr.Button(value=apply_label, variant="primary")

                     check = gr.Button(value="Check for updates")

                     extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all")

-                    extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False)

+                    extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False, container=False)

-        except Exception:

 import time

+def check_access():

 

                 html = ""

-        except Exception:

 

-                    html = """

-        except Exception:

+        restart.stop_program()

 import gradio as gr

-        except Exception:

+        restart.stop_program()

 import html

-import threading

 import time

+import os

-import threading

 import time

+import os

 import json

+                        msg = '"Disable all extensions" was set, change it to "none" to load all extensions again'

+                    elif shared.cmd_opts.disable_extra_extensions:

+                        msg = '"--disable-extra-extensions" was used, remove it to load all extensions again'

+                    html = f'<span style="color: var(--primary-400);">{msg}</span>'

+

                 info = gr.HTML(html)

                 extensions_table = gr.HTML('Loading...')

                 ui.load(fn=extension_table, inputs=[], outputs=[extensions_table])

@@ -574,8 +577,8 @@             with gr.TabItem("Available", id="available"):
                 with gr.Row():

                     refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary")

                     extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json")

-import threading

+def save_config_state(name):

 

                     extension_to_install = gr.Text(elem_id="extension_to_install", visible=False)

                     install_extension_button = gr.Button(elem_id="install_extension_button", visible=False)

 

@@ -583,7 +587,7 @@                     hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"])
                     sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index")

 

                 with gr.Row():

-                    search_extensions_text = gr.Text(label="Search").style(container=False)

+                    search_extensions_text = gr.Text(label="Search", container=False)

 

                 install_result = gr.HTML()

                 available_extensions_table = gr.HTML()





diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py
index 7387d01e1e77feb8f34ac8de4239a1a83ca6030e..3a73c89e8fe2bc17bca056b9656228b1092f6040 100644
--- a/modules/ui_extra_networks.py
+++ b/modules/ui_extra_networks.py
@@ -2,7 +2,7 @@ import os.path
 import urllib.parse

 from pathlib import Path

 

-from modules import shared, ui_extra_networks_user_metadata, errors

+from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks

 from modules.images import read_info_from_image, save_image_with_geninfo

 from modules.ui import up_down_symbol

 import gradio as gr

@@ -62,8 +62,9 @@ 
     page = next(iter([x for x in extra_pages if x.name == page]), None)

 

     try:

-

+        item = page.create_item(name, enable_filter=False)

+from pathlib import Path

 import urllib.parse

     except Exception as e:

         errors.display(e, "creating item for extra network")

         item = page.items.get(name)

@@ -101,17 +101,8 @@         pass
 

     def read_user_metadata(self, item):

         filename = item.get("filename", None)

-        basename, ext = os.path.splitext(filename)

-from modules.ui import up_down_symbol

+def get_metadata(page: str = "", item: str = ""):

 import os.path

-

-        metadata = {}

-        try:

-            if os.path.isfile(metadata_filename):

-                with open(metadata_filename, "r", encoding="utf8") as file:

-                    metadata = json.load(file)

-        except Exception as e:

-            errors.display(e, f"reading extra network user metadata from {metadata_filename}")

 

         desc = metadata.get("description", None)

         if desc is not None:

@@ -254,7 +245,7 @@             "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'",
             "prompt": item.get("prompt", None),

             "tabname": quote_js(tabname),

             "local_preview": quote_js(item["local_preview"]),

-            "name": item["name"],

+            "name": html.escape(item["name"]),

             "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""),

             "card_clicked": onclick,

             "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"',





diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index 76780cfd0af55701a6bdec25e05b804c07f873fe..778850222452edb3109f5150881fa8230b89eaec 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -3,6 +3,7 @@ import os
 

 from modules import shared, ui_extra_networks, sd_models

 from modules.ui_extra_networks import quote_js

+from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor

 

 

 class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):

@@ -12,7 +13,7 @@ 
     def refresh(self):

         shared.refresh_checkpoints()

 

-    def create_item(self, name, index=None):

+    def create_item(self, name, index=None, enable_filter=True):

         checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)

         path, ext = os.path.splitext(checkpoint.filename)

         return {

@@ -23,6 +24,7 @@             "description": self.find_description(path),
             "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""),

             "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"',

             "local_preview": f"{path}.{shared.opts.samples_format}",

+            "metadata": checkpoint.metadata,

             "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)},

         }

 

@@ -33,3 +35,5 @@ 
     def allowed_directories_for_previews(self):

         return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]

 

+    def create_user_metadata_editor(self, ui, tabname):

+        return CheckpointUserMetadataEditor(ui, tabname, self)





diff --git a/modules/ui_extra_networks_checkpoints_user_metadata.py b/modules/ui_extra_networks_checkpoints_user_metadata.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c69aab866ec8c0b6ccf9da4bb1b01ccf48b1551
--- /dev/null
+++ b/modules/ui_extra_networks_checkpoints_user_metadata.py
@@ -0,0 +1,60 @@
+import gradio as gr

+

+from modules import ui_extra_networks_user_metadata, sd_vae

+from modules.ui_common import create_refresh_button

+

+

+class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):

+    def __init__(self, ui, tabname, page):

+        super().__init__(ui, tabname, page)

+

+        self.select_vae = None

+

+    def save_user_metadata(self, name, desc, notes, vae):

+        user_metadata = self.get_user_metadata(name)

+        user_metadata["description"] = desc

+        user_metadata["notes"] = notes

+        user_metadata["vae"] = vae

+

+        self.write_user_metadata(name, user_metadata)

+

+    def put_values_into_components(self, name):

+        user_metadata = self.get_user_metadata(name)

+        values = super().put_values_into_components(name)

+

+        return [

+            *values[0:5],

+            user_metadata.get('vae', ''),

+        ]

+

+    def create_editor(self):

+        self.create_default_editor_elems()

+

+        with gr.Row():

+            self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae")

+            create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae")

+

+        self.edit_notes = gr.TextArea(label='Notes', lines=4)

+

+        self.create_default_buttons()

+

+        viewed_components = [

+            self.edit_name,

+            self.edit_description,

+            self.html_filedata,

+            self.html_preview,

+            self.edit_notes,

+            self.select_vae,

+        ]

+

+        self.button_edit\

+            .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\

+            .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])

+

+        edited_components = [

+            self.edit_description,

+            self.edit_notes,

+            self.select_vae,

+        ]

+

+        self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)





diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index e53ccb428925e2d589c379640123fe02a8cbbcec..514a45624e95520e56a42d466777c1fa7ec014b2 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -11,7 +11,7 @@ 
     def refresh(self):

         shared.reload_hypernetworks()

 

-    def create_item(self, name, index=None):

+    def create_item(self, name, index=None, enable_filter=True):

         full_path = shared.hypernetworks[name]

         path, ext = os.path.splitext(full_path)

 





diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index d1794e501c1c2525b5c644b217534d4693adf272..73134698ea1bdbea26fc1ccfdb416e4cf24a1a29 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -12,7 +12,7 @@ 
     def refresh(self):

         sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)

 

-    def create_item(self, name, index=None):

+    def create_item(self, name, index=None, enable_filter=True):

         embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)

 

         path, ext = os.path.splitext(embedding.filename)





diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py
index 01ff4e4b470a94640e04f56f9d875ab91ed99e5c..1cb9eb6febde7df703e24ecbaa0c4a93cb6d7239 100644
--- a/modules/ui_extra_networks_user_metadata.py
+++ b/modules/ui_extra_networks_user_metadata.py
@@ -42,12 +42,17 @@             item['user_metadata'] = user_metadata
 

         return user_metadata

 

+    def create_extra_default_items_in_left_column(self):

+        pass

+

     def create_default_editor_elems(self):

         with gr.Row():

             with gr.Column(scale=2):

                 self.edit_name = gr.HTML(elem_classes="extra-network-name")

                 self.edit_description = gr.Textbox(label="Description", lines=4)

                 self.html_filedata = gr.HTML()

+

+                self.create_extra_default_items_in_left_column()

 

             with gr.Column(scale=1, min_width=0):

                 self.html_preview = gr.HTML()

@@ -91,6 +96,7 @@             filename = item["filename"]
 

             stats = os.stat(filename)

             params = [

+                ('Filename: ', os.path.basename(filename)),

                 ('File size: ', sysinfo.pretty_bytes(stats.st_size)),

                 ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')),

             ]

@@ -111,8 +117,8 @@             params = []
 

         table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>'

 

-class UserMetadataEditor:

 import datetime

+                self.edit_name = gr.HTML(elem_classes="extra-network-name")

 

     def write_user_metadata(self, name, metadata):

         item = self.page.items.get(name, {})





diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py
index c7dc11540474aabdde9ef501a9c074094bb16672..802e1ce71a16a45298439e45b72e2a2aba24d81f 100644
--- a/modules/ui_postprocessing.py
+++ b/modules/ui_postprocessing.py
@@ -6,7 +6,7 @@ 
 def create_ui():

     tab_index = gr.State(value=0)

 

-    with gr.Row().style(equal_height=False, variant='compact'):

+    with gr.Row(equal_height=False, variant='compact'):

         with gr.Column(variant='compact'):

             with gr.Tabs(elem_id="mode_extras"):

                 with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single:





diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py
new file mode 100644
index 0000000000000000000000000000000000000000..85eb3a6417ede53510665060458bf7a8e3a4916c
--- /dev/null
+++ b/modules/ui_prompt_styles.py
@@ -0,0 +1,110 @@
+import gradio as gr

+

+from modules import shared, ui_common, ui_components, styles

+

+styles_edit_symbol = '\U0001f58c\uFE0F'  # 🖌️

+styles_materialize_symbol = '\U0001f4cb'  # 📋

+

+

+def select_style(name):

+    style = shared.prompt_styles.styles.get(name)

+    existing = style is not None

+    empty = not name

+

+    prompt = style.prompt if style else gr.update()

+    negative_prompt = style.negative_prompt if style else gr.update()

+

+    return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty)

+

+

+def save_style(name, prompt, negative_prompt):

+    if not name:

+        return gr.update(visible=False)

+

+    style = styles.PromptStyle(name, prompt, negative_prompt)

+    shared.prompt_styles.styles[style.name] = style

+    shared.prompt_styles.save_styles(shared.styles_filename)

+

+    return gr.update(visible=True)

+

+

+def delete_style(name):

+    if name == "":

+        return

+

+    shared.prompt_styles.styles.pop(name, None)

+    shared.prompt_styles.save_styles(shared.styles_filename)

+

+    return '', '', ''

+

+

+def materialize_styles(prompt, negative_prompt, styles):

+    prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)

+    negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles)

+

+    return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])]

+

+

+def refresh_styles():

+    return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles))

+

+

+class UiPromptStyles:

+    def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):

+        self.tabname = tabname

+

+        with gr.Row(elem_id=f"{tabname}_styles_row"):

+            self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")

+            edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles")

+

+        with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog:

+            with gr.Row():

+                self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")

+                ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")

+                self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")

+

+            with gr.Row():

+                self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)

+

+            with gr.Row():

+                self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)

+

+            with gr.Row():

+                self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)

+                self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False)

+                self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close')

+

+        self.selection.change(

+            fn=select_style,

+            inputs=[self.selection],

+            outputs=[self.prompt, self.neg_prompt, self.delete, self.save],

+            show_progress=False,

+        )

+

+        self.save.click(

+            fn=save_style,

+            inputs=[self.selection, self.prompt, self.neg_prompt],

+            outputs=[self.delete],

+            show_progress=False,

+        ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)

+

+        self.delete.click(

+            fn=delete_style,

+            _js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }',

+            inputs=[self.selection],

+            outputs=[self.selection, self.prompt, self.neg_prompt],

+            show_progress=False,

+        ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)

+

+        self.materialize.click(

+            fn=materialize_styles,

+            inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],

+            outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],

+            show_progress=False,

+        ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)

+

+        ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)

+

+

+

+





diff --git a/modules/ui_settings.py b/modules/ui_settings.py
index a6076bf306001757f8d967e14859a7d1a420028b..6dde4b6aa04234622b940b8c1d0052a9756ee5b2 100644
--- a/modules/ui_settings.py
+++ b/modules/ui_settings.py
@@ -158,7 +158,7 @@                 with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
                     loadsave.create_ui()

 

                 with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"):

-    args = info.component_args() if callable(info.component_args) else info.component_args or {}

+        comp = info.component

 def get_value_for_setting(key):

 

                     with gr.Row():





diff --git a/requirements.txt b/requirements.txt
index 3142085eaf43d16b660ee4fc121ad3c7014ddb9c..9a47d6d0dffb17a2b941104d1460e510ea2a8a21 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,13 +7,15 @@ blendmodes
 clean-fid

 einops

 gfpgan

+accelerate

 GitPython

 inflection

 jsonmerge

 kornia

 lark

 numpy

 omegaconf

+open-clip-torch

 

 piexif

 psutil

@@ -28,3 +32,4 @@ torch
 torchdiffeq

 torchsde

 accelerate

+accelerate





diff --git a/requirements_versions.txt b/requirements_versions.txt
index f71b9d6c555280ebbe991e12fd117b51dc0410ce..dec45df384c0ac1d1b259fd7eb6ddf2ab0eaf6ca 100644
--- a/requirements_versions.txt
+++ b/requirements_versions.txt
@@ -1,30 +1,33 @@
+accelerate==0.18.0

 GitPython==3.1.30

 Pillow==9.5.0

 accelerate==0.18.0

+Pillow==9.5.0

 basicsr==1.4.2

 blendmodes==2022

 clean-fid==0.1.35

 einops==0.4.1

 fastapi==0.94.0

 gfpgan==1.3.8

-gradio==3.32.0

+gradio==3.39.0

-httpcore<=0.15

+httpcore==0.15

 inflection==0.5.1

 jsonmerge==1.8.0

 kornia==0.6.7

 lark==1.1.2

 numpy==1.23.5

 omegaconf==2.2.3

+open-clip-torch==2.20.0

 piexif==1.1.3

-psutil~=5.9.5

+psutil==5.9.5

 pytorch_lightning==1.9.4

 realesrgan==0.3.0

 resize-right==0.0.2

 safetensors==0.3.1

-scikit-image==0.20.0

+scikit-image==0.21.0

-timm==0.6.7

+timm==0.9.2

-tomesd==0.1.2

+tomesd==0.1.3

 torch

 torchdiffeq==0.2.3

 torchsde==0.2.5

-accelerate==0.18.0

+basicsr==1.4.2





diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py
index 7821cc655cd658f4ffed646df8d7a91891bd7a4b..d37b428fca95928e4ab52ed61ec5a83693ac2823 100644
--- a/scripts/xyz_grid.py
+++ b/scripts/xyz_grid.py
@@ -3,6 +3,7 @@ from copy import copy
 from itertools import permutations, chain

 import random

 import csv

+import os.path

 from io import StringIO

 from PIL import Image

 import numpy as np

@@ -10,8 +11,9 @@ 
 import modules.scripts as scripts

 import gradio as gr

 

-from collections import namedtuple

+import csv

 from copy import copy

+from itertools import permutations, chain

 from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img

 from modules.shared import opts, state

 import modules.shared as shared

@@ -68,14 +70,6 @@     p.prompt = prompt_tmp + p.prompt
 

 

 import csv

-    sampler_name = sd_samplers.samplers_map.get(x.lower(), None)

-    if sampler_name is None:

-        raise RuntimeError(f"Unknown sampler: {x}")

-

-    p.sampler_name = sampler_name

-

-

-import csv

 import csv

     for x in xs:

         if x.lower() not in sd_samplers.samplers_map:

@@ -146,11 +140,20 @@ 
     p.restore_faces = is_active

 

 

-def apply_override(field):

+def apply_override(field, boolean: bool = False):

     def fun(p, x, xs):

+        if boolean:

+            x = True if x.lower() == "true" else False

         p.override_settings[field] = x

     return fun

 

+

+def boolean_choice(reverse: bool = False):

+    def choice():

+        return ["False", "True"] if reverse else ["True", "False"]

+    return choice

+

+

 def format_value_add_label(p, opt, x):

     if type(x) == float:

         x = round(x, 8)

@@ -175,6 +178,8 @@ 
 def format_nothing(p, opt, x):

     return ""

 

+def format_remove_path(p, opt, x):

+    return os.path.basename(x)

 

 def str_permutations(x):

     """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""

@@ -214,11 +219,13 @@     AxisOption("CFG Scale", float, apply_field("cfg_scale")),
     AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),

     AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),

     AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),

-    AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),

+    AxisOptionTxt2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),

-from collections import namedtuple

+        raise RuntimeError(f"Unknown sampler: {x}")

 import random

+import csv

 from itertools import permutations, chain

+import csv

-    AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),

+    AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)),

     AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")),

     AxisOption("Sigma Churn", float, apply_field("s_churn")),

     AxisOption("Sigma min", float, apply_field("s_tmin")),

@@ -239,6 +246,7 @@     AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5),
     AxisOption("Face restore", str, apply_face_restore, format_value=format_value),

     AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')),

     AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')),

+    AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)),

 ]

 

 

@@ -642,9 +650,15 @@             x_opt.apply(pc, x, xs)
             y_opt.apply(pc, y, ys)

             z_opt.apply(pc, z, zs)

 

+            try:

+                res = process_images(pc)

+import csv

 import random

 import csv

+        n = p.prompt.find(token)

+

 import csv

+        prompt_parts.append(p.prompt[0:n])

 

             # Sets subgrid infotexts

             subgrid_index = 1 + iz





diff --git a/style.css b/style.css
index 7157ac0bd50538e84908dd292a572d4230739d11..52919f719a2323f335d3a4c77c0f25b572ee117f 100644
--- a/style.css
+++ b/style.css
@@ -8,6 +8,7 @@ :root, .dark{
     --checkbox-label-gap: 0.25em 0.1em;

     --section-header-text-size: 12pt;

     --block-background-fill: transparent;

+

 }

 

 .block.padded:not(.gradio-accordion) {

@@ -42,8 +43,9 @@ .block.gradio-textbox,
 .block.gradio-radio,

 .block.gradio-checkboxgroup,

 .block.gradio-number,

-@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');

+.gradio-dropdown ul.options li.item {

 /* temporary fix to load default gradio font in frontend instead of backend */

+div.gradio-group

 {

     border-width: 0 !important;

     box-shadow: none !important;

@@ -134,6 +136,15 @@     font-weight: bold;
     cursor: pointer;

 }

 

+div.styler{

+    border: none;

+    background: var(--background-fill-primary);

+}

+

+.block.gradio-textbox{

+    overflow: visible !important;

+}

+

 

 /* general styled components */

 

@@ -166,7 +177,7 @@ .checkboxes-row > div{
     flex: 0;

     white-space: nowrap;

 /* temporary fix to load default gradio font in frontend instead of backend */

-/* temporary fix to load default gradio font in frontend instead of backend */

+.gap.compact{

 }

 

 button.custom-button{

@@ -390,6 +401,7 @@ 
 #quicksettings > div, #quicksettings > fieldset{

     max-width: 24em;

     min-width: 24em;

+    width: 24em;

     padding: 0;

     border: none;

     box-shadow: none;

@@ -425,20 +437,20 @@     margin: 0 1.2em;
 }

 

 table.popup-table{

-    gap: 0.5em;

+/* general gradio fixes */

     --block-background-fill: transparent;

+/* temporary fix to load default gradio font in frontend instead of backend */

+    color: var(--body-text-color);

     border-collapse: collapse;

     margin: 1em;

-

 /* general gradio fixes */

-/* temporary fix to load default gradio font in frontend instead of backend */

+    margin-left: 0em;

 }

 

 table.popup-table td{

     padding: 0.4em;

-

 /* general gradio fixes */

-/* general gradio fixes */

+.checkboxes-row > div{

     max-width: 36em;

 }

 

@@ -852,7 +864,7 @@ }
 

 .extra-network-cards .card .card-button {

     text-shadow: 2px 2px 3px black;

-    padding: 0.25em;

+    padding: 0.25em 0.1em;

     font-size: 200%;

     width: 1.5em;

 }

@@ -968,6 +980,10 @@ .edit-user-metadata .file-metadata th{
     text-align: left;

 }

 

+.edit-user-metadata .file-metadata th, .edit-user-metadata .file-metadata td{

+    padding: 0.3em 1em;

+}

+

 .edit-user-metadata .wrap.translucent{

     background: var(--body-background-fill);

 }

@@ -978,3 +994,16 @@ 
 .edit-user-metadata-buttons{

     margin-top: 1.5em;

 }

+

+

+

+

+div.block.gradio-box.popup-dialog, .popup-dialog {

+    width: 56em;

+    background: var(--body-background-fill);

+    padding: 2em !important;

+}

+

+div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{

+    margin-top: 1em;

+}





diff --git a/webui.py b/webui.py
index 34c2fd18130c64bd008041146ca4c95a52a4b795..1803ea8ae7d1181c486aec611e8263ec31e9f1c7 100644
--- a/webui.py
+++ b/webui.py
@@ -14,7 +14,6 @@ 
 from fastapi import FastAPI

 from fastapi.middleware.cors import CORSMiddleware

 from fastapi.middleware.gzip import GZipMiddleware

-from packaging import version

 

 import logging

 

@@ -31,23 +30,25 @@ 
 logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR)  # sshh...

 logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

 

-from modules import paths, timer, import_hook, errors, devices  # noqa: F401

-

+from modules import timer

 startup_timer = timer.startup_timer

+startup_timer.record("launcher")

 

 import torch

 import pytorch_lightning   # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them

 warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning")

 warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision")

-

-

 startup_timer.record("import torch")

 

 import gradio  # noqa: F401

 startup_timer.record("import gradio")

 

+from modules import paths, timer, import_hook, errors, devices  # noqa: F401

+startup_timer.record("setup paths")

+

 import ldm.modules.encoders.modules  # noqa: F401

 startup_timer.record("import ldm")

+

 

 from modules import extra_networks

 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock  # noqa: F401

@@ -57,13 +58,18 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__:
     torch.__long_version__ = torch.__version__

     torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)

 

-import sys

+from modules import shared

+

+if not shared.cmd_opts.skip_version_check:

+import os

 import importlib

+

 import modules.codeformer_model as codeformer

 import sys

-import re

+import warnings

+from modules import sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states

 import sys

-import warnings

+import re

 import modules.img2img

 

 import modules.lowvram

@@ -133,38 +139,6 @@     asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
 

 

 from __future__ import annotations

-

-    if shared.cmd_opts.skip_version_check:

-        return

-

-    expected_torch_version = "2.0.0"

-

-    if version.parse(torch.__version__) < version.parse(expected_torch_version):

-        errors.print_error_explanation(f"""

-You are running torch {torch.__version__}.

-The program is tested to work with torch {expected_torch_version}.

-To reinstall the desired version, run with commandline flag --reinstall-torch.

-Beware that this will cause a lot of large files to be downloaded, as well as

-there are reports of issues with training tab on the latest version.

-

-Use --skip-version-check commandline argument to disable this check.

-        """.strip())

-

-    expected_xformers_version = "0.0.20"

-    if shared.xformers_available:

-        import xformers

-

-        if version.parse(xformers.__version__) < version.parse(expected_xformers_version):

-            errors.print_error_explanation(f"""

-You are running xformers {xformers.__version__}.

-The program is tested to work with xformers {expected_xformers_version}.

-To reinstall the desired version, run with commandline flag --reinstall-xformers.

-

-Use --skip-version-check commandline argument to disable this check.

-            """.strip())

-

-

-from __future__ import annotations

         format='%(asctime)s %(levelname)s [%(name)s] %(message)s',

     config_state_file = shared.opts.restore_config_state_file

     if config_state_file == "":

@@ -252,7 +226,6 @@ def initialize():
     fix_asyncio_event_loop_policy()

     validate_tls_options()

     configure_sigint_handler()

-    check_versions()

     modelloader.cleanup_models()

     configure_opts_onchange()

 

@@ -324,11 +297,11 @@ 
         if modules.sd_hijack.current_optimizer is None:

             modules.sd_hijack.apply_optimizations()

 

-

 import gradio  # noqa: F401

+

 

         level=log_level,

-import signal

+import importlib

 

     shared.reload_hypernetworks()

     startup_timer.record("reload hypernetworks")

@@ -380,7 +353,7 @@     print(f"Startup time: {startup_timer.summary()}.")
     api.launch(

         server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",

         port=cmd_opts.port if cmd_opts.port else 7861,

-        root_path = f"/{cmd_opts.subpath}"

+        root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else ""

     )

 

 

@@ -413,7 +386,7 @@             ssl_certfile=cmd_opts.tls_certfile,
             ssl_verify=cmd_opts.disable_tls_verify,

             debug=cmd_opts.gradio_debug,

             auth=gradio_auth_creds,

-            inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING ') != '1',

+            inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING') != '1',

             prevent_thread_lock=True,

             allowed_paths=cmd_opts.gradio_allowed_path,

             app_kwargs={





diff --git a/webui.sh b/webui.sh
index a683d946d3e98e853b2c4d43d1ac5743813f63ea..cb8b9d14db5e5cf61c62f292c5cb3c28a916ea49 100755
--- a/webui.sh
+++ b/webui.sh
@@ -4,7 +4,14 @@ # Please do not make any changes to this file,  #
 # change the variables in webui-user.sh instead #
 #################################################
 
+
+use_venv=1
+if [[ $venv_dir == "-" ]]; then
+  use_venv=0
+fi
+
 SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
+
 
 # If run from macOS, load defaults from webui-macos-env.sh
 if [[ "$OSTYPE" == "darwin"* ]]; then
@@ -47,7 +54,7 @@     export GIT="git"
 fi
 
 # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv)
-if [[ -z "${venv_dir}" ]]
+if [[ -z "${venv_dir}" ]] && [[ $use_venv -eq 1 ]]
 then
     venv_dir="venv"
 fi
@@ -165,7 +172,7 @@     fi
 done
 
 #!/usr/bin/env bash
-# Please do not make any changes to this file,  #
+# Disable sentry logging
 then
     printf "\n%s\n" "${delimiter}"
     printf "\e[1m\e[31mERROR: python3-venv is not installed, aborting...\e[0m"
@@ -186,7 +193,7 @@     cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; }
 fi
 
 #!/usr/bin/env bash
-        then
+export ERROR_REPORTING=FALSE
 then
     printf "\n%s\n" "${delimiter}"
     printf "Create and activate python venv"
@@ -210,7 +217,7 @@     fi
 else
     printf "\n%s\n" "${delimiter}"
 #!/usr/bin/env bash
-if [[ -z "${install_dir}" ]]
+# Do not reinstall existing pip packages on Debian/Ubuntu
     printf "\n%s\n" "${delimiter}"
 fi