~/Projects/stable-diffusion-webui
git clone https://code.lsong.org/stable-diffusion-webui
Commit
- Commit
- ef1698fd6dbd6387341a1eeeded068ff1476ee50
- Author
- AUTOMATIC1111 <[email protected]>
- Date
- 2023-08-05 08:01:38 +0300 +0300
- Diffstat
.github/workflows/run_tests.yaml | 1 CHANGELOG.md | 90 ++ README.md | 4 extensions-builtin/Lora/extra_networks_lora.py | 36 extensions-builtin/Lora/lora.py | 506 ------------ extensions-builtin/Lora/lyco_helpers.py | 21 extensions-builtin/Lora/network.py | 155 +++ extensions-builtin/Lora/network_full.py | 22 extensions-builtin/Lora/network_hada.py | 55 + extensions-builtin/Lora/network_ia3.py | 30 extensions-builtin/Lora/network_lokr.py | 64 + extensions-builtin/Lora/network_lora.py | 86 ++ extensions-builtin/Lora/networks.py | 468 +++++++++++ extensions-builtin/Lora/preload.py | 1 extensions-builtin/Lora/scripts/lora_script.py | 90 + extensions-builtin/Lora/ui_edit_user_metadata.py | 30 extensions-builtin/Lora/ui_extra_networks_lora.py | 43 extensions-builtin/mobile/javascript/mobile.js | 26 html/extra-networks-card.html | 2 javascript/extraNetworks.js | 2 javascript/hints.js | 11 javascript/localization.js | 10 javascript/ui.js | 14 launch.py | 7 modules/api/api.py | 51 modules/api/models.py | 8 modules/cache.py | 31 modules/call_queue.py | 5 modules/cmd_args.py | 6 modules/devices.py | 86 + modules/errors.py | 53 + modules/extensions.py | 18 modules/extra_networks.py | 35 modules/extras.py | 38 modules/generation_parameters_copypaste.py | 3 modules/gradio_extensons.py | 60 + modules/hypernetworks/hypernetwork.py | 7 modules/images.py | 4 modules/img2img.py | 34 modules/launch_utils.py | 85 + modules/lowvram.py | 67 + modules/paths.py | 25 modules/processing.py | 229 +++- modules/prompt_parser.py | 131 ++ modules/rng_philox.py | 102 ++ modules/script_loading.py | 5 modules/scripts.py | 81 - modules/sd_disable_initialization.py | 102 ++ modules/sd_hijack.py | 60 + modules/sd_hijack_clip.py | 37 modules/sd_hijack_inpainting.py | 2 modules/sd_hijack_open_clip.py | 34 modules/sd_hijack_optimizations.py | 57 + modules/sd_hijack_unet.py | 8 modules/sd_models.py | 276 ++++- modules/sd_models_config.py | 9 modules/sd_models_xl.py | 108 ++ modules/sd_samplers.py | 3 modules/sd_samplers_common.py | 19 modules/sd_samplers_compvis.py | 6 modules/sd_samplers_extra.py | 74 + modules/sd_samplers_kdiffusion.py | 67 + modules/sd_vae.py | 16 modules/sd_vae_approx.py | 59 modules/sd_vae_taesd.py | 24 modules/shared.py | 89 + modules/styles.py | 5 modules/sysinfo.py | 7 modules/textual_inversion/textual_inversion.py | 24 modules/timer.py | 24 modules/txt2img.py | 3 modules/ui.py | 481 ++++------- modules/ui_checkpoint_merger.py | 124 ++ modules/ui_common.py | 38 modules/ui_components.py | 2 modules/ui_extensions.py | 27 modules/ui_extra_networks.py | 18 modules/ui_extra_networks_checkpoints.py | 6 modules/ui_extra_networks_checkpoints_user_metadata.py | 60 + modules/ui_extra_networks_hypernets.py | 2 modules/ui_extra_networks_textual_inversion.py | 2 modules/ui_extra_networks_user_metadata.py | 8 modules/ui_postprocessing.py | 2 modules/ui_prompt_styles.py | 110 ++ modules/ui_settings.py | 2 requirements.txt | 3 requirements_versions.txt | 17 scripts/xyz_grid.py | 40 style.css | 45 webui.py | 63 - webui.sh | 15
Merge branch 'dev' into extra-networks-always-visible
diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index e9370cc0758f9fa40f72d0bfbc999f3acf935dd6..3dafaf8dcfcd14fd7a7ca3385806efad5550b871 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -41,6 +41,7 @@ launch.py --skip-prepare-environment --skip-torch-cuda-test --test-server + --do-not-download-clip --no-half --disable-opt-split-attention --use-cpu all diff --git a/CHANGELOG.md b/CHANGELOG.md index 925403a9138f294fe1b2534471cf509c717c3dfe..b18c6867348e0cdd3610f1f3c1f292e9347c902d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,93 @@ +## 1.5.1 + +### Minor: + * support parsing text encoder blocks in some new LoRAs + * delete scale checker script due to user demand + +### Extensions and API: + * add postprocess_batch_list script callback + +### Bug Fixes: + * fix TI training for SD1 + * fix reload altclip model error + * prepend the pythonpath instead of overriding it + * fix typo in SD_WEBUI_RESTARTING + * if txt2img/img2img raises an exception, finally call state.end() + * fix composable diffusion weight parsing + * restyle Startup profile for black users + * fix webui not launching with --nowebui + * catch exception for non git extensions + * fix some options missing from /sdapi/v1/options + * fix for extension update status always saying "unknown" + * fix display of extra network cards that have `<>` in the name + * update lora extension to work with python 3.8 + + +## 1.5.0 + +### Features: + * SD XL support + * user metadata system for custom networks + * extended Lora metadata editor: set activation text, default weight, view tags, training info + * Lora extension rework to include other types of networks (all that were previously handled by LyCORIS extension) + * show github stars for extenstions + * img2img batch mode can read extra stuff from png info + * img2img batch works with subdirectories + * hotkeys to move prompt elements: alt+left/right + * restyle time taken/VRAM display + * add textual inversion hashes to infotext + * optimization: cache git extension repo information + * move generate button next to the generated picture for mobile clients + * hide cards for networks of incompatible Stable Diffusion version in Lora extra networks interface + * skip installing packages with pip if they all are already installed - startup speedup of about 2 seconds + +### Minor: + * checkbox to check/uncheck all extensions in the Installed tab + * add gradio user to infotext and to filename patterns + * allow gif for extra network previews + * add options to change colors in grid + * use natural sort for items in extra networks + * Mac: use empty_cache() from torch 2 to clear VRAM + * added automatic support for installing the right libraries for Navi3 (AMD) + * add option SWIN_torch_compile to accelerate SwinIR upscale + * suppress printing TI embedding info at start to console by default + * speedup extra networks listing + * added `[none]` filename token. + * removed thumbs extra networks view mode (use settings tab to change width/height/scale to get thumbs) + * add always_discard_next_to_last_sigma option to XYZ plot + * automatically switch to 32-bit float VAE if the generated picture has NaNs without the need for `--no-half-vae` commandline flag. + +### Extensions and API: + * api endpoints: /sdapi/v1/server-kill, /sdapi/v1/server-restart, /sdapi/v1/server-stop + * allow Script to have custom metaclass + * add model exists status check /sdapi/v1/options + * rename --add-stop-route to --api-server-stop + * add `before_hr` script callback + * add callback `after_extra_networks_activate` + * disable rich exception output in console for API by default, use WEBUI_RICH_EXCEPTIONS env var to enable + * return http 404 when thumb file not found + * allow replacing extensions index with environment variable + +### Bug Fixes: + * fix for catch errors when retrieving extension index #11290 + * fix very slow loading speed of .safetensors files when reading from network drives + * API cache cleanup + * fix UnicodeEncodeError when writing to file CLIP Interrogator batch mode + * fix warning of 'has_mps' deprecated from PyTorch + * fix problem with extra network saving images as previews losing generation info + * fix throwing exception when trying to resize image with I;16 mode + * fix for #11534: canvas zoom and pan extension hijacking shortcut keys + * fixed launch script to be runnable from any directory + * don't add "Seed Resize: -1x-1" to API image metadata + * correctly remove end parenthesis with ctrl+up/down + * fixing --subpath on newer gradio version + * fix: check fill size none zero when resize (fixes #11425) + * use submit and blur for quick settings textbox + * save img2img batch with images.save_image() + * prevent running preload.py for disabled extensions + * fix: previously, model name was added together with directory name to infotext and to [model_name] filename pattern; directory name is now not included + + ## 1.4.1 ### Bug Fixes: diff --git a/README.md b/README.md index e6d8e4bd423925b38f685b4c268b4751a8bd9bb0..2fd6e425de71d7e8eacba73672143e5f788a23ba 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ - [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions - [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions - Now without any bad letters! - Load checkpoints in safetensors format -- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64 +- Eased resolution restriction: generated image's dimension must be a multiple of 8 rather than 64 - Now with a license! - Reorder elements in the UI from settings screen @@ -168,5 +168,7 @@ - Instruct pix2pix - Tim Brooks (star), Aleksander Holynski (star), Alexei A. Efros (no star) - https://github.com/timothybrooks/instruct-pix2pix - Security advice - RyotaK - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC - TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd +- LyCORIS - KohakuBlueleaf +- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - (You) diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index 66ee9c8563919fcda648f58d34278b85b4571f8d..ba2945c6fe1e77d87226f08fb20da0624959364b 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -1,5 +1,5 @@ from modules import extra_networks, shared -import lora +import networks class ExtraNetworkLora(extra_networks.ExtraNetwork): @@ -9,29 +9,44 @@ def activate(self, p, params_list): additional = shared.opts.sd_lora - if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional): + if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional): p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) names = [] -from modules import extra_networks, shared + def __init__(self): + unet_multipliers = [] + dyn_dims = [] for params in params_list: assert params.items + names.append(params.positional[0]) + + te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0 + te_multiplier = float(params.named.get("te", te_multiplier)) + +class ExtraNetworkLora(extra_networks.ExtraNetwork): from modules import extra_networks, shared + unet_multiplier = float(params.named.get("unet", unet_multiplier)) + + dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None + dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim + + te_multipliers.append(te_multiplier) +class ExtraNetworkLora(extra_networks.ExtraNetwork): super().__init__('lora') -from modules import extra_networks, shared +class ExtraNetworkLora(extra_networks.ExtraNetwork): def activate(self, p, params_list): -from modules import extra_networks, shared +class ExtraNetworkLora(extra_networks.ExtraNetwork): additional = shared.opts.sd_lora if shared.opts.lora_add_hashes_to_infotext: -import lora + network_hashes = [] + def __init__(self): -import lora + def __init__(self): from modules import extra_networks, shared - shorthash = item.lora_on_disk.shorthash if not shorthash: continue @@ -41,11 +56,12 @@ continue alias = alias.replace(":", "").replace(",", "") + def __init__(self): import lora - additional = shared.opts.sd_lora - if lora_hashes: + def __init__(self): + p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes) def deactivate(self, p): pass diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 467ad65f200ef643df68494ee1a3354f68de7b14..9365aa74b4fe30afbb27fa08284ccd6bcb094fe7 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -1,526 +1,20 @@ -import os -import re import torch -from typing import Union - -from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes, cache - -metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} - -re_digits = re.compile(r"\d+") -re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") -re_compiled = {} - -suffix_conversion = { -import os import re - "resnets": { - "conv1": "in_layers_2", - "conv2": "out_layers_3", - "time_emb_proj": "emb_layers_1", -import os metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} - } -} - -def convert_diffusers_name_to_compvis(key, is_sd2): - def match(match_list, regex_text): - regex = re_compiled.get(regex_text) -import re import torch - regex = re.compile(regex_text) - re_compiled[regex_text] = regex - - r = re.match(regex, key) - if not r: return False - - match_list.clear() - match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) - return True m = [] - - if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): - suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) - return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - - if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): - suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) - return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" - -import torch re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") import torch -from typing import Union - return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - - if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): - return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" - -from typing import Union import torch - return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" - - if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): - if is_sd2: - if 'mlp_fc1' in m[1]: - return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" - elif 'mlp_fc2' in m[1]: - - else: - return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" - - import torch - - return key - - -class LoraOnDisk: - def __init__(self, name, filename): - self.name = name - self.filename = filename - self.metadata = {} - self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" - - def read_metadata(): - metadata = sd_models.read_metadata_from_safetensors(filename) -from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes, cache import torch - - return metadata - - if self.is_safetensors: - try: - self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata) - except Exception as e: - errors.display(e, f"reading lora {filename}") - - if self.metadata: -metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} import os - for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): -metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} import torch - - self.metadata = m - - self.alias = self.metadata.get('ss_output_name', self.name) - - self.hash = None - self.shorthash = None - self.set_hash( - self.metadata.get('sshs_model_hash') or - hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or - '' - ) - -re_digits = re.compile(r"\d+") import torch - self.hash = v - self.shorthash = self.hash[0:12] - - if self.shorthash: - available_lora_hash_lookup[self.shorthash] = self - - def read_hash(self): - if not self.hash: - self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '') - - def get_alias(self): -re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") import re -re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") import torch - else: - return self.alias - - -class LoraModule: - def __init__(self, name, lora_on_disk: LoraOnDisk): - self.name = name - self.lora_on_disk = lora_on_disk - self.multiplier = 1.0 - self.modules = {} - self.mtime = None - - self.mentioned_name = None -re_compiled = {} import torch - - -class LoraUpDownModule: - def __init__(self): - self.up = None - self.down = None - self.alpha = None - - -def assign_lora_names_to_compvis_modules(sd_model): - lora_layer_mapping = {} - - for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): - lora_name = name.replace(".", "_") -suffix_conversion = { import torch - module.lora_layer_name = lora_name - - for name, module in shared.sd_model.model.named_modules(): - lora_name = name.replace(".", "_") - lora_layer_mapping[lora_name] = module - module.lora_layer_name = lora_name - - sd_model.lora_layer_mapping = lora_layer_mapping - - -def load_lora(name, lora_on_disk): - lora = LoraModule(name, lora_on_disk) - lora.mtime = os.path.getmtime(lora_on_disk.filename) - - sd = sd_models.read_state_dict(lora_on_disk.filename) - - # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 - if not hasattr(shared.sd_model, 'lora_layer_mapping'): - assign_lora_names_to_compvis_modules(shared.sd_model) - - keys_failed_to_match = {} - is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping - - for key_diffusers, weight in sd.items(): - key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1) - key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2) - - sd_module = shared.sd_model.lora_layer_mapping.get(key, None) - - if sd_module is None: - m = re_x_proj.match(key) - if m: - sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) - - if sd_module is None: - keys_failed_to_match[key_diffusers] = key - continue - - lora_module = lora.modules.get(key, None) - if lora_module is None: - lora_module = LoraUpDownModule() - lora.modules[key] = lora_module - - if lora_key == "alpha": - lora_module.alpha = weight.item() - continue - - if type(sd_module) == torch.nn.Linear: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(sd_module) == torch.nn.MultiheadAttention: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1): - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3): - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False) - else: - print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') - continue - raise AssertionError(f"Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}") - - with torch.no_grad(): - module.weight.copy_(weight) - - module.to(device=devices.cpu, dtype=devices.dtype) - - if lora_key == "lora_up.weight": - lora_module.up = module - elif lora_key == "lora_down.weight": - lora_module.down = module - else: - raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha") - - if keys_failed_to_match: - print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}") - - return lora - - -def load_loras(names, multipliers=None): - already_loaded = {} - - for lora in loaded_loras: - if lora.name in names: - already_loaded[lora.name] = lora - - loaded_loras.clear() - - loras_on_disk = [available_lora_aliases.get(name, None) for name in names] - if any(x is None for x in loras_on_disk): - list_available_loras() - - loras_on_disk = [available_lora_aliases.get(name, None) for name in names] - - failed_to_load_loras = [] - - for i, name in enumerate(names): - lora = already_loaded.get(name, None) - - lora_on_disk = loras_on_disk[i] - - if lora_on_disk is not None: - if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime: - try: - lora = load_lora(name, lora_on_disk) - except Exception as e: - errors.display(e, f"loading Lora {lora_on_disk.filename}") - continue - - lora.mentioned_name = name - - lora_on_disk.read_hash() - - if lora is None: - failed_to_load_loras.append(name) - print(f"Couldn't find Lora with name {name}") - continue - - lora.multiplier = multipliers[i] if multipliers else 1.0 - loaded_loras.append(lora) - - if failed_to_load_loras: - sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras)) - - -def lora_calc_updown(lora, module, target): - with torch.no_grad(): - up = module.up.weight.to(target.device, dtype=target.dtype) - down = module.down.weight.to(target.device, dtype=target.dtype) - - if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): - updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) - elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): - updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) - else: - updown = up @ down - - updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) - - return updown - - -def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): - weights_backup = getattr(self, "lora_weights_backup", None) - - if weights_backup is None: - return - - if isinstance(self, torch.nn.MultiheadAttention): - self.in_proj_weight.copy_(weights_backup[0]) - self.out_proj.weight.copy_(weights_backup[1]) - else: - self.weight.copy_(weights_backup) - - -def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): - """ - Applies the currently selected set of Loras to the weights of torch layer self. - If weights already have this particular set of loras applied, does nothing. - If not, restores orginal weights from backup and alters weights according to loras. - """ - - lora_layer_name = getattr(self, 'lora_layer_name', None) - if lora_layer_name is None: - return - - current_names = getattr(self, "lora_current_names", ()) - wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) - - weights_backup = getattr(self, "lora_weights_backup", None) - if weights_backup is None: - if isinstance(self, torch.nn.MultiheadAttention): - weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) - else: - weights_backup = self.weight.to(devices.cpu, copy=True) - - self.lora_weights_backup = weights_backup - - if current_names != wanted_names: - lora_restore_weights_from_backup(self) - - for lora in loaded_loras: - module = lora.modules.get(lora_layer_name, None) - if module is not None and hasattr(self, 'weight'): - self.weight += lora_calc_updown(lora, module, self.weight) - continue - - module_q = lora.modules.get(lora_layer_name + "_q_proj", None) - module_k = lora.modules.get(lora_layer_name + "_k_proj", None) - module_v = lora.modules.get(lora_layer_name + "_v_proj", None) - module_out = lora.modules.get(lora_layer_name + "_out_proj", None) - - if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: - updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight) - updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight) - updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight) - updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) - - self.in_proj_weight += updown_qkv - self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight) - continue - - if module is None: - continue - - print(f'failed to calculate lora weights for layer {lora_layer_name}') - - self.lora_current_names = wanted_names - - -def lora_forward(module, input, original_forward): - """ - Old way of applying Lora by executing operations during layer's forward. - Stacking many loras this way results in big performance degradation. - """ - - if len(loaded_loras) == 0: - return original_forward(module, input) - - input = devices.cond_cast_unet(input) - - lora_restore_weights_from_backup(module) - lora_reset_cached_weight(module) - - res = original_forward(module, input) - - lora_layer_name = getattr(module, 'lora_layer_name', None) - for lora in loaded_loras: - module = lora.modules.get(lora_layer_name, None) - if module is None: - continue - - module.up.to(device=devices.device) - module.down.to(device=devices.device) - - res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) - - return res - - -def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): - self.lora_current_names = () - self.lora_weights_backup = None - - -def lora_Linear_forward(self, input): - if shared.opts.lora_functional: - return lora_forward(self, input, torch.nn.Linear_forward_before_lora) - - lora_apply_weights(self) - - return torch.nn.Linear_forward_before_lora(self, input) - - -def lora_Linear_load_state_dict(self, *args, **kwargs): - lora_reset_cached_weight(self) - - return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) - - -def lora_Conv2d_forward(self, input): - if shared.opts.lora_functional: - return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora) - - lora_apply_weights(self) - - return torch.nn.Conv2d_forward_before_lora(self, input) - - -def lora_Conv2d_load_state_dict(self, *args, **kwargs): - lora_reset_cached_weight(self) - - return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) - - -def lora_MultiheadAttention_forward(self, *args, **kwargs): - lora_apply_weights(self) - - return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs) - - -def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs): - lora_reset_cached_weight(self) - - return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs) - - -def list_available_loras(): - available_loras.clear() - available_lora_aliases.clear() - forbidden_lora_aliases.clear() - available_lora_hash_lookup.clear() - forbidden_lora_aliases.update({"none": 1, "Addams": 1}) - - os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) - - candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) - for filename in candidates: - if os.path.isdir(filename): - continue - - name = os.path.splitext(os.path.basename(filename))[0] - try: - entry = LoraOnDisk(name, filename) - except OSError: # should catch FileNotFoundError and PermissionError etc. - errors.report(f"Failed to load LoRA {name} from {filename}", exc_info=True) - continue - - available_loras[name] = entry - - if entry.alias in available_lora_aliases: - forbidden_lora_aliases[entry.alias.lower()] = 1 - - available_lora_aliases[name] = entry - available_lora_aliases[entry.alias] = entry - - -re_lora_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") - - -def infotext_pasted(infotext, params): - if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: - return # if the other extension is active, it will handle those fields, no need to do anything - - added = [] - - for k in params: - if not k.startswith("AddNet Model "): - continue - - num = k[13:] - - if params.get("AddNet Module " + num) != "LoRA": - continue - - name = params.get("AddNet Model " + num) - if name is None: - continue - - m = re_lora_name.match(name) - if m: - name = m.group(1) - - multiplier = params.get("AddNet Weight A " + num, "1.0") - - added.append(f"<lora:{name}:{multiplier}>") - - if added: - params["Prompt"] += "\n" + "".join(added) - - -available_loras = {} -available_lora_aliases = {} -available_lora_hash_lookup = {} -forbidden_lora_aliases = {} -loaded_loras = [] - -list_available_loras() diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..279b34bc928bcb52979fd67068be0c2ca35b847b --- /dev/null +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -0,0 +1,21 @@ +import torch + + +def make_weight_cp(t, wa, wb): + temp = torch.einsum('i j k l, j r -> i r k l', t, wb) + return torch.einsum('i j k l, i r -> r j k l', temp, wa) + + +def rebuild_conventional(up, down, shape, dyn_dim=None): + up = up.reshape(up.size(0), -1) + down = down.reshape(down.size(0), -1) + if dyn_dim is not None: + up = up[:, :dyn_dim] + down = down[:dyn_dim, :] + return (up @ down).reshape(shape) + + +def rebuild_cp_decomposition(up, down, mid): + up = up.reshape(up.size(0), -1) + down = down.reshape(down.size(0), -1) + return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py new file mode 100644 index 0000000000000000000000000000000000000000..0a18d69eb26412d54c3b36e7801b5557691e2b68 --- /dev/null +++ b/extensions-builtin/Lora/network.py @@ -0,0 +1,155 @@ +from __future__ import annotations +import os +from collections import namedtuple +import enum + +from modules import sd_models, cache, errors, hashes, shared + +NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd_module']) + +metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} + + +class SdVersion(enum.Enum): + Unknown = 1 + SD1 = 2 + SD2 = 3 + SDXL = 4 + + +class NetworkOnDisk: + def __init__(self, name, filename): + self.name = name + self.filename = filename + self.metadata = {} + self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" + + def read_metadata(): + metadata = sd_models.read_metadata_from_safetensors(filename) + metadata.pop('ssmd_cover_images', None) # those are cover images, and they are too big to display in UI as text + + return metadata + + if self.is_safetensors: + try: + self.metadata = cache.cached_data_for_file('safetensors-metadata', "lora/" + self.name, filename, read_metadata) + except Exception as e: + errors.display(e, f"reading lora {filename}") + + if self.metadata: + m = {} + for k, v in sorted(self.metadata.items(), key=lambda x: metadata_tags_order.get(x[0], 999)): + m[k] = v + + self.metadata = m + + self.alias = self.metadata.get('ss_output_name', self.name) + + self.hash = None + self.shorthash = None + self.set_hash( + self.metadata.get('sshs_model_hash') or + hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or + '' + ) + + self.sd_version = self.detect_version() + + def detect_version(self): + if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"): + return SdVersion.SDXL + elif str(self.metadata.get('ss_v2', "")) == "True": + return SdVersion.SD2 + elif len(self.metadata): + return SdVersion.SD1 + + return SdVersion.Unknown + + def set_hash(self, v): + self.hash = v + self.shorthash = self.hash[0:12] + + if self.shorthash: + import networks + networks.available_network_hash_lookup[self.shorthash] = self + + def read_hash(self): + if not self.hash: + self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '') + + def get_alias(self): + import networks + if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in networks.forbidden_network_aliases: + return self.name + else: + return self.alias + + +class Network: # LoraModule + def __init__(self, name, network_on_disk: NetworkOnDisk): + self.name = name + self.network_on_disk = network_on_disk + self.te_multiplier = 1.0 + self.unet_multiplier = 1.0 + self.dyn_dim = None + self.modules = {} + self.mtime = None + + self.mentioned_name = None + """the text that was used to add the network to prompt - can be either name or an alias""" + + +class ModuleType: + def create_module(self, net: Network, weights: NetworkWeights) -> Network | None: + return None + + +class NetworkModule: + def __init__(self, net: Network, weights: NetworkWeights): + self.network = net + self.network_key = weights.network_key + self.sd_key = weights.sd_key + self.sd_module = weights.sd_module + + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.dim = None + self.bias = weights.w.get("bias") + self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None + self.scale = weights.w["scale"].item() if "scale" in weights.w else None + + def multiplier(self): + if 'transformer' in self.sd_key[:20]: + return self.network.te_multiplier + else: + return self.network.unet_multiplier + + def calc_scale(self): + if self.scale is not None: + return self.scale + if self.dim is not None and self.alpha is not None: + return self.alpha / self.dim + + return 1.0 + + def finalize_updown(self, updown, orig_weight, output_shape): + if self.bias is not None: + updown = updown.reshape(self.bias.shape) + updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown = updown.reshape(output_shape) + + if len(output_shape) == 4: + updown = updown.reshape(output_shape) + + if orig_weight.size().numel() == updown.size().numel(): + updown = updown.reshape(orig_weight.shape) + + return updown * self.calc_scale() * self.multiplier() + + def calc_updown(self, target): + raise NotImplementedError() + + def forward(self, x, y): + raise NotImplementedError() + diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py new file mode 100644 index 0000000000000000000000000000000000000000..109b4c2c594e5079067d55331271ebafcf6c9fe4 --- /dev/null +++ b/extensions-builtin/Lora/network_full.py @@ -0,0 +1,22 @@ +import network + + +class ModuleTypeFull(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["diff"]): + return NetworkModuleFull(net, weights) + + return None + + +class NetworkModuleFull(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.weight = weights.w.get("diff") + + def calc_updown(self, orig_weight): + output_shape = self.weight.shape + updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py new file mode 100644 index 0000000000000000000000000000000000000000..5fcb0695fbb36f01a1dbdeaba87ea288a3d00eb2 --- /dev/null +++ b/extensions-builtin/Lora/network_hada.py @@ -0,0 +1,55 @@ +import lyco_helpers +import network + + +class ModuleTypeHada(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]): + return NetworkModuleHada(net, weights) + + return None + + +class NetworkModuleHada(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.w1a = weights.w["hada_w1_a"] + self.w1b = weights.w["hada_w1_b"] + self.dim = self.w1b.shape[0] + self.w2a = weights.w["hada_w2_a"] + self.w2b = weights.w["hada_w2_b"] + + self.t1 = weights.w.get("hada_t1") + self.t2 = weights.w.get("hada_t2") + + def calc_updown(self, orig_weight): + w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) + w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + + output_shape = [w1a.size(0), w1b.size(1)] + + if self.t1 is not None: + output_shape = [w1a.size(1), w1b.size(1)] + t1 = self.t1.to(orig_weight.device, dtype=orig_weight.dtype) + updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) + output_shape += t1.shape[2:] + else: + if len(w1b.shape) == 4: + output_shape += w1b.shape[2:] + updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) + + if self.t2 is not None: + t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) + updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) + else: + updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) + + updown = updown1 * updown2 + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py new file mode 100644 index 0000000000000000000000000000000000000000..7edc4249791c0c58528ef4eb3f7df04c81d56020 --- /dev/null +++ b/extensions-builtin/Lora/network_ia3.py @@ -0,0 +1,30 @@ +import network + + +class ModuleTypeIa3(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["weight"]): + return NetworkModuleIa3(net, weights) + + return None + + +class NetworkModuleIa3(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.w = weights.w["weight"] + self.on_input = weights.w["on_input"].item() + + def calc_updown(self, orig_weight): + w = self.w.to(orig_weight.device, dtype=orig_weight.dtype) + + output_shape = [w.size(0), orig_weight.size(1)] + if self.on_input: + output_shape.reverse() + else: + w = w.reshape(-1, 1) + + updown = orig_weight * w + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py new file mode 100644 index 0000000000000000000000000000000000000000..340acdab3d0c7d2275db9d2684c8e70a0ea6d889 --- /dev/null +++ b/extensions-builtin/Lora/network_lokr.py @@ -0,0 +1,64 @@ +import torch + +import lyco_helpers +import network + + +class ModuleTypeLokr(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w) + has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w) + if has_1 and has_2: + return NetworkModuleLokr(net, weights) + + return None + + +def make_kron(orig_shape, w1, w2): + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + w2 = w2.contiguous() + return torch.kron(w1, w2).reshape(orig_shape) + + +class NetworkModuleLokr(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.w1 = weights.w.get("lokr_w1") + self.w1a = weights.w.get("lokr_w1_a") + self.w1b = weights.w.get("lokr_w1_b") + self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim + self.w2 = weights.w.get("lokr_w2") + self.w2a = weights.w.get("lokr_w2_a") + self.w2b = weights.w.get("lokr_w2_b") + self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim + self.t2 = weights.w.get("lokr_t2") + + def calc_updown(self, orig_weight): + if self.w1 is not None: + w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype) + else: + w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) + w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w1 = w1a @ w1b + + if self.w2 is not None: + w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype) + elif self.t2 is None: + w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = w2a @ w2b + else: + t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) + + output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] + if len(orig_weight.shape) == 4: + output_shape = orig_weight.shape + + updown = make_kron(output_shape, w1, w2) + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..26c0a72c237f36f11ac6bf39d0ddc3d4d286bda2 --- /dev/null +++ b/extensions-builtin/Lora/network_lora.py @@ -0,0 +1,86 @@ +import torch + +import lyco_helpers +import network +from modules import devices + + +class ModuleTypeLora(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["lora_up.weight", "lora_down.weight"]): + return NetworkModuleLora(net, weights) + + return None + + +class NetworkModuleLora(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.up_model = self.create_module(weights.w, "lora_up.weight") + self.down_model = self.create_module(weights.w, "lora_down.weight") + self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True) + + self.dim = weights.w["lora_down.weight"].shape[0] + + def create_module(self, weights, key, none_ok=False): + weight = weights.get(key) + + if weight is None and none_ok: + return None + + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + + if is_linear: + weight = weight.reshape(weight.shape[0], -1) + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif is_conv and key == "lora_down.weight" or key == "dyn_up": + if len(weight.shape) == 2: + weight = weight.reshape(weight.shape[0], -1, 1, 1) + + if weight.shape[2] != 1 or weight.shape[3] != 1: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + else: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif is_conv and key == "lora_mid.weight": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + elif is_conv and key == "lora_up.weight" or key == "dyn_down": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + else: + raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') + + with torch.no_grad(): + if weight.shape != module.weight.shape: + weight = weight.reshape(module.weight.shape) + module.weight.copy_(weight) + + module.to(device=devices.cpu, dtype=devices.dtype) + module.weight.requires_grad_(False) + + return module + + def calc_updown(self, orig_weight): + up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + + output_shape = [up.size(0), down.size(1)] + if self.mid_model is not None: + # cp-decomposition + mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) + output_shape += mid.shape[2:] + else: + if len(down.shape) == 4: + output_shape += down.shape[2:] + updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim) + + return self.finalize_updown(updown, orig_weight, output_shape) + + def forward(self, x, y): + self.up_model.to(device=devices.device) + self.down_model.to(device=devices.device) + + return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale() + + diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..17cbe1bb7fe383ff1213b17d06d0665f7765a392 --- /dev/null +++ b/extensions-builtin/Lora/networks.py @@ -0,0 +1,468 @@ +import os +import re + +import network +import network_lora +import network_hada +import network_ia3 +import network_lokr +import network_full + +import torch +from typing import Union + +from modules import shared, devices, sd_models, errors, scripts, sd_hijack + +module_types = [ + network_lora.ModuleTypeLora(), + network_hada.ModuleTypeHada(), + network_ia3.ModuleTypeIa3(), + network_lokr.ModuleTypeLokr(), + network_full.ModuleTypeFull(), +] + + +re_digits = re.compile(r"\d+") +re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_compiled = {} + +suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + } +} + + +def convert_diffusers_name_to_compvis(key, is_sd2): + def match(match_list, regex_text): + regex = re_compiled.get(regex_text) + if regex is None: + regex = re.compile(regex_text) + re_compiled[regex_text] = regex + + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, r"lora_unet_conv_in(.*)"): + return f'diffusion_model_input_blocks_0_0{m[0]}' + + if match(m, r"lora_unet_conv_out(.*)"): + return f'diffusion_model_out_2{m[0]}' + + if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): + return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}" + + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) + return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" + + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): + return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" + + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" + + if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): + if is_sd2: + if 'mlp_fc1' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): + if 'mlp_fc1' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + + return key + + +def assign_network_names_to_compvis_modules(sd_model): + network_layer_mapping = {} + + if shared.sd_model.is_sdxl: + for i, embedder in enumerate(shared.sd_model.conditioner.embedders): + if not hasattr(embedder, 'wrapped'): + continue + + for name, module in embedder.wrapped.named_modules(): + network_name = f'{i}_{name.replace(".", "_")}' + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + else: + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + network_name = name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + + for name, module in shared.sd_model.model.named_modules(): + network_name = name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + + sd_model.network_layer_mapping = network_layer_mapping + + +def load_network(name, network_on_disk): + net = network.Network(name, network_on_disk) + net.mtime = os.path.getmtime(network_on_disk.filename) + + sd = sd_models.read_state_dict(network_on_disk.filename) + + # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 + if not hasattr(shared.sd_model, 'network_layer_mapping'): + assign_network_names_to_compvis_modules(shared.sd_model) + + keys_failed_to_match = {} + is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping + + matched_networks = {} + + for key_network, weight in sd.items(): + key_network_without_network_parts, network_part = key_network.split(".", 1) + + key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + + if sd_module is None: + m = re_x_proj.match(key) + if m: + sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None) + + # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" + if sd_module is None and "lora_unet" in key_network_without_network_parts: + key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: + key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + + # some SD1 Loras also have correct compvis keys + if sd_module is None: + key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + + if sd_module is None: + keys_failed_to_match[key_network] = key + continue + + if key not in matched_networks: + matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module) + + matched_networks[key].w[network_part] = weight + + for key, weights in matched_networks.items(): + net_module = None + for nettype in module_types: + net_module = nettype.create_module(net, weights) + if net_module is not None: + break + + if net_module is None: + raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") + + net.modules[key] = net_module + + if keys_failed_to_match: + print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}") + + return net + + +def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): + already_loaded = {} + + for net in loaded_networks: + if net.name in names: + already_loaded[net.name] = net + + loaded_networks.clear() + + networks_on_disk = [available_network_aliases.get(name, None) for name in names] + if any(x is None for x in networks_on_disk): + list_available_networks() + + networks_on_disk = [available_network_aliases.get(name, None) for name in names] + + failed_to_load_networks = [] + + for i, name in enumerate(names): + net = already_loaded.get(name, None) + + network_on_disk = networks_on_disk[i] + + if network_on_disk is not None: + if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: + try: + net = load_network(name, network_on_disk) + except Exception as e: + errors.display(e, f"loading network {network_on_disk.filename}") + continue + + net.mentioned_name = name + + network_on_disk.read_hash() + + if net is None: + failed_to_load_networks.append(name) + print(f"Couldn't find network with name {name}") + continue + + net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 + net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0 + net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 + loaded_networks.append(net) + + if failed_to_load_networks: + sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks)) + + +def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + weights_backup = getattr(self, "network_weights_backup", None) + + if weights_backup is None: + return + + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) + else: + self.weight.copy_(weights_backup) + + +def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + """ + Applies the currently selected set of networks to the weights of torch layer self. + If weights already have this particular set of networks applied, does nothing. + If not, restores orginal weights from backup and alters weights according to networks. + """ + + network_layer_name = getattr(self, 'network_layer_name', None) + if network_layer_name is None: + return + + current_names = getattr(self, "network_current_names", ()) + wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) + + weights_backup = getattr(self, "network_weights_backup", None) + if weights_backup is None: + if isinstance(self, torch.nn.MultiheadAttention): + weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) + else: + weights_backup = self.weight.to(devices.cpu, copy=True) + + self.network_weights_backup = weights_backup + + if current_names != wanted_names: + network_restore_weights_from_backup(self) + + for net in loaded_networks: + module = net.modules.get(network_layer_name, None) + if module is not None and hasattr(self, 'weight'): + with torch.no_grad(): + updown = module.calc_updown(self.weight) + + if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: + # inpainting model. zero pad updown to make channel[1] 4 to 9 + updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) + + self.weight += updown + continue + + module_q = net.modules.get(network_layer_name + "_q_proj", None) + module_k = net.modules.get(network_layer_name + "_k_proj", None) + module_v = net.modules.get(network_layer_name + "_v_proj", None) + module_out = net.modules.get(network_layer_name + "_out_proj", None) + + if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: + with torch.no_grad(): + updown_q = module_q.calc_updown(self.in_proj_weight) + updown_k = module_k.calc_updown(self.in_proj_weight) + updown_v = module_v.calc_updown(self.in_proj_weight) + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + updown_out = module_out.calc_updown(self.out_proj.weight) + + self.in_proj_weight += updown_qkv + self.out_proj.weight += updown_out + continue + + if module is None: + continue + + print(f'failed to calculate network weights for layer {network_layer_name}') + + self.network_current_names = wanted_names + + +def network_forward(module, input, original_forward): + """ + Old way of applying Lora by executing operations during layer's forward. + Stacking many loras this way results in big performance degradation. + """ + + if len(loaded_networks) == 0: + return original_forward(module, input) + + input = devices.cond_cast_unet(input) + + network_restore_weights_from_backup(module) + network_reset_cached_weight(module) + + y = original_forward(module, input) + + network_layer_name = getattr(module, 'network_layer_name', None) + for lora in loaded_networks: + module = lora.modules.get(network_layer_name, None) + if module is None: + continue + + y = module.forward(y, input) + + return y + + +def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): + self.network_current_names = () + self.network_weights_backup = None + + +def network_Linear_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, torch.nn.Linear_forward_before_network) + + network_apply_weights(self) + + return torch.nn.Linear_forward_before_network(self, input) + + +def network_Linear_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs) + + +def network_Conv2d_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, torch.nn.Conv2d_forward_before_network) + + network_apply_weights(self) + + return torch.nn.Conv2d_forward_before_network(self, input) + + +def network_Conv2d_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs) + + +def network_MultiheadAttention_forward(self, *args, **kwargs): + network_apply_weights(self) + + return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs) + + +def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs) + + +def list_available_networks(): + available_networks.clear() + available_network_aliases.clear() + forbidden_network_aliases.clear() + available_network_hash_lookup.clear() + forbidden_network_aliases.update({"none": 1, "Addams": 1}) + + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) + + candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) + candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) + for filename in candidates: + if os.path.isdir(filename): + continue + + name = os.path.splitext(os.path.basename(filename))[0] + try: + entry = network.NetworkOnDisk(name, filename) + except OSError: # should catch FileNotFoundError and PermissionError etc. + errors.report(f"Failed to load network {name} from {filename}", exc_info=True) + continue + + available_networks[name] = entry + + if entry.alias in available_network_aliases: + forbidden_network_aliases[entry.alias.lower()] = 1 + + available_network_aliases[name] = entry + available_network_aliases[entry.alias] = entry + + +re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") + + +def infotext_pasted(infotext, params): + if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: + return # if the other extension is active, it will handle those fields, no need to do anything + + added = [] + + for k in params: + if not k.startswith("AddNet Model "): + continue + + num = k[13:] + + if params.get("AddNet Module " + num) != "LoRA": + continue + + name = params.get("AddNet Model " + num) + if name is None: + continue + + m = re_network_name.match(name) + if m: + name = m.group(1) + + multiplier = params.get("AddNet Weight A " + num, "1.0") + + added.append(f"<lora:{name}:{multiplier}>") + + if added: + params["Prompt"] += "\n" + "".join(added) + + +available_networks = {} +available_network_aliases = {} +loaded_networks = [] +available_network_hash_lookup = {} +forbidden_network_aliases = {} + +list_available_networks() diff --git a/extensions-builtin/Lora/preload.py b/extensions-builtin/Lora/preload.py index 863dc5c0b510a20e03ed68adec0108ebf1e03474..50961be33d7e29214b00b6e69daf0c3e5e8f108c 100644 --- a/extensions-builtin/Lora/preload.py +++ b/extensions-builtin/Lora/preload.py @@ -4,3 +4,4 @@ def preload(parser): parser.add_argument("--lora-dir", type=str, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora')) + parser.add_argument("--lyco-dir-backcompat", type=str, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS')) diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index e650f469f3663cee190818444e5c18f67e3ef6e4..cd28afc92e7ae82d9df4329febcc28f40a254abe 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -4,87 +4,93 @@ import torch import gradio as gr from fastapi import FastAPI -import lora +import network +import networks +import lora # noqa:F401 import extra_networks_lora import ui_extra_networks_lora from modules import script_callbacks, ui_extra_networks, extra_networks, shared def unload(): - torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora + torch.nn.Linear.forward = torch.nn.Linear_forward_before_network - torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora + torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network - torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network - torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora + torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network - torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora + torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network - torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora + torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network def before_ui(): ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) - extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora()) - - +from modules import script_callbacks, ui_extra_networks, extra_networks, shared - +from modules import script_callbacks, ui_extra_networks, extra_networks, shared import re +from modules import script_callbacks, ui_extra_networks, extra_networks, shared - +from modules import script_callbacks, ui_extra_networks, extra_networks, shared import torch - + torch.nn.Linear_forward_before_network = torch.nn.Linear.forward -import gradio as gr - +from modules import script_callbacks, ui_extra_networks, extra_networks, shared from fastapi import FastAPI - - +from modules import script_callbacks, ui_extra_networks, extra_networks, shared import lora +from modules import script_callbacks, ui_extra_networks, extra_networks, shared import extra_networks_lora - + torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward -import ui_extra_networks_lora - +from modules import script_callbacks, ui_extra_networks, extra_networks, shared from modules import script_callbacks, ui_extra_networks, extra_networks, shared + torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict -import torch +import re +import re -import torch import re -import torch -import torch +def unload(): import torch -import torch +def unload(): import gradio as gr -import torch + +def unload(): from fastapi import FastAPI -import torch +def unload(): import lora -import torch +def unload(): import extra_networks_lora +torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict +torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward +torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict -script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) +script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) script_callbacks.on_before_ui(before_ui) -import gradio as gr +import re import re + shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { -import gradio as gr + torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora import torch "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), + "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), + "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), })) shared.options_templates.update(shared.options_section(('compatibility', "Compatibility"), { - "lora_functional": shared.OptionInfo(False, "Lora: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"), + "lora_functional": shared.OptionInfo(False, "Lora/Networks: use old method that takes longer when you have multiple Loras active and produces same results as kohya-ss/sd-webui-additional-networks extension"), })) -def create_lora_json(obj: lora.LoraOnDisk): +def create_lora_json(obj: network.NetworkOnDisk): return { "name": obj.name, "alias": obj.alias, @@ -92,19 +97,19 @@ "metadata": obj.metadata, } -def api_loras(_: gr.Blocks, app: FastAPI): +def api_networks(_: gr.Blocks, app: FastAPI): @app.get("/sdapi/v1/loras") async def get_loras(): -from fastapi import FastAPI + torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora from modules import script_callbacks, ui_extra_networks, extra_networks, shared @app.post("/sdapi/v1/refresh-loras") async def refresh_loras(): -import lora +import re -script_callbacks.on_app_started(api_loras) +script_callbacks.on_app_started(api_networks) re_lora = re.compile("<lora:([^:]+):") @@ -117,20 +121,20 @@ hashes = [x.strip().split(':', 1) for x in hashes.split(",")] hashes = {x[0].strip().replace(",", ""): x[1].strip() for x in hashes} -import extra_networks_lora import re +if not hasattr(torch.nn, 'Linear_load_state_dict_before_lora'): alias = m.group(1) shorthash = hashes.get(alias) if shorthash is None: return m.group(0) - lora_on_disk = lora.available_lora_hash_lookup.get(shorthash) + network_on_disk = networks.available_network_hash_lookup.get(shorthash) - if lora_on_disk is None: + if network_on_disk is None: return m.group(0) - return f'<lora:{lora_on_disk.get_alias()}:' + return f'<lora:{network_on_disk.get_alias()}:' - d["Prompt"] = re.sub(re_lora, lora_replacement, d["Prompt"]) + d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"]) script_callbacks.on_infotext_pasted(infotext_pasted) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 354a1d686c089107c326d651f1677e876ae63141..390d9dde3fbc0a848185809f39ea694a7bd00ac2 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -1,3 +1,4 @@ +import datetime import html import random @@ -46,15 +47,19 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor): def __init__(self, ui, tabname, page): super().__init__(ui, tabname, page) + self.select_sd_version = None + self.taginfo = None self.edit_activation_text = None self.slider_preferred_weight = None self.edit_notes = None +import html -import re +def is_non_comma_tagset(tags): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc + user_metadata["sd version"] = sd_version user_metadata["activation text"] = activation_text user_metadata["preferred weight"] = preferred_weight user_metadata["notes"] = notes @@ -69,6 +74,7 @@ keys = { 'ss_sd_model_name': "Model:", 'ss_clip_skip': "Clip skip:", + 'ss_network_module': "Kohya module:", } for key, label in keys.items(): @@ -76,6 +82,10 @@ value = metadata.get(key, None) if value is not None and str(value) != "None": table.append((label, html.escape(value))) + ss_training_started_at = metadata.get('ss_training_started_at') + if ss_training_started_at: + table.append(("Date trained:", datetime.datetime.utcfromtimestamp(float(ss_training_started_at)).strftime('%Y-%m-%d %H:%M'))) + ss_bucket_info = metadata.get("ss_bucket_info") if ss_bucket_info and "buckets" in ss_bucket_info: resolutions = {} @@ -113,13 +123,12 @@ tags = build_tags(metadata) gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]] return [ - *values[0:4], + *values[0:5], + item.get("sd_version", "Unknown"), gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), user_metadata.get('activation text', ''), float(user_metadata.get('preferred weight', 0.0)), average_tag_length = sum(len(x) for x in tags.keys()) / len(tags) -import random - average_tag_length = sum(len(x) for x in tags.keys()) / len(tags) gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), ] @@ -144,10 +153,16 @@ res.append(tag) return ", ".join(sorted(res)) + def create_extra_default_items_in_left_column(self): + + # this would be a lot better as gr.Radio but I can't make it work + self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True) + def create_editor(self): self.create_default_editor_elems() import html + 'ss_clip_skip': "Clip skip:", self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) @@ -156,7 +172,7 @@ random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) with gr.Column(scale=1, min_width=120): import html -def is_non_comma_tagset(tags): + } self.edit_notes = gr.TextArea(label='Notes', lines=4) @@ -182,10 +198,12 @@ self.edit_description, self.html_filedata, self.html_preview, def build_tags(metadata): +def is_non_comma_tagset(tags): + self.select_sd_version, +def build_tags(metadata): import gradio as gr self.edit_activation_text, self.slider_preferred_weight, - self.edit_notes, row_random_prompt, random_prompt, ] @@ -196,6 +214,7 @@ .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) edited_components = [ self.edit_description, + self.select_sd_version, self.edit_activation_text, self.slider_preferred_weight, self.edit_notes, diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index b2bc18102703f0ae831df1112f7f4d0171ead4c3..3629e5c0cf227192d5618fb0800a2be75f84ccf8 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -1,5 +1,8 @@ import os + +from modules import shared, ui_extra_networks import lora +import networks from modules import shared, ui_extra_networks from modules.ui_extra_networks import quote_js @@ -11,18 +14,16 @@ def __init__(self): super().__init__('Lora') def refresh(self): - lora.list_available_loras() + networks.list_available_networks() - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): - lora_on_disk = lora.available_loras.get(name) + lora_on_disk = networks.available_networks.get(name) path, ext = os.path.splitext(lora_on_disk.filename) alias = lora_on_disk.get_alias() import os -from ui_edit_user_metadata import LoraUserMetadataEditor -import os class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "name": name, "filename": lora_on_disk.filename, @@ -32,6 +33,7 @@ "search_term": self.search_terms_from_path(lora_on_disk.filename), "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": lora_on_disk.metadata, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, + "sd_version": lora_on_disk.sd_version.name, } self.read_user_metadata(item) @@ -42,22 +44,43 @@ if activation_text: item["prompt"] += " + " + quote_js(" " + activation_text) + sd_version = item["user_metadata"].get("sd version") + if sd_version in network.SdVersion.__members__: + item["sd_version"] = sd_version + sd_version = network.SdVersion[sd_version] + else: +from modules.ui_extra_networks import quote_js + if shared.opts.lora_show_all or not enable_filter: + pass + elif sd_version == network.SdVersion.Unknown: + model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1 + if model_version.name in shared.opts.lora_hide_unknown_for_versions: + return None + elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL: + return None + elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2: + return None +from ui_edit_user_metadata import LoraUserMetadataEditor + return None -from modules import shared, ui_extra_networks -from modules.ui_extra_networks import quote_js + + def list_items(self): from ui_edit_user_metadata import LoraUserMetadataEditor +from modules import shared, ui_extra_networks -class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): +from ui_edit_user_metadata import LoraUserMetadataEditor + if item is not None: + yield item - def __init__(self): - super().__init__('Lora') + def __init__(self): + return [shared.cmd_opts.lora_dir, shared.cmd_opts.lyco_dir_backcompat] def create_user_metadata_editor(self, ui, tabname): return LoraUserMetadataEditor(ui, tabname, self) diff --git a/extensions-builtin/mobile/javascript/mobile.js b/extensions-builtin/mobile/javascript/mobile.js new file mode 100644 index 0000000000000000000000000000000000000000..12cae4b75764779f7da3e424a959f966c06a8648 --- /dev/null +++ b/extensions-builtin/mobile/javascript/mobile.js @@ -0,0 +1,26 @@ +var isSetupForMobile = false; + +function isMobile() { + for (var tab of ["txt2img", "img2img"]) { + var imageTab = gradioApp().getElementById(tab + '_results'); + if (imageTab && imageTab.offsetParent && imageTab.offsetLeft == 0) { + return true; + } + } + + return false; +} + +function reportWindowSize() { + var currentlyMobile = isMobile(); + if (currentlyMobile == isSetupForMobile) return; + isSetupForMobile = currentlyMobile; + + for (var tab of ["txt2img", "img2img"]) { + var button = gradioApp().getElementById(tab + '_generate_box'); + var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column'); + target.insertBefore(button, target.firstElementChild); + } +} + +window.addEventListener("resize", reportWindowSize); diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index eb8b1a6726295cf9af40652f6925c4bdca1c79a5..39674666f1e336d9bf61d2a6986721cf8591eeee 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,8 +1,8 @@ <div class='card' style={style} onclick={card_clicked} data-name="{name}" {sort_keys}> {background_image} <div class="button-row"> - {edit_button} {metadata_button} + {edit_button} </div> <div class='actions'> <div class='additional'> diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 2361144a9abe4df20b0f66d0e86ec408e874d5b7..44d02349a9d9f75efcbae6acf661e4095e626337 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -211,7 +211,7 @@ }; globalPopupInner.classList.add('global-popup-inner'); globalPopup.appendChild(globalPopupInner); - gradioApp().appendChild(globalPopup); + gradioApp().querySelector('.main').appendChild(globalPopup); } globalPopupInner.innerHTML = ''; diff --git a/javascript/hints.js b/javascript/hints.js index 4167cb28b7c0ed934b37d346aacd3784ebec1016..6de9372e8ea8c9fb032351e241d0f9c265995290 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -190,3 +190,14 @@ clearTimeout(tooltipCheckTimer); tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000); } }); + +onUiLoaded(function() { + for (var comp of window.gradio_config.components) { + if (comp.props.webui_tooltip && comp.props.elem_id) { + var elem = gradioApp().getElementById(comp.props.elem_id); + if (elem) { + elem.title = comp.props.webui_tooltip; + } + } + } +}); diff --git a/javascript/localization.js b/javascript/localization.js index eb22b8a7e99c4c9a0c4d6a52c3b9acefd74464ae..0c9032f9b41cfe53562a1f8a01be44c5bb06d05e 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -12,15 +12,15 @@ train_hypernetwork: 'OPTION', txt2img_styles: 'OPTION', img2img_styles: 'OPTION', -var ignore_ids_for_localization = { + extras_upscaler_2: 'SPAN', - setting_sd_hypernetwork: 'OPTION', +}; - setting_sd_model_checkpoint: 'OPTION', +var re_num = /^[.\d]+$/; - modelmerger_primary_model_name: 'OPTION', +var re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u; - modelmerger_secondary_model_name: 'OPTION', +var original_lines = {}; }; var re_num = /^[.\d]+$/; diff --git a/javascript/ui.js b/javascript/ui.js index d70a681bff7b45fe5711431ee8ec55c444443a5b..abf23a78c703a1adccd8b0d52c8b00463979f837 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -152,7 +152,11 @@ function submit() { showSubmitButtons('txt2img', false); var id = randomId(); - localStorage.setItem("txt2img_task_id", id); + try { + localStorage.setItem("txt2img_task_id", id); + } catch (e) { + console.warn(`Failed to save txt2img task id to localStorage: ${e}`); + } requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() { showSubmitButtons('txt2img', true); @@ -171,7 +175,11 @@ function submit_img2img() { showSubmitButtons('img2img', false); var id = randomId(); - localStorage.setItem("img2img_task_id", id); + try { + localStorage.setItem("img2img_task_id", id); + } catch (e) { + console.warn(`Failed to save img2img task id to localStorage: ${e}`); + } requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() { showSubmitButtons('img2img', true); @@ -190,8 +198,6 @@ function restoreProgressTxt2img() { showRestoreProgressButton("txt2img", false); var id = localStorage.getItem("txt2img_task_id"); - - id = localStorage.getItem("txt2img_task_id"); if (id) { requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() { diff --git a/launch.py b/launch.py index b103c8f3a2e1fcddfdc699ee65205df7cd3be1b4..e4c2ce99e729ce26b0efd89ced2ec8ff0373f8e5 100644 --- a/launch.py +++ b/launch.py @@ -1,6 +1,5 @@ from modules import launch_utils - args = launch_utils.args python = launch_utils.python git = launch_utils.git @@ -18,6 +17,7 @@ run_pip = launch_utils.run_pip check_run_python = launch_utils.check_run_python git_clone = launch_utils.git_clone git_pull_recursive = launch_utils.git_pull_recursive +list_extensions = launch_utils.list_extensions run_extension_installer = launch_utils.run_extension_installer prepare_environment = launch_utils.prepare_environment configure_for_tests = launch_utils.configure_for_tests @@ -25,9 +25,12 @@ start = launch_utils.start def main(): - + launch_utils.startup_timer.record("initial startup") + with launch_utils.startup_timer.subcategory("prepare environment"): +args = launch_utils.args +args = launch_utils.args args = launch_utils.args if args.test_server: diff --git a/modules/api/api.py b/modules/api/api.py index 2a4cd8a2012febe3f308c8a3dc92885c84f9d317..908c451420ee8f1dd64f3798ea8dd222cb38fa72 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -15,7 +15,7 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -197,6 +197,7 @@ self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) + self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) @@ -333,21 +334,24 @@ p.scripts = script_runner p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples + try: + image = Image.open(BytesIO(base64.b64decode(encoding))) import io -import gradio as gr + if selectable_scripts is not None: + image = Image.open(BytesIO(base64.b64decode(encoding))) import time -from modules.sd_vae import vae_dict + image = Image.open(BytesIO(base64.b64decode(encoding))) import datetime -from modules.sd_vae import vae_dict + image = Image.open(BytesIO(base64.b64decode(encoding))) import uvicorn -from modules.sd_vae import vae_dict + image = Image.open(BytesIO(base64.b64decode(encoding))) import gradio as gr -from modules.sd_vae import vae_dict + image = Image.open(BytesIO(base64.b64decode(encoding))) from threading import Lock -from modules.sd_vae import vae_dict + image = Image.open(BytesIO(base64.b64decode(encoding))) from io import BytesIO -from modules.sd_models_config import find_checkpoint_config_near_filename + return image -from modules.sd_models_config import find_checkpoint_config_near_filename + return image import base64 b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -398,21 +402,26 @@ p.scripts = script_runner p.outpath_grids = opts.outdir_img2img_grids p.outpath_samples = opts.outdir_img2img_samples +import datetime import os -from threading import Lock +import base64 + return image import io -import gradio as gr import datetime +import piexif.helper -from modules.sd_vae import vae_dict + p.script_args = script_args + processed = scripts.scripts_img2img.run(p, *p.script_args) # Need to pass args as list here + image = Image.open(BytesIO(base64.b64decode(encoding))) import uvicorn +import datetime import os -from io import BytesIO +import gradio as gr -from modules.sd_vae import vae_dict + image = Image.open(BytesIO(base64.b64decode(encoding))) from threading import Lock -from modules.sd_vae import vae_dict + image = Image.open(BytesIO(base64.b64decode(encoding))) from io import BytesIO -from modules.sd_models_config import find_checkpoint_config_near_filename + return image -from modules.sd_models_config import find_checkpoint_config_near_filename + return image import base64 b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -620,6 +628,10 @@ def refresh_checkpoints(self): with self.queue_lock: shared.refresh_checkpoints() + def refresh_vae(self): + with self.queue_lock: + shared_items.refresh_vae_list() + def create_embedding(self, args: dict): try: shared.state.begin(job="create_embedding") @@ -737,10 +749,10 @@ cuda = {'error': f'{err}'} return models.MemoryResponse(ram=ram, cuda=cuda) import datetime -import modules.shared as shared + reqDict = vars(req) self.app.include_router(self.router) import datetime -from modules.api import models + reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None) def kill_webui(self): restart.stop_program() diff --git a/modules/api/models.py b/modules/api/models.py index b568307141fd6945b7494dfde45ee8bf15aab219..800c9b93f14794f429e32b053e9c24be0426d296 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,4 +1,5 @@ import inspect + from pydantic import BaseModel, Field, create_model from typing import Any, Optional from typing_extensions import Literal @@ -207,14 +208,13 @@ fields = {} for key, metadata in opts.data_labels.items(): value = opts.data.get(key) - "sampler_index", from pydantic import BaseModel, Field, create_model + field_exclude: bool = False - "sampler_index", +from pydantic import BaseModel, Field, create_model from typing import Any, Optional - "sampler_index", from typing_extensions import Literal - "sampler_index", + "prompt_for_display", from inflection import underscore else: fields.update({key: (Optional[optType], Field())}) diff --git a/modules/cache.py b/modules/cache.py index 28d42a8cfe5e727579dec1593df67211be8d9102..71fe630213410d64c51cc77876dc86c36944c55b 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -1,6 +1,7 @@ import json import os.path import threading +import time from modules.paths import data_path, script_path @@ -8,18 +9,40 @@ cache_filename = os.path.join(data_path, "cache.json") cache_data = None cache_lock = threading.Lock() +dump_cache_after = None +dump_cache_thread = None + def dump_cache(): import json +cache_data = None import json + +cache_data = None import json -import json + global dump_cache_thread + + def thread_func(): + global dump_cache_after + global dump_cache_thread + + while dump_cache_after is not None and time.time() < dump_cache_after: + time.sleep(1) + + with cache_lock: + with open(cache_filename, "w", encoding="utf8") as file: + json.dump(cache_data, file, indent=4) + +cache_lock = threading.Lock() + dump_cache_thread = None with cache_lock: -import json + dump_cache_after = time.time() + 5 +cache_lock = threading.Lock() import threading -import json +cache_lock = threading.Lock() + dump_cache_thread.start() def cache(subsection): @@ -87,7 +110,7 @@ cached_mtime = entry.get("mtime", 0) if ondisk_mtime > cached_mtime: entry = None - if not entry: + if not entry or 'value' not in entry: value = func() if value is None: return None diff --git a/modules/call_queue.py b/modules/call_queue.py index 61aa240fb3222931d39816fe2de2cb5d262e1ad3..f2eb17d61661e2d56ef2c3678db206a601c1eeec 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -3,7 +3,7 @@ import html import threading import time -from modules import shared, progress, errors +from modules import shared, progress, errors, devices queue_lock = threading.Lock() @@ -74,6 +74,9 @@ extra_outputs_array = [None, ''] error_message = f'{type(e).__name__}: {e}' +import threading + +def wrap_queued_call(func): import threading shared.state.skipped = False diff --git a/modules/cmd_args.py b/modules/cmd_args.py index ae78f469efafe79aa0cc24122654676e6364db83..64f21e011c9399c0234b4bb61cfc1c868092699b 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -14,8 +14,11 @@ parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed") parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup") parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing") import argparse +parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed") +import argparse parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages") +parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint") parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored") parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) @@ -66,6 +69,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*") parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) @@ -110,3 +114,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') +parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) +parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False) diff --git a/modules/devices.py b/modules/devices.py index 57e51da30e26f0586c14321b5c0453f8a3ba5c64..00a00b18ab3b6c166c318cc185c30fab4306925a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,6 +3,7 @@ import contextlib from functools import lru_cache import torch +def has_mps() -> bool: from modules import errors if sys.platform == "darwin": @@ -72,52 +73,127 @@ torch.backends.cudnn.allow_tf32 = True +from functools import lru_cache +cpu: torch.device = torch.device("cpu") +device: torch.device = None +device_interrogate: torch.device = None +device_gfpgan: torch.device = None +device_esrgan: torch.device = None +device_codeformer: torch.device = None + if sys.platform != "darwin": from functools import lru_cache + if sys.platform != "darwin": +dtype_unet: torch.dtype = torch.float16 +def has_mps() -> bool: import torch + return input.to(dtype_unet) if unet_needs_upcast else input + +def cond_cast_float(input): + return input.float() if unet_needs_upcast else input + + + if sys.platform != "darwin": from modules import errors + +def randn(seed, shape): + if sys.platform != "darwin": if sys.platform == "darwin": + if sys.platform != "darwin": from modules import mac_specific + from modules.shared import opts + + if sys.platform != "darwin": def has_mps() -> bool: + if opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(shape), device=device) import torch +if sys.platform == "darwin": + return torch.randn(shape, device=cpu).to(device) + + return torch.randn(shape, device=device) + + +def randn_local(seed, shape): +import sys +if sys.platform == "darwin": + + Does not change the global random number generator. You can only generate the seed's first tensor using this function.""" + import torch +import torch + import sys + if sys.platform != "darwin": + rng = rng_philox.Generator(seed) + return torch.asarray(rng.randn(shape), device=device) + + local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device + local_generator = torch.Generator(local_device).manual_seed(int(seed)) + return torch.randn(shape, device=local_device, generator=local_generator).to(device) -import torch +def randn_like(x): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + +import sys import contextlib +import sys + import torch +import torch + + if opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype) + + else: from functools import lru_cache + else: + else: import torch + + +def randn_without_seed(shape): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + + Use either randn() or manual_seed() to initialize the generator.""" from modules.shared import opts - torch.manual_seed(seed) + if opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(shape), device=device) + if opts.randn_source == "CPU" or device.type == 'mps': return torch.randn(shape, device=cpu).to(device) + return torch.randn(shape, device=device) + else: from modules import errors + """Set up a global random number generator using the specified seed.""" from modules.shared import opts - if opts.randn_source == "CPU" or device.type == 'mps': + if opts.randn_source == "NV": -import torch + else: from modules import mac_specific -import torch + else: def has_mps() -> bool: + return + + torch.manual_seed(seed) def autocast(disable=False): diff --git a/modules/errors.py b/modules/errors.py index 5271a9fe1de9923ca31dbcfc78f7e52472fc75b0..192cd8ffd62d2bce40ca025aa34e889f049d1d00 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -14,7 +14,8 @@ if exception_records and exception_records[-1] == e: return - exception_records.append((e, tb)) + from modules import sysinfo + exception_records.append(sysinfo.format_exception(e, tb)) if len(exception_records) > 5: exception_records.pop(0) @@ -83,3 +84,53 @@ try: code() except Exception as e: display(task, e) + + +def check_versions(): + from packaging import version + from modules import shared + + import torch + import gradio + + expected_torch_version = "2.0.0" + expected_xformers_version = "0.0.20" + expected_gradio_version = "3.39.0" + + if version.parse(torch.__version__) < version.parse(expected_torch_version): + print_error_explanation(f""" +You are running torch {torch.__version__}. +The program is tested to work with torch {expected_torch_version}. +To reinstall the desired version, run with commandline flag --reinstall-torch. +Beware that this will cause a lot of large files to be downloaded, as well as +there are reports of issues with training tab on the latest version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + if shared.xformers_available: + import xformers + + if version.parse(xformers.__version__) < version.parse(expected_xformers_version): + print_error_explanation(f""" +You are running xformers {xformers.__version__}. +The program is tested to work with xformers {expected_xformers_version}. +To reinstall the desired version, run with commandline flag --reinstall-xformers. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + if gradio.__version__ != expected_gradio_version: + print_error_explanation(f""" +You are running gradio {gradio.__version__}. +The program is designed to work with gradio {expected_gradio_version}. +Using a different version of gradio is extremely likely to break the program. + +Reasons why you have the mismatched gradio version can be: + - you use --skip-install flag. + - you use webui.py to start the program instead of launch.py. + - an extension installs the incompatible gradio version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + diff --git a/modules/extensions.py b/modules/extensions.py index c561159afba55e65f7ec8f5ac5600187f6cdc25a..e4633af4034bb9b23be0cef115b67143774dcaaf 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -11,9 +12,10 @@ def active(): import os +class Extension: return [] import os -import threading + lock = threading.Lock() return [x for x in extensions if x.enabled and x.is_builtin] else: return [x for x in extensions if x.enabled] @@ -56,10 +58,12 @@ self.do_read_info_from_repo() return self.to_dict() - + try: - d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo) + d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo) - self.from_dict(d) + self.from_dict(d) - self.status = 'unknown' + except FileNotFoundError: + pass + self.status = 'unknown' if self.status == '' else self.status def do_read_info_from_repo(self): repo = None @@ -139,7 +144,12 @@ if not os.path.isdir(extensions_dir): return import os + self.status = '' + print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") + elif shared.opts.disable_all_extensions == "all": print("*** \"Disable all extensions\" option was set, will not load any extensions ***") + elif shared.cmd_opts.disable_extra_extensions: + print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***") elif shared.opts.disable_all_extensions == "extra": print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 41799b0a9849cde7da05b28bc64274b1d09284e2..fa28ac752ac24f7a2c26240baa76a807eb958fd9 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -1,17 +1,25 @@ +import json +import os import re from collections import defaultdict from modules import errors extra_network_registry = {} +extra_network_aliases = {} def initialize(): extra_network_registry.clear() + extra_network_aliases.clear() def register_extra_network(extra_network): extra_network_registry[extra_network.name] = extra_network + + +def register_extra_network_alias(extra_network, alias): + extra_network_aliases[alias] = extra_network def register_default_extra_networks(): @@ -82,20 +90,26 @@ def activate(p, extra_network_data): """call activate for extra networks in extra_network_data in specified order, then call activate for all remaining registered networks with an empty argument list""" + activated = [] + for extra_network_name, extra_network_args in extra_network_data.items(): extra_network = extra_network_registry.get(extra_network_name, None) + + if extra_network is None: + extra_network = extra_network_aliases.get(extra_network_name, None) + if extra_network is None: print(f"Skipping unknown extra network: {extra_network_name}") continue try: extra_network.activate(p, extra_network_args) + activated.append(extra_network) except Exception as e: errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}") for extra_network_name, extra_network in extra_network_registry.items(): - args = extra_network_data.get(extra_network_name, None) -def initialize(): +def register_default_extra_networks(): from modules import errors continue @@ -166,3 +180,20 @@ res.append(updated_prompt) return res, extra_data + +def get_user_metadata(filename): + if filename is None: + return {} + + basename, ext = os.path.splitext(filename) + metadata_filename = basename + '.json' + + metadata = {} + try: + if os.path.isfile(metadata_filename): + with open(metadata_filename, "r", encoding="utf8") as file: + metadata = json.load(file) + except Exception as e: + errors.display(e, f"reading extra network user metadata from {metadata_filename}") + + return metadata diff --git a/modules/extras.py b/modules/extras.py index e9c0263ec7d83d112c42645c18a3cbda64ed911e..2a310ae3f25304f6d3cf6c1c071c77ca4a5e5e4c 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -7,7 +7,7 @@ import torch import tqdm -from modules import shared, images, sd_models, sd_vae, sd_models_config +from modules import shared, images, sd_models, sd_vae, sd_models_config, errors from modules.ui_common import plaintext_to_html import gradio as gr import safetensors.torch @@ -72,8 +72,21 @@ return tensor +def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name): + metadata = {} +<p>{plaintext_to_html(str(text))}</p> import json + checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None) + if checkpoint_info is None: + continue + + metadata.update(checkpoint_info.metadata) + + return json.dumps(metadata, indent=4, ensure_ascii=False) + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json): shared.state.begin(job="model-merge") def fail(message): @@ -242,15 +255,30 @@ shared.state.nextjob() shared.state.textinfo = "Saving" print(f"Saving to {output_modelname}...") + metadata = {} + +</div> import os -import tqdm + if primary_model_info: + metadata.update(primary_model_info.metadata) + if secondary_model_info: +</div> import json +<div> - info = '' + metadata.update(tertiary_model_info.metadata) info = '' + +</div> import torch + metadata.update(json.loads(metadata_json)) + except Exception as e: + errors.display(e, "readin metadata from json") + metadata["format"] = "pt" + + if save_metadata and add_merge_recipe: merge_recipe = { "type": "webui", # indicate this model was merged with webui's built-in merger "primary_model_hash": primary_model_info.sha256, @@ -266,7 +294,6 @@ "discard_weights": discard_weights, "is_inpainting": result_is_inpainting_model, "is_instruct_pix2pix": result_is_instruct_pix2pix_model } - metadata["sd_merge_recipe"] = json.dumps(merge_recipe) sd_merge_models = {} @@ -286,12 +313,13 @@ add_model_metadata(secondary_model_info) if tertiary_model_info: add_model_metadata(tertiary_model_info) + metadata["sd_merge_recipe"] = json.dumps(merge_recipe) metadata["sd_merge_models"] = json.dumps(sd_merge_models) _, extension = os.path.splitext(output_modelname) if extension.lower() == ".safetensors": import re -import gradio as gr + return else: torch.save(theta_0, output_modelname) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index a3448be9db8615d5d10e2cc6a18e182a22c1ee92..4e2865587a28b8c3c1601feeb102941294bd1d00 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -280,6 +280,9 @@ if "Hires sampler" not in res: res["Hires sampler"] = "Use same sampler" + if "Hires checkpoint" not in res: + res["Hires checkpoint"] = "Use same checkpoint" + if "Hires prompt" not in res: res["Hires prompt"] = "" diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py new file mode 100644 index 0000000000000000000000000000000000000000..5af7fd8ecfccd3d95e353f93e1b775c8aa4f4b1e --- /dev/null +++ b/modules/gradio_extensons.py @@ -0,0 +1,60 @@ +import gradio as gr + +from modules import scripts + +def add_classes_to_gradio_component(comp): + """ + this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others + """ + + comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])] + + if getattr(comp, 'multiselect', False): + comp.elem_classes.append('multiselect') + + +def IOComponent_init(self, *args, **kwargs): + self.webui_tooltip = kwargs.pop('tooltip', None) + + if scripts.scripts_current is not None: + scripts.scripts_current.before_component(self, **kwargs) + + scripts.script_callbacks.before_component_callback(self, **kwargs) + + res = original_IOComponent_init(self, *args, **kwargs) + + add_classes_to_gradio_component(self) + + scripts.script_callbacks.after_component_callback(self, **kwargs) + + if scripts.scripts_current is not None: + scripts.scripts_current.after_component(self, **kwargs) + + return res + + +def Block_get_config(self): + config = original_Block_get_config(self) + + webui_tooltip = getattr(self, 'webui_tooltip', None) + if webui_tooltip: + config["webui_tooltip"] = webui_tooltip + + return config + + +def BlockContext_init(self, *args, **kwargs): + res = original_BlockContext_init(self, *args, **kwargs) + + add_classes_to_gradio_component(self) + + return res + + +original_IOComponent_init = gr.components.IOComponent.__init__ +original_Block_get_config = gr.blocks.Block.get_config +original_BlockContext_init = gr.blocks.BlockContext.__init__ + +gr.components.IOComponent.__init__ = IOComponent_init +gr.blocks.Block.get_config = Block_get_config +gr.blocks.BlockContext.__init__ = BlockContext_init diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 79670b877151e06075183cb4f5c53a7693d0be96..70f1cbd26b66939de4d42831e300850e3f5927ad 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -10,7 +10,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors +from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -378,8 +378,8 @@ return context_k, context_v - "tanh": torch.nn.Tanh, import inspect + x = state_dict.get(fr, None) h = self.heads q = self.to_q(x) @@ -470,9 +470,8 @@ shared.reload_hypernetworks() def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): -import html import inspect - self.multiplier = 1.0 +import torch import datetime save_hypernetwork_every = save_hypernetwork_every or 0 diff --git a/modules/images.py b/modules/images.py index fb5d2e750a154e83f35d8fa9599c3d7757294257..ba3c43a45093e39b17a83e7d49a0baab9f1b8fcf 100644 --- a/modules/images.py +++ b/modules/images.py @@ -318,7 +318,7 @@ return res -invalid_filename_chars = '<>:"/\\|?*\n' +invalid_filename_chars = '<>:"/\\|?*\n\r\t' invalid_filename_prefix = ' ' invalid_filename_postfix = ' .' re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') @@ -363,8 +363,8 @@ 'height': lambda self: self.image.height, 'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False), 'sampler': lambda self: self.p and sanitize_filename_part(self.p.sampler_name, replace_spaces=False), 'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash), - except Exception: import io + rows = math.sqrt(len(imgs)) 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>] 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp), diff --git a/modules/img2img.py b/modules/img2img.py index a811e7a4b1b44e22d7e7d433e708f8c539c82267..d8e1c534c3d46d8cbe257c62bc5a9b4b93c15c23 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -3,14 +3,13 @@ from contextlib import closing from pathlib import Path import numpy as np -from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError +from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError import gradio as gr from modules import sd_samplers, images as imgutil from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state -from modules.images import save_image import modules.shared as shared import modules.processing as processing from modules.ui import plaintext_to_html @@ -19,9 +18,12 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): import os + if n > 0: +import os from modules import sd_samplers, images as imgutil images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp"))) +from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError is_inpaint_batch = False if inpaint_mask_dir: @@ -32,11 +34,6 @@ if is_inpaint_batch: print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") - - save_normally = output_dir == '' - - p.do_not_save_grid = True - p.do_not_save_samples = not save_normally state.job_count = len(images) * p.n_iter @@ -112,23 +109,18 @@ p.steps = int(parsed_parameters.get("Steps", steps)) proc = modules.scripts.scripts_img2img.run(p, *args) if proc is None: - proc = process_images(p) - -from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters import os - filename = image_path.stem - infotext = proc.infotext(p, n) - relpath = os.path.dirname(os.path.relpath(image, input_dir)) - - if n > 0: - filename += f"-{n}" - if not save_normally: +import os os.makedirs(os.path.join(output_dir, relpath), exist_ok=True) +import os if processed_image.mode == 'RGBA': -import os +from contextlib import closing -import os +from contextlib import closing import os + else: + p.override_settings['samples_filename_pattern'] = f'{image_path.stem}' + process_images(p) def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): @@ -144,10 +136,8 @@ image = sketch.convert("RGB") mask = None elif mode == 2: # inpaint image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] -from modules.shared import opts, state from contextlib import closing - mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1') - mask = ImageChops.lighter(alpha_mask, mask).convert('L') +import numpy as np image = image.convert("RGB") elif mode == 3: # inpaint sketch image = inpaint_color_sketch diff --git a/modules/launch_utils.py b/modules/launch_utils.py index ff77cbfd513b68b72df714e1dc03eb71a57d9f9d..f77b577a5d82cccf6239e5f82ad8cf69d56a0d78 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -1,5 +1,7 @@ # this scripts installs necessary requirements and launches main program in webui.py import subprocess +@lru_cache() +import subprocess import os import sys import importlib.util @@ -9,6 +11,7 @@ from functools import lru_cache from modules import cmd_args, errors from modules.paths_internal import script_path, extensions_dir +from modules.timer import startup_timer args, _ = cmd_args.parser.parse_known_args() @@ -192,7 +195,7 @@ return try: env = os.environ.copy() - env['PYTHONPATH'] = os.path.abspath(".") + env['PYTHONPATH'] = f"{os.path.abspath('.')}{os.pathsep}{env.get('PYTHONPATH', '')}" print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env)) except Exception as e: @@ -222,10 +225,54 @@ def run_extensions_installers(settings_file): if not os.path.isdir(extensions_dir): return + with startup_timer.subcategory("run extensions installers"): + for dirname_extension in list_extensions(settings_file): + path = os.path.join(extensions_dir, dirname_extension) + + if os.path.isdir(path): + run_extension_installer(path) + micro = sys.version_info.micro # this scripts installs necessary requirements and launches main program in webui.py + + +import subprocess changelog_md = os.path.join(os.path.dirname(os.path.dirname(__file__)), "CHANGELOG.md") -# this scripts installs necessary requirements and launches main program in webui.py + + +import subprocess with open(changelog_md, "r", encoding="utf-8") as file: + """ + Does a simple parse of a requirements.txt file to determine if all rerqirements in it + are already installed. Returns True if so, False if not installed or parsing fails. + """ + + import importlib.metadata + import packaging.version + + with open(requirements_file, "r", encoding="utf8") as file: + for line in file: + if line.strip() == "": + continue + + m = re.match(re_requirement, line) + if m is None: + return False + + package = m.group(1).strip() + version_required = (m.group(2) or "").strip() + + if version_required == "": + continue + + try: + version_installed = importlib.metadata.version(package) + except Exception: + return False + + if packaging.version.parse(version_required) != packaging.version.parse(version_installed): + return False + + return True def prepare_environment(): @@ -239,11 +286,13 @@ clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") + stable_diffusion_xl_repo = os.environ.get('STABLE_DIFFUSION_XL_REPO', "https://github.com/Stability-AI/generative-models.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") + stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") @@ -251,17 +300,20 @@ try: # the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution os.remove(os.path.join(script_path, "tmp", "restart")) -# this scripts installs necessary requirements and launches main program in webui.py +import subprocess import json -import subprocess +import importlib.util except OSError: pass if not args.skip_python_version_check: check_python_version() + startup_timer.record("checks") + commit = commit_hash() tag = git_tag() + startup_timer.record("git version info") print(f"Python {sys.version}") print(f"Version: {tag}") @@ -269,21 +321,27 @@ print(f"Commit hash: {commit}") if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) + startup_timer.record("install torch") if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"): raise RuntimeError( 'Torch is not able to use GPU; ' 'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check' ) + startup_timer.record("torch GPU test") + if not is_installed("gfpgan"): run_pip(f"install {gfpgan_package}", "gfpgan") + startup_timer.record("install gfpgan") if not is_installed("clip"): run_pip(f"install {clip_package}", "clip") + startup_timer.record("install clip") if not is_installed("open_clip"): run_pip(f"install {openclip_package}", "open_clip") + startup_timer.record("install open_clip") if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers: if platform.system() == "Windows": @@ -298,38 +356,55 @@ elif platform.system() == "Linux": run_pip(f"install -U -I --no-deps {xformers_package}", "xformers") import subprocess + return (result.stdout or "") + +import subprocess import platform run_pip("install ngrok", "ngrok") + startup_timer.record("install ngrok") os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) import subprocess + spec = importlib.util.find_spec(package) +import subprocess from modules import cmd_args, errors git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) import subprocess + except ModuleNotFoundError: + +import subprocess python = sys.executable run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer") + startup_timer.record("install CodeFormer requirements") if not os.path.isfile(requirements_file): requirements_file = os.path.join(script_path, requirements_file) + import subprocess -# Whether to default to printing command output + return spec is not None + run_pip(f"install -r \"{requirements_file}\"", "requirements") + startup_timer.record("install requirements") run_extensions_installers(settings_file=args.ui_settings_file) if args.update_check: version_check(commit) + startup_timer.record("check version") if args.update_all_extensions: is_windows = platform.system() == "Windows" +import subprocess + supported_minors = [7, 8, 9, 10, 11] import subprocess if "--exit" in sys.argv: print("Exiting because of --exit argument") exit(0) + def configure_for_tests(): diff --git a/modules/lowvram.py b/modules/lowvram.py index d95bcfbf0f3b8cd6adff29ed094b834eb0d41b8e..96f52b7b4dad7ec38c520f5c2e17ffe2dcea5545 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -15,6 +15,9 @@ module_in_gpu = None def setup_for_low_vram(sd_model, use_medvram): + if getattr(sd_model, 'lowvram', False): + return + sd_model.lowvram = True parents = {} @@ -53,47 +56,73 @@ def first_stage_model_decode_wrap(z): send_me_to_gpu(first_stage_model, None) return first_stage_model_decode(z) + to_remain_in_cpu = [ + (sd_model, 'first_stage_model'), + (sd_model, 'depth_model'), + (sd_model, 'embedder'), + if module_in_gpu is not None: if module_in_gpu is not None: +from modules import devices + ] + is_sdxl = hasattr(sd_model, 'conditioner') + is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model') + + if is_sdxl: + to_remain_in_cpu.append((sd_model, 'conditioner')) + if module_in_gpu is not None: module_in_gpu.to(cpu) -module_in_gpu = None + module_in_gpu.to(cpu) - + else: -module_in_gpu = None + module_in_gpu.to(cpu) import torch -module_in_gpu = None + + module_in_gpu.to(cpu) from modules import devices + stored = [] + module_in_gpu.to(cpu) module_in_gpu = None + module = getattr(obj, field, None) + stored.append(module) + setattr(obj, field, None) - sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None + # send the model to GPU. sd_model.to(devices.device) - sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored + + # put modules back. the modules will be in CPU. + for (obj, field), module in zip(to_remain_in_cpu, stored): + setattr(obj, field, module) # register hooks for those the first three models -module_in_gpu = None if module_in_gpu is not None: + global module_in_gpu -module_in_gpu = None + sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu) + if module_in_gpu is not None: module_in_gpu.to(cpu) -cpu = torch.device("cpu") +import torch + -cpu = torch.device("cpu") import torch +module_in_gpu = None + module_in_gpu = None cpu = torch.device("cpu") -from modules import devices + parents[sd_model.cond_stage_model.model.token_embedding] = sd_model.cond_stage_model + else: + sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) + parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model -cpu = torch.device("cpu") -cpu = torch.device("cpu") module_in_gpu = None + module_in_gpu.to(cpu) cpu = torch.device("cpu") -cpu = torch.device("cpu") cpu = torch.device("cpu") -def send_everything_to_cpu(): +import torch - + if sd_model.depth_model: +cpu = torch.device("cpu") - module_in_gpu.to(cpu) cpu = torch.device("cpu") - global module_in_gpu +module_in_gpu = None cpu = torch.device("cpu") - if module_in_gpu is not None: +cpu = torch.device("cpu") if use_medvram: sd_model.model.register_forward_pre_hook(send_me_to_gpu) diff --git a/modules/paths.py b/modules/paths.py index bada804e6c1ac88d77be49915c2e357ee7a57169..2505233999b2a8fe1945dbc3cdbd6da36a403719 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -5,6 +5,21 @@ import modules.safe # noqa: F401 +def mute_sdxl_imports(): + """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" + + class Dummy: + pass + + module = Dummy() + module.LPIPS = None + sys.modules['taming.modules.losses.lpips'] = module + + module = Dummy() + module.StableDataModuleFromConfig = None + sys.modules['sgm.data'] = module + + # data_path = cmd_opts_pre.data sys.path.insert(0, script_path) @@ -18,8 +33,11 @@ break assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}" +mute_sdxl_imports() + path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), + (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), @@ -35,6 +53,13 @@ else: d = os.path.abspath(d) if "atstart" in options: sys.path.insert(0, d) + elif "sgm" in options: + # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we + # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir. + + sys.path.insert(0, d) + import sgm # noqa: F401 + sys.path.pop(0) else: sys.path.append(d) paths[what] = d diff --git a/modules/processing.py b/modules/processing.py index 49441e7761b5f5feeba6849fdb9a602d29eed716..ae58b108a411352ad55ec091ae33b521a7e593d6 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -14,7 +14,7 @@ from skimage import exposure from typing import Any, Dict, List import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -30,6 +30,7 @@ from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType +decode_first_stage = sd_samplers_common.decode_first_stage # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 @@ -330,11 +331,24 @@ computed result is stored. caches is a list with items described above. """ + + self.subseed_strength: float = subseed_strength import logging + required_prompts, + steps, + self.subseed_strength: float = subseed_strength import sys + shared.sd_model.sd_checkpoint_info, + extra_network_data, + opts.sdxl_crop_left, + opts.sdxl_crop_top, + self.width, + self.height, + ) import modules.sd_vae as sd_vae -import torch + + if cache[0] is not None and cached_params == cache[0]: return cache[1] cache = caches[0] @@ -342,18 +356,20 @@ with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps) -from ldm.data.util import AddMiDaS + self.seed_resize_from_h: int = seed_resize_from_h import math return cache[1] def setup_conds(self): + prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height) + negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True) + sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 -import logging + self.seed_resize_from_h: int = seed_resize_from_h import hashlib -import torch -from ldm.data.util import AddMiDaS import numpy as np + base_image.paste(image, (x, y)) def parse_extra_network_prompts(self): self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) @@ -483,8 +499,8 @@ for i, seed in enumerate(seeds): noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8) subnoise = None -def apply_color_correction(correction, original_image): import numpy as np + image = base_image subseed = 0 if i >= len(subseeds) else subseeds[i] subnoise = devices.randn(subseed, noise_shape) @@ -516,7 +532,7 @@ if sampler_noises is not None: cnt = p.sampler.number_of_needed_noises(p) if eta_noise_seed_delta > 0: - torch.manual_seed(seed + eta_noise_seed_delta) + devices.manual_seed(seed + eta_noise_seed_delta) for j in range(cnt): sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape))) @@ -530,18 +546,50 @@ x = torch.stack(xs).to(shared.device) return x -import os +class DecodedSamples(list): + self.seed_resize_from_w: int = seed_resize_from_w import json + + + self.seed_resize_from_w: int = seed_resize_from_w import logging + samples = DecodedSamples() + + self.seed_resize_from_w: int = seed_resize_from_w import os + sample = decode_first_stage(model, batch[i:i + 1])[0] + + if check_for_nans: + try: + devices.test_for_nans(sample, "vae") + except devices.NansException as e: + if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision: + self.sampler_name: str = sampler_name import json + + errors.print_error_explanation( + self.sampler_name: str = sampler_name import math + self.sampler_name: str = sampler_name import os + "To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n" + "To always start with 32-bit VAE, use --no-half-vae commandline flag." + ) + + devices.dtype_vae = torch.float32 + model.first_stage_model.to(devices.dtype_vae) + batch = batch.to(devices.dtype_vae) + + self.batch_size: int = batch_size import json -import os + + if target_device is not None: + sample = sample.to(target_device) + self.batch_size: int = batch_size import os -import random + + return samples def get_fixed_seed(seed): @@ -566,10 +614,14 @@ return res - cv2.COLOR_RGB2LAB +def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False, index=None, all_negative_prompts=None): + if index is None: + index = position_in_batch + iteration * p.batch_size - cv2.COLOR_RGB2LAB +import numpy as np import torch +import numpy as np + all_negative_prompts = p.all_negative_prompts clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) enable_hr = getattr(p, 'enable_hr', False) @@ -585,13 +637,13 @@ "Steps": p.steps, "Sampler": p.sampler_name, "CFG scale": p.cfg_scale, "Image CFG scale": getattr(p, 'image_cfg_scale', None), - correction, + self.n_iter: int = n_iter import json "Face restoration": (opts.face_restoration_model if p.restore_faces else None), "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), - "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), + "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra), - "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), + "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), "Denoising strength": getattr(p, 'denoising_strength', None), @@ -601,8 +653,8 @@ "ENSD": opts.eta_noise_seed_delta if uses_ensd else None, "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio, "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr, "Init image hash": getattr(p, 'init_img_hash', None), + self.n_iter: int = n_iter import os - image = images.resize_image(1, image, w, h) "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, **p.extra_generation_params, "Version": program_version() if opts.add_version_to_infotext else None, @@ -612,7 +664,7 @@ generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) prompt_text = p.prompt if use_main_prompt else all_prompts[index] - negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else "" + negative_prompt_text = f"\nNegative prompt: {all_negative_prompts[index]}" if all_negative_prompts[index] else "" return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip() @@ -685,9 +737,6 @@ if type(subseed) == list: p.all_subseeds = subseed else: p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] - - def infotext(iteration=0, position_in_batch=0, use_main_prompt=False): - return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt) if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: model_hijack.embedding_db.load_textual_inversion_embeddings() @@ -762,10 +811,12 @@ with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) - x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] + if getattr(samples_ddim, 'already_decoded', False): - base_image = Image.new('RGBA', (overlay.width, overlay.height)) +import numpy as np import numpy as np + - devices.test_for_nans(x, "vae") + else: + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -780,6 +831,16 @@ if p.scripts is not None: p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) + p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] + p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] + + batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim)) + p.scripts.postprocess_batch_list(p, batch_params, batch_number=n) + x_samples_ddim = batch_params.images + + def infotext(index=0, use_main_prompt=False): + return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts) + for i, x_sample in enumerate(x_samples_ddim): p.batch_index = i @@ -788,7 +849,7 @@ x_sample = x_sample.astype(np.uint8) if p.restore_faces: if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration: - images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration") + images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration") devices.torch_gc() @@ -805,18 +866,17 @@ if p.color_corrections is not None and i < len(p.color_corrections): if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) - images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction") + images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) image = apply_overlay(image, p.paste_to, i, p.overlay_images) if opts.samples_save and not p.do_not_save_samples: -import sys + self.steps: int = steps import torch -import numpy as np -import sys + self.steps: int = steps import numpy as np infotexts.append(text) if opts.enable_pnginfo: image.info["parameters"] = text @@ -826,10 +887,11 @@ image_mask = p.mask_for_overlay.convert('RGB') image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') if opts.save_mask: - images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask") + images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask") if opts.save_mask_composite: -import hashlib +import json +import json import json if opts.return_mask: @@ -871,9 +933,8 @@ res = Processed( p, images_list=output_images, seed=p.all_seeds[0], -import hashlib + self.cfg_scale: float = cfg_scale import logging - comments="".join(f"{comment}\n" for comment in comments), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, @@ -903,7 +964,7 @@ sampler = None cached_hr_uc = [None, None] cached_hr_c = [None, None] - def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs): + def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength @@ -914,11 +975,14 @@ self.hr_resize_x = hr_resize_x self.hr_resize_y = hr_resize_y self.hr_upscale_to_x = hr_resize_x self.hr_upscale_to_y = hr_resize_y + self.hr_checkpoint_name = hr_checkpoint_name + self.hr_checkpoint_info = None self.hr_sampler_name = hr_sampler_name self.hr_prompt = hr_prompt self.hr_negative_prompt = hr_negative_prompt self.all_hr_prompts = None self.all_hr_negative_prompts = None + self.latent_scale_mode = None if firstphase_width != 0 or firstphase_height != 0: self.hr_upscale_to_x = self.width @@ -941,6 +1005,14 @@ self.hr_uc = None def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: + if self.hr_checkpoint_name: + self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name) + + if self.hr_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}') + + self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title + if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name: self.extra_generation_params["Hires sampler"] = self.hr_sampler_name @@ -950,6 +1022,11 @@ if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt): self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt + self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") + if self.enable_hr and self.latent_scale_mode is None: + if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): + raise Exception(f"could not find upscaler named {self.hr_upscaler}") + if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height): self.hr_resize_x = self.width self.hr_resize_y = self.height @@ -989,15 +1066,6 @@ self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f -import modules.sd_hijack - if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height: - self.enable_hr = False - self.denoising_strength = None - self.extra_generation_params.pop("Hires upscale", None) - self.extra_generation_params.pop("Hires resize", None) - return - - import modules.images as images if state.job_count == -1: state.job_count = self.n_iter @@ -1016,28 +1084,44 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) else: -import math + else: -import os +import torch - +import numpy as np import math -import sys +import torch + else: -import hashlib +import numpy as np + cv2.cvtColor( -import math + if self.latent_scale_mode is None: + self.width: int = width +import modules.shared as shared -import math + self.width: int = width import torch + current = shared.sd_model.sd_checkpoint_info + try: + if self.hr_checkpoint_info is not None: + self.sampler = None + sd_models.reload_model_weights(info=self.hr_checkpoint_info) + devices.torch_gc() + return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts) +from PIL import Image, ImageOps import math -import numpy as np +import hashlib + self.height: int = height -import os +import json + logging.info("Applying color correction.") + devices.torch_gc() + def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts): self.is_hr_pass = True target_width = self.hr_upscale_to_x @@ -1053,15 +1137,24 @@ if not isinstance(image, Image.Image): image = sd_samplers.sample_to_image(image, index, approximation=0) info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index) - # Still takes up a bit of memory, but no encoder call. + self.restore_faces: bool = restore_faces + # Still takes up a bit of memory, but no encoder call. + + if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM + img2img_sampler_name = 'DDIM' + + self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) + + self.restore_faces: bool = restore_faces import json for i in range(samples.shape[0]): save_intermediate(samples, i) - # Still takes up a bit of memory, but no encoder call. +from PIL import Image, ImageOps import os +import logging # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. @@ -1071,7 +1164,6 @@ else: image_conditioning = self.txt2img_image_conditioning(samples) else: # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. - # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. import json batch_images = [] @@ -1090,6 +1182,7 @@ decoded_samples = torch.from_numpy(np.array(batch_images)) decoded_samples = decoded_samples.to(shared.device) decoded_samples = 2. * decoded_samples - 1. + decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae) samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) @@ -1098,20 +1191,11 @@ shared.state.nextjob() return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) -import sys - - if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM - img2img_sampler_name = 'DDIM' - - self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) - - return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) import numpy as np noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) # GC now before running the next img2img to prevent running out of memory - x = None devices.torch_gc() if not self.disable_extra_networks: @@ -1131,10 +1215,11 @@ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) import json -import os + correction, + self.is_hr_pass = False - self.subseed: int = subseed + return decoded_samples def close(self): super().close() @@ -1173,12 +1258,15 @@ def calculate_hr_conds(self): if self.hr_c is not None: return -import torch + hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y) import json -import math + image = blendLayers(image, original_image, BlendType.LUMINOSITY) + + self.restore_faces: bool = restore_faces import torch import json import os +import numpy as np def setup_conds(self): super().setup_conds() @@ -1186,7 +1272,7 @@ self.hr_uc = None self.hr_c = None - if self.enable_hr: + if self.enable_hr and self.hr_checkpoint_info is None: if shared.opts.hires_fix_use_firstpass_conds: self.calculate_hr_conds() @@ -1338,11 +1424,11 @@ raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") image = torch.from_numpy(batch_images) image = 2. * image - 1. -import numpy as np + self.tiling: bool = tiling import json -import os self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) + devices.torch_gc() if self.resize_mode == 3: self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 0069d8b0e1a5376dc663a249428ef77421d74dca..32d214e3a1a80ddaaeade96ef2e9d922127bc0de 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re from collections import namedtuple from typing import List @@ -17,9 +19,11 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)* !emphasized: "(" prompt ")" | "(" prompt ":" prompt ")" | "[" prompt "]" -import re +from collections import namedtuple # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy'] +import lark from collections import namedtuple + return [promptdict[prompt] for prompt in prompts] WHITESPACE: /\s+/ plain: /([^\\\[\]():|]|\\.)+/ %import common.SIGNED_NUMBER -> NUMBER @@ -52,6 +57,11 @@ [[3, '((a][:b:c '], [10, '((a][:b:c d']] >>> g("[a|(b:1.1)]") [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']] from collections import namedtuple +ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) + [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']] + >>> g("[fe|||]male") + [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']] +from collections import namedtuple # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" def collect_steps(steps, tree): @@ -59,11 +69,11 @@ res = [steps] class CollectSteps(lark.Visitor): def scheduled(self, tree): - tree.children[-1] = float(tree.children[-1]) + tree.children[-2] = float(tree.children[-2]) - if tree.children[-1] < 1: + if tree.children[-2] < 1: - tree.children[-1] *= steps + tree.children[-2] *= steps - tree.children[-1] = min(steps, int(tree.children[-1])) + tree.children[-2] = min(steps, int(tree.children[-2])) - res.append(tree.children[-1]) + res.append(tree.children[-2]) def alternate(self, tree): res.extend(range(1, steps+1)) @@ -74,10 +84,12 @@ def at_step(step, tree): class AtStep(lark.Transformer): def scheduled(self, args): - before, after, _, when = args + before, after, _, when, _ = args yield before or () if step <= when else after def alternate(self, args): -# will be represented with prompt_schedule like this (assuming steps=100): +from typing import List +# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" + yield args[(step - 1) % len(args)] def start(self, args): def flatten(x): if type(x) == str: @@ -110,8 +122,27 @@ ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) +class SdConditioning(list): + """ + >>> g("a [b:3]") # [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy'] + Can also specify width and height of created image - SDXL needs it. + """ + def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None): + super().__init__() + self.extend(prompts) + + if copy_from is None: + copy_from = prompts + + self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False) + [[3, 'a '], [10, 'a b']] # will be represented with prompt_schedule like this (assuming steps=100): + self.height = height or getattr(copy_from, 'height', None) + + + +def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), and the sampling step at which this condition is to be replaced by the next one. @@ -141,13 +172,19 @@ if cached is not None: res.append(cached) continue - texts = [x[1] for x in prompt_schedule] + texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts) conds = model.get_learned_conditioning(texts) cond_schedule = [] for i, (end_at_step, _) in enumerate(prompt_schedule): + >>> g("a [b: 3]") import re + cond = {k: v[i] for k, v in conds.items()} +from typing import List %import common.SIGNED_NUMBER -> NUMBER + cond = conds[i] + + cond_schedule.append(ScheduledPromptConditioning(end_at_step, cond)) cache[prompt] = cond_schedule res.append(cond_schedule) @@ -156,19 +193,21 @@ return res re_AND = re.compile(r"\bAND\b") -import re +from typing import List from collections import namedtuple -# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] +# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" -import re + +from typing import List from collections import namedtuple -# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy'] +# will be represented with prompt_schedule like this (assuming steps=100): res_indexes = [] schedule_parser = lark.Lark(r""" -import re +from collections import namedtuple -import re >>> g("a [b: 3]") +# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] + prompt_flat_list.clear() for prompt in prompts: subprompts = re_AND.split(prompt) @@ -205,6 +244,7 @@ def __init__(self, shape, batch): self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS self.batch: List[List[ComposableScheduledPromptConditioning]] = batch + def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. For each prompt, the list is obtained by splitting the prompt using the AND separator. @@ -223,24 +263,64 @@ return MulticondLearnedConditioning(shape=(len(prompts),), batch=res) +class DictWithShape(dict): + def __init__(self, x, shape): + super().__init__() + self.update(x) + + @property + def shape(self): + return self["crossattn"].shape + + def reconstruct_cond_batch(c: List[List[ScheduledPromptConditioning]], current_step): param = c[0][0].cond -import re + >>> g("a [[[b]]:2]") # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" + + >>> g("a [[[b]]:2]") # will be represented with prompt_schedule like this (assuming steps=100): + dict_cond = param + res = {k: torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) for k, param in dict_cond.items()} + res = DictWithShape(res, (len(c),) + dict_cond['crossattn'].shape) + else: + res = torch.zeros((len(c),) + param.shape, device=param.device, dtype=param.dtype) + for i, cond_schedule in enumerate(c): target_index = 0 for current, entry in enumerate(cond_schedule): if current_step <= entry.end_at_step: target_index = current break + + if is_dict: + for k, param in cond_schedule[target_index].cond.items(): + res[k][i] = param + else: + res[i] = cond_schedule[target_index].cond + import re + """ + + + [[2, 'a '], [10, 'a [[b]]']] # will be represented with prompt_schedule like this (assuming steps=100): +from typing import List import lark +# [25, 'fantasy landscape with a mountain and an oak in foreground shoddy'] - + # and won't be able to torch.stack them. So this fixes that. import re +# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy'] from collections import namedtuple + for i in range(len(tensors)): + if tensors[i].shape[0] != token_count: + last_vector = tensors[i][-1:] +scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]" # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" + tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + + return torch.stack(tensors) + def reconstruct_multicond_batch(c: MulticondLearnedConditioning, current_step): @@ -264,22 +344,20 @@ tensors.append(composable_prompt.schedules[target_index].cond) conds_list.append(conds_for_batch) -scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]" + >>> g("[(a:2):3]") -scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]" + >>> g("[(a:2):3]") import re -scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]" + >>> g("[(a:2):3]") from collections import namedtuple -scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]" + >>> g("[(a:2):3]") from typing import List + [[2, 'a '], [10, 'a [[b]]']] import re -# [50, 'fantasy landscape with a lake and an oak in foreground in background shoddy'] + >>> g("[(a:2):3]") import lark -scheduled: "[" [prompt ":"] prompt ":" [WHITESPACE] NUMBER "]" - last_vector_repeated = last_vector.repeat([token_count - tensors[i].shape[0], 1]) - tensors[i] = torch.vstack([tensors[i], last_vector_repeated]) + >>> g("[(a:2):3]") - return conds_list, torch.stack(tensors).to(device=param.device, dtype=param.dtype) re_attention = re.compile(r""" @@ -291,7 +369,7 @@ \\\\| \\| \(| \[| -:([+-]?[.\d]+)\)| +:\s*([+-]?[.\d]+)\s*\)| \)| ]| [^\\()\[\]:]+| diff --git a/modules/rng_philox.py b/modules/rng_philox.py new file mode 100644 index 0000000000000000000000000000000000000000..5532cf9dd676012ae37b7fda3433b9b1d122f80a --- /dev/null +++ b/modules/rng_philox.py @@ -0,0 +1,102 @@ +"""RNG imitiating torch cuda randn on CPU. You are welcome. + +Usage: + +``` +g = Generator(seed=0) +print(g.randn(shape=(3, 4))) +``` + +Expected output: +``` +[[-0.92466259 -0.42534415 -2.6438457 0.14518388] + [-0.12086647 -0.57972564 -0.62285122 -0.32838709] + [-1.07454231 -0.36314407 -1.67105067 2.26550497]] +``` +""" + +import numpy as np + +philox_m = [0xD2511F53, 0xCD9E8D57] +philox_w = [0x9E3779B9, 0xBB67AE85] + +two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32) +two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32) + + +def uint32(x): + """Converts (N,) np.uint64 array into (2, N) np.unit32 array.""" + return x.view(np.uint32).reshape(-1, 2).transpose(1, 0) + + +def philox4_round(counter, key): + """A single round of the Philox 4x32 random number generator.""" + + v1 = uint32(counter[0].astype(np.uint64) * philox_m[0]) + v2 = uint32(counter[2].astype(np.uint64) * philox_m[1]) + + counter[0] = v2[1] ^ counter[1] ^ key[0] + counter[1] = v2[0] + counter[2] = v1[1] ^ counter[3] ^ key[1] + counter[3] = v1[0] + + +def philox4_32(counter, key, rounds=10): + """Generates 32-bit random numbers using the Philox 4x32 random number generator. + + Parameters: + counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation). + key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed). + rounds (int): The number of rounds to perform. + + Returns: + numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers. + """ + + for _ in range(rounds - 1): + philox4_round(counter, key) + + key[0] = key[0] + philox_w[0] + key[1] = key[1] + philox_w[1] + + philox4_round(counter, key) + return counter + + +def box_muller(x, y): + """Returns just the first out of two numbers generated by Box–Muller transform algorithm.""" + u = x * two_pow32_inv + two_pow32_inv / 2 + v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2 + + s = np.sqrt(-2.0 * np.log(u)) + + r1 = s * np.sin(v) + return r1.astype(np.float32) + + +class Generator: + """RNG that produces same outputs as torch.randn(..., device='cuda') on CPU""" + + def __init__(self, seed): + self.seed = seed + self.offset = 0 + + def randn(self, shape): + """Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform.""" + + n = 1 + for x in shape: + n *= x + + counter = np.zeros((4, n), dtype=np.uint32) + counter[0] = self.offset + counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3] + self.offset += 1 + + key = np.empty(n, dtype=np.uint64) + key.fill(self.seed) + key = uint32(key) + + g = philox4_32(counter, key) + + return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3] diff --git a/modules/script_loading.py b/modules/script_loading.py index 306a1f35f778ee06b12910b92a38c0cc2be75a31..0d55f1932ee20122bff1cef60551548f90d6a634 100644 --- a/modules/script_loading.py +++ b/modules/script_loading.py @@ -12,11 +12,12 @@ return module -def preload_extensions(extensions_dir, parser): +def preload_extensions(extensions_dir, parser, extension_list=None): if not os.path.isdir(extensions_dir): return - for dirname in sorted(os.listdir(extensions_dir)): + extensions = extension_list if extension_list is not None else os.listdir(extensions_dir) + for dirname in sorted(extensions): preload_script = os.path.join(extensions_dir, dirname, "preload.py") if not os.path.isfile(preload_script): continue diff --git a/modules/scripts.py b/modules/scripts.py index 7d9dd59f2ad40d567254df4a90d9fc6a9d8d9dcd..f7d060aa59cbe5691fb7128064079db12388f1ef 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -16,6 +16,11 @@ def __init__(self, image): self.image = image +class PostprocessBatchListArgs: + def __init__(self, images): + self.images = images + + class Script: name = None """script's internal name derived from title""" @@ -119,7 +124,7 @@ pass def after_extra_networks_activate(self, p, *args, **kwargs): """ - Calledafter extra networks activation, before conds calculation + Called after extra networks activation, before conds calculation allow modification of the network after extra networks activation been applied won't be call if p.disable_extra_networks @@ -152,6 +157,25 @@ **kwargs will have same items as process_batch, and also: - batch_number - index of current batch, from 0 to number of batches-1 - images - torch tensor with all generated images, with values ranging from 0 to 1; + """ + + pass + + def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs): + """ + Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor. + This is useful when you want to update the entire batch instead of individual images. + + You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc. + If the number of images is different from the batch size when returning, + then the script has the responsibility to also update the following attributes in the processing object (p): + - p.prompts + - p.negative_prompts + - p.seeds + - p.subseeds + + **kwargs will have same items as process_batch, and also: + - batch_number - index of current batch, from 0 to number of batches-1 """ pass @@ -537,6 +561,15 @@ except Exception: errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True) import sys + Same as process_batch(), but called for every batch after it has been generated. + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_batch_list(p, pp, *script_args, **kwargs) + except Exception: + errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) + +import sys def __init__(self, image): for script in self.alwayson_scripts: try: @@ -600,49 +633,3 @@ scripts_img2img.reload_sources(cache) reload_scripts = load_scripts # compatibility alias - - -def add_classes_to_gradio_component(comp): - """ - this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others - """ - - comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])] - - if getattr(comp, 'multiselect', False): - comp.elem_classes.append('multiselect') - - - -def IOComponent_init(self, *args, **kwargs): - if scripts_current is not None: - scripts_current.before_component(self, **kwargs) - - script_callbacks.before_component_callback(self, **kwargs) - - res = original_IOComponent_init(self, *args, **kwargs) - - add_classes_to_gradio_component(self) - - script_callbacks.after_component_callback(self, **kwargs) - - if scripts_current is not None: - scripts_current.after_component(self, **kwargs) - - return res - - -original_IOComponent_init = gr.components.IOComponent.__init__ -gr.components.IOComponent.__init__ = IOComponent_init - - -def BlockContext_init(self, *args, **kwargs): - res = original_BlockContext_init(self, *args, **kwargs) - - add_classes_to_gradio_component(self) - - return res - - -original_BlockContext_init = gr.blocks.BlockContext.__init__ -gr.blocks.BlockContext.__init__ = BlockContext_init diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 9fc89dc6a75985e68bb39f564dd55ce928ec6b29..695c573626947078aeffa550b40a5d7a38c6467c 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -3,8 +3,32 @@ import open_clip import torch import transformers.utils.hub +from modules import shared + +class ReplaceHelper: + def __init__(self): + self.replaced = [] + + def replace(self, obj, field, func): + original = getattr(obj, field, None) + if original is None: + return None + + self.replaced.append((obj, field, original)) + setattr(obj, field, func) + +import open_clip class DisableInitialization: + + def restore(self): + for obj, field, original in self.replaced: + setattr(obj, field, original) + + self.replaced.clear() + + +class DisableInitialization(ReplaceHelper): """ When an object of this class enters a `with` block, it starts: - preventing torch's layer initialization functions from working @@ -21,7 +45,7 @@ ``` """ def __init__(self, disable_clip=True): - self.replaced = [] + super().__init__() self.disable_clip = disable_clip def replace(self, obj, field, func): @@ -86,11 +110,87 @@ self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) def __exit__(self, exc_type, exc_val, exc_tb): + """ class DisableInitialization: + + + """ """ + """ + Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device, + which results in those parameters having no values and taking no memory. model.to() will be broken and + will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict. + + Usage: + ``` + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) + ``` + """ + + def __enter__(self): + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + When an object of this class enters a `with` block, it starts: class DisableInitialization: + x["device"] = "meta" + When an object of this class enters a `with` block, it starts: When an object of this class enters a `with` block, it starts: + linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs))) + conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs))) + mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs))) + self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None) + class DisableInitialization: +class DisableInitialization: + self.restore() + + +class LoadStateDictOnMeta(ReplaceHelper): + """ + Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device. + As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory. + Meant to be used together with InitializeOnMeta above. + + Usage: + ``` + with sd_disable_initialization.LoadStateDictOnMeta(state_dict): + model.load_state_dict(state_dict, strict=False) + ``` + """ + + def __init__(self, state_dict, device): + super().__init__() + self.state_dict = state_dict + self.device = device + + def __enter__(self): + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + sd = self.state_dict + device = self.device + + def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs): + params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta] + + for name, param in params: + if param.is_meta: + self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad) + + - changes CLIP and OpenCLIP to not download model weights - preventing torch's layer initialization functions from working + for name, _ in params: + key = prefix + name + if key in sd: + del sd[key] + + linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) + conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) + mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.restore() diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 6b5aae4b5f8b89dcb106d4b92f5db235915cafaf..9ad98199818bdea18f278168a5d66813d1bad66d 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -2,11 +2,10 @@ import torch from torch.nn.functional import silu from types import MethodType -import modules.textual_inversion.textual_inversion from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting import ldm.modules.attention import ldm.modules.diffusionmodules.model @@ -15,6 +14,11 @@ import ldm.models.diffusion.ddim import ldm.models.diffusion.plms import ldm.modules.encoders.modules +import sgm.modules.attention +import sgm.modules.diffusionmodules.model +import sgm.modules.diffusionmodules.openaimodel +import sgm.modules.encoders.modules + attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward @@ -26,9 +30,13 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention # silence new console spam from SD2 from torch.nn.functional import silu +diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity +ldm.modules.diffusionmodules.model.print = shared.ldm_print +ldm.util.print = shared.ldm_print +ldm.models.diffusion.ddpm.print = shared.ldm_print from torch.nn.functional import silu -import modules.textual_inversion.textual_inversion +ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None @@ -58,6 +66,9 @@ ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th + sgm.modules.diffusionmodules.model.nonlinearity = silu + sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th + if current_optimizer is not None: current_optimizer.undo() current_optimizer = None @@ -90,6 +101,10 @@ def undo_optimizations(): ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + + sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity + sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward def fix_checkpoint(): @@ -155,13 +170,13 @@ clip = None optimization_method = None import ldm.modules.diffusionmodules.model -from types import MethodType -import ldm.modules.diffusionmodules.model + import modules.textual_inversion.textual_inversion self.extra_generation_params = {} self.comments = [] + self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) def apply_optimizations(self, option=None): @@ -172,6 +187,32 @@ errors.display(e, "applying cross attention optimization") undo_optimizations() def hijack(self, m): + conditioner = getattr(m, 'conditioner', None) + if conditioner: + text_cond_models = [] + + for i in range(len(conditioner.embedders)): + embedder = conditioner.embedders[i] + typename = type(embedder).__name__ + if typename == 'FrozenOpenCLIPEmbedder': + embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) + conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) + if typename == 'FrozenCLIPEmbedder': + model_embeddings = embedder.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) + conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) + if typename == 'FrozenOpenCLIPEmbedder2': + embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g') + conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) + + if len(text_cond_models) == 1: + m.cond_stage_model = text_cond_models[0] + else: + m.cond_stage_model = conditioner + if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) @@ -209,9 +250,8 @@ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward def undo_hijack(self, m): -import torch from torch.nn.functional import silu - + if selection == "None": m.cond_stage_model = m.cond_stage_model.wrapped elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: @@ -260,11 +300,12 @@ self.hijack(m) class EmbeddingsWithFixes(torch.nn.Module): +ldm.modules.diffusionmodules.model.print = lambda *args: None import torch -def weighted_forward(sd_model, x, c, w, *args, **kwargs): super().__init__() self.wrapped = wrapped self.embeddings = embeddings + self.textual_inversion_key = textual_inversion_key def forward(self, input_ids): batch_fixes = self.embeddings.fixes @@ -278,8 +319,9 @@ vecs = [] for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: -# new memory efficient cross attention blocks do not support hypernets and we already from torch.nn.functional import silu + elif matching_optimizer is None: + emb = devices.cond_cast_unet(vec) emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index c1d780a3317634cff6440baf461126709d7bd033..8f29057a9cfcfafcf18e5c5c3eb095a8a1649198 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -42,6 +42,10 @@ self.hijack: sd_hijack.StableDiffusionModelHijack = hijack self.chunk_length = 75 + self.is_trainable = getattr(wrapped, 'is_trainable', False) + self.input_key = getattr(wrapped, 'input_key', 'txt') + self.legacy_ucg_val = None + def empty_chunk(self): """creates an empty PromptChunk and returns it""" @@ -157,7 +161,7 @@ chunk.multipliers.append(weight) position += 1 continue - emb_len = int(embedding.vec.shape[0]) + emb_len = int(embedding.vectors) if len(chunk.tokens) + emb_len > self.chunk_length: next_chunk() @@ -199,10 +203,10 @@ def forward(self, texts): """ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will -import math + """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to -from modules.shared import opts An example shape returned by this function can be: (2, 77, 768). + For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values. Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" """ @@ -242,9 +246,14 @@ name = name.replace(":", "").replace(",", "") hashes.append(f"{name}: {shorthash}") if hashes: + if self.hijack.extra_generation_params.get("TI hashes"): + hashes.append(self.hijack.extra_generation_params.get("TI hashes")) self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) - return torch.hstack(zs) + if getattr(self.wrapped, 'return_pooled', False): + return torch.hstack(zs), zs[0].pooled + else: + return torch.hstack(zs) def process_tokens(self, remade_batch_tokens, batch_multipliers): """ @@ -264,12 +273,17 @@ tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad z = self.encode_with_transformers(tokens) + pooled = getattr(z, 'pooled', None) + # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() z = z * (original_mean / new_mean) + + if pooled is not None: + z.pooled = pooled return z @@ -326,3 +340,18 @@ ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0) return embedded + + +class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + def encode_with_transformers(self, tokens): + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden") + + if self.wrapped.layer == "last": + z = outputs.last_hidden_state + else: + z = outputs.hidden_states[self.wrapped.layer_idx] + + return z diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index c1977b194190858fbb9726eb02256cdeafc654a0..2d44b856668a1a6a361a672a68beaf513ad717b8 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -92,6 +92,4 @@ return x_prev, pred_x0, e_t def do_inpainting_hijack(): - # p_sample_plms is needed because PLMS can't work with dicts as conditionings - ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py index f733e8529fb6cd68d97b2f255bc705d0cd949fbc..25c5e9831a3527b9555971b5d9add68519a2208e 100644 --- a/modules/sd_hijack_open_clip.py +++ b/modules/sd_hijack_open_clip.py @@ -35,3 +35,37 @@ ids = torch.asarray([ids], device=devices.device, dtype=torch.int) embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) return embedded + + +class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ',</w>'][0] + self.id_start = tokenizer.encoder["<start_of_text>"] + self.id_end = tokenizer.encoder["<end_of_text>"] + self.id_pad = 0 + + def tokenize(self, texts): + assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' + + tokenized = [tokenizer.encode(text) for text in texts] + + return tokenized + + def encode_with_transformers(self, tokens): + d = self.wrapped.encode_with_transformer(tokens) + z = d[self.wrapped.layer] + + pooled = d.get("pooled") + if pooled is not None: + z.pooled = pooled + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + ids = tokenizer.encode(init_text) + ids = torch.asarray([ids], device=devices.device, dtype=torch.int) + embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0) + + return embedded diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 53e27adea2caaf39f3d58a9a015bb36eb0395bba..0e810eec8a9a01f28ca96007595eb9d00e08eff9 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -14,7 +14,11 @@ import ldm.modules.attention import ldm.modules.diffusionmodules.model +import sgm.modules.attention +import sgm.modules.diffusionmodules.model + diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward +sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward class SdOptimization: @@ -39,6 +43,9 @@ def undo(self): ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward + class SdOptimizationXformers(SdOptimization): name = "xformers" @@ -51,6 +58,8 @@ def apply(self): ldm.modules.attention.CrossAttention.forward = xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward + sgm.modules.attention.CrossAttention.forward = xformers_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward class SdOptimizationSdpNoMem(SdOptimization): @@ -65,6 +74,8 @@ def apply(self): ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward + sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward class SdOptimizationSdp(SdOptimizationSdpNoMem): @@ -76,6 +87,8 @@ def apply(self): ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward + sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward class SdOptimizationSubQuad(SdOptimization): @@ -86,6 +99,8 @@ def apply(self): ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward + sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward class SdOptimizationV1(SdOptimization): @@ -94,9 +109,10 @@ label = "original v1" cmd_opt = "opt_split_attention_v1" priority = 10 - def apply(self): from torch import einsum +import math + label = "scaled dot product without memory efficient attention" import math @@ -110,6 +126,7 @@ return 1000 if not torch.cuda.is_available() else 10 def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI class SdOptimizationDoggettx(SdOptimization): @@ -120,6 +137,9 @@ def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward from ldm.util import default +import torch + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward + label = "scaled dot product without memory efficient attention" import torch @@ -157,7 +177,7 @@ return psutil.virtual_memory().available # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None): +def split_cross_attention_forward_v1(self, x, context=None, mask=None, **kwargs): h = self.heads q_in = self.to_q(x) @@ -198,8 +218,8 @@ return self.to_out(r2) # taken from https://github.com/Doggettx/stable-diffusion and modified -diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward import psutil + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) h = self.heads q_in = self.to_q(x) @@ -241,9 +261,9 @@ max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + slice_size = q.shape[1] // steps for i in range(0, q.shape[1], slice_size): - end = i + slice_size + end = min(i + slice_size, q.shape[1]) s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) s2 = s1.softmax(dim=-1, dtype=q.dtype) @@ -265,17 +285,20 @@ # -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- mem_total_gb = psutil.virtual_memory().total // (1 << 30) + def einsum_op_compvis(q, k, v): s = einsum('b i d, b j d -> b i j', q, k) s = s.softmax(dim=-1, dtype=s.dtype) return einsum('b i j, b j d -> b i d', s, v) + def einsum_op_slice_0(q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], slice_size): end = i + slice_size r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) return r + def einsum_op_slice_1(q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) @@ -283,6 +306,7 @@ for i in range(0, q.shape[1], slice_size): end = i + slice_size r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) return r + def einsum_op_mps_v1(q, k, v): if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 @@ -293,11 +317,13 @@ if slice_size % 4096 == 0: slice_size -= 1 return einsum_op_slice_1(q, k, v, slice_size) + def einsum_op_mps_v2(q, k, v): if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: return einsum_op_compvis(q, k, v) else: priority: int = 0 + def einsum_op_tensor_mem(q, k, v, max_tensor_mb): @@ -309,6 +335,7 @@ if div <= q.shape[0]: return einsum_op_slice_0(q, k, v, q.shape[0] // div) return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + def einsum_op_cuda(q, k, v): stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] @@ -318,6 +345,7 @@ mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch # Divide factor of safety as there's copying and fragmentation return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + def einsum_op(q, k, v): if q.device.type == 'cuda': @@ -332,7 +360,8 @@ # Smaller slices are faster due to L2/L3/SLC caches. # Tested on i7 with 8MB L3 cache. return einsum_op_tensor_mem(q, k, v, 32) -def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): + +def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, **kwargs): h = self.heads q = self.to_q(x) @@ -360,8 +389,8 @@ # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface - return f"{self.name} - {self.label}" +from __future__ import annotations assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." h = self.heads @@ -396,6 +425,7 @@ x = out_proj(x) x = dropout(x) return x + def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): bytes_per_token = torch.finfo(q.dtype).bits//8 @@ -447,8 +477,8 @@ return None + cmd_opt = "opt_sdp_no_mem_attention" import math -def list_optimizers(res): h = self.heads q_in = self.to_q(x) context = default(context, x) @@ -470,11 +500,12 @@ out = out.to(dtype) out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) + # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface - ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward +import psutil batch_size, sequence_length, inner_dim = x.shape if mask is not None: @@ -514,11 +545,13 @@ # dropout hidden_states = self.to_out[1](hidden_states) return hidden_states -import psutil + + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return scaled_dot_product_attention_forward(self, x, context, mask) + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) @@ -577,6 +610,7 @@ h3 += x return h3 + def xformers_attnblock_forward(self, x): try: h_ = x @@ -599,6 +633,7 @@ out = self.proj_out(out) return x + out except NotImplementedError: return cross_attention_attnblock_forward(self, x) + def sdp_attnblock_forward(self, x): h_ = x @@ -620,9 +655,11 @@ out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out + def sdp_no_mem_attnblock_forward(self, x): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return sdp_attnblock_forward(self, x) + def sub_quad_attnblock_forward(self, x): h_ = x diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index ca1daf45f3c625fc6947b9547f4130ca7ed80ab2..2101f1a04152bd53214934d61b06e0d22af71ca7 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -39,8 +39,11 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): if isinstance(cond, dict): for y in cond.keys(): -from packaging import version +from modules.sd_hijack_utils import CondFunc This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match; + cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + else: + cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] with devices.autocast(): return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() @@ -78,3 +81,6 @@ first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devices.dtype_vae), **kwargs) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) + +CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast) +CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) diff --git a/modules/sd_models.py b/modules/sd_models.py index 060e0007c10aacd6e7517d476d0c3fad43c7e593..f60516046324fecad4df2f5baf42c2bea5fa3e42 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,9 +14,9 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet +import gc import collections - +import os.path from modules.timer import Timer import tomesd @@ -34,6 +34,8 @@ def __init__(self, filename): self.filename = filename abspath = os.path.abspath(filename) + self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" + if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): name = abspath.replace(shared.cmd_opts.ckpt_dir, '') elif abspath.startswith(model_path): @@ -44,35 +46,43 @@ if name.startswith("\\") or name.startswith("/"): name = name[1:] - self.name = name - self.name_for_extra = os.path.splitext(os.path.basename(filename))[0] - self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] import gc - +from ldm.util import instantiate_from_config self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}") +import threading import gc -import os.path +from modules.sd_hijack_inpainting import do_inpainting_hijack import gc -import sys +from modules.timer import Timer import gc +import threading import gc +import tomesd - + try: import gc -import threading +model_dir = "Stable-diffusion" - + except Exception as e: import gc +model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) -import gc +import sys import torch -import gc +import sys import re -import gc +import sys import safetensors.torch -import threading +import gc -import threading + +import gc import collections + self.shorthash = self.sha256[0:10] if self.sha256 else None + + self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' + self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]' + + self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) def register(self): checkpoints_list[self.title] = self @@ -89,9 +99,10 @@ if self.shorthash not in self.ids: self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] - + self.shorthash = self.sha256[0:10] if self.sha256 else None import sys self.title = f'{self.name} [{self.shorthash}]' + self.short_title = f'{self.name_for_extra} [{self.shorthash}]' self.register() return self.shorthash @@ -112,14 +123,9 @@ enable_midas_autodownload() -def checkpoint_tiles(): - def convert(name): - return int(name) if name.isdigit() else name.lower() - - def alphanumeric_key(key): +def checkpoint_tiles(use_short=False): - return [convert(c) for c in re.split('([0-9]+)', key)] + self.shorthash = self.sha256[0:10] if self.sha256 else None - return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) def list_models(): @@ -142,18 +148,26 @@ shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) -import safetensors.torch + self.shorthash = self.sha256[0:10] if self.sha256 else None import torch checkpoint_info = CheckpointInfo(filename) checkpoint_info.register() +re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") + + def get_closet_checkpoint_match(search_string): checkpoint_info = checkpoint_aliases.get(search_string, None) if checkpoint_info is not None: return checkpoint_info found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title)) + if found: + return found[0] + + search_string_without_checksum = re.sub(re_strip_checksum, '', search_string) + found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title)) if found: return found[0] @@ -301,16 +315,24 @@ if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + model.is_sdxl = hasattr(model, 'conditioner') + self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' import os.path -import os.path + model.is_sd1 = not model.is_sdxl and not model.is_sd2 + + if model.is_sdxl: + sd_models_xl.extend_sdxl(model) + model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) -import sys +import os.path timer.record("apply weights to model") if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model + checkpoints_loaded[checkpoint_info] = state_dict + model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) -import torch +import sys if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) @@ -334,7 +356,7 @@ model.depth_model = depth_model timer.record("apply half()") - devices.dtype_unet = model.model.diffusion_model.dtype + devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 model.first_stage_model.to(devices.dtype_vae) @@ -349,9 +371,10 @@ model.sd_model_checkpoint = checkpoint_info.filename model.sd_checkpoint_info = checkpoint_info shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 -import os.path +import gc import sys - +import re + model.logvar = model.logvar.to(devices.device) # fix for training sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() @@ -408,11 +431,12 @@ if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False - if shared.cmd_opts.no_half: + if hasattr(sd_config.model.params, 'unet_config'): - sd_config.model.params.unet_config.params.use_fp16 = False + if shared.cmd_opts.no_half: + self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) import os.path - return int(name) if name.isdigit() else name.lower() + elif shared.cmd_opts.upcast_sampling: - sd_config.model.params.unet_config.params.use_fp16 = True + sd_config.model.params.unet_config.params.use_fp16 = True if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" @@ -425,11 +449,15 @@ sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' +sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight' + self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) + class SdModelData: def __init__(self): self.sd_model = None + self.loaded_sd_models = [] self.was_loaded_at_least_once = False self.lock = threading.Lock() @@ -444,6 +472,7 @@ return self.sd_model try: load_model() + except Exception as e: errors.display(e, "loading stable diffusion model", full_traceback=True) print("", file=sys.stderr) @@ -455,34 +484,79 @@ def set_sd_model(self, v): self.sd_model = v + try: + self.loaded_sd_models.remove(v) + except ValueError: + pass + + if v is not None: + self.loaded_sd_models.insert(0, v) + model_data = SdModelData() - elif abspath.startswith(model_path): +def get_empty_cond(sd_model): + from modules import extra_networks, processing + + p = processing.StableDiffusionProcessingTxt2Img() + extra_networks.activate(p, {}) + + if hasattr(sd_model, 'conditioner'): + d = sd_model.get_learned_conditioning([""]) + _, ext = os.path.splitext(self.filename) - from modules import lowvram, sd_hijack + else: -import sys + _, ext = os.path.splitext(self.filename) import collections + + + _, ext = os.path.splitext(self.filename) import os.path + from modules import lowvram import sys -import collections + self.hash = model_hash(filename) + lowvram.send_everything_to_cpu() + else: + m.to(devices.cpu) + + devices.torch_gc() + + +def send_model_to_device(m): + _, ext = os.path.splitext(self.filename) import sys + import sys + self.hash = model_hash(filename) + lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) + else: + m.to(shared.device) + + +def send_model_to_trash(m): + if ext.lower() == ".safetensors": import collections import gc + return self.shorthash + + elif abspath.startswith(model_path): -import threading + from modules import sd_hijack elif abspath.startswith(model_path): +import os.path elif abspath.startswith(model_path): -import torch +import safetensors.torch elif abspath.startswith(model_path): -import re +import sys - + send_model_to_trash(model_data.sd_model) + model_data.sd_model = None elif abspath.startswith(model_path): -import safetensors.torch +import torch + + timer.record("unload existing model") if already_loaded_state_dict is not None: state_dict = already_loaded_state_dict @@ -490,7 +563,7 @@ else: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict + clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict) timer.record("find config") @@ -503,32 +576,32 @@ print(f"Creating model from config: {checkpoint_config}") sd_model = None try: - with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd): + with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): - sd_model = instantiate_from_config(sd_config.model) + with sd_disable_initialization.InitializeOnMeta(): - except Exception: + sd_model = instantiate_from_config(sd_config.model) - pass + + except Exception as e: + errors.display(e, "creating model quickly", full_traceback=True) if sd_model is None: print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) - else: + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) sd_model.used_config = checkpoint_config timer.record("create model") - load_model_weights(sd_model, checkpoint_info, state_dict, timer) - -import sys import gc +def list_models(): import sys + def convert(name): import gc -import collections import re -import re import sys - self.shorthash = self.sha256[0:10] if self.sha256 else None + send_model_to_device(sd_model) timer.record("move model to device") sd_hijack.model_hijack.hijack(sd_model) @@ -535,9 +609,8 @@ timer.record("hijack") sd_model.eval() -import sys import gc -import torch + cmd_ckpt = shared.cmd_opts.ckpt model_data.was_loaded_at_least_once = True sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model @@ -549,8 +622,8 @@ timer.record("scripts callbacks") with devices.autocast(), torch.no_grad(): - if name.startswith("\\") or name.startswith("/"): import gc + if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt): timer.record("calculate empty prompt") @@ -559,55 +632,107 @@ return sd_model +def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): +checkpoints_loaded = collections.OrderedDict() import sys -import threading + Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models. +import gc import re +import safetensors.torch + If not, returns the model that can be used to load weights from checkpoint_info's file. + If no such model exists, returns None. + Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit). +checkpoints_loaded = collections.OrderedDict() import sys + + already_loaded = None + for i in reversed(range(len(model_data.loaded_sd_models))): + self.metadata = read_metadata_from_safetensors(filename) import threading +import gc import safetensors.torch + + already_loaded = loaded_model + continue + + if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0: + print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}") + model_data.loaded_sd_models.pop() + send_model_to_trash(loaded_model) + except Exception as e: import sys +import threading +import gc + send_model_to_cpu(sd_model) + timer.record("send model to cpu") -import sys + if already_loaded is not None: + send_model_to_device(already_loaded) + timer.record("send model to device") + model_data.set_sd_model(already_loaded) + errors.display(e, f"reading checkpoint metadata: {filename}") import collections import sys +import torch + errors.display(e, f"reading checkpoint metadata: {filename}") import os.path + print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})") import sys +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet + load_model(checkpoint_info) + self.name = name -import sys + elif len(model_data.loaded_sd_models) > 0: + sd_model = model_data.loaded_sd_models.pop() + model_data.sd_model = sd_model -import sys -import gc + print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}") + return sd_model else: -import sys + return None + +import sys import threading +import re + from modules import devices, sd_hijack name = name[1:] -import threading + elif abspath.startswith(model_path): import safetensors.torch name = name[1:] -import torch +import collections + sd_model = model_data.sd_model name = name[1:] +import sys + current_checkpoint_info = None +import re import re name = name[1:] -import safetensors.torch +import threading import sys + return self.shorthash + return sd_model + + sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) + def register(self): import sys - sd_model.to(devices.cpu) + return sd_model + if sd_model is not None: import sys + import torch -import collections - + send_model_to_cpu(sd_model) import sys +import torch import collections -import safetensors.torch state_dict = get_checkpoint_state_dict(checkpoint_info, timer) @@ -616,9 +739,10 @@ timer.record("find config") if sd_model is None or checkpoint_config != sd_model.used_config: -import sys + if sd_model is not None: + def register(self): import torch -import gc + load_model(checkpoint_info, already_loaded_state_dict=state_dict) return model_data.sd_model @@ -640,6 +764,8 @@ sd_model.to(devices.device) timer.record("move model to device") print(f"Weights loaded in {timer.summary()}.") + + model_data.set_sd_model(sd_model) return sd_model diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 9bfe1237d1ded6d8ade2e565fe33622b444b8b02..8266fa39797b2044ddd0abd5e921bca5c6f87bea 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -6,12 +6,15 @@ from modules import shared, paths, sd_disable_initialization sd_configs_path = shared.sd_configs_path sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") +sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "configs", "inference") config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") +config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") +config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -68,7 +71,11 @@ sd2_cond_proj_weight = sd.get('cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight', None) diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) - if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: + if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: + return config_sdxl + if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: + return config_sdxl_refiner + elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: return config_depth_model elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: return config_unclip diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py new file mode 100644 index 0000000000000000000000000000000000000000..0112332161fa21807e705cc1763681eefd7456e5 --- /dev/null +++ b/modules/sd_models_xl.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import torch + +import sgm.models.diffusion +import sgm.modules.diffusionmodules.denoiser_scaling +import sgm.modules.diffusionmodules.discretizer +from modules import devices, shared, prompt_parser + + +def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): + for embedder in self.conditioner.embedders: + embedder.ucg_rate = 0.0 + + width = getattr(batch, 'width', 1024) + height = getattr(batch, 'height', 1024) + is_negative_prompt = getattr(batch, 'is_negative_prompt', False) + aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score + + devices_args = dict(device=devices.device, dtype=devices.dtype) + + sdxl_conds = { + "txt": batch, + "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), + "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1), + "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), + "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1), + } + + force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch) + c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) + + return c + + +def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): + return self.model(x, t, cond) + + +def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility + return x + + +sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning +sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding + + +def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt): + res = [] + + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]: + encoded = embedder.encode_embedding_init_text(init_text, nvpt) + res.append(encoded) + + return torch.cat(res, dim=1) + + +def tokenize(self: sgm.modules.GeneralConditioner, texts): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]: + return embedder.tokenize(texts) + + raise AssertionError('no tokenizer available') + + + +def process_texts(self, texts): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: + return embedder.process_texts(texts) + + +def get_target_prompt_token_count(self, token_count): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]: + return embedder.get_target_prompt_token_count(token_count) + + +# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist +sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text +sgm.modules.GeneralConditioner.tokenize = tokenize +sgm.modules.GeneralConditioner.process_texts = process_texts +sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count + + +def extend_sdxl(model): + """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" + + dtype = next(model.model.diffusion_model.parameters()).dtype + model.model.diffusion_model.dtype = dtype + model.model.conditioning_key = 'crossattn' + model.cond_stage_key = 'txt' + # model.cond_stage_model will be set in sd_hijack + + model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" + + discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() + model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + + model.conditioner.wrapped = torch.nn.Module() + + +sgm.modules.attention.print = shared.ldm_print +sgm.modules.diffusionmodules.model.print = shared.ldm_print +sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print +sgm.modules.encoders.modules.print = shared.ldm_print + +# this gets the code to load the vanilla attention that we override +sgm.modules.attention.SDP_IS_AVAILABLE = True +sgm.modules.attention.XFORMERS_IS_AVAILABLE = False diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index f22aad8f2165c5eb350d147b673b1accacbcd818..bea2684c4db6171075a38427bb9c01b34d9e688a 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -28,6 +28,9 @@ config = find_sampler_config(name) assert config is not None, f'bad sampler name: {name}' + if model.is_sdxl and config.options.get("no_sdxl", False): + raise Exception(f"Sampler {config.name} is not supported for SDXL") + sampler = config.constructor(model) sampler.config = config diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 763829f1ca34be1878cdcafd5d7e594488f6f2f8..b3d344e777b61da442c0fd8d8d2b7e7b7bb0ffc0 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -2,10 +2,9 @@ from collections import namedtuple import numpy as np import torch from PIL import Image -from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd +import torch from modules.shared import opts, state -import modules.shared as shared SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) @@ -37,7 +36,7 @@ elif approximation == 3: x_sample = sample * 1.5 x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() else: - x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5 + x_sample = decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5 x_sample = torch.clamp(x_sample, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) @@ -46,6 +45,12 @@ return Image.fromarray(x_sample) +def decode_first_stage(model, x): + x = model.decode_first_stage(x.to(devices.dtype_vae)) + + return x + + def sample_to_image(samples, index=0, approximation=None): return single_sample_to_image(samples[index], approximation) @@ -85,13 +90,15 @@ class InterruptedException(BaseException): pass -if opts.randn_source == "CPU": +def replace_torchsde_browinan(): import torchsde._brownian.brownian_interval def torchsde_randn(size, dtype, device, seed): +SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) -from collections import namedtuple +import numpy as np -import numpy as np + +replace_torchsde_browinan() diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index bdae8b404b50505530b35400825f95752b143691..4a8396f97ec3fd75b2eabdc9fbf97ac34af27e6a 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -12,11 +12,11 @@ samplers_data_compvis = [ import math -import math + if isinstance(cond, dict): import math -import ldm.models.diffusion.ddim + if self.conditioning_key == "crossattn-adm": import math -import ldm.models.diffusion.plms + image_conditioning = cond["c_adm"] ] diff --git a/modules/sd_samplers_extra.py b/modules/sd_samplers_extra.py new file mode 100644 index 0000000000000000000000000000000000000000..1b981ca80c355cbb6a92915d422cfa19e51b21c4 --- /dev/null +++ b/modules/sd_samplers_extra.py @@ -0,0 +1,74 @@ +import torch +import tqdm +import k_diffusion.sampling + + [email protected]_grad() +def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None): + """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023) + Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]} + If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list + """ + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + step_id = 0 + from k_diffusion.sampling import to_d, get_sigmas_karras + + def heun_step(x, old_sigma, new_sigma, second_order=True): + nonlocal step_id + denoised = model(x, old_sigma * s_in, **extra_args) + d = to_d(x, old_sigma, denoised) + if callback is not None: + callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised}) + dt = new_sigma - old_sigma + if new_sigma == 0 or not second_order: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, new_sigma * s_in, **extra_args) + d_2 = to_d(x_2, new_sigma, denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + step_id += 1 + return x + + steps = sigmas.shape[0] - 1 + if restart_list is None: + if steps >= 20: + restart_steps = 9 + restart_times = 1 + if steps >= 36: + restart_steps = steps // 4 + restart_times = 2 + sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device) + restart_list = {0.1: [restart_steps + 1, restart_times, 2]} + else: + restart_list = {} + + restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()} + + step_list = [] + for i in range(len(sigmas) - 1): + step_list.append((sigmas[i], sigmas[i + 1])) + if i + 1 in restart_list: + restart_steps, restart_times, restart_max = restart_list[i + 1] + min_idx = i + 1 + max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0)) + if max_idx < min_idx: + sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] + while restart_times > 0: + restart_times -= 1 + step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])]) + + last_sigma = None + for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable): + if last_sigma is None: + last_sigma = old_sigma + elif last_sigma < old_sigma: + x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5 + x = heun_step(x, old_sigma, new_sigma) + last_sigma = new_sigma + + return x diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 71581b763593c2cf1da347f2e6ea3898a891d6b5..8bb639f57a9f1b7c90a9d15e02845d52991758cc 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,7 +2,7 @@ from collections import deque import torch import inspect import k_diffusion.sampling -from modules import prompt_parser, devices, sd_samplers_common +from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra from modules.shared import opts, state import modules.shared as shared @@ -31,12 +31,16 @@ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), ] +from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback + ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}), +] + samplers_data_k_diffusion = [ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) for label, funcname, aliases, options in samplers_k_diffusion import inspect - + ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), ] sampler_extra_params = { @@ -52,6 +56,28 @@ 'karras': k_diffusion.sampling.get_sigmas_karras, 'exponential': k_diffusion.sampling.get_sigmas_exponential, 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential } + + +def catenate_conds(conds): + if not isinstance(conds[0], dict): + return torch.cat(conds) + + return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} + + +def subscript_cond(cond, a, b): + if not isinstance(cond, dict): + return cond[a:b] + + return {key: vec[a:b] for key, vec in cond.items()} + + +def pad_cond(tensor, repeats, empty): + if not isinstance(tensor, dict): + return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) + + tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) + return tensor class CFGDenoiser(torch.nn.Module): @@ -106,11 +132,14 @@ repeats = [len(conds_list[i]) for i in range(batch_size)] if shared.sd_model.model.conditioning_key == "crossattn-adm": image_uncond = torch.zeros_like(image_cond) - make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm} + make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} else: image_uncond = image_cond + if isinstance(uncond, dict): + make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} + else: + sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) import modules.shared as shared -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback if not is_edit_model: x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) @@ -142,34 +171,33 @@ empty = shared.sd_model.cond_stage_model_empty_prompt num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] if num_repeats < 0: -from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback +import inspect import k_diffusion.sampling +from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback self.padded_cond_uncond = True elif num_repeats > 0: - uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1) + uncond = pad_cond(uncond, num_repeats, empty) self.padded_cond_uncond = True if tensor.shape[1] == uncond.shape[1] or skip_uncond: if is_edit_model: + for label, funcname, aliases, options in samplers_k_diffusion from collections import deque - ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), elif skip_uncond: cond_in = tensor else: -from collections import deque + for label, funcname, aliases, options in samplers_k_diffusion import torch -import k_diffusion.sampling if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict([cond_in], image_cond_in)) + x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) else: x_out = torch.zeros_like(x_in) for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size -from collections import deque import inspect -from collections import deque + def __init__(self, model): else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size @@ -178,17 +206,16 @@ a = batch_offset b = min(a + batch_size, tensor.shape[0]) if not is_edit_model: -from collections import deque import inspect -import modules.shared as shared + super().__init__() else: c_crossattn = torch.cat([tensor[a:b]], uncond) x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) if not skip_uncond: - ('Euler', 'sample_euler', ['k_euler'], {}), import inspect + self.inner_model = model denoised_image_indexes = [x[0][0] for x in conds_list] if skip_uncond: @@ -244,12 +271,9 @@ noise = self.sampler_noises.popleft() if noise.shape == x.shape: return noise - if opts.randn_source == "CPU" or x.device.type == 'mps': - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}), import inspect -import modules.shared as shared +from modules import prompt_parser, devices, sd_samplers_common from modules.shared import opts, state - return torch.randn_like(x) class KDiffusionSampler: @@ -258,7 +282,7 @@ denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) + self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.sampler_noises = None @@ -364,6 +388,9 @@ elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) + elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential': + m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) + sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device) else: sigmas = self.model_wrap.get_sigmas(steps) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index e4ff29946e8548e8dad447aa90f264d472acb4ed..0bd5e19bb3f9bbdfa937e91703249215f8e034e3 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,6 +1,6 @@ import os import collections -from modules import paths, shared, devices, script_callbacks, sd_models +from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks import glob from copy import deepcopy @@ -15,6 +15,7 @@ loaded_vae_file = None checkpoint_info = None checkpoints_loaded = collections.OrderedDict() + def get_base_vae(model): if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: @@ -50,6 +51,7 @@ return os.path.basename(filepath) def refresh_vae_list(): + global vae_dict vae_dict.clear() paths = [ @@ -83,6 +85,8 @@ for filepath in candidates: name = get_filename(filepath) vae_dict[name] = filepath + vae_dict = dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0]))) + def find_vae_near_checkpoint(checkpoint_file): checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0] @@ -96,6 +100,16 @@ def resolve_vae(checkpoint_file): if shared.cmd_opts.vae_path is not None: return shared.cmd_opts.vae_path, 'from commandline argument' + + metadata = extra_networks.get_user_metadata(checkpoint_file) + vae_metadata = metadata.get("vae", None) + if vae_metadata is not None and vae_metadata != "Automatic": + if vae_metadata == "None": + return None, None + + vae_from_metadata = vae_dict.get(vae_metadata, None) + if vae_from_metadata is not None: + return vae_from_metadata, "from user metadata" is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index e2f004683be33bf0c717c2de770e250d404fc22f..86bd658ad32af5ba22d087ec668f450fedd693a0 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -2,9 +2,9 @@ import os import torch from torch import nn -from modules import devices, paths +from modules import devices, paths, shared -sd_vae_approx_model = None +sd_vae_approx_models = {} class VAEApprox(nn.Module): @@ -31,44 +31,69 @@ return x +def download_model(model_path, model_url): + if not os.path.exists(model_path): + os.makedirs(os.path.dirname(model_path), exist_ok=True) + + print(f'Downloading VAEApprox model to: {model_path}') + torch.hub.download_url_to_file(model_url, model_path) + + def model(): + model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt" + loaded_model = sd_vae_approx_models.get(model_name) + if loaded_model is None: +from modules import devices, paths class VAEApprox(nn.Module): - + if not os.path.exists(model_path): - +from modules import devices, paths def __init__(self): + if not os.path.exists(model_path): +from modules import devices, paths super(VAEApprox, self).__init__() -import torch +sd_vae_approx_model = None -import torch + +sd_vae_approx_model = None import os -import torch +sd_vae_approx_model = None -import torch +sd_vae_approx_model = None import torch -import torch +sd_vae_approx_model = None from torch import nn -import torch +sd_vae_approx_model = None from modules import devices, paths -import torch +sd_vae_approx_model = None sd_vae_approx_model = None def cheap_approximation(sample): # https://discuss.huggingface.co/t/decoding-latents-to-rgb-without-upscaling/23204/2 -import torch + if shared.sd_model.is_sdxl: + coeffs = [ +sd_vae_approx_model = None super(VAEApprox, self).__init__() -from torch import nn +class VAEApprox(nn.Module): -from torch import nn +class VAEApprox(nn.Module): import os -from torch import nn +class VAEApprox(nn.Module): -from torch import nn +class VAEApprox(nn.Module): import torch +class VAEApprox(nn.Module): from torch import nn -from torch import nn + coeffs = [ + [ 0.298, 0.207, 0.208], + [ 0.187, 0.286, 0.173], + [-0.158, 0.189, 0.264], + [-0.184, -0.271, -0.473], + ] + + coefs = torch.tensor(coeffs).to(sample.device) x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs) diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 5e8496e8739e67b509d9ded4213b88902982114d..5bf7c76e1dd8ca9a7fc3b624861ed9962387222e 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -8,9 +8,10 @@ import os import torch import torch.nn as nn -from modules import devices, paths_internal +from modules import devices, paths_internal, shared +import os """ def conv(n_in, n_out, **kwargs): @@ -60,10 +62,8 @@ """[0, 1] -> raw latents""" return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) - +import os Tiny AutoEncoder for Stable Diffusion - model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth' - if not os.path.exists(model_path): os.makedirs(os.path.dirname(model_path), exist_ok=True) @@ -72,19 +72,22 @@ torch.hub.download_url_to_file(model_url, model_path) def model(): + model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" +import os -from modules import devices, paths_internal +import os https://github.com/madebyollin/taesd - model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth") + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name) - download_model(model_path) + download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) if os.path.exists(model_path): - sd_vae_taesd = TAESD(model_path) - sd_vae_taesd.eval() + loaded_model = TAESD(model_path) -https://github.com/madebyollin/taesd import os +from modules import devices, paths_internal + loaded_model.to(devices.device, devices.dtype) + sd_vae_taesd_models[model_name] = loaded_model else: raise FileNotFoundError('TAESD model not found') - return sd_vae_taesd.decoder + return loaded_model.decoder diff --git a/modules/shared.py b/modules/shared.py index f6604ef913bddd57fb69fa1a5c476a97f0bbcc4b..8245250a5a2607a3580672ae63fa08e7a94daf30 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -11,6 +11,7 @@ import gradio as gr import torch import tqdm +import launch import modules.interrogate import modules.memmon import modules.styles @@ -26,8 +27,9 @@ demo = None parser = cmd_args.parser -import json +import threading import re + script_loading.preload_extensions(extensions_builtin_dir, parser) if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None: @@ -220,17 +222,24 @@ if self.current_latent is None: return import modules.sd_samplers - if opts.show_progress_grid: + -import datetime +import re skipped = False -from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args +import threading import sys +config_filename = cmd_opts.ui_settings_file import datetime -import time + else: import threading + "ysharma/steampunk" -from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args + self.current_image_sampling_step = self.sampling_step + + except Exception: + # when switching models during genration, VAE would be on CPU, so creating an image will fail. +config_filename = cmd_opts.ui_settings_file import time + errors.record_exception() def assign_current_image(self, image): self.current_image = image @@ -390,13 +398,15 @@ "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), })) options_templates.update(options_section(('system', "System"), { - "show_warnings": OptionInfo(False, "Show warnings in console."), + "show_warnings": OptionInfo(False, "Show warnings in console.").needs_restart(), + "show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_restart(), "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"), "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), + "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), })) options_templates.update(options_section(('training', "Training"), { @@ -416,47 +426,67 @@ })) options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), + "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), +os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) import os -import gradio as gr + "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"), + "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_restart(), import os +import json import datetime -import sys import os +demo = None + "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"), + "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), + "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), + "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), + "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), +if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None: import datetime + import threading +loaded_hypernetworks = [] - "samples_filename_pattern", + "sdxl_crop_top": OptionInfo(0, "crop top coordinate"), +import threading import time -import os +hypernetworks = {} import datetime -import logging -import os + "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), +if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None: import datetime + +hypernetworks = {} import os -from typing import Optional import os -import json import datetime +import sys import os -demo = None +import modules.devices as devices import os -parser = cmd_args.parser +from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args import os -import json +from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 +hypernetworks = {} import re + "img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(), + "img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_restart(), + "img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_restart(), + "img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_restart(), + "outdir_grids", import os -script_loading.preload_extensions(extensions_builtin_dir, parser) import os -if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None: + skipped = False })) + options_templates.update(options_section(('optimizations', "Optimizations"), { "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), - "directories_filename_pattern", +hypernetworks = {} "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), @@ -474,7 +502,7 @@ "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), })) -options_templates.update(options_section(('interrogate', "Interrogate Options"), { +options_templates.update(options_section(('interrogate', "Interrogate"), { "interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"), "interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"), "interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), @@ -508,11 +536,7 @@ options_templates.update(options_section(('ui', "User interface"), { "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_restart(), "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}).needs_restart(), "outdir_grids", -import datetime - "outdir_grids", import json - "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"), - "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), @@ -531,10 +555,11 @@ "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(), "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(), - "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(), + "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_restart(), "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(), })) + options_templates.update(options_section(('infotext', "Infotext"), { "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), @@ -908,3 +933,10 @@ if not opts.list_hidden_files and ("/." in root or "\\." in root): continue yield os.path.join(root, filename) + + +def ldm_print(*args, **kwargs): + if opts.hide_ldm_prints: + return + + print(*args, **kwargs) diff --git a/modules/styles.py b/modules/styles.py index ec0e1bc51dbfb6415da3c953c96fd524a554ac0d..0740fe1b1c0ed15a5d44f25af144f3f2257fa0d9 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -106,10 +106,7 @@ # Always keep a backup file around if os.path.exists(path): shutil.copy(path, f"{path}.bak") - fd = os.open(path, os.O_RDWR | os.O_CREAT) - with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: - # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, - # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() + with open(path, "w", encoding="utf-8-sig", newline='') as file: writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) writer.writeheader() writer.writerows(style._asdict() for k, style in self.styles.items()) diff --git a/modules/sysinfo.py b/modules/sysinfo.py index 5f15ac4fa94ca0eb1a41bac92623963bd6140189..cf24c6dd4a4effb272311b5d83df7673444fa108 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -109,12 +109,17 @@ def format_traceback(tb): return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)] +def format_exception(e, tb): + return {"exception": str(e), "traceback": format_traceback(tb)} + + def get_exceptions(): try: from modules import errors -import psutil +import json import json +import hashlib except Exception as e: return str(e) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 6166c76f6537b757ebea4ec40b5dc59eecfc2ff5..aa79dc09843ae6575af352bd61a7c05ba16376c5 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -13,7 +13,7 @@ import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -import os + if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None: import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -182,36 +182,44 @@ data = safetensors.torch.load_file(path, device="cpu") else: return + # textual inversion embeddings if 'string_to_param' in data: param_dict = data['string_to_param'] param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 assert len(param_dict) == 1, 'embedding file has multiple terms in it' emb = next(iter(param_dict.items()))[1] -import os + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding import torch -import os + def __init__(self, vec, name, step=None): + shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] + optimizer_saved_dict = { import os import torch +from contextlib import closing from collections import namedtuple assert len(data.keys()) == 1, 'embedding file has multiple terms in it' emb = next(iter(data.values())) if len(emb.shape) == 1: emb = emb.unsqueeze(0) + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] else: raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") import modules.textual_inversion.dataset -import datetime -import modules.textual_inversion.dataset import csv embedding.step = data.get('step', None) embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) -from modules.textual_inversion.learn_schedule import LearnRateScheduler + optimizer_saved_dict = { from contextlib import closing -from modules.textual_inversion.learn_schedule import LearnRateScheduler + optimizer_saved_dict = { embedding.filename = path embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '') @@ -387,6 +394,8 @@ assert log_directory, "Log directory is empty" def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + from modules import processing + save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 template_file = textual_inversion_templates.get(template_filename, None) diff --git a/modules/timer.py b/modules/timer.py index da99e49f82df4260676cf8205cf6bcc93b9d3f01..1d38595c7fff24a9047a619f440811007f6cde44 100644 --- a/modules/timer.py +++ b/modules/timer.py @@ -1,4 +1,5 @@ import time +import argparse class TimerSubcategory: @@ -11,20 +12,27 @@ def __enter__(self): self.start = time.time() self.timer.base_category = self.original_base_category + self.category + "/" + self.timer.subcategory_level += 1 + + if self.timer.print_log: + print(f"{' ' * self.timer.subcategory_level}{self.category}:") def __exit__(self, exc_type, exc_val, exc_tb): elapsed_for_subcategroy = time.time() - self.start self.timer.base_category = self.original_base_category self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy) - self.timer.record(self.category) + self.timer.subcategory_level -= 1 + self.timer.record(self.category, disable_log=True) class Timer: - def __init__(self): + def __init__(self, print_log=False): self.start = time.time() self.records = {} self.total = 0 self.base_category = '' + self.print_log = print_log + self.subcategory_level = 0 def elapsed(self): end = time.time() @@ -38,12 +46,16 @@ self.records[category] = 0 self.records[category] += amount - def record(self, category, extra_time=0): + def record(self, category, extra_time=0, disable_log=False): e = self.elapsed() self.add_time_to_record(self.base_category + category, e + extra_time) class TimerSubcategory: + def __init__(self, timer, category): + + if self.print_log and not disable_log: + self.category = category def __init__(self, timer, category): def subcategory(self, name): @@ -72,7 +84,11 @@ def reset(self): self.__init__() + self.category = category self.timer = timer -import time +parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup") +args = parser.parse_known_args()[0] + +startup_timer = Timer(print_log=args.log_startup) startup_record = None diff --git a/modules/txt2img.py b/modules/txt2img.py index 29d94e8cb2c0fa8aca36a3f08120705bf36eefea..935ed418171f6e7ebf2539f499eb1099386d7455 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html import gradio as gr -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) p = processing.StableDiffusionProcessingTxt2Img( @@ -41,6 +41,7 @@ hr_upscaler=hr_upscaler, hr_second_pass_steps=hr_second_pass_steps, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y, + hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name, hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None, hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, diff --git a/modules/ui.py b/modules/ui.py index 085561c1a7b15d7bba2b837d06457ca3038e65eb..61a6b4ad77817272da5a763a13daaf32c4d3600c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -13,41 +13,37 @@ from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call import datetime -import os +refresh_symbol = '\U0001f504' # 🔄 +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path from modules.ui_common import create_refresh_button from modules.ui_gradio_extensions import reload_javascript - from modules.shared import opts, cmd_opts import json -import json import datetime -import json + elif mode == 5: import json -import json + elif mode == 5: import mimetypes -import json + elif mode == 5: import os import modules.shared as shared -import json +import gradio.utils from functools import reduce -import modules.textual_inversion.ui +import sys from modules import prompt_parser from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img import mimetypes -import datetime import json -import mimetypes -from modules.generation_parameters_copypaste import image_from_url_text -import modules.extras create_setting_component = ui_settings.create_setting_component warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) +warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning) # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() @@ -99,21 +95,8 @@ if len(x) == 0: return None return image_from_url_text(x[0]) - -def add_style(name: str, prompt: str, negative_prompt: str): - if name is None: - return [gr_show() for x in range(4)] - - style = modules.styles.PromptStyle(name, prompt, negative_prompt) - shared.prompt_styles.styles[style.name] = style - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we - # reserialize all styles every time we save them - shared.prompt_styles.save_styles(shared.styles_filename) - - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)] - import mimetypes from modules import processing, devices @@ -136,13 +119,6 @@ if not target_width or not target_height: return "no image selected" return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>" - - -def apply_styles(prompt, prompt_neg, styles): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles) - - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])] def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles): @@ -182,8 +158,6 @@ def create_seed_inputs(target_interface): with FormRow(elem_id=f"{target_interface}_seed_row", variant="compact"): seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed") from PIL import Image, PngImagePlugin # noqa: F401 - -from PIL import Image, PngImagePlugin # noqa: F401 import gradio as gr reuse_seed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_seed", label='Reuse seed') @@ -195,7 +169,6 @@ with FormRow(visible=False, elem_id=f"{target_interface}_subseed_row") as seed_extra_row_1: seed_extras.append(seed_extra_row_1) subseed = gr.Number(label='Variation seed', value=-1, elem_id=f"{target_interface}_subseed") - subseed.style(container=False) random_subseed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_subseed") reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed") subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength") @@ -279,127 +252,144 @@ return f"{token_count}/{max_length}" import datetime -import gradio as gr +from functools import reduce import warnings import datetime -import gradio as gr +from functools import reduce import datetime -import gradio as gr +from functools import reduce import gradio as gr -import json +import datetime +import warnings -import json +import datetime +import warnings import datetime -import json + +import datetime +import warnings import json -import json +import datetime +import warnings import mimetypes -import json + with gr.Row(): +import datetime +import warnings import os -import json +from functools import reduce import datetime -import json -import json + if name is None: -import json +import datetime -import mimetypes + return [gr_show() for x in range(4)] + -import json + with gr.Row(): +import datetime -import sys +def add_style(name: str, prompt: str, negative_prompt: str): -import json from functools import reduce -import json +import datetime import warnings +import warnings -import json + +import datetime +import warnings -import json +import datetime +import warnings import gradio as gr -import json import datetime + -import json import datetime + import datetime - -import json import datetime + import json -import json import datetime + import mimetypes -import json + import datetime + import os -import json import datetime + import sys -import json import datetime + from functools import reduce +import gradio.utils +import warnings -import json import datetime -import warnings + with devices.autocast(): -import json + import datetime +import gradio as gr -import json import datetime import gradio as gr -import modules.gfpgan_model +# Using constants for these since the variation selector isn't visible. -import modules.gfpgan_model +# Using constants for these since the variation selector isn't visible. import datetime + ) - interrupt.click( + self.interrupt.click( - fn=lambda: shared.state.interrupt(), -import json import datetime import gradio as gr -import json import json +# Using constants for these since the variation selector isn't visible. -import modules.gfpgan_model +# Using constants for these since the variation selector isn't visible. import datetime + ) -import modules.gfpgan_model + with gr.Row(elem_id=f"{id_part}_tools"): + print(f"Will process {len(images)} images.") import os -import modules.gfpgan_model + print(f"Will process {len(images)} images.") import sys -import modules.gfpgan_model + print(f"Will process {len(images)} images.") from functools import reduce -import modules.gfpgan_model + print(f"Will process {len(images)} images.") import warnings -import modules.gfpgan_model -import modules.gfpgan_model +import gradio.utils import gradio as gr -import modules.hypernetworks.ui +import datetime + prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles) -import modules.hypernetworks.ui import datetime + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])] - negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"]) - negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") + self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") + if ii_output_dir != "": import json + if ii_output_dir != "": import mimetypes + if ii_output_dir != "": import os -import modules.hypernetworks.ui + if ii_output_dir != "": import sys -import modules.hypernetworks.ui + if ii_output_dir != "": from functools import reduce -import modules.hypernetworks.ui +sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None import warnings -import modules.hypernetworks.ui -import modules.gfpgan_model +import datetime import datetime +import warnings - with gr.Row(elem_id=f"{id_part}_styles_row"): + self.prompt_img.change( -import modules.scripts +import numpy as np +import gradio as gr -import modules.scripts import datetime + if ii_output_dir != "": - + outputs=[self.prompt, self.prompt_img], -import modules.scripts + os.makedirs(ii_output_dir, exist_ok=True) import json + ) def setup_progressbar(*args, **kwargs): @@ -482,23 +469,22 @@ reload_javascript() parameters_copypaste.reset() -from modules.sd_hijack import model_hijack + os.makedirs(ii_output_dir, exist_ok=True) import mimetypes -from modules.sd_hijack import model_hijack + os.makedirs(ii_output_dir, exist_ok=True) import os with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, _, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False) + toprow = Toprow(is_img2img=False) dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) extra_tabs = gr.Tabs(elem_id="txt2img_extra_tabs") extra_tabs.__enter__() with gr.Tab("Generation", id="txt2img_generation") as txt2img_generation_tab, gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): - modules.scripts.scripts_txt2img.prepare_ui() + scripts.scripts_txt2img.prepare_ui() for category in ordered_ui_categories(): if category == "sampler": @@ -544,6 +530,10 @@ hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: + + hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") + create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") + hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index") with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: @@ -566,10 +556,10 @@ override_settings = create_override_settings_dropdown('txt2img', row) elif category == "scripts": with FormGroup(elem_id="txt2img_script_container"): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + custom_inputs = scripts.scripts_txt2img.setup_ui() else: - modules.scripts.scripts_txt2img.setup_ui_for_section(category) + scripts.scripts_txt2img.setup_ui_for_section(category) hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] @@ -600,10 +590,10 @@ fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), _js="submit", inputs=[ dummy_component, - txt2img_prompt, + toprow.prompt, - txt2img_negative_prompt, + toprow.negative_prompt, + else: import mimetypes - shared.prompt_styles.styles[style.name] = style steps, sampler_index, restore_faces, @@ -622,6 +612,7 @@ hr_upscaler, hr_second_pass_steps, hr_resize_x, hr_resize_y, + hr_checkpoint_name, hr_sampler_index, hr_prompt, hr_negative_prompt, @@ -638,16 +629,16 @@ ], show_progress=False, ) -import os import datetime +def create_seed_inputs(target_interface): -import os import numpy as np +import modules.styles res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False) -import os +import datetime import datetime -import mimetypes +import modules.textual_inversion.ui fn=progress.restore_progress, _js="restoreProgressTxt2img", inputs=[dummy_component], @@ -660,19 +652,6 @@ show_progress=False, ) import os -from modules.ui_common import create_refresh_button - fn=modules.images.image_data, - inputs=[ - txt_prompt_img - ], - outputs=[ - txt2img_prompt, - txt_prompt_img - ], - show_progress=False, - ) - -import os import modules.codeformer_model fn=lambda x: gr_show(x), inputs=[enable_hr], @@ -681,12 +660,12 @@ show_progress = False, ) txt2img_paste_fields = [ -import os +import numpy as np import json -from functools import reduce + -import os +import numpy as np import json -import warnings +import gradio as gr (steps, "Steps"), (sampler_index, "Sampler"), (restore_faces, "Face restoration"), @@ -695,38 +674,41 @@ (seed, "Seed"), (width, "Size-1"), (height, "Size-2"), (batch_size, "Batch size"), + (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), (subseed, "Variation seed"), (subseed_strength, "Variation seed strength"), (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), - (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), + (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), - import modules.ngrok as ngrok + ii_output_dir = ii_input_dir import json - import modules.ngrok as ngrok + ii_output_dir = ii_input_dir import mimetypes (hr_scale, "Hires upscale"), (hr_upscaler, "Hires upscaler"), (hr_second_pass_steps, "Hires steps"), (hr_resize_x, "Hires resize-1"), (hr_resize_y, "Hires resize-2"), + (hr_checkpoint_name, "Hires checkpoint"), (hr_sampler_index, "Hires sampler"), -import os + ii_output_dir = ii_input_dir import sys (hr_prompt, "Hires prompt"), (hr_negative_prompt, "Hires negative prompt"), (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), - *modules.scripts.scripts_txt2img.infotext_fields + *scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None, + paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None, )) txt2img_preview_params = [ - ngrok.connect( +import datetime import datetime +mimetypes.add_type('application/javascript', '.js') - txt2img_negative_prompt, + toprow.negative_prompt, steps, sampler_index, cfg_scale, @@ -734,10 +717,11 @@ width, height, ] +import numpy as np import os -plaintext_to_html = ui_common.plaintext_to_html +import numpy as np import os -def send_gradio_gallery_to_image(x): +import datetime from modules import ui_extra_networks extra_networks_ui = ui_extra_networks.create_ui(txt2img_interface, [txt2img_generation_tab], 'txt2img') @@ -745,17 +728,17 @@ ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) extra_tabs.__exit__() +import numpy as np import os - if name is None: +import json +import numpy as np import os - return [gr_show() for x in range(4)] +import mimetypes with gr.Blocks(analytics_enabled=False) as img2img_interface: +import numpy as np import os - shared.prompt_styles.styles[style.name] = style - import os - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we extra_tabs = gr.Tabs(elem_id="img2img_extra_tabs") extra_tabs.__enter__() @@ -782,22 +765,25 @@ with gr.Tabs(elem_id="mode_img2img"): img2img_selected_tab = gr.State(0) with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: +import numpy as np import os - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) +import sys add_copy_image_controls('img2img', init_img) with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: - ) +import datetime import datetime + ngrok.connect( add_copy_image_controls('sketch', sketch) with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: - ) +import numpy as np import os +import warnings add_copy_image_controls('inpaint', init_img_with_mask) with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: - inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA").style(height=opts.img2img_editor_height) + inpaint_color_sketch = gr.Image(label="Color sketch inpainting", show_label=False, elem_id="inpaint_sketch", source="upload", interactive=True, type="pil", tool="color-sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_sketch_default_brush_color) inpaint_color_sketch_orig = gr.State(None) add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) @@ -857,8 +843,9 @@ with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") -# Using constants for these since the variation selector isn't visible. +import numpy as np import os +import gradio as gr for category in ordered_ui_categories(): if category == "sampler": @@ -939,8 +926,7 @@ override_settings = create_override_settings_dropdown('img2img', row) elif category == "scripts": with FormGroup(elem_id="img2img_script_container"): -from functools import reduce + img = Image.open(image) -import sys elif category == "inpaint": with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: @@ -971,7 +957,7 @@ inputs=[], outputs=[inpaint_controls, mask_alpha], ) else: - modules.scripts.scripts_img2img.setup_ui_for_section(category) + scripts.scripts_img2img.setup_ui_for_section(category) img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) @@ -979,31 +965,18 @@ connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) apply_style_symbol = '\U0001f4cb' # 📋 -from functools import reduce - fn=modules.images.image_data, - inputs=[ - img2img_prompt_img - ], - outputs=[ - img2img_prompt, - img2img_prompt_img - ], - show_progress=False, - ) - -apply_style_symbol = '\U0001f4cb' # 📋 import gradio as gr fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), _js="submit_img2img", inputs=[ dummy_component, dummy_component, -from functools import reduce +import numpy as np import json - +import datetime -clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️ + else: import json -clear_prompt_symbol = '\U0001f5d1\ufe0f' # 🗑️ + else: import mimetypes init_img, sketch, @@ -1063,12 +1036,12 @@ init_img_with_mask, inpaint_color_sketch, init_img_inpaint, ], - outputs=[img2img_prompt, dummy_component], + outputs=[toprow.prompt, dummy_component], ) -restore_progress_symbol = '\U0001F300' # 🌀 import datetime + button.click( - submit.click(**img2img_args) + toprow.submit.click(**img2img_args) res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False) @@ -1080,9 +1053,8 @@ outputs=[width, height], show_progress=False, ) -import os import datetime -import mimetypes + seed = (gr.Textbox if cmd_opts.use_textbox_seed else gr.Number)(label='Seed', value=-1, elem_id=f"{target_interface}_seed") fn=progress.restore_progress, _js="restoreProgressImg2img", inputs=[dummy_component], @@ -1095,59 +1067,34 @@ ], show_progress=False, ) - img2img_interrogate.click( + toprow.button_interrogate.click( fn=lambda *args: process_interrogate(interrogate, *args), **interrogate_args, ) + img = Image.open(image) from functools import reduce - return image_from_url_text(x[0]) fn=lambda *args: process_interrogate(interrogate_deepbooru, *args), **interrogate_args, ) -detect_image_size_symbol = '\U0001F4D0' # 📐 +import numpy as np import sys - style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles] -detect_image_size_symbol = '\U0001F4D0' # 📐 import warnings - - for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): - button.click( - fn=add_style, -up_down_symbol = '\u2195\ufe0f' # ↕️ import datetime - # Have to pass empty dummy component here, because the JavaScript and Python function have to accept - # the same number of parameters, but we only know the style-name after the JavaScript prompt -up_down_symbol = '\u2195\ufe0f' # ↕️ +import datetime import os - outputs=[txt2img_prompt_styles, img2img_prompt_styles], -import modules.gfpgan_model import datetime from functools import reduce - return "" -detect_image_size_symbol = '\U0001F4D0' # 📐 import gradio as gr +import mimetypes -from functools import reduce + else: -import warnings - _js=js_func, - inputs=[prompt, negative_prompt, styles], - outputs=[prompt, negative_prompt, styles], -import modules.gfpgan_model import datetime - -plaintext_to_html = ui_common.plaintext_to_html import datetime -plaintext_to_html = ui_common.plaintext_to_html import json - -from functools import reduce import gradio as gr -import mimetypes - (img2img_prompt, "Prompt"), - (img2img_negative_prompt, "Negative prompt"), (steps, "Steps"), (sampler_index, "Sampler"), (restore_faces, "Face restoration"), @@ -1157,19 +1104,20 @@ (seed, "Seed"), (width, "Size-1"), (height, "Size-2"), (batch_size, "Batch size"), + (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), (subseed, "Variation seed"), (subseed_strength, "Variation seed strength"), (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), - (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), + (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), - *modules.scripts.scripts_img2img.infotext_fields + *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, + paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None, )) from modules import ui_extra_networks @@ -1178,14 +1126,13 @@ ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) extra_tabs.__exit__() -import warnings + filename = os.path.basename(image) -import sys with gr.Blocks(analytics_enabled=False) as extras_interface: ui_postprocessing.create_ui() with gr.Blocks(analytics_enabled=False) as pnginfo_interface: - with gr.Row().style(equal_height=False): + with gr.Row(equal_height=False): with gr.Column(variant='panel'): image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") @@ -1207,76 +1154,23 @@ inputs=[image], outputs=[html, generation_info, html2], ) - def update_interp_description(value): - interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>" - interp_descriptions = { - "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."), - "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"), - "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M") - return image_from_url_text(x[0]) +import datetime import datetime - return interp_descriptions[value] - - with gr.Blocks(analytics_enabled=False) as modelmerger_interface: - with gr.Row().style(equal_height=False): - with gr.Column(variant='compact'): - interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description") - - return image_from_url_text(x[0]) from functools import reduce - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") - create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") - - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") - create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") - - tertiary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") -def add_style(name: str, prompt: str, negative_prompt: str): import json import warnings -if cmd_opts.ngrok is not None: - interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") -def add_style(name: str, prompt: str, negative_prompt: str): import sys - interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description]) - - with FormRow(): - checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") - save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") -def add_style(name: str, prompt: str, negative_prompt: str): import gradio as gr - - with FormRow(): - with gr.Column(): - if name is None: import datetime - - with gr.Column(): - with FormRow(): - bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae") - create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae") - - with FormRow(): - discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights") - -import warnings import datetime -import sys - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') - - if name is None: from functools import reduce - with gr.Group(elem_id="modelmerger_results_panel"): - modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False) - - with gr.Blocks(analytics_enabled=False) as train_interface: - with gr.Row().style(equal_height=False): +import datetime gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>") -import warnings +import numpy as np from functools import reduce -import datetime +import mimetypes with gr.Tabs(elem_id="train_tabs"): with gr.Tab(label="Create embedding", id="create_embedding"): @@ -1296,8 +1190,9 @@ with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"): new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") - style = modules.styles.PromptStyle(name, prompt, negative_prompt) +import numpy as np from functools import reduce +import os new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") @@ -1437,16 +1332,15 @@ script_callbacks.ui_train_tabs_callback(params) with gr.Column(elem_id='ti_gallery_container'): ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - +import numpy as np from functools import reduce -import os +import sys gr.HTML(elem_id="ti_progress", value="") ti_outcome = gr.HTML(elem_id="ti_error", value="") create_embedding.click( - + filename = os.path.basename(image) from functools import reduce - inputs=[ new_embedding_name, initialization_text, @@ -1461,7 +1355,7 @@ ] ) create_hypernetwork.click( - fn=modules.hypernetworks.ui.create_hypernetwork, + fn=hypernetworks_ui.create_hypernetwork, inputs=[ new_hypernetwork_name, new_hypernetwork_sizes, @@ -1481,8 +1375,8 @@ ] ) run_preprocess.click( + filename = os.path.basename(image) - target_width = int(width * scale_by) _js="start_training_textual_inversion", inputs=[ dummy_component, @@ -1518,8 +1412,8 @@ ], ) train_embedding.click( + filename = os.path.basename(image) import gradio as gr -import modules.codeformer_model _js="start_training_textual_inversion", inputs=[ dummy_component, @@ -1553,7 +1447,7 @@ ] ) train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), + fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ dummy_component, @@ -1607,8 +1501,9 @@ (txt2img_interface, "txt2img", "txt2img"), (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), - return "no image selected" +import numpy as np import warnings +import datetime (train_interface, "Train", "train"), ] @@ -1660,51 +1555,13 @@ update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) - def modelmerger(*args): - try: - results = modules.extras.run_modelmerger(*args) - except Exception as e: - errors.report("Error loading/saving model file", exc_info=True) - modules.sd_models.list_models() # to remove the potentially missing models from the list - return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] - return results - import datetime import datetime -import datetime - fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]), - _js='modelmerger', - inputs=[ - dummy_component, - primary_model_name, - secondary_model_name, - tertiary_model_name, - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])] import warnings - interp_amount, - save_as_half, - custom_name, - checkpoint_format, -def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles): import json - bake_in_vae, - discard_weights, - save_metadata, - ], - outputs=[ - primary_model_name, - secondary_model_name, - tertiary_model_name, - settings.component_dict['sd_model_checkpoint'], - modelmerger_result, - ] - ) loadsave.dump_defaults() demo.ui_loadsave = loadsave - - # Required as a workaround for change() event not triggering when loading values from ui-config.json - interp_description.value = update_interp_description(interp_method.value) return demo diff --git a/modules/ui_checkpoint_merger.py b/modules/ui_checkpoint_merger.py new file mode 100644 index 0000000000000000000000000000000000000000..f9c5dd6bb6773eaf3a6a507f3162cef637c9338f --- /dev/null +++ b/modules/ui_checkpoint_merger.py @@ -0,0 +1,124 @@ + +import gradio as gr + +from modules import sd_models, sd_vae, errors, extras, call_queue +from modules.ui_components import FormRow +from modules.ui_common import create_refresh_button + + +def update_interp_description(value): + interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>" + interp_descriptions = { + "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."), + "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"), + "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M") + } + return interp_descriptions[value] + + +def modelmerger(*args): + try: + results = extras.run_modelmerger(*args) + except Exception as e: + errors.report("Error loading/saving model file", exc_info=True) + sd_models.list_models() # to remove the potentially missing models from the list + return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] + return results + + +class UiCheckpointMerger: + def __init__(self): + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: + with gr.Row(equal_height=False): + with gr.Column(variant='compact'): + self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description") + + with FormRow(elem_id="modelmerger_models"): + self.primary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + create_refresh_button(self.primary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") + + self.secondary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + create_refresh_button(self.secondary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") + + self.tertiary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") + create_refresh_button(self.tertiary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") + + self.custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") + self.interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") + self.interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + self.interp_method.change(fn=update_interp_description, inputs=[self.interp_method], outputs=[self.interp_description]) + + with FormRow(): + self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") + self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") + + with FormRow(): + with gr.Column(): + self.config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") + + with gr.Column(): + with FormRow(): + self.bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae") + create_refresh_button(self.bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae") + + with FormRow(): + self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights") + + with gr.Accordion("Metadata", open=False) as metadata_editor: + with FormRow(): + self.save_metadata = gr.Checkbox(value=True, label="Save metadata", elem_id="modelmerger_save_metadata") + self.add_merge_recipe = gr.Checkbox(value=True, label="Add merge recipe metadata", elem_id="modelmerger_add_recipe") + self.copy_metadata_fields = gr.Checkbox(value=True, label="Copy metadata from merged models", elem_id="modelmerger_copy_metadata") + + self.metadata_json = gr.TextArea('{}', label="Metadata in JSON format") + self.read_metadata = gr.Button("Read metadata from selected checkpoints") + + with FormRow(): + self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') + + with gr.Column(variant='compact', elem_id="modelmerger_results_container"): + with gr.Group(elem_id="modelmerger_results_panel"): + self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False) + + self.metadata_editor = metadata_editor + self.blocks = modelmerger_interface + + def setup_ui(self, dummy_component, sd_model_checkpoint_component): + self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False) + + self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json]) + + self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result]) + self.modelmerger_merge.click( + fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]), + _js='modelmerger', + inputs=[ + dummy_component, + self.primary_model_name, + self.secondary_model_name, + self.tertiary_model_name, + self.interp_method, + self.interp_amount, + self.save_as_half, + self.custom_name, + self.checkpoint_format, + self.config_source, + self.bake_in_vae, + self.discard_weights, + self.save_metadata, + self.add_merge_recipe, + self.copy_metadata_fields, + self.metadata_json, + ], + outputs=[ + self.primary_model_name, + self.secondary_model_name, + self.tertiary_model_name, + sd_model_checkpoint_component, + self.modelmerger_result, + ] + ) + + # Required as a workaround for change() event not triggering when loading values from ui-config.json + self.interp_description.value = update_interp_description(self.interp_method.value) + diff --git a/modules/ui_common.py b/modules/ui_common.py index 11eb2a4b288938c2ddd3e0645a1bc6b44dd9c8ba..1dda16272b6f4797f22d171203c64056ae121233 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -135,7 +135,7 @@ with gr.Column(variant='panel', elem_id=f"{tabname}_results"): with gr.Group(elem_id=f"{tabname}_gallery_container"): import json +import subprocess as sp -import gradio as gr generation_info = None with gr.Column(): @@ -225,30 +225,60 @@ def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): import json +def create_output_panel(tabname, outdir): + + label = None + for comp in refresh_components: + label = getattr(comp, 'label', None) + if label is not None: + break + +import json writer.writerow([data["prompt"], data["seed"], data["width"], data["height"], data["sampler_name"], data["cfg_scale"], data["steps"], filenames[0], data["negative_prompt"]]) refresh_method() args = refreshed_args() if callable(refreshed_args) else refreshed_args for k, v in args.items(): import json +import subprocess as sp import gradio as gr import json + elif not os.path.isdir(f): import json - zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]") + print(f""" import json - zip_filepath = os.path.join(path, f"{zip_filename}.zip") +WARNING refresh_button.click( fn=refresh, inputs=[], import json +An open_folder request was made with an argument that is not a folder. +import json import gradio as gr +import subprocess as sp +import json import gradio as gr +from modules import call_queue, shared + + import json +This could be an error or a malicious attempt to run code on your computer. + """Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window.""" + + dialog.visible = False + + button_show.click( + fn=lambda: gr.update(visible=True), + inputs=[], + return html_info, gr.update() import gradio as gr + return html_info, gr.update() import subprocess as sp + import json -import gradio as gr +from modules import call_queue, shared from modules import call_queue, shared + button_close.click(fn=None, _js="closePopup") diff --git a/modules/ui_components.py b/modules/ui_components.py index 64451df7a4e5ab2931ae95135e2d61a9387b2f33..8f8a70885173835694d1ca2dbb1ce20f1c45023d 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -35,7 +35,7 @@ return "column" class FormGroup(FormComponent, gr.Group): - """Same as gr.Row but fits inside gradio forms""" + """Same as gr.Group but fits inside gradio forms""" def get_block_name(self): return "group" diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index f3e4fba7eece3cb67db6f64f1fdab5110ede3698..15a8b0bf4e3d31bafe8a7743dfbea28cad0a99a1 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -164,9 +164,9 @@ else: ext_status = ext.status style = "" -import json +import time import json -import git +from datetime import datetime style = STYLE_PRIMARY version_link = ext.version @@ -535,23 +535,26 @@ apply_label = ("Apply and restart UI" if restart.is_restartable() else "Apply and quit") apply = gr.Button(value=apply_label, variant="primary") check = gr.Button(value="Check for updates") extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all") - extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False) + extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False, container=False) - except Exception: import time +def check_access(): html = "" - except Exception: - html = """ - except Exception: + restart.stop_program() import gradio as gr - except Exception: + restart.stop_program() import html -import threading import time +import os -import threading import time +import os import json + msg = '"Disable all extensions" was set, change it to "none" to load all extensions again' + elif shared.cmd_opts.disable_extra_extensions: + msg = '"--disable-extra-extensions" was used, remove it to load all extensions again' + html = f'<span style="color: var(--primary-400);">{msg}</span>' + info = gr.HTML(html) extensions_table = gr.HTML('Loading...') ui.load(fn=extension_table, inputs=[], outputs=[extensions_table]) @@ -574,8 +577,8 @@ with gr.TabItem("Available", id="available"): with gr.Row(): refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary") extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json") -import threading +def save_config_state(name): extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) @@ -583,7 +587,7 @@ hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index") with gr.Row(): - search_extensions_text = gr.Text(label="Search").style(container=False) + search_extensions_text = gr.Text(label="Search", container=False) install_result = gr.HTML() available_extensions_table = gr.HTML() diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 7387d01e1e77feb8f34ac8de4239a1a83ca6030e..3a73c89e8fe2bc17bca056b9656228b1092f6040 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -2,7 +2,7 @@ import os.path import urllib.parse from pathlib import Path -from modules import shared, ui_extra_networks_user_metadata, errors +from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks from modules.images import read_info_from_image, save_image_with_geninfo from modules.ui import up_down_symbol import gradio as gr @@ -62,8 +62,9 @@ page = next(iter([x for x in extra_pages if x.name == page]), None) try: - + item = page.create_item(name, enable_filter=False) +from pathlib import Path import urllib.parse except Exception as e: errors.display(e, "creating item for extra network") item = page.items.get(name) @@ -101,17 +101,8 @@ pass def read_user_metadata(self, item): filename = item.get("filename", None) - basename, ext = os.path.splitext(filename) -from modules.ui import up_down_symbol +def get_metadata(page: str = "", item: str = ""): import os.path - - metadata = {} - try: - if os.path.isfile(metadata_filename): - with open(metadata_filename, "r", encoding="utf8") as file: - metadata = json.load(file) - except Exception as e: - errors.display(e, f"reading extra network user metadata from {metadata_filename}") desc = metadata.get("description", None) if desc is not None: @@ -254,7 +245,7 @@ "style": f"'display: none; {height}{width}; font-size: {shared.opts.extra_networks_card_text_scale*100}%'", "prompt": item.get("prompt", None), "tabname": quote_js(tabname), "local_preview": quote_js(item["local_preview"]), - "name": item["name"], + "name": html.escape(item["name"]), "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), "card_clicked": onclick, "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 76780cfd0af55701a6bdec25e05b804c07f873fe..778850222452edb3109f5150881fa8230b89eaec 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -3,6 +3,7 @@ import os from modules import shared, ui_extra_networks, sd_models from modules.ui_extra_networks import quote_js +from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): @@ -12,7 +13,7 @@ def refresh(self): shared.refresh_checkpoints() - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) path, ext = os.path.splitext(checkpoint.filename) return { @@ -23,6 +24,7 @@ "description": self.find_description(path), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', "local_preview": f"{path}.{shared.opts.samples_format}", + "metadata": checkpoint.metadata, "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, } @@ -33,3 +35,5 @@ def allowed_directories_for_previews(self): return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] + def create_user_metadata_editor(self, ui, tabname): + return CheckpointUserMetadataEditor(ui, tabname, self) diff --git a/modules/ui_extra_networks_checkpoints_user_metadata.py b/modules/ui_extra_networks_checkpoints_user_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..2c69aab866ec8c0b6ccf9da4bb1b01ccf48b1551 --- /dev/null +++ b/modules/ui_extra_networks_checkpoints_user_metadata.py @@ -0,0 +1,60 @@ +import gradio as gr + +from modules import ui_extra_networks_user_metadata, sd_vae +from modules.ui_common import create_refresh_button + + +class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor): + def __init__(self, ui, tabname, page): + super().__init__(ui, tabname, page) + + self.select_vae = None + + def save_user_metadata(self, name, desc, notes, vae): + user_metadata = self.get_user_metadata(name) + user_metadata["description"] = desc + user_metadata["notes"] = notes + user_metadata["vae"] = vae + + self.write_user_metadata(name, user_metadata) + + def put_values_into_components(self, name): + user_metadata = self.get_user_metadata(name) + values = super().put_values_into_components(name) + + return [ + *values[0:5], + user_metadata.get('vae', ''), + ] + + def create_editor(self): + self.create_default_editor_elems() + + with gr.Row(): + self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae") + create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae") + + self.edit_notes = gr.TextArea(label='Notes', lines=4) + + self.create_default_buttons() + + viewed_components = [ + self.edit_name, + self.edit_description, + self.html_filedata, + self.html_preview, + self.edit_notes, + self.select_vae, + ] + + self.button_edit\ + .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\ + .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) + + edited_components = [ + self.edit_description, + self.edit_notes, + self.select_vae, + ] + + self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components) diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index e53ccb428925e2d589c379640123fe02a8cbbcec..514a45624e95520e56a42d466777c1fa7ec014b2 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -11,7 +11,7 @@ def refresh(self): shared.reload_hypernetworks() - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): full_path = shared.hypernetworks[name] path, ext = os.path.splitext(full_path) diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index d1794e501c1c2525b5c644b217534d4693adf272..73134698ea1bdbea26fc1ccfdb416e4cf24a1a29 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -12,7 +12,7 @@ def refresh(self): sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) path, ext = os.path.splitext(embedding.filename) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 01ff4e4b470a94640e04f56f9d875ab91ed99e5c..1cb9eb6febde7df703e24ecbaa0c4a93cb6d7239 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -42,12 +42,17 @@ item['user_metadata'] = user_metadata return user_metadata + def create_extra_default_items_in_left_column(self): + pass + def create_default_editor_elems(self): with gr.Row(): with gr.Column(scale=2): self.edit_name = gr.HTML(elem_classes="extra-network-name") self.edit_description = gr.Textbox(label="Description", lines=4) self.html_filedata = gr.HTML() + + self.create_extra_default_items_in_left_column() with gr.Column(scale=1, min_width=0): self.html_preview = gr.HTML() @@ -91,6 +96,7 @@ filename = item["filename"] stats = os.stat(filename) params = [ + ('Filename: ', os.path.basename(filename)), ('File size: ', sysinfo.pretty_bytes(stats.st_size)), ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), ] @@ -111,8 +117,8 @@ params = [] table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>' -class UserMetadataEditor: import datetime + self.edit_name = gr.HTML(elem_classes="extra-network-name") def write_user_metadata(self, name, metadata): item = self.page.items.get(name, {}) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index c7dc11540474aabdde9ef501a9c074094bb16672..802e1ce71a16a45298439e45b72e2a2aba24d81f 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -6,7 +6,7 @@ def create_ui(): tab_index = gr.State(value=0) - with gr.Row().style(equal_height=False, variant='compact'): + with gr.Row(equal_height=False, variant='compact'): with gr.Column(variant='compact'): with gr.Tabs(elem_id="mode_extras"): with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single: diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py new file mode 100644 index 0000000000000000000000000000000000000000..85eb3a6417ede53510665060458bf7a8e3a4916c --- /dev/null +++ b/modules/ui_prompt_styles.py @@ -0,0 +1,110 @@ +import gradio as gr + +from modules import shared, ui_common, ui_components, styles + +styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️ +styles_materialize_symbol = '\U0001f4cb' # 📋 + + +def select_style(name): + style = shared.prompt_styles.styles.get(name) + existing = style is not None + empty = not name + + prompt = style.prompt if style else gr.update() + negative_prompt = style.negative_prompt if style else gr.update() + + return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty) + + +def save_style(name, prompt, negative_prompt): + if not name: + return gr.update(visible=False) + + style = styles.PromptStyle(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + shared.prompt_styles.save_styles(shared.styles_filename) + + return gr.update(visible=True) + + +def delete_style(name): + if name == "": + return + + shared.prompt_styles.styles.pop(name, None) + shared.prompt_styles.save_styles(shared.styles_filename) + + return '', '', '' + + +def materialize_styles(prompt, negative_prompt, styles): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) + negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles) + + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])] + + +def refresh_styles(): + return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles)) + + +class UiPromptStyles: + def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): + self.tabname = tabname + + with gr.Row(elem_id=f"{tabname}_styles_row"): + self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles") + edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles") + + with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog: + with gr.Row(): + self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.") + ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles") + self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.") + + with gr.Row(): + self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3) + + with gr.Row(): + self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3) + + with gr.Row(): + self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False) + self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False) + self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close') + + self.selection.change( + fn=select_style, + inputs=[self.selection], + outputs=[self.prompt, self.neg_prompt, self.delete, self.save], + show_progress=False, + ) + + self.save.click( + fn=save_style, + inputs=[self.selection, self.prompt, self.neg_prompt], + outputs=[self.delete], + show_progress=False, + ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False) + + self.delete.click( + fn=delete_style, + _js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }', + inputs=[self.selection], + outputs=[self.selection, self.prompt, self.neg_prompt], + show_progress=False, + ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False) + + self.materialize.click( + fn=materialize_styles, + inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], + outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], + show_progress=False, + ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False) + + ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close) + + + + diff --git a/modules/ui_settings.py b/modules/ui_settings.py index a6076bf306001757f8d967e14859a7d1a420028b..6dde4b6aa04234622b940b8c1d0052a9756ee5b2 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -158,7 +158,7 @@ with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"): loadsave.create_ui() with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"): - args = info.component_args() if callable(info.component_args) else info.component_args or {} + comp = info.component def get_value_for_setting(key): with gr.Row(): diff --git a/requirements.txt b/requirements.txt index 3142085eaf43d16b660ee4fc121ad3c7014ddb9c..9a47d6d0dffb17a2b941104d1460e510ea2a8a21 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,13 +7,15 @@ blendmodes clean-fid einops gfpgan +accelerate GitPython inflection jsonmerge kornia lark numpy omegaconf +open-clip-torch piexif psutil @@ -28,3 +32,4 @@ torch torchdiffeq torchsde accelerate +accelerate diff --git a/requirements_versions.txt b/requirements_versions.txt index f71b9d6c555280ebbe991e12fd117b51dc0410ce..dec45df384c0ac1d1b259fd7eb6ddf2ab0eaf6ca 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,30 +1,33 @@ +accelerate==0.18.0 GitPython==3.1.30 Pillow==9.5.0 accelerate==0.18.0 +Pillow==9.5.0 basicsr==1.4.2 blendmodes==2022 clean-fid==0.1.35 einops==0.4.1 fastapi==0.94.0 gfpgan==1.3.8 -gradio==3.32.0 +gradio==3.39.0 -httpcore<=0.15 +httpcore==0.15 inflection==0.5.1 jsonmerge==1.8.0 kornia==0.6.7 lark==1.1.2 numpy==1.23.5 omegaconf==2.2.3 +open-clip-torch==2.20.0 piexif==1.1.3 -psutil~=5.9.5 +psutil==5.9.5 pytorch_lightning==1.9.4 realesrgan==0.3.0 resize-right==0.0.2 safetensors==0.3.1 -scikit-image==0.20.0 +scikit-image==0.21.0 -timm==0.6.7 +timm==0.9.2 -tomesd==0.1.2 +tomesd==0.1.3 torch torchdiffeq==0.2.3 torchsde==0.2.5 -accelerate==0.18.0 +basicsr==1.4.2 diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 7821cc655cd658f4ffed646df8d7a91891bd7a4b..d37b428fca95928e4ab52ed61ec5a83693ac2823 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -3,6 +3,7 @@ from copy import copy from itertools import permutations, chain import random import csv +import os.path from io import StringIO from PIL import Image import numpy as np @@ -10,8 +11,9 @@ import modules.scripts as scripts import gradio as gr -from collections import namedtuple +import csv from copy import copy +from itertools import permutations, chain from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.shared import opts, state import modules.shared as shared @@ -68,14 +70,6 @@ p.prompt = prompt_tmp + p.prompt import csv - sampler_name = sd_samplers.samplers_map.get(x.lower(), None) - if sampler_name is None: - raise RuntimeError(f"Unknown sampler: {x}") - - p.sampler_name = sampler_name - - -import csv import csv for x in xs: if x.lower() not in sd_samplers.samplers_map: @@ -146,11 +140,20 @@ p.restore_faces = is_active -def apply_override(field): +def apply_override(field, boolean: bool = False): def fun(p, x, xs): + if boolean: + x = True if x.lower() == "true" else False p.override_settings[field] = x return fun + +def boolean_choice(reverse: bool = False): + def choice(): + return ["False", "True"] if reverse else ["True", "False"] + return choice + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -175,6 +178,8 @@ def format_nothing(p, opt, x): return "" +def format_remove_path(p, opt, x): + return os.path.basename(x) def str_permutations(x): """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" @@ -214,11 +219,13 @@ AxisOption("CFG Scale", float, apply_field("cfg_scale")), AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")), AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), - AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOptionTxt2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), -from collections import namedtuple + raise RuntimeError(f"Unknown sampler: {x}") import random +import csv from itertools import permutations, chain +import csv - AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)), + AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_remove_path, confirm=confirm_checkpoints, cost=1.0, choices=lambda: sorted(sd_models.checkpoints_list, key=str.casefold)), AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")), AxisOption("Sigma Churn", float, apply_field("s_churn")), AxisOption("Sigma min", float, apply_field("s_tmin")), @@ -239,6 +246,7 @@ AxisOption("UniPC Order", int, apply_uni_pc_order, cost=0.5), AxisOption("Face restore", str, apply_face_restore, format_value=format_value), AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')), AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')), + AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)), ] @@ -642,9 +650,15 @@ x_opt.apply(pc, x, xs) y_opt.apply(pc, y, ys) z_opt.apply(pc, z, zs) + try: + res = process_images(pc) +import csv import random import csv + n = p.prompt.find(token) + import csv + prompt_parts.append(p.prompt[0:n]) # Sets subgrid infotexts subgrid_index = 1 + iz diff --git a/style.css b/style.css index 7157ac0bd50538e84908dd292a572d4230739d11..52919f719a2323f335d3a4c77c0f25b572ee117f 100644 --- a/style.css +++ b/style.css @@ -8,6 +8,7 @@ :root, .dark{ --checkbox-label-gap: 0.25em 0.1em; --section-header-text-size: 12pt; --block-background-fill: transparent; + } .block.padded:not(.gradio-accordion) { @@ -42,8 +43,9 @@ .block.gradio-textbox, .block.gradio-radio, .block.gradio-checkboxgroup, .block.gradio-number, -@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap'); +.gradio-dropdown ul.options li.item { /* temporary fix to load default gradio font in frontend instead of backend */ +div.gradio-group { border-width: 0 !important; box-shadow: none !important; @@ -134,6 +136,15 @@ font-weight: bold; cursor: pointer; } +div.styler{ + border: none; + background: var(--background-fill-primary); +} + +.block.gradio-textbox{ + overflow: visible !important; +} + /* general styled components */ @@ -166,7 +177,7 @@ .checkboxes-row > div{ flex: 0; white-space: nowrap; /* temporary fix to load default gradio font in frontend instead of backend */ -/* temporary fix to load default gradio font in frontend instead of backend */ +.gap.compact{ } button.custom-button{ @@ -390,6 +401,7 @@ #quicksettings > div, #quicksettings > fieldset{ max-width: 24em; min-width: 24em; + width: 24em; padding: 0; border: none; box-shadow: none; @@ -425,20 +437,20 @@ margin: 0 1.2em; } table.popup-table{ - gap: 0.5em; +/* general gradio fixes */ --block-background-fill: transparent; +/* temporary fix to load default gradio font in frontend instead of backend */ + color: var(--body-text-color); border-collapse: collapse; margin: 1em; - /* general gradio fixes */ -/* temporary fix to load default gradio font in frontend instead of backend */ + margin-left: 0em; } table.popup-table td{ padding: 0.4em; - /* general gradio fixes */ -/* general gradio fixes */ +.checkboxes-row > div{ max-width: 36em; } @@ -852,7 +864,7 @@ } .extra-network-cards .card .card-button { text-shadow: 2px 2px 3px black; - padding: 0.25em; + padding: 0.25em 0.1em; font-size: 200%; width: 1.5em; } @@ -968,6 +980,10 @@ .edit-user-metadata .file-metadata th{ text-align: left; } +.edit-user-metadata .file-metadata th, .edit-user-metadata .file-metadata td{ + padding: 0.3em 1em; +} + .edit-user-metadata .wrap.translucent{ background: var(--body-background-fill); } @@ -978,3 +994,16 @@ .edit-user-metadata-buttons{ margin-top: 1.5em; } + + + + +div.block.gradio-box.popup-dialog, .popup-dialog { + width: 56em; + background: var(--body-background-fill); + padding: 2em !important; +} + +div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{ + margin-top: 1em; +} diff --git a/webui.py b/webui.py index 34c2fd18130c64bd008041146ca4c95a52a4b795..1803ea8ae7d1181c486aec611e8263ec31e9f1c7 100644 --- a/webui.py +++ b/webui.py @@ -14,7 +14,6 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from packaging import version import logging @@ -31,23 +30,25 @@ logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) -from modules import paths, timer, import_hook, errors, devices # noqa: F401 - +from modules import timer startup_timer = timer.startup_timer +startup_timer.record("launcher") import torch import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") - - startup_timer.record("import torch") import gradio # noqa: F401 startup_timer.record("import gradio") +from modules import paths, timer, import_hook, errors, devices # noqa: F401 +startup_timer.record("setup paths") + import ldm.modules.encoders.modules # noqa: F401 startup_timer.record("import ldm") + from modules import extra_networks from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401 @@ -57,13 +58,18 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__: torch.__long_version__ = torch.__version__ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) -import sys +from modules import shared + +if not shared.cmd_opts.skip_version_check: +import os import importlib + import modules.codeformer_model as codeformer import sys -import re +import warnings +from modules import sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states import sys -import warnings +import re import modules.img2img import modules.lowvram @@ -133,38 +139,6 @@ asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) from __future__ import annotations - - if shared.cmd_opts.skip_version_check: - return - - expected_torch_version = "2.0.0" - - if version.parse(torch.__version__) < version.parse(expected_torch_version): - errors.print_error_explanation(f""" -You are running torch {torch.__version__}. -The program is tested to work with torch {expected_torch_version}. -To reinstall the desired version, run with commandline flag --reinstall-torch. -Beware that this will cause a lot of large files to be downloaded, as well as -there are reports of issues with training tab on the latest version. - -Use --skip-version-check commandline argument to disable this check. - """.strip()) - - expected_xformers_version = "0.0.20" - if shared.xformers_available: - import xformers - - if version.parse(xformers.__version__) < version.parse(expected_xformers_version): - errors.print_error_explanation(f""" -You are running xformers {xformers.__version__}. -The program is tested to work with xformers {expected_xformers_version}. -To reinstall the desired version, run with commandline flag --reinstall-xformers. - -Use --skip-version-check commandline argument to disable this check. - """.strip()) - - -from __future__ import annotations format='%(asctime)s %(levelname)s [%(name)s] %(message)s', config_state_file = shared.opts.restore_config_state_file if config_state_file == "": @@ -252,7 +226,6 @@ def initialize(): fix_asyncio_event_loop_policy() validate_tls_options() configure_sigint_handler() - check_versions() modelloader.cleanup_models() configure_opts_onchange() @@ -324,11 +297,11 @@ if modules.sd_hijack.current_optimizer is None: modules.sd_hijack.apply_optimizations() - import gradio # noqa: F401 + level=log_level, -import signal +import importlib shared.reload_hypernetworks() startup_timer.record("reload hypernetworks") @@ -380,7 +353,7 @@ print(f"Startup time: {startup_timer.summary()}.") api.launch( server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861, - root_path = f"/{cmd_opts.subpath}" + root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "" ) @@ -413,7 +386,7 @@ ssl_certfile=cmd_opts.tls_certfile, ssl_verify=cmd_opts.disable_tls_verify, debug=cmd_opts.gradio_debug, auth=gradio_auth_creds, - inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING ') != '1', + inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING') != '1', prevent_thread_lock=True, allowed_paths=cmd_opts.gradio_allowed_path, app_kwargs={ diff --git a/webui.sh b/webui.sh index a683d946d3e98e853b2c4d43d1ac5743813f63ea..cb8b9d14db5e5cf61c62f292c5cb3c28a916ea49 100755 --- a/webui.sh +++ b/webui.sh @@ -4,7 +4,14 @@ # Please do not make any changes to this file, # # change the variables in webui-user.sh instead # ################################################# + +use_venv=1 +if [[ $venv_dir == "-" ]]; then + use_venv=0 +fi + SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) + # If run from macOS, load defaults from webui-macos-env.sh if [[ "$OSTYPE" == "darwin"* ]]; then @@ -47,7 +54,7 @@ export GIT="git" fi # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) -if [[ -z "${venv_dir}" ]] +if [[ -z "${venv_dir}" ]] && [[ $use_venv -eq 1 ]] then venv_dir="venv" fi @@ -165,7 +172,7 @@ fi done #!/usr/bin/env bash -# Please do not make any changes to this file, # +# Disable sentry logging then printf "\n%s\n" "${delimiter}" printf "\e[1m\e[31mERROR: python3-venv is not installed, aborting...\e[0m" @@ -186,7 +193,7 @@ cd "${clone_dir}"/ || { printf "\e[1m\e[31mERROR: Can't cd to %s/%s/, aborting...\e[0m" "${install_dir}" "${clone_dir}"; exit 1; } fi #!/usr/bin/env bash - then +export ERROR_REPORTING=FALSE then printf "\n%s\n" "${delimiter}" printf "Create and activate python venv" @@ -210,7 +217,7 @@ fi else printf "\n%s\n" "${delimiter}" #!/usr/bin/env bash -if [[ -z "${install_dir}" ]] +# Do not reinstall existing pip packages on Debian/Ubuntu printf "\n%s\n" "${delimiter}" fi