~/Projects/stable-diffusion-webui
git clone https://code.lsong.org/stable-diffusion-webui
Commit
- Commit
- 189229bbf9276fb73e48c783856b02fc57ab5c9b
- Author
- AUTOMATIC1111 <[email protected]>
- Date
- 2023-08-24 11:09:04 +0300 +0300
- Diffstat
.eslintrc.js | 6 CHANGELOG.md | 131 CITATION.cff | 7 README.md | 12 extensions-builtin/Lora/extra_networks_lora.py | 10 extensions-builtin/Lora/lora_patches.py | 31 extensions-builtin/Lora/network.py | 6 extensions-builtin/Lora/network_full.py | 7 extensions-builtin/Lora/network_norm.py | 28 extensions-builtin/Lora/networks.py | 206 extensions-builtin/Lora/scripts/lora_script.py | 44 extensions-builtin/Lora/ui_edit_user_metadata.py | 2 extensions-builtin/Lora/ui_extra_networks_lora.py | 3 extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js | 198 extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py | 2 extensions-builtin/canvas-zoom-and-pan/style.css | 3 extensions-builtin/extra-options-section/scripts/extra_options_section.py | 45 extensions-builtin/mobile/javascript/mobile.js | 6 javascript/extraNetworks.js | 49 javascript/hints.js | 11 javascript/imageviewer.js | 5 javascript/inputAccordion.js | 37 javascript/localStorage.js | 26 javascript/localization.js | 43 javascript/progressbar.js | 75 javascript/resizeHandle.js | 139 javascript/ui.js | 35 launch.py | 6 modules/api/api.py | 60 modules/api/models.py | 12 modules/cache.py | 8 modules/call_queue.py | 5 modules/cmd_args.py | 11 modules/config_states.py | 14 modules/devices.py | 55 modules/errors.py | 50 modules/extensions.py | 13 modules/extra_networks.py | 19 modules/extras.py | 38 modules/fifo_lock.py | 37 modules/generation_parameters_copypaste.py | 62 modules/gradio_extensons.py | 73 modules/hypernetworks/hypernetwork.py | 5 modules/images.py | 51 modules/img2img.py | 62 modules/initialize.py | 168 modules/initialize_util.py | 202 modules/interrogate.py | 4 modules/launch_utils.py | 79 modules/localization.py | 4 modules/logging_config.py | 16 modules/lowvram.py | 21 modules/mac_specific.py | 7 modules/options.py | 245 modules/patches.py | 64 modules/postprocessing.py | 61 modules/processing.py | 829 modules/processing_scripts/refiner.py | 49 modules/processing_scripts/seed.py | 111 modules/progress.py | 52 modules/prompt_parser.py | 54 modules/realesrgan_model.py | 1 modules/rng.py | 170 modules/rng_philox.py | 102 modules/script_callbacks.py | 27 modules/scripts.py | 204 modules/sd_disable_initialization.py | 146 modules/sd_hijack.py | 39 modules/sd_hijack_clip.py | 4 modules/sd_hijack_inpainting.py | 97 modules/sd_hijack_optimizations.py | 13 modules/sd_models.py | 296 modules/sd_models_config.py | 3 modules/sd_models_types.py | 31 modules/sd_models_xl.py | 17 modules/sd_samplers.py | 19 modules/sd_samplers_cfg_denoiser.py | 230 modules/sd_samplers_common.py | 250 modules/sd_samplers_compvis.py | 224 modules/sd_samplers_extra.py | 74 modules/sd_samplers_kdiffusion.py | 411 modules/sd_samplers_timesteps.py | 167 modules/sd_samplers_timesteps_impl.py | 137 modules/sd_unet.py | 2 modules/sd_vae.py | 109 modules/sd_vae_approx.py | 2 modules/sd_vae_taesd.py | 50 modules/shared.py | 890 modules/shared_cmd_options.py | 18 modules/shared_gradio_themes.py | 67 modules/shared_init.py | 49 modules/shared_items.py | 55 modules/shared_options.py | 330 modules/shared_state.py | 159 modules/shared_total_tqdm.py | 37 modules/styles.py | 5 modules/sub_quadratic_attention.py | 6 modules/sysinfo.py | 10 modules/textual_inversion/textual_inversion.py | 24 modules/timer.py | 24 modules/txt2img.py | 18 modules/ui.py | 778 modules/ui_checkpoint_merger.py | 124 modules/ui_common.py | 53 modules/ui_components.py | 73 modules/ui_extensions.py | 261 modules/ui_extra_networks.py | 90 modules/ui_extra_networks_checkpoints.py | 10 modules/ui_extra_networks_checkpoints_user_metadata.py | 66 modules/ui_extra_networks_hypernets.py | 9 modules/ui_extra_networks_textual_inversion.py | 5 modules/ui_extra_networks_user_metadata.py | 11 modules/ui_loadsave.py | 10 modules/ui_postprocessing.py | 2 modules/ui_prompt_styles.py | 110 modules/ui_settings.py | 2 modules/ui_tempdir.py | 8 modules/util.py | 58 requirements.txt | 5 requirements_versions.txt | 13 scripts/xyz_grid.py | 188 style.css | 205 test/conftest.py | 15 webui-macos-env.sh | 3 webui.py | 417 webui.sh | 7
Merge branch 'dev' into release_candidate
diff --git a/.eslintrc.js b/.eslintrc.js index f33aca09fa022638e45e8737386402711e464656..4777c276e9b13fa04ce3e9c7222df3d357fd824e 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -87,5 +87,11 @@ modalPrevImage: "readonly", modalNextImage: "readonly", // token-counters.js setupTokenCounters: "readonly", + // localStorage.js + localSet: "readonly", + localGet: "readonly", + localRemove: "readonly", + // resizeHandle.js + setupResizeHandle: "writable" } }; diff --git a/CHANGELOG.md b/CHANGELOG.md index 461fef9a784f1a4406d354eae35ee5ae30dd8f1a..d55925e420dfdd108b656546030a71df54ea559b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,134 @@ +## 1.6.0 + +### Features: + * refiner support [#12371](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12371) + * add NV option for Random number generator source setting, which allows to generate same pictures on CPU/AMD/Mac as on NVidia videocards + * add style editor dialog + * hires fix: add an option to use a different checkpoint for second pass ([#12181](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12181)) + * option to keep multiple loaded models in memory ([#12227](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12227)) + * new samplers: Restart, DPM++ 2M SDE Exponential, DPM++ 2M SDE Heun, DPM++ 2M SDE Heun Karras, DPM++ 2M SDE Heun Exponential, DPM++ 3M SDE, DPM++ 3M SDE Karras, DPM++ 3M SDE Exponential ([#12300](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12300), [#12519](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12519), [#12542](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12542)) + * rework DDIM, PLMS, UniPC to use CFG denoiser same as in k-diffusion samplers: + * makes all of them work with img2img + * makes prompt composition posssible (AND) + * makes them available for SDXL + * always show extra networks tabs in the UI ([#11808](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11808)) + * use less RAM when creating models ([#11958](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11958), [#12599](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12599)) + * textual inversion inference support for SDXL + * extra networks UI: show metadata for SD checkpoints + * checkpoint merger: add metadata support + * prompt editing and attention: add support for whitespace after the number ([ red : green : 0.5 ]) (seed breaking change) ([#12177](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12177)) + * VAE: allow selecting own VAE for each checkpoint (in user metadata editor) + * VAE: add selected VAE to infotext + * options in main UI: add own separate setting for txt2img and img2img, correctly read values from pasted infotext, add setting for column count ([#12551](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12551)) + * add resize handle to txt2img and img2img tabs, allowing to change the amount of horizontable space given to generation parameters and resulting image gallery ([#12687](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12687), [#12723](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12723)) + * change default behavior for batching cond/uncond -- now it's on by default, and is disabled by an UI setting (Optimizatios -> Batch cond/uncond) - if you are on lowvram/medvram and are getting OOM exceptions, you will need to enable it + * show current position in queue and make it so that requests are processed in the order of arrival ([#12707](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12707)) + * add `--medvram-sdxl` flag that only enables `--medvram` for SDXL models + * prompt editing timeline has separate range for first pass and hires-fix pass (seed breaking change) ([#12457](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12457)) + +### Minor: + * img2img batch: RAM savings, VRAM savings, .tif, .tiff in img2img batch ([#12120](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12120), [#12514](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12514), [#12515](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12515)) + * postprocessing/extras: RAM savings ([#12479](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12479)) + * XYZ: in the axis labels, remove pathnames from model filenames + * XYZ: support hires sampler ([#12298](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12298)) + * XYZ: new option: use text inputs instead of dropdowns ([#12491](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12491)) + * add gradio version warning + * sort list of VAE checkpoints ([#12297](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12297)) + * use transparent white for mask in inpainting, along with an option to select the color ([#12326](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12326)) + * move some settings to their own section: img2img, VAE + * add checkbox to show/hide dirs for extra networks + * Add TAESD(or more) options for all the VAE encode/decode operation ([#12311](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12311)) + * gradio theme cache, new gradio themes, along with explanation that the user can input his own values ([#12346](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12346), [#12355](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12355)) + * sampler fixes/tweaks: s_tmax, s_churn, s_noise, s_tmax ([#12354](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12354), [#12356](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12356), [#12357](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12357), [#12358](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12358), [#12375](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12375), [#12521](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12521)) + * update README.md with correct instructions for Linux installation ([#12352](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12352)) + * option to not save incomplete images, on by default ([#12338](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12338)) + * enable cond cache by default + * git autofix for repos that are corrupted ([#12230](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12230)) + * allow to open images in new browser tab by middle mouse button ([#12379](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12379)) + * automatically open webui in browser when running "locally" ([#12254](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12254)) + * put commonly used samplers on top, make DPM++ 2M Karras the default choice + * zoom and pan: option to auto-expand a wide image, improved integration ([#12413](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12413), [#12727](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12727)) + * option to cache Lora networks in memory + * rework hires fix UI to use accordion + * face restoration and tiling moved to settings - use "Options in main UI" setting if you want them back + * change quicksettings items to have variable width + * Lora: add Norm module, add support for bias ([#12503](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12503)) + * Lora: output warnings in UI rather than fail for unfitting loras; switch to logging for error output in console + * support search and display of hashes for all extra network items ([#12510](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12510)) + * add extra noise param for img2img operations ([#12564](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12564)) + * support for Lora with bias ([#12584](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12584)) + * make interrupt quicker ([#12634](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12634)) + * configurable gallery height ([#12648](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12648)) + * make results column sticky ([#12645](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12645)) + * more hash filename patterns ([#12639](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12639)) + * make image viewer actually fit the whole page ([#12635](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12635)) + * make progress bar work independently from live preview display which results in it being updated a lot more often + * forbid Full live preview method for medvram and add a setting to undo the forbidding + * make it possible to localize tooltips and placeholders + +### Extensions and API: + * gradio 3.39 + * also bump versions for packages: transformers, GitPython, accelerate, scikit-image, timm, tomesd + * support tooltip kwarg for gradio elements: gr.Textbox(label='hello', tooltip='world') + * properly clear the total console progressbar when using txt2img and img2img from API + * add cmd_arg --disable-extra-extensions and --disable-all-extensions ([#12294](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12294)) + * shared.py and webui.py split into many files + * add --loglevel commandline argument for logging + * add a custom UI element that combines accordion and checkbox + * avoid importing gradio in tests because it spams warnings + * put infotext label for setting into OptionInfo definition rather than in a separate list + * make `StableDiffusionProcessingImg2Img.mask_blur` a property, make more inline with PIL `GaussianBlur` ([#12470](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12470)) + * option to make scripts UI without gr.Group + * add a way for scripts to register a callback for before/after just a single component's creation + * use dataclass for StableDiffusionProcessing + * store patches for Lora in a specialized module instead of inside torch + * support http/https URLs in API ([#12663](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12663), [#12698](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12698)) + * add extra noise callback ([#12616](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12616)) + * dump current stack traces when exiting with SIGINT + * add type annotations for extra fields of shared.sd_model + +### Bug Fixes: + * Don't crash if out of local storage quota for javascriot localStorage + * XYZ plot do not fail if an exception occurs + * fix missing TI hash in infotext if generation uses both negative and positive TI ([#12269](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12269)) + * localization fixes ([#12307](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12307)) + * fix sdxl model invalid configuration after the hijack + * correctly toggle extras checkbox for infotext paste ([#12304](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12304)) + * open raw sysinfo link in new page ([#12318](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12318)) + * prompt parser: Account for empty field in alternating words syntax ([#12319](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12319)) + * add tab and carriage return to invalid filename chars ([#12327](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12327)) + * fix api only Lora not working ([#12387](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12387)) + * fix options in main UI misbehaving when there's just one element + * make it possible to use a sampler from infotext even if it's hidden in the dropdown + * fix styles missing from the prompt in infotext when making a grid of batch of multiplie images + * prevent bogus progress output in console when calculating hires fix dimensions + * fix --use-textbox-seed + * fix broken `Lora/Networks: use old method` option ([#12466](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12466)) + * properly return `None` for VAE hash when using `--no-hashing` ([#12463](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12463)) + * MPS/macOS fixes and optimizations ([#12526](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12526)) + * add second_order to samplers that mistakenly didn't have it + * when refreshing cards in extra networks UI, do not discard user's custom resolution + * fix processing error that happens if batch_size is not a multiple of how many prompts/negative prompts there are ([#12509](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12509)) + * fix inpaint upload for alpha masks ([#12588](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12588)) + * fix exception when image sizes are not integers ([#12586](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12586)) + * fix incorrect TAESD Latent scale ([#12596](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12596)) + * auto add data-dir to gradio-allowed-path ([#12603](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12603)) + * fix exception if extensuions dir is missing ([#12607](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12607)) + * fix issues with api model-refresh and vae-refresh ([#12638](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12638)) + * fix img2img background color for transparent images option not being used ([#12633](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12633)) + * attempt to resolve NaN issue with unstable VAEs in fp32 mk2 ([#12630](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12630)) + * implement missing undo hijack for SDXL + * fix xyz swap axes ([#12684](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12684)) + * fix errors in backup/restore tab if any of config files are broken ([#12689](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12689)) + * fix SD VAE switch error after model reuse ([#12685](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12685)) + * fix trying to create images too large for the chosen format ([#12667](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12667)) + * create Gradio temp directory if necessary ([#12717](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12717)) + * prevent possible cache loss if exiting as it's being written by using an atomic operation to replace the cache with the new version + * set devices.dtype_unet correctly + * run RealESRGAN on GPU for non-CUDA devices ([#12737](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12737)) + * prevent extra network buttons being obscured by description for very small card sizes ([#12745](https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/12745)) + + ## 1.5.2 ### Bug Fixes: diff --git a/CITATION.cff b/CITATION.cff new file mode 100644 index 0000000000000000000000000000000000000000..2c781aff450c8604eb3cf876d2c3585a96a5a590 --- /dev/null +++ b/CITATION.cff @@ -0,0 +1,7 @@ +cff-version: 1.2.0 +message: "If you use this software, please cite it as below." +authors: + - given-names: AUTOMATIC1111 +title: "Stable Diffusion Web UI" +date-released: 2022-08-22 +url: "https://github.com/AUTOMATIC1111/stable-diffusion-webui" diff --git a/README.md b/README.md index b796d15004187e1db3dc6f012977e27dcbc2502d..4e08344008caf22fc8a8865de7bc9744061ffec0 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,8 @@ - Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime) - Clip skip - Hypernetworks - Loras (same as Hypernetworks but more pretty) + - a man in a `((tuxedo))` - will pay more attention to tuxedo - One click install and run script (but you still must install python and git) - Can select to load a different VAE from settings screen - Estimated completion time in progress bar - API @@ -87,12 +89,16 @@ - [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions - [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions - Now without any bad letters! - Load checkpoints in safetensors format + - a man in a `((tuxedo))` - will pay more attention to tuxedo - Outpainting - Now with a license! - Reorder elements in the UI from settings screen ## Installation and Running -Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. +Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for: +- [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) +- [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. +- [Intel CPUs, Intel GPUs (both integrated and discrete)](https://github.com/openvinotoolkit/stable-diffusion-webui/wiki/Installation-on-Intel-Silicon) (external wiki page) Alternatively, use online services (like Google Colab): @@ -114,7 +121,7 @@ 1. Install the dependencies: ```bash # Debian-based: # Stable Diffusion web UI -- Color Sketch +- Checkpoint Merger, a tab that allows you to merge up to 3 checkpoints into one # Red Hat-based: sudo dnf install wget git python3 # Arch-based: @@ -123,7 +130,7 @@ ``` 2. Navigate to the directory you would like the webui to be installed and execute the following command: ```bash # Stable Diffusion web UI -- Loopback, run img2img processing multiple times +- [Custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Scripts) with many extensions from community ``` 3. Run `webui.sh`. 4. Check `webui-user.sh` for options. @@ -169,5 +176,6 @@ - Security advice - RyotaK - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC - TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd - LyCORIS - KohakuBlueleaf +- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - (You) diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index ba2945c6fe1e77d87226f08fb20da0624959364b..005ff32cbe3718c1d13deba2090025a060a0b704 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -6,8 +6,13 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): def __init__(self): super().__init__('lora') + self.errors = {} + """mapping of network names to the number of errors the network had during operation""" + def activate(self, p, params_list): additional = shared.opts.sd_lora + + self.errors.clear() if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional): p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts] @@ -57,4 +62,7 @@ p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes) def deactivate(self, p): class ExtraNetworkLora(extra_networks.ExtraNetwork): -import networks + super().__init__('lora') + p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items())) + + self.errors.clear() diff --git a/extensions-builtin/Lora/lora_patches.py b/extensions-builtin/Lora/lora_patches.py new file mode 100644 index 0000000000000000000000000000000000000000..b394d8e9ed41ff2fc56de03d6e6194aa50b5f2b2 --- /dev/null +++ b/extensions-builtin/Lora/lora_patches.py @@ -0,0 +1,31 @@ +import torch + +import networks +from modules import patches + + +class LoraPatches: + def __init__(self): + self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward) + self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict) + self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward) + self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict) + self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward) + self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict) + self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward) + self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict) + self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward) + self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict) + + def undo(self): + self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward') + self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict') + self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward') + self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict') + self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward') + self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict') + self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward') + self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict') + self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward') + self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict') + diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 0a18d69eb26412d54c3b36e7801b5557691e2b68..d8e8dfb7ff0420c98f83ecc9ab92d02b3d40c8b5 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -133,7 +134,8 @@ return 1.0 from __future__ import annotations from __future__ import annotations +from collections import namedtuple if self.bias is not None: updown = updown.reshape(self.bias.shape) updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) @@ -146,7 +148,10 @@ if orig_weight.size().numel() == updown.size().numel(): updown = updown.reshape(orig_weight.shape) from __future__ import annotations -class SdVersion(enum.Enum): +class NetworkOnDisk: + ex_bias = ex_bias * self.multiplier() + + return updown * self.calc_scale() * self.multiplier(), ex_bias def calc_updown(self, target): raise NotImplementedError() diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py index 109b4c2c594e5079067d55331271ebafcf6c9fe4..bf6930e96c03924695f8c5f372eee6cec3a68670 100644 --- a/extensions-builtin/Lora/network_full.py +++ b/extensions-builtin/Lora/network_full.py @@ -14,10 +14,15 @@ def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) self.weight = weights.w.get("diff") + self.ex_bias = weights.w.get("diff_b") def calc_updown(self, orig_weight): output_shape = self.weight.shape updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + if self.ex_bias is not None: + ex_bias = self.ex_bias.to(orig_weight.device, dtype=orig_weight.dtype) + else: + + import network - if all(x in weights.w for x in ["diff"]): diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..ce450158068ef85ebe11cc60756ed991465c0e54 --- /dev/null +++ b/extensions-builtin/Lora/network_norm.py @@ -0,0 +1,28 @@ +import network + + +class ModuleTypeNorm(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["w_norm", "b_norm"]): + return NetworkModuleNorm(net, weights) + + return None + + +class NetworkModuleNorm(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.w_norm = weights.w.get("w_norm") + self.b_norm = weights.w.get("b_norm") + + def calc_updown(self, orig_weight): + output_shape = self.w_norm.shape + updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype) + + if self.b_norm is not None: + ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype) + else: + ex_bias = None + + return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 17cbe1bb7fe383ff1213b17d06d0665f7765a392..96f935b236fdf2afec46d902029b9ae2031b4ebd 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -1,12 +1,15 @@ +import logging import os import re +import lora_patches import network import network_lora import network_hada import network_ia3 import network_lokr import network_full +import network_norm import torch from typing import Union @@ -19,6 +22,7 @@ network_hada.ModuleTypeHada(), network_ia3.ModuleTypeIa3(), network_lokr.ModuleTypeLokr(), network_full.ModuleTypeFull(), + network_norm.ModuleTypeNorm(), ] @@ -31,6 +35,8 @@ "attentions": {}, "resnets": { "conv1": "in_layers_2", "conv2": "out_layers_3", + "norm1": "in_layers_0", + "norm2": "out_layers_0", "time_emb_proj": "emb_layers_1", "conv_shortcut": "skip_connection", } @@ -190,12 +196,20 @@ net.modules[key] = net_module if keys_failed_to_match: + import os + + network_lora.ModuleTypeLora(), import network +def purge_networks_from_memory(): + while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0: + name = next(iter(networks_in_memory)) + network_lora.ModuleTypeLora(), -import network + + devices.torch_gc() def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): @@ -215,19 +228,25 @@ networks_on_disk = [available_network_aliases.get(name, None) for name in names] failed_to_load_networks = [] + import os -import network_lora import network_hada net = already_loaded.get(name, None) network_hada.ModuleTypeHada(), -import network_lokr +import network_full + network_lokr.ModuleTypeLokr(), + import os - if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): +import network_lokr + if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: try: net = load_network(name, network_on_disk) + + networks_in_memory.pop(name, None) + networks_in_memory[name] = net except Exception as e: errors.display(e, f"loading network {network_on_disk.filename}") continue @@ -238,8 +257,8 @@ network_on_disk.read_hash() if net is None: failed_to_load_networks.append(name) +def convert_diffusers_name_to_compvis(key, is_sd2): import os - return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" continue net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 @@ -248,37 +267,54 @@ net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 loaded_networks.append(net) if failed_to_load_networks: + sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) + + purge_networks_from_memory() + + +def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): network_lokr.ModuleTypeLokr(), -import network_hada +import network_lokr + "resnets": { + if weights_backup is None and bias_backup is None: import os -import network_ia3 +def assign_network_names_to_compvis_modules(sd_model): + +def convert_diffusers_name_to_compvis(key, is_sd2): import network_ia3 import os -import network_ia3 + module.network_layer_name = network_name +def convert_diffusers_name_to_compvis(key, is_sd2): import network_lokr + "conv_shortcut": "skip_connection", -import os import network_ia3 -import network_full +import network_ia3 - network_full.ModuleTypeFull(), + def match(match_list, regex_text): - network_full.ModuleTypeFull(), + def match(match_list, regex_text): import os import os -import network_lokr + module.network_layer_name = network_name + def match(match_list, regex_text): import re - network_full.ModuleTypeFull(), + else: + def match(match_list, regex_text): else: import os -import network_lokr +import network_full import network + regex = re_compiled.get(regex_text) - + else: - network_full.ModuleTypeFull(), + def match(match_list, regex_text): import network_lora + + +def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): """ Applies the currently selected set of networks to the weights of torch layer self. If weights already have this particular set of networks applied, does nothing. @@ -293,9 +329,12 @@ current_names = getattr(self, "network_current_names", ()) wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) weights_backup = getattr(self, "network_weights_backup", None) -import os + def match(match_list, regex_text): import network_ia3 + if current_names != (): + def match(match_list, regex_text): import network_full + if isinstance(self, torch.nn.MultiheadAttention): weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) else: @@ -303,28 +342,52 @@ weights_backup = self.weight.to(devices.cpu, copy=True) self.network_weights_backup = weights_backup + bias_backup = getattr(self, "network_bias_backup", None) + if bias_backup is None: + if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None: + bias_backup = self.out_proj.bias.to(devices.cpu, copy=True) + elif getattr(self, 'bias', None) is not None: + bias_backup = self.bias.to(devices.cpu, copy=True) + else: + bias_backup = None + self.network_bias_backup = bias_backup + if current_names != wanted_names: network_restore_weights_from_backup(self) for net in loaded_networks: module = net.modules.get(network_layer_name, None) if module is not None and hasattr(self, 'weight'): -re_digits = re.compile(r"\d+") + try: + if match(m, r"lora_unet_conv_out(.*)"): -re_digits = re.compile(r"\d+") + import network +import network_lokr -import re + if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: + if regex is None: + import network_lora +import os + + if regex is None: import re + if ex_bias is not None and hasattr(self, 'bias'): + if self.bias is None: + self.bias = torch.nn.Parameter(ex_bias) + if regex is None: import network_hada -re_digits = re.compile(r"\d+") + if regex is None: import network_ia3 -re_digits = re.compile(r"\d+") +import network_lora import network_lokr -import os + logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") + import network_hada + +import network_lokr import network_lora module_q = net.modules.get(network_layer_name + "_q_proj", None) @@ -332,31 +396,44 @@ module_v = net.modules.get(network_layer_name + "_v_proj", None) module_out = net.modules.get(network_layer_name + "_out_proj", None) if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: -re_digits = re.compile(r"\d+") + try: + with torch.no_grad(): + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + regex = re.compile(regex_text) import re -import os + updown_v, _ = module_v.calc_updown(self.in_proj_weight) + regex = re.compile(regex_text) import network -re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") + regex = re.compile(regex_text) import network_lora -re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") + + regex = re.compile(regex_text) import network_hada -re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") + regex = re.compile(regex_text) import network_ia3 -re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") + regex = re.compile(regex_text) import network_lokr + else: + self.out_proj.bias = torch.nn.Parameter(ex_bias) + else: + re_compiled[regex_text] = regex import re -import os + + except RuntimeError as e: + if regex is None: import network_full -re_compiled = {} + regex = re.compile(regex_text) - network_ia3.ModuleTypeIa3(), + +import network_lokr import network_lora if module is None: continue - print(f'failed to calculate network weights for layer {network_layer_name}') + logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation") + extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 self.network_current_names = wanted_names @@ -383,8 +460,8 @@ module = lora.modules.get(network_layer_name, None) if module is None: continue -import re re_compiled[regex_text] = regex +import network_lora return y @@ -396,59 +473,93 @@ def network_Linear_forward(self, input): if shared.opts.lora_functional: - return network_forward(self, input, torch.nn.Linear_forward_before_network) + return network_forward(self, input, originals.Linear_forward) network_apply_weights(self) - return torch.nn.Linear_forward_before_network(self, input) + return originals.Linear_forward(self, input) def network_Linear_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) + return originals.Linear_load_state_dict(self, *args, **kwargs) + + +def network_Conv2d_forward(self, input): "attentions": {}, -import network_full + + return key import re +import network import network_lora + + r = re.match(regex, key) + + import re -import network +import network_lora import re -import network_lora + return f'diffusion_model_out_2{m[0]}' + + r = re.match(regex, key) import os + +def network_GroupNorm_forward(self, input): "attentions": {}, -import network_lora + + return network_forward(self, input, originals.GroupNorm_forward) import re +import network import network_lora -import re + + return originals.GroupNorm_forward(self, input) -import re + r = re.match(regex, key) import network_lora + network_reset_cached_weight(self) + return originals.GroupNorm_load_state_dict(self, *args, **kwargs) + + +def network_LayerNorm_forward(self, input): "attentions": {}, + + r = re.match(regex, key) import network_lokr import re +import network import network_lora + + return originals.LayerNorm_forward(self, input) + + +def network_LayerNorm_load_state_dict(self, *args, **kwargs): +import re import network +import network_lokr + + return originals.LayerNorm_load_state_dict(self, *args, **kwargs) def network_MultiheadAttention_forward(self, *args, **kwargs): network_apply_weights(self) + if not r: import re - return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs) + return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs) def list_available_networks(): @@ -516,9 +627,14 @@ if added: params["Prompt"] += "\n" + "".join(added) +originals: lora_patches.LoraPatches = None + +extra_network_lora = None + available_networks = {} available_network_aliases = {} loaded_networks = [] +networks_in_memory = {} available_network_hash_lookup = {} forbidden_network_aliases = {} diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index cd28afc92e7ae82d9df4329febcc28f40a254abe..ef23968c563351b3409bb472e1ddb81ee7b489ea 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -1,64 +1,35 @@ import re -import torch import gradio as gr from fastapi import FastAPI import network import networks import lora # noqa:F401 +import lora_patches import extra_networks_lora import ui_extra_networks_lora from modules import script_callbacks, ui_extra_networks, extra_networks, shared -import re import re -import torch + - torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network - torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network - torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network - torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network -import re +import lora # noqa:F401 import lora # noqa:F401 def before_ui(): ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) - extra_network = extra_networks_lora.ExtraNetworkLora() - extra_networks.register_extra_network(extra_network) - extra_networks.register_extra_network_alias(extra_network, "lyco") - - -if not hasattr(torch.nn, 'Linear_forward_before_network'): - torch.nn.Linear_forward_before_network = torch.nn.Linear.forward - -if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'): - torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict - - import lora # noqa:F401 - import extra_networks_lora - -import torch +import extra_networks_lora -import torch +import extra_networks_lora import re -import torch - torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward - -if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'): - torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict +import extra_networks_lora -torch.nn.Linear.forward = networks.network_Linear_forward -torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict -torch.nn.Conv2d.forward = networks.network_Conv2d_forward -torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict -torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward -torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) @@ -72,6 +43,7 @@ "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), + "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}), })) @@ -128,3 +100,5 @@ d["Prompt"] = re.sub(re_lora, network_replacement, d["Prompt"]) script_callbacks.on_infotext_pasted(infotext_pasted) + +shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 2ca997f7ce94d76ec89a62f7556eecf8d502775d..390d9dde3fbc0a848185809f39ea694a7bd00ac2 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -168,7 +168,7 @@ random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False) with gr.Column(scale=1, min_width=120): import datetime - tag = tag.strip() + table = super().get_metadata_table(name) self.edit_notes = gr.TextArea(label='Notes', lines=4) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 3629e5c0cf227192d5618fb0800a2be75f84ccf8..55409a7829d828a45e85cd9d8f63ed71b2c1cdcb 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -25,9 +25,10 @@ item = { "name": name, "filename": lora_on_disk.filename, + "shorthash": lora_on_disk.shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(lora_on_disk.filename), + "search_term": self.search_terms_from_path(lora_on_disk.filename) + " " + (lora_on_disk.hash or ""), "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": lora_on_disk.metadata, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js index 30199dcd60aa3df4b5440c1dfa0de0319ac1374a..234238910a9ee1a7552b5690dfcd3b831839009f 100644 --- a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js +++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js @@ -13,9 +13,25 @@ "Sketch": elementIDs.sketch }; onUiLoaded(async() => { + img2imgTabs: "#mode_img2img .tab-nav", +onUiLoaded(async() => { inpaint: "#img2maskimg", // Get active tab onUiLoaded(async() => { + img2imgTabs: "#mode_img2img .tab-nav", + /** + * Waits for an element to be present in the DOM. + */ + const waitForElement = (id) => new Promise(resolve => { + const checkForElement = () => { + const element = document.querySelector(id); + if (element) return resolve(element); + setTimeout(checkForElement, 100); + }; + checkForElement(); + }); + +onUiLoaded(async() => { rangeGroup: "#img2img_column_size", const tabs = elements.img2imgTabs.querySelectorAll("button"); @@ -36,13 +52,19 @@ } // Wait until opts loaded async function waitForOpts() { - img2imgTabs: "#mode_img2img .tab-nav", + inpaintSketch: "#inpaint_sketch", onUiLoaded(async() => { + inpaintSketch: "#inpaint_sketch", if (window.opts && Object.keys(window.opts).length) { return window.opts; } await new Promise(resolve => setTimeout(resolve, 100)); } + } + + // Detect whether the element has a horizontal scroll bar + function hasHorizontalScrollbar(element) { + return element.scrollWidth > element.clientWidth; } // Function for defining the "Ctrl", "Shift" and "Alt" keys @@ -204,9 +226,11 @@ canvas_hotkey_move: "KeyF", canvas_hotkey_overlap: "KeyO", canvas_disabled_functions: [], canvas_show_tooltip: true, + inpaintSketch: "#inpaint_sketch", onUiLoaded(async() => { + const tabNameToElementId = { inpaintSketch: "#inpaint_sketch", - }; + if (tab.classList.contains("selected")) { }; const functionMap = { @@ -254,8 +278,8 @@ for (const input of rangeInputs) { input?.addEventListener("input", () => restoreImgRedMask(elements)); } + /[a-z]/i.test(value)) || onUiLoaded(async() => { - } const targetElement = gradioApp().querySelector(elemId); if (!targetElement) { @@ -367,6 +391,10 @@ panX: 0, panY: 0 }; + if (isExtension) { + targetElement.style.overflow = "hidden"; + } + fixCanvas(); targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`; @@ -377,31 +405,51 @@ toggleOverlap("off"); fullScreenMode = false; - if ( + const closeBtn = targetElement.querySelector("button[aria-label='Remove Image']"); + inpaintSketch: "#inpaint_sketch", const elementIDs = { - // Check for conflicting hotkeys + inpaintSketch: "#inpaint_sketch", + inpaintSketch: "#inpaint_sketch", const elementIDs = { - if (!usedKeys.has(normalizedUserValue)) { + rangeGroup: "#img2img_column_size", const elementIDs = { - }; const elementIDs = { - sketch: "#img2img_sketch" + inpaintSketch: "#inpaint_sketch", + const activeTab = getActiveTab(elements); + inpaintSketch: "#inpaint_sketch", return tabNameToElementId[activeTab.innerText]; + if ( + inpaintSketch: "#inpaint_sketch", img2imgTabs: "#mode_img2img .tab-nav", + parseFloat(canvas.style.width) > parentElement.offsetWidth && + specialKeys.includes(value) const elementIDs = { - }; + ) { + fitToElement(); inpaint: "#img2maskimg", + } + } + } + if ( return tabNameToElementId[activeTab.innerText]; + specialKeys.includes(value) inpaintSketch: "#inpaint_sketch", return tabNameToElementId[activeTab.innerText]; - rangeGroup: "#img2img_column_size", +onUiLoaded(async() => { return tabNameToElementId[activeTab.innerText]; + const elementIDs = { sketch: "#img2img_sketch" + inpaintSketch: "#inpaint_sketch", const elementIDs = { + result[key] = normalizedUserValue; const elementIDs = { + } else { + } + + targetElement.style.width = ""; } // Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements @@ -468,6 +515,10 @@ targetElement.style.transformOrigin = "0 0"; targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`; toggleOverlap("on"); + if (isExtension) { + targetElement.style.overflow = "visible"; + } + return newZoomLevel; } @@ -490,9 +541,8 @@ fullScreenMode = false; elemData[elemId].zoomLevel = updateZoom( elemData[elemId].zoomLevel + - img2imgTabs: "#mode_img2img .tab-nav", inpaintSketch: "#inpaint_sketch", -onUiLoaded(async() => { + switch (key) { zoomPosX - targetElement.getBoundingClientRect().left, zoomPosY - targetElement.getBoundingClientRect().top ); @@ -509,11 +559,20 @@ function fitToElement() { //Reset Zoom targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`; + let parentElement; + + if (isExtension) { + parentElement = targetElement.closest('[id^="component-"]'); + } else { + parentElement = targetElement.parentElement; + } + + // Get element and screen dimensions const elementWidth = targetElement.offsetWidth; const elementHeight = targetElement.offsetHeight; +onUiLoaded(async() => { img2imgTabs: "#mode_img2img .tab-nav", - // Iterate through defaultHotkeysConfig keys const screenWidth = parentElement.clientWidth; const screenHeight = parentElement.clientHeight; @@ -564,13 +623,19 @@ const canvas = gradioApp().querySelector( `${elemId} canvas[key="interface"]` ); + if (isExtension) { + inpaintSketch: "#inpaint_sketch", img2imgTabs: "#mode_img2img .tab-nav", - console.error( + rangeGroup: "#img2img_column_size", + } + return event.ctrlKey; - rangeGroup: "#img2img_column_size", + inpaintSketch: "#inpaint_sketch", +onUiLoaded(async() => { img2imgTabs: "#mode_img2img .tab-nav", - userValue + if (canvas.offsetWidth > 862 || isExtension) { + targetElement.style.width = (canvas.offsetWidth + 2) + "px"; } if (fullScreenMode) { @@ -672,8 +737,49 @@ mouseX = e.offsetX; mouseY = e.offsetY; } + inpaintSketch: "#inpaint_sketch", return event.altKey; + // We detect if an image has a scroll bar or not, make a fullscreen to reveal the image, then reduce it to fit into the element. + // We hide the image and show it to the user when it is ready. + + targetElement.isExpanded = false; + function autoExpand() { + ); }; + if (canvas) { + if (hasHorizontalScrollbar(targetElement) && targetElement.isExpanded === false) { + targetElement.style.visibility = "hidden"; + setTimeout(() => { + fitToScreen(); + resetZoom(); + targetElement.style.visibility = "visible"; + targetElement.isExpanded = true; + }, 10); + } + } + } + + targetElement.addEventListener("mousemove", getMousePosition); + + //observers + // Creating an observer with a callback function to handle DOM changes + const observer = new MutationObserver((mutationsList, observer) => { + for (let mutation of mutationsList) { + // If the style attribute of the canvas has changed, by observation it happens only when the picture changes + if (mutation.type === 'attributes' && mutation.attributeName === 'style' && + mutation.target.tagName.toLowerCase() === 'canvas') { + targetElement.isExpanded = false; + setTimeout(resetZoom, 10); + } + } + }); + + // Apply auto expand if enabled + if (hotkeysConfig.canvas_auto_expand) { + targetElement.addEventListener("mousemove", autoExpand); + // Set up an observer to track attribute changes + observer.observe(targetElement, {attributes: true, childList: true, subtree: true}); + } // Handle events only inside the targetElement let isKeyDownHandlerAttached = false; @@ -779,6 +885,11 @@ function handleMoveByKey(e) { if (isMoving && elemId === activeElement) { updatePanPosition(e.movementX, e.movementY); targetElement.style.pointerEvents = "none"; + + if (isExtension) { + targetElement.style.overflow = "visible"; + } + } else { targetElement.style.pointerEvents = "auto"; } @@ -793,14 +904,66 @@ gradioApp().addEventListener("mousemove", handleMoveByKey); } inpaintSketch: "#inpaint_sketch", + typeof userValue === "object" || + applyZoomAndPan(elementIDs.inpaint, false); + applyZoomAndPan(elementIDs.inpaintSketch, false); + + inpaintSketch: "#inpaint_sketch", -onUiLoaded(async() => { + inpaint: "#img2maskimg", + const applyZoomAndPanIntegration = async(id, elementIDs) => { + const mainEl = document.querySelector(id); + if (id.toLocaleLowerCase() === "none") { + for (const elementID of elementIDs) { inpaintSketch: "#inpaint_sketch", + }; + if (!el) break; + applyZoomAndPan(elementID); + const elementIDs = { const elementIDs = { + return; + } + inpaintSketch: "#inpaint_sketch", + result[key] = normalizedUserValue; + mainEl.addEventListener("click", async() => { + for (const elementID of elementIDs) { + // Format hotkey for display + if (!el) break; + applyZoomAndPan(elementID); + } + }, {once: true}); + }; +onUiLoaded(async() => { img2imgTabs: "#mode_img2img .tab-nav", + window.applyZoomAndPan = applyZoomAndPan; // Only 1 elements, argument elementID, for example applyZoomAndPan("#txt2img_controlnet_ControlNet_input_image") inpaintSketch: "#inpaint_sketch", + userValue + + /* + The function `applyZoomAndPanIntegration` takes two arguments: + + function formatHotkeyForDisplay(hotkey) { + If the `id` value is "none", the functionality will be applied to all elements specified in the second argument without a click event. + + 2. `elementIDs`: An array of string identifiers for elements. Zoom and pan functionality will be applied to each of these elements on click of the element specified by the first argument. + If "none" is specified in the first argument, the functionality will be applied to each of these elements without a click event. + + function formatHotkeyForDisplay(hotkey) { inpaint: "#img2maskimg", inpaintSketch: "#inpaint_sketch", + const tabNameToElementId = { inpaintSketch: "#inpaint_sketch", + In this example, zoom and pan functionality will be applied to the element with the identifier "txt2img_controlnet_ControlNet_input_image" upon clicking the element with the identifier "txt2img_controlnet". + */ + + // More examples + // Add integration with ControlNet txt2img One TAB + // applyZoomAndPanIntegration("#txt2img_controlnet", ["#txt2img_controlnet_ControlNet_input_image"]); + + // Add integration with ControlNet txt2img Tabs + // applyZoomAndPanIntegration("#txt2img_controlnet",Array.from({ length: 10 }, (_, i) => `#txt2img_controlnet_ControlNet-${i}_input_image`)); + + // Add integration with Inpaint Anything + // applyZoomAndPanIntegration("None", ["#ia_sam_image", "#ia_sel_mask"]); }); diff --git a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py index 380176ce26ccecbdfa1a64791543f3061eba64ed..2d8d2d1c014be5dc1bac24b2c71079351fe1177e 100644 --- a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py +++ b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py @@ -10,6 +10,8 @@ "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), "canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, neededs for testing"), "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"), import gradio as gr + "canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), +import gradio as gr from modules import shared "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size", "Moving canvas","Fullscreen","Reset Zoom","Overlap"]}), })) diff --git a/extensions-builtin/canvas-zoom-and-pan/style.css b/extensions-builtin/canvas-zoom-and-pan/style.css index 6bcc9570c45cf9b2ac426dd5981d78dcb0ac72d0..5d8054e65196408c97791727088088650f102b21 100644 --- a/extensions-builtin/canvas-zoom-and-pan/style.css +++ b/extensions-builtin/canvas-zoom-and-pan/style.css @@ -61,3 +61,6 @@ from {opacity: 0;} to {opacity: 1;} } +.styler { + overflow:inherit !important; +} \ No newline at end of file diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py index a05e10d865ab5a5e30dd9db936bad4cb96c4643a..983f87ff0335ef951cd091949c914fe3d597b665 100644 --- a/extensions-builtin/extra-options-section/scripts/extra_options_section.py +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -1,5 +1,7 @@ +import math + import gradio as gr -from modules import scripts, shared, ui_components, ui_settings +from modules import scripts, shared, ui_components, ui_settings, generation_parameters_copypaste from modules.ui_components import FormColumn @@ -19,25 +21,46 @@ def ui(self, is_img2img): self.comps = [] self.setting_names = [] + self.infotext_fields = [] + extra_options = shared.opts.extra_options_img2img if is_img2img else shared.opts.extra_options_txt2img + + mapping = {k: v for v, k in generation_parameters_copypaste.infotext_to_setting_name_mapping} with gr.Blocks() as interface: + import gradio as gr + + row_count = math.ceil(len(extra_options) / shared.opts.extra_options_cols) + + for row in range(row_count): + with gr.Row(): + for col in range(shared.opts.extra_options_cols): + index = row * shared.opts.extra_options_cols + col + if index >= len(extra_options): + self.comps = None -import gradio as gr + + self.setting_names = None -from modules import scripts, shared, ui_components, ui_settings + +class ExtraOptionsSection(scripts.Script): -from modules import scripts, shared, ui_components, ui_settings +class ExtraOptionsSection(scripts.Script): import gradio as gr -from modules import scripts, shared, ui_components, ui_settings +class ExtraOptionsSection(scripts.Script): from modules import scripts, shared, ui_components, ui_settings -from modules import scripts, shared, ui_components, ui_settings +class ExtraOptionsSection(scripts.Script): from modules.ui_components import FormColumn -from modules import scripts, shared, ui_components, ui_settings + setting_infotext_name = mapping.get(setting_name) + if setting_infotext_name is not None: + self.infotext_fields.append((comp, setting_infotext_name)) from modules import scripts, shared, ui_components, ui_settings + class ExtraOptionsSection(scripts.Script): + def __init__(self): + return res[0] if len(res) == 1 else res interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False) @@ -50,6 +73,10 @@ p.override_settings[name] = value shared.options_templates.update(shared.options_section(('ui', "User interface"), { - "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(), + "extra_options_txt2img": shared.OptionInfo([], "Options in main UI - txt2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img interfaces").needs_reload_ui(), - "extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion") + "extra_options_img2img": shared.OptionInfo([], "Options in main UI - img2img", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in img2img interfaces").needs_reload_ui(), + "extra_options_cols": shared.OptionInfo(1, "Options in main UI - number of columns", gr.Number, {"precision": 0}).needs_reload_ui(), + "extra_options_accordion": shared.OptionInfo(False, "Options in main UI - place into an accordion").needs_reload_ui() })) + + diff --git a/extensions-builtin/mobile/javascript/mobile.js b/extensions-builtin/mobile/javascript/mobile.js index 12cae4b75764779f7da3e424a959f966c06a8648..652f07ac7eceb7ac780d6c19c1be85480471491a 100644 --- a/extensions-builtin/mobile/javascript/mobile.js +++ b/extensions-builtin/mobile/javascript/mobile.js @@ -20,7 +20,13 @@ for (var tab of ["txt2img", "img2img"]) { var button = gradioApp().getElementById(tab + '_generate_box'); var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column'); target.insertBefore(button, target.firstElementChild); + + gradioApp().getElementById(tab + '_results').classList.toggle('mobile', currentlyMobile); } } window.addEventListener("resize", reportWindowSize); + +onUiLoaded(function() { + reportWindowSize(); +}); diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 5582a6e5d3b4ea553e661f00048fa98b12564d90..3bc723d3718326418820a7efdba109060e6af421 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -1,23 +1,49 @@ +function toggleCss(key, css, enable) { + var style = document.getElementById(key); + if (enable && !style) { + var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase(); function setupExtraNetworksForTab(tabname) { + var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase(); gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks'); + var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase(); + var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase(); var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); + var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea'); + if (style && !enable) { + var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase(); var sort = gradioApp().getElementById(tabname + '_extra_sort'); var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); + var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea'); - var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); + if (style) { - + style.innerHTML == ''; + var text = elem.querySelector('.name').textContent.toLowerCase() + " " + elem.querySelector('.search_term').textContent.toLowerCase(); search.classList.add('search'); + } +} + function setupExtraNetworksForTab(tabname) { + gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks'); + + var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); + var visible = text.indexOf(searchTerm) != -1; + var visible = text.indexOf(searchTerm) != -1; function setupExtraNetworksForTab(tabname) { -function setupExtraNetworksForTab(tabname) { + var sort = gradioApp().getElementById(tabname + '_extra_sort'); -function setupExtraNetworksForTab(tabname) { + var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); + var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks'); + var visible = text.indexOf(searchTerm) != -1; + var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input'); + function setupExtraNetworksForTab(tabname) { - + gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks'); + tabs.appendChild(searchDiv); tabs.appendChild(sort); tabs.appendChild(sortOrder); tabs.appendChild(refresh); + tabs.appendChild(showDirsDiv); var applyFilter = function() { var searchTerm = search.value.toLowerCase(); @@ -83,6 +109,15 @@ applySort(); }); extraNetworksApplyFilter[tabname] = applyFilter; + + var showDirsUpdate = function() { + var css = '#' + tabname + '_extra_tabs .extra-network-subdirs { display: none; }'; + toggleCss(tabname + '_extra_show_dirs_style', css, !showDirs.checked); + localSet('extra-networks-show-dirs', showDirs.checked ? 1 : 0); + }; + showDirs.checked = localGet('extra-networks-show-dirs', 1) == 1; + showDirs.addEventListener("change", showDirsUpdate); + showDirsUpdate(); } function applyExtraNetworkFilter(tabname) { @@ -182,8 +217,8 @@ event.preventDefault(); } function extraNetworksSearchButton(tabs_id, event) { -function setupExtraNetworksForTab(tabname) { gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks'); + var sortKeyStore = sortKey ? sortKey + (reverse ? "Reverse" : "") : ""; var button = event.target; var text = button.classList.contains("search-all") ? "" : button.textContent.trim(); @@ -309,7 +345,7 @@ newDiv.innerHTML = data.html; var newCard = newDiv.firstElementChild; gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks'); - var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea'); + if (!sortKey || sortKeyStore == sort.dataset.sortkey) { card.parentElement.insertBefore(newCard, card); card.parentElement.removeChild(card); } diff --git a/javascript/hints.js b/javascript/hints.js index 4167cb28b7c0ed934b37d346aacd3784ebec1016..6de9372e8ea8c9fb032351e241d0f9c265995290 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -190,3 +190,14 @@ clearTimeout(tooltipCheckTimer); tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000); } }); + +onUiLoaded(function() { + for (var comp of window.gradio_config.components) { + if (comp.props.webui_tooltip && comp.props.elem_id) { + var elem = gradioApp().getElementById(comp.props.elem_id); + if (elem) { + elem.title = comp.props.webui_tooltip; + } + } + } +}); diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 677e95c1bc7b700e61bd5e6263e980dd703165e2..c21d396eefd5283691091fc5b87aba570a325297 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -136,6 +136,11 @@ // For other browsers the event is click to make it possiblr to drag picture. var event = isFirefox ? 'mousedown' : 'click'; e.addEventListener(event, function(evt) { + if (evt.button == 1) { + open(evt.target.src); + evt.preventDefault(); + return; + } if (!opts.js_modal_lightbox || evt.button != 0) return; modalZoomSet(gradioApp().getElementById('modalImage'), opts.js_modal_lightbox_initially_zoomed); diff --git a/javascript/inputAccordion.js b/javascript/inputAccordion.js new file mode 100644 index 0000000000000000000000000000000000000000..f2839852ee710bc1f4ae03e6788c1781001006a0 --- /dev/null +++ b/javascript/inputAccordion.js @@ -0,0 +1,37 @@ +var observerAccordionOpen = new MutationObserver(function(mutations) { + mutations.forEach(function(mutationRecord) { + var elem = mutationRecord.target; + var open = elem.classList.contains('open'); + + var accordion = elem.parentNode; + accordion.classList.toggle('input-accordion-open', open); + + var checkbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input"); + checkbox.checked = open; + updateInput(checkbox); + + var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); + if (extra) { + extra.style.display = open ? "" : "none"; + } + }); +}); + +function inputAccordionChecked(id, checked) { + var label = gradioApp().querySelector('#' + id + " .label-wrap"); + if (label.classList.contains('open') != checked) { + label.click(); + } +} + +onUiLoaded(function() { + for (var accordion of gradioApp().querySelectorAll('.input-accordion')) { + var labelWrap = accordion.querySelector('.label-wrap'); + observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']}); + + var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); + if (extra) { + labelWrap.insertBefore(extra, labelWrap.lastElementChild); + } + } +}); diff --git a/javascript/localStorage.js b/javascript/localStorage.js new file mode 100644 index 0000000000000000000000000000000000000000..dc1a36c328799ea3df1843001d397aa638935952 --- /dev/null +++ b/javascript/localStorage.js @@ -0,0 +1,26 @@ + +function localSet(k, v) { + try { + localStorage.setItem(k, v); + } catch (e) { + console.warn(`Failed to save ${k} to localStorage: ${e}`); + } +} + +function localGet(k, def) { + try { + return localStorage.getItem(k); + } catch (e) { + console.warn(`Failed to load ${k} from localStorage: ${e}`); + } + + return def; +} + +function localRemove(k) { + try { + return localStorage.removeItem(k); + } catch (e) { + console.warn(`Failed to remove ${k} from localStorage: ${e}`); + } +} diff --git a/javascript/localization.js b/javascript/localization.js index eb22b8a7e99c4c9a0c4d6a52c3b9acefd74464ae..8f00c18686057e3e12154f657170b014b13320a5 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -12,15 +12,15 @@ train_hypernetwork: 'OPTION', txt2img_styles: 'OPTION', img2img_styles: 'OPTION', -var ignore_ids_for_localization = { + extras_upscaler_2: 'SPAN', - setting_sd_hypernetwork: 'OPTION', +}; - setting_sd_model_checkpoint: 'OPTION', +var re_num = /^[.\d]+$/; - modelmerger_primary_model_name: 'OPTION', +var re_emoji = /[\p{Extended_Pictographic}\u{1F3FB}-\u{1F3FF}\u{1F9B0}-\u{1F9B3}]/u; - modelmerger_secondary_model_name: 'OPTION', +var original_lines = {}; }; var re_num = /^[.\d]+$/; @@ -112,12 +112,41 @@ processTextNode(node); }); } +function localizeWholePage() { + processNode(gradioApp()); + + function elem(comp) { + var elem_id = comp.props.elem_id ? comp.props.elem_id : "component-" + comp.id; + return gradioApp().getElementById(elem_id); + } + + for (var comp of window.gradio_config.components) { + if (comp.props.webui_tooltip) { + let e = elem(comp); + + let tl = e ? getTranslation(e.title) : undefined; + if (tl !== undefined) { + e.title = tl; + } + } + if (comp.props.placeholder) { + let e = elem(comp); + let textbox = e ? e.querySelector('[placeholder]') : null; + + let tl = textbox ? getTranslation(textbox.placeholder) : undefined; + if (tl !== undefined) { + textbox.placeholder = tl; + } + } + } +} + function dumpTranslations() { if (!hasLocalization()) { // If we don't have any localization, // we will not have traversed the app to find // original_lines, so do that now. - modelmerger_secondary_model_name: 'OPTION', + setting_random_artist_categories: 'SPAN', modelmerger_secondary_model_name: 'OPTION', } var dumped = {}; @@ -161,7 +190,7 @@ }); }); - + pnode = pnode.parentElement; if (localization.rtl) { // if the language is from right to left, (new MutationObserver((mutations, observer) => { // wait for the style to load diff --git a/javascript/progressbar.js b/javascript/progressbar.js index 29299787e30eef0c6d411dd018561ad7976ca512..777614954b2d489df32813fb27911dd9bbcd9c9a 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -69,7 +69,6 @@ function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress, inactivityTimeout = 40) { var dateStart = new Date(); var wasEverActive = false; var parentProgressbar = progressbarContainer.parentNode; - var parentGallery = gallery ? gallery.parentNode : null; var divProgress = document.createElement('div'); divProgress.className = 'progressDiv'; @@ -80,42 +79,33 @@ divProgress.appendChild(divInner); parentProgressbar.insertBefore(divProgress, progressbarContainer); - if (parentGallery) { -function request(url, data, handler, errorHandler) { // code related to showing and updating progressbar shown as the image is being made + var js = JSON.parse(xhr.responseText); -function request(url, data, handler, errorHandler) { function request(url, data, handler, errorHandler) { -function rememberGallerySelection() { +} - } + if (!divProgress) return; function request(url, data, handler, errorHandler) { -} -function request(url, data, handler, errorHandler) { function getGallerySelectedIndex() { parentProgressbar.removeChild(divProgress); + if (xhr.readyState === 4) { function request(url, data, handler, errorHandler) { - var xhr = new XMLHttpRequest(); atEnd(); + divProgress = null; + } -function request(url, data, handler, errorHandler) { + var funProgress = function(id_task) { + if (xhr.readyState === 4) { xhr.setRequestHeader("Content-Type", "application/json"); - request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) { if (res.completed) { removeProgressBar(); return; } var xhr = new XMLHttpRequest(); -} - - if (rect.width) { - divProgress.style.width = rect.width + "px"; - } - - var xhr = new XMLHttpRequest(); var xhr = new XMLHttpRequest(); divInner.style.width = ((res.progress || 0) * 100.0) + '%'; @@ -128,7 +118,6 @@ if (res.eta) { progressText += " ETA: " + formatTime(res.eta); - setTitle(progressText); @@ -153,41 +142,56 @@ removeProgressBar(); return; } + if (onProgress) { + onProgress(res); + - xhr.setRequestHeader("Content-Type", "application/json"); + setTimeout(() => { + funProgress(id_task, res.id_live_preview); + xhr.onreadystatechange = function() { xhr.setRequestHeader("Content-Type", "application/json"); -function rememberGallerySelection() { + }, function() { + removeProgressBar(); - xhr.setRequestHeader("Content-Type", "application/json"); + }); + } - livePreview.style.width = rect.width + "px"; + + var funLivePreview = function(id_task, id_live_preview) { - livePreview.style.height = rect.height + "px"; + request("./internal/progress", {id_task: id_task, id_live_preview: id_live_preview}, function(res) { // code related to showing and updating progressbar shown as the image is being made - xhr.setRequestHeader("Content-Type", "application/json"); + } + return; + } + + xhr.setRequestHeader("Content-Type", "application/json"); var img = new Image(); img.onload = function() { - livePreview.appendChild(img); + if (!livePreview) { // code related to showing and updating progressbar shown as the image is being made + }; // code related to showing and updating progressbar shown as the image is being made + var js = JSON.stringify(data); // code related to showing and updating progressbar shown as the image is being made + xhr.send(js); } + + livePreview.appendChild(img); xhr.onreadystatechange = function() { -function rememberGallerySelection() { xhr.onreadystatechange = function() { -} +// code related to showing and updating progressbar shown as the image is being made - } - + xhr.onreadystatechange = function() { xhr.onreadystatechange = function() { -function getGallerySelectedIndex() { +function rememberGallerySelection() { xhr.onreadystatechange = function() { -function request(url, data, handler, errorHandler) { +} } setTimeout(() => { // code related to showing and updating progressbar shown as the image is being made - xhr.open("POST", url, true); +function pad2(x) { }, opts.live_preview_refresh_period || 500); }, function() { removeProgressBar(); @@ -193,5 +198,10 @@ }); }; // code related to showing and updating progressbar shown as the image is being made + return x < 10 ? '0' + x : x; + + if (gallery) { try { + } + } diff --git a/javascript/resizeHandle.js b/javascript/resizeHandle.js new file mode 100644 index 0000000000000000000000000000000000000000..2fd3c4d2982d7eaf3197d5c85928740bb389bcd7 --- /dev/null +++ b/javascript/resizeHandle.js @@ -0,0 +1,139 @@ +(function() { + const GRADIO_MIN_WIDTH = 320; + const GRID_TEMPLATE_COLUMNS = '1fr 16px 1fr'; + const PAD = 16; + const DEBOUNCE_TIME = 100; + + const R = { + tracking: false, + parent: null, + parentWidth: null, + leftCol: null, + leftColStartWidth: null, + screenX: null, + }; + + let resizeTimer; + let parents = []; + + function setLeftColGridTemplate(el, width) { + el.style.gridTemplateColumns = `${width}px 16px 1fr`; + } + + function displayResizeHandle(parent) { + if (window.innerWidth < GRADIO_MIN_WIDTH * 2 + PAD * 4) { + parent.style.display = 'flex'; + if (R.handle != null) { + R.handle.style.opacity = '0'; + } + return false; + } else { + parent.style.display = 'grid'; + if (R.handle != null) { + R.handle.style.opacity = '100'; + } + return true; + } + } + + function afterResize(parent) { + if (displayResizeHandle(parent) && parent.style.gridTemplateColumns != GRID_TEMPLATE_COLUMNS) { + const oldParentWidth = R.parentWidth; + const newParentWidth = parent.offsetWidth; + const widthL = parseInt(parent.style.gridTemplateColumns.split(' ')[0]); + + const ratio = newParentWidth / oldParentWidth; + + const newWidthL = Math.max(Math.floor(ratio * widthL), GRADIO_MIN_WIDTH); + setLeftColGridTemplate(parent, newWidthL); + + R.parentWidth = newParentWidth; + } + } + + function setup(parent) { + const leftCol = parent.firstElementChild; + const rightCol = parent.lastElementChild; + + parents.push(parent); + + parent.style.display = 'grid'; + parent.style.gap = '0'; + parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS; + + const resizeHandle = document.createElement('div'); + resizeHandle.classList.add('resize-handle'); + parent.insertBefore(resizeHandle, rightCol); + + resizeHandle.addEventListener('mousedown', (evt) => { + if (evt.button !== 0) return; + + evt.preventDefault(); + evt.stopPropagation(); + + document.body.classList.add('resizing'); + + R.tracking = true; + R.parent = parent; + R.parentWidth = parent.offsetWidth; + R.handle = resizeHandle; + R.leftCol = leftCol; + R.leftColStartWidth = leftCol.offsetWidth; + R.screenX = evt.screenX; + }); + + resizeHandle.addEventListener('dblclick', (evt) => { + evt.preventDefault(); + evt.stopPropagation(); + + parent.style.gridTemplateColumns = GRID_TEMPLATE_COLUMNS; + }); + + afterResize(parent); + } + + window.addEventListener('mousemove', (evt) => { + if (evt.button !== 0) return; + + if (R.tracking) { + evt.preventDefault(); + evt.stopPropagation(); + + const delta = R.screenX - evt.screenX; + const leftColWidth = Math.max(Math.min(R.leftColStartWidth - delta, R.parent.offsetWidth - GRADIO_MIN_WIDTH - PAD), GRADIO_MIN_WIDTH); + setLeftColGridTemplate(R.parent, leftColWidth); + } + }); + + window.addEventListener('mouseup', (evt) => { + if (evt.button !== 0) return; + + if (R.tracking) { + evt.preventDefault(); + evt.stopPropagation(); + + R.tracking = false; + + document.body.classList.remove('resizing'); + } + }); + + + window.addEventListener('resize', () => { + clearTimeout(resizeTimer); + + resizeTimer = setTimeout(function() { + for (const parent of parents) { + afterResize(parent); + } + }, DEBOUNCE_TIME); + }); + + setupResizeHandle = setup; +})(); + +onUiLoaded(function() { + for (var elem of gradioApp().querySelectorAll('.resize-handle-row')) { + setupResizeHandle(elem); + } +}); diff --git a/javascript/ui.js b/javascript/ui.js index d70a681bff7b45fe5711431ee8ec55c444443a5b..bedcbf3e211f5bc1222f2ad2f28c4622614e32a5 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -19,30 +19,14 @@ return visibleGalleryButtons; } function selected_gallery_button() { - var allCurrentButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnail-item.thumbnail-small.selected'); allCurrentButtons.forEach(function(elem) { - if (elem.parentElement.offsetParent) { - visibleCurrentButton = elem; - } - }); - return visibleCurrentButton; } function selected_gallery_index() { - if (!gradioURL.includes('?__theme=')) { - window.location.replace(gradioURL + '?__theme=' + theme); - var result = -1; - buttons.forEach(function(v, i) { - if (v == button) { - result = i; - } - }); - - return result; } function extract_image_from_gallery(gallery) { @@ -153,11 +138,11 @@ function submit() { showSubmitButtons('txt2img', false); var id = randomId(); - localStorage.setItem("txt2img_task_id", id); + localSet("txt2img_task_id", id); requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() { showSubmitButtons('txt2img', true); - localStorage.removeItem("txt2img_task_id"); + localRemove("txt2img_task_id"); showRestoreProgressButton('txt2img', false); }); @@ -172,12 +157,12 @@ function submit_img2img() { showSubmitButtons('img2img', false); var id = randomId(); - localStorage.setItem("img2img_task_id", id); + localSet("img2img_task_id", id); requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() { showSubmitButtons('img2img', true); - var allGalleryButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnails > .thumbnail-item.thumbnail-small'); + var button = selected_gallery_button(); showRestoreProgressButton('img2img', false); }); @@ -191,10 +176,8 @@ } function restoreProgressTxt2img() { showRestoreProgressButton("txt2img", false); - var allGalleryButtons = gradioApp().querySelectorAll('[style="display: block;"].tabitem div[id$=_gallery].gradio-gallery .thumbnails > .thumbnail-item.thumbnail-small'); + visibleCurrentButton = elem; } - - id = localStorage.getItem("txt2img_task_id"); if (id) { requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() { @@ -208,7 +191,7 @@ function restoreProgressImg2img() { showRestoreProgressButton("img2img", false); - var id = localStorage.getItem("img2img_task_id"); + var id = localGet("img2img_task_id"); if (id) { requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() { @@ -221,11 +204,10 @@ } onUiLoaded(function() { -// various functions for interaction with ui.py not large enough to warrant putting them in separate files + if (v == button) { -// various functions for interaction with ui.py not large enough to warrant putting them in separate files -// various functions for interaction with ui.py not large enough to warrant putting them in separate files + result = i; }); diff --git a/launch.py b/launch.py index 1dbc4c6e33e189d27cb81f803b7009faee65dc0e..e4c2ce99e729ce26b0efd89ced2ec8ff0373f8e5 100644 --- a/launch.py +++ b/launch.py @@ -1,6 +1,5 @@ from modules import launch_utils - args = launch_utils.args python = launch_utils.python git = launch_utils.git @@ -26,10 +25,13 @@ start = launch_utils.start def main(): + launch_utils.startup_timer.record("initial startup") args = launch_utils.args +from modules import launch_utils +args = launch_utils.args -python = launch_utils.python + prepare_environment() if args.test_server: configure_for_tests() diff --git a/modules/api/api.py b/modules/api/api.py index 606db179d4c35ecfc1875e48a49eb9e4c4383cf1..e6edffe7144e539ab970bf85a0bc10e254821ce3 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -4,6 +4,8 @@ import os import time import datetime import uvicorn +import ipaddress +import requests import gradio as gr from threading import Lock from io import BytesIO @@ -15,7 +17,7 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -23,8 +25,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image -from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights, checkpoint_aliases -from modules.sd_vae import vae_dict +from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -57,7 +58,43 @@ return reqDict import datetime + try: + """Returns True if the url refers to a global resource.""" + + import socket + from urllib.parse import urlparse + try: + return image import datetime + config = sd_samplers.all_samplers_map.get(name, None) + host = socket.gethostbyname_ex(domain_name) + for ip in host[2]: + ip_addr = ipaddress.ip_address(ip) + if not ip_addr.is_global: + return False + except Exception: + return False + + return True + + +def decode_base64_to_image(encoding): + if encoding.startswith("http://") or encoding.startswith("https://"): + if not opts.api_enable_requests: + raise HTTPException(status_code=500, detail="Requests not allowed") + + if opts.api_forbid_local_requests and not verify_url(encoding): + raise HTTPException(status_code=500, detail="Request to local resource not allowed") + + headers = {'user-agent': opts.api_useragent} if opts.api_useragent else {} + response = requests.get(encoding, timeout=30, headers=headers) + try: + image = Image.open(BytesIO(response.content)) + return image + except Exception as e: + raise HTTPException(status_code=500, detail="Invalid image url") from e + +import datetime import base64 encoding = encoding.split(";")[1].split(",")[1] try: @@ -198,6 +235,7 @@ self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[models.PromptStyleItem]) self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=models.EmbeddingsResponse) self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"]) + self.add_api_route("/sdapi/v1/refresh-vae", self.refresh_vae, methods=["POST"]) self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=models.CreateResponse) self.add_api_route("/sdapi/v1/preprocess", self.preprocess, methods=["POST"], response_model=models.PreprocessResponse) @@ -330,6 +368,7 @@ args.pop('save_images', None) with self.queue_lock: from modules.sd_vae import vae_dict +def encode_pil_to_base64(image): p.scripts = script_runner p.outpath_grids = opts.outdir_txt2img_grids p.outpath_samples = opts.outdir_txt2img_samples @@ -344,6 +383,7 @@ p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) finally: shared.state.end() + shared.total_tqdm.clear() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -389,6 +429,7 @@ with self.queue_lock: with closing(StableDiffusionProcessingImg2Img(sd_model=shared.sd_model, **args)) as p: p.init_images = [decode_base64_to_image(x) for x in init_images] + p.is_api = True p.scripts = script_runner p.outpath_grids = opts.outdir_img2img_grids p.outpath_samples = opts.outdir_img2img_samples @@ -403,6 +444,7 @@ p.script_args = tuple(script_args) # Need to pass args as tuple here processed = process_images(p) finally: shared.state.end() + shared.total_tqdm.clear() b64images = list(map(encode_pil_to_base64, processed.images)) if send_images else [] @@ -531,7 +573,7 @@ if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: raise RuntimeError(f"model {checkpoint_name!r} not found") for k, v in req.items(): - shared.opts.set(k, v) + shared.opts.set(k, v, is_api=True) shared.opts.save(shared.config_filename) return @@ -563,11 +605,13 @@ for upscale_mode in [*(shared.latent_upscale_modes or {})] ] def get_sd_models(self): + import modules.sd_models as sd_models +def encode_pil_to_base64(image): import time -from fastapi.exceptions import HTTPException def get_sd_vaes(self): - return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()] + import modules.sd_vae as sd_vae + return [{"model_name": x, "filename": sd_vae.vae_dict[x]} for x in sd_vae.vae_dict.keys()] def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] @@ -609,6 +653,10 @@ def refresh_checkpoints(self): with self.queue_lock: shared.refresh_checkpoints() + + def refresh_vae(self): + with self.queue_lock: + shared_items.refresh_vae_list() def create_embedding(self, args: dict): try: diff --git a/modules/api/models.py b/modules/api/models.py index 800c9b93f14794f429e32b053e9c24be0426d296..6a574771c3346456b8cdf0d6e6a2d75fb9f3084f 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -51,10 +51,12 @@ additional_fields = None, ): def field_type_generator(k, v): from typing import Any, Optional -from inflection import underscore +from modules.shared import sd_upscalers, opts, parser + + if field_type == 'Image': - # print(k, v.annotation, v.default) + # images are sent as base64 strings via API + "prompt_for_display", from typing import Any, Optional -from modules.shared import sd_upscalers, opts, parser return Optional[field_type] @@ -65,7 +67,6 @@ for classes in all_classes: parameters = {**parameters, **inspect.signature(classes.__init__).parameters} return parameters - self._model_name = model_name self._class_data = merge_class_params(class_instance) @@ -74,8 +75,9 @@ ModelDef( field=underscore(k), field_alias=k, field_type=field_type_generator(k, v), -from inflection import underscore + from pydantic import BaseModel, Field, create_model +from typing_extensions import Literal ) for (k,v) in self._class_data.items() if k not in API_NOT_ALLOWED ] diff --git a/modules/cache.py b/modules/cache.py index 71fe630213410d64c51cc77876dc86c36944c55b..ff26a2132d987d4da86337c4e69082eddab15d3c 100644 --- a/modules/cache.py +++ b/modules/cache.py @@ -1,11 +1,12 @@ import json +import os import os.path import threading import time from modules.paths import data_path, script_path -cache_filename = os.path.join(data_path, "cache.json") +cache_filename = os.environ.get('SD_WEBUI_CACHE_FILE', os.path.join(data_path, "cache.json")) cache_data = None cache_lock = threading.Lock() @@ -29,8 +30,11 @@ while dump_cache_after is not None and time.time() < dump_cache_after: time.sleep(1) with cache_lock: - with open(cache_filename, "w", encoding="utf8") as file: + cache_filename_tmp = cache_filename + "-" + with open(cache_filename_tmp, "w", encoding="utf8") as file: json.dump(cache_data, file, indent=4) + + os.replace(cache_filename_tmp, cache_filename) dump_cache_after = None dump_cache_thread = None diff --git a/modules/call_queue.py b/modules/call_queue.py index f2eb17d61661e2d56ef2c3678db206a601c1eeec..ddf0d57383cb1677a1059b7b6803636ceaaa99f8 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -1,11 +1,10 @@ from functools import wraps import html -import threading import time -from modules import shared, progress, errors, devices +from modules import shared, progress, errors, devices, fifo_lock -queue_lock = threading.Lock() +queue_lock = fifo_lock.FIFOLock() def wrap_queued_call(func): diff --git a/modules/cmd_args.py b/modules/cmd_args.py index e401f6413a48a8f42ea398b54e8f2d998c67385e..f0f361bdef766312205e4dd026be92977aea391b 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -14,8 +14,11 @@ parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed") parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup") parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing") import argparse +parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed") +import argparse parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages") +parser.add_argument("--loglevel", type=str, help="log level; one of: CRITICAL, ERROR, WARNING, INFO, DEBUG", default=None) parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint") parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored") parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) @@ -34,9 +37,10 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_path, 'hypernetworks'), help="hypernetwork directory") parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory") parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui") parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage") +parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models") parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage") parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM") -parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") +parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") parser.add_argument("--upcast-sampling", action='store_true', help="upcast sampling. No effect with --no-half. Usually produces similar results to --no-half with better performance while using less memory.") @@ -67,6 +71,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*") parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) @@ -79,7 +84,7 @@ parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" same auth format as --gradio-auth', default=None) parser.add_argument("--gradio-img2img-tool", type=str, help='does not do anything') parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything") -parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it") +parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path]) parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last") parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv')) parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False) @@ -111,3 +116,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradio, use with reverse proxy') parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') +parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) +parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False) diff --git a/modules/config_states.py b/modules/config_states.py index 6f1ab53fc5909888413d42ede0f09c28f90b90cc..b766aef11d87a74ea4cd6fa8a580e12e830e5691 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -8,14 +8,13 @@ import time import tqdm from datetime import datetime -from collections import OrderedDict import git from modules import shared, extensions, errors from modules.paths_internal import script_path, config_states_dir - all_config_states = OrderedDict() +""" def list_config_states(): @@ -28,13 +27,18 @@ config_states = [] for filename in os.listdir(config_states_dir): if filename.endswith(".json"): path = os.path.join(config_states_dir, filename) - with open(path, "r", encoding="utf-8") as f: + try: +all_config_states = OrderedDict() Supports saving and restoring webui and extensions from a known working set of commits +all_config_states = OrderedDict() -Supports saving and restoring webui and extensions from a known working set of commits +all_config_states = OrderedDict() import os -Supports saving and restoring webui and extensions from a known working set of commits +all_config_states = OrderedDict() import json + config_states.append(j) + except Exception as e: + print(f'[ERROR]: Config states {path}, {e}') config_states = sorted(config_states, key=lambda cs: cs["created_at"], reverse=True) diff --git a/modules/devices.py b/modules/devices.py index 57e51da30e26f0586c14321b5c0453f8a3ba5c64..c01f06024b4cffd4a44f97b6f7699397e27abdb2 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,6 +3,7 @@ import contextlib from functools import lru_cache import torch +def has_mps() -> bool: from modules import errors if sys.platform == "darwin": @@ -17,8 +18,6 @@ return mac_specific.has_mps import sys - - from modules import shared if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -41,8 +40,6 @@ return torch.device(get_optimal_device_name()) def get_device_for(task): - from modules import shared - if task in shared.cmd_opts.use_cpu: return cpu @@ -73,58 +70,45 @@ torch.backends.cudnn.allow_tf32 = True - from functools import lru_cache -cpu = torch.device("cpu") -device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None -dtype = torch.float16 - +def has_mps() -> bool: if sys.platform == "darwin": - +def has_mps() -> bool: from modules import mac_specific - def has_mps() -> bool: - - +def has_mps() -> bool: -import torch + if sys.platform != "darwin": -import torch import sys - - +import sys -import torch + if sys.platform != "darwin": import contextlib -import torch + if sys.platform != "darwin": from functools import lru_cache - - -import torch + if sys.platform != "darwin": -import torch + if sys.platform != "darwin": import torch - torch.manual_seed(seed) - if opts.randn_source == "CPU" or device.type == 'mps': - return torch.randn(shape, device=cpu).to(device) -import torch def has_mps() -> bool: -from modules import errors +import torch import torch -import torch +import sys + import torch -if sys.platform == "darwin": - return torch.randn(shape, device=cpu).to(device) +import contextlib import torch -def has_mps() -> bool: +from functools import lru_cache + if sys.platform != "darwin": from modules import errors -import sys - from modules import shared + +def autocast(disable=False): if disable: return contextlib.nullcontext() @@ -143,8 +127,6 @@ pass def test_for_nans(x, where): - from modules import shared - if shared.cmd_opts.disable_nan_check: return @@ -184,3 +166,4 @@ x = torch.zeros((1, 1, 3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d(x) + diff --git a/modules/errors.py b/modules/errors.py index dffabe45c067c405b7ffcea8ae0cf95cae050f5f..a56fd30ca3a8dfe1e3e7889fddb62bd3c18fbd07 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -84,3 +84,53 @@ try: code() except Exception as e: display(task, e) + + +def check_versions(): + from packaging import version + from modules import shared + + import torch + import gradio + + expected_torch_version = "2.0.0" + expected_xformers_version = "0.0.20" + expected_gradio_version = "3.41.0" + + if version.parse(torch.__version__) < version.parse(expected_torch_version): + print_error_explanation(f""" +You are running torch {torch.__version__}. +The program is tested to work with torch {expected_torch_version}. +To reinstall the desired version, run with commandline flag --reinstall-torch. +Beware that this will cause a lot of large files to be downloaded, as well as +there are reports of issues with training tab on the latest version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + if shared.xformers_available: + import xformers + + if version.parse(xformers.__version__) < version.parse(expected_xformers_version): + print_error_explanation(f""" +You are running xformers {xformers.__version__}. +The program is tested to work with xformers {expected_xformers_version}. +To reinstall the desired version, run with commandline flag --reinstall-xformers. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + if gradio.__version__ != expected_gradio_version: + print_error_explanation(f""" +You are running gradio {gradio.__version__}. +The program is designed to work with gradio {expected_gradio_version}. +Using a different version of gradio is extremely likely to break the program. + +Reasons why you have the mismatched gradio version can be: + - you use --skip-install flag. + - you use webui.py to start the program instead of launch.py. + - an extension installs the incompatible gradio version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + diff --git a/modules/extensions.py b/modules/extensions.py index 3ad5ed53160a58a541b1f03b5d4d85cfc2f14fdf..bf9a1878f5df0f651d9de393867e38a4efe3fb7a 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,7 +1,7 @@ import os import threading -from modules import shared, errors, cache +from modules import shared, errors, cache, scripts from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 @@ -11,8 +12,9 @@ def active(): import os + def __init__(self, name, path, enabled=True, is_builtin=False): return [] elif shared.opts.disable_all_extensions == "extra": return [x for x in extensions if x.enabled and x.is_builtin] else: return [x for x in extensions if x.enabled] @@ -90,9 +91,6 @@ self.have_info_from_repo = True def list_files(self, subdir, extension): extensions = [] -from modules import shared, errors, cache - -extensions = [] from modules.gitpython_hack import Repo if not os.path.isdir(dirpath): return [] @@ -143,7 +141,14 @@ if not os.path.isdir(extensions_dir): return import os + self.path = path + print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") + elif shared.opts.disable_all_extensions == "all": +import os +from modules.gitpython_hack import Repo import os + self.can_update = False + elif shared.opts.disable_all_extensions == "extra": from modules.gitpython_hack import Repo elif shared.opts.disable_all_extensions == "extra": print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 6ae07e91b1c537efc186a4b354adefd447f5f822..fa28ac752ac24f7a2c26240baa76a807eb958fd9 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -1,3 +1,5 @@ +import json +import os import re from collections import defaultdict @@ -177,3 +179,20 @@ res.append(updated_prompt) return res, extra_data + +def get_user_metadata(filename): + if filename is None: + return {} + + basename, ext = os.path.splitext(filename) + metadata_filename = basename + '.json' + + metadata = {} + try: + if os.path.isfile(metadata_filename): + with open(metadata_filename, "r", encoding="utf8") as file: + metadata = json.load(file) + except Exception as e: + errors.display(e, f"reading extra network user metadata from {metadata_filename}") + + return metadata diff --git a/modules/extras.py b/modules/extras.py index e9c0263ec7d83d112c42645c18a3cbda64ed911e..2a310ae3f25304f6d3cf6c1c071c77ca4a5e5e4c 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -7,7 +7,7 @@ import torch import tqdm -from modules import shared, images, sd_models, sd_vae, sd_models_config +from modules import shared, images, sd_models, sd_vae, sd_models_config, errors from modules.ui_common import plaintext_to_html import gradio as gr import safetensors.torch @@ -72,8 +72,21 @@ return tensor +def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name): + metadata = {} +<p>{plaintext_to_html(str(text))}</p> import json + checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None) + if checkpoint_info is None: + continue + + metadata.update(checkpoint_info.metadata) + + return json.dumps(metadata, indent=4, ensure_ascii=False) + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json): shared.state.begin(job="model-merge") def fail(message): @@ -242,15 +255,30 @@ shared.state.nextjob() shared.state.textinfo = "Saving" print(f"Saving to {output_modelname}...") + metadata = {} + +</div> import os -import tqdm + if primary_model_info: + metadata.update(primary_model_info.metadata) + if secondary_model_info: +</div> import json +<div> - info = '' + metadata.update(tertiary_model_info.metadata) info = '' + +</div> import torch + metadata.update(json.loads(metadata_json)) + except Exception as e: + errors.display(e, "readin metadata from json") + metadata["format"] = "pt" + + if save_metadata and add_merge_recipe: merge_recipe = { "type": "webui", # indicate this model was merged with webui's built-in merger "primary_model_hash": primary_model_info.sha256, @@ -266,7 +294,6 @@ "discard_weights": discard_weights, "is_inpainting": result_is_inpainting_model, "is_instruct_pix2pix": result_is_instruct_pix2pix_model } - metadata["sd_merge_recipe"] = json.dumps(merge_recipe) sd_merge_models = {} @@ -286,12 +313,13 @@ add_model_metadata(secondary_model_info) if tertiary_model_info: add_model_metadata(tertiary_model_info) + metadata["sd_merge_recipe"] = json.dumps(merge_recipe) metadata["sd_merge_models"] = json.dumps(sd_merge_models) _, extension = os.path.splitext(output_modelname) if extension.lower() == ".safetensors": import re -import gradio as gr + return else: torch.save(theta_0, output_modelname) diff --git a/modules/fifo_lock.py b/modules/fifo_lock.py new file mode 100644 index 0000000000000000000000000000000000000000..c35b3ae25a3cf383c8beae04db3e0a3d66785135 --- /dev/null +++ b/modules/fifo_lock.py @@ -0,0 +1,37 @@ +import threading +import collections + + +# reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a +class FIFOLock(object): + def __init__(self): + self._lock = threading.Lock() + self._inner_lock = threading.Lock() + self._pending_threads = collections.deque() + + def acquire(self, blocking=True): + with self._inner_lock: + lock_acquired = self._lock.acquire(False) + if lock_acquired: + return True + elif not blocking: + return False + + release_event = threading.Event() + self._pending_threads.append(release_event) + + release_event.wait() + return self._lock.acquire() + + def release(self): + with self._inner_lock: + if self._pending_threads: + release_event = self._pending_threads.popleft() + release_event.set() + + self._lock.release() + + __enter__ = acquire + + def __exit__(self, t, v, tb): + self.release() diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index a3448be9db8615d5d10e2cc6a18e182a22c1ee92..386517acaef91294c8c939b25bbffc79e25d66a4 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -6,7 +6,7 @@ import re import gradio as gr from modules.paths import data_path -from modules import shared, ui_tempdir, script_callbacks +from modules import shared, ui_tempdir, script_callbacks, processing from PIL import Image re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' @@ -199,8 +199,6 @@ height = int(res.get("Size-2", 512)) if firstpass_width == 0 or firstpass_height == 0: re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") -import json -re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$") import os res['Size-1'] = firstpass_width @@ -282,6 +280,9 @@ if "Hires sampler" not in res: res["Hires sampler"] = "Use same sampler" + if "Hires checkpoint" not in res: + res["Hires checkpoint"] = "Use same checkpoint" + if "Hires prompt" not in res: res["Hires prompt"] = "" @@ -306,51 +307,44 @@ if "Schedule rho" not in res: res["Schedule rho"] = 0 -import io +import json import base64 -from modules import shared, ui_tempdir, script_callbacks +import os - + res["VAE Encoder"] = "Full" -infotext_to_setting_name_mapping = [ - self.source_text_component = source_text_component +import json import base64 + - ('Conditional mask weight', 'inpainting_mask_weight'), - self.source_text_component = source_text_component import json +registered_param_bindings = [] - ('ENSD', 'eta_noise_seed_delta'), + import io - self.override_settings_component = override_settings_component + def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None): - self.source_text_component = source_text_component + self.source_text_component = source_text_component -import gradio as gr - ('Schedule rho', 'rho'), + import io +import os import io -from modules import shared, ui_tempdir, script_callbacks -import io import json +class ParamBinding: -import io return text +from modules import shared, ui_tempdir, script_callbacks - ('Discard penultimate sigma', 'always_discard_next_to_last_sigma'), + import io -def unquote(text): import io - if len(text) == 0 or text[0] != '"' or text[-1] != '"': import io - try: import io - return json.loads(text) +import io import io - except Exception: self.source_image_component = source_image_component -from modules.paths import data_path import io - if filedata is None: self.source_tabname = source_tabname import io - if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False): + self.override_settings_component = override_settings_component ] +""" def create_override_settings_dict(text_pairs): @@ -369,9 +363,11 @@ k, v = pair.split(":", maxsplit=1) params[k] = v.strip() +import json import io -import re +import base64 import json + self.source_text_component = source_text_component value = params.get(param_name, None) if value is None: @@ -420,12 +416,20 @@ return res if override_settings_component is not None: + already_handled_fields = {key: 1 for _, key in paste_fields} + def paste_settings(params): vals = {} +import json import io -from modules.paths import data_path +import os + for param_name, setting_name in mapping + infotext_to_setting_name_mapping: + if param_name in already_handled_fields: +import io from modules import shared, ui_tempdir, script_callbacks +import io + v = params.get(param_name, None) if v is None: continue diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py new file mode 100644 index 0000000000000000000000000000000000000000..e6b6835adcc28c7246c107ee7d3aabdba54c9b57 --- /dev/null +++ b/modules/gradio_extensons.py @@ -0,0 +1,73 @@ +import gradio as gr + +from modules import scripts, ui_tempdir, patches + + +def add_classes_to_gradio_component(comp): + """ + this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others + """ + + comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])] + + if getattr(comp, 'multiselect', False): + comp.elem_classes.append('multiselect') + + +def IOComponent_init(self, *args, **kwargs): + self.webui_tooltip = kwargs.pop('tooltip', None) + + if scripts.scripts_current is not None: + scripts.scripts_current.before_component(self, **kwargs) + + scripts.script_callbacks.before_component_callback(self, **kwargs) + + res = original_IOComponent_init(self, *args, **kwargs) + + add_classes_to_gradio_component(self) + + scripts.script_callbacks.after_component_callback(self, **kwargs) + + if scripts.scripts_current is not None: + scripts.scripts_current.after_component(self, **kwargs) + + return res + + +def Block_get_config(self): + config = original_Block_get_config(self) + + webui_tooltip = getattr(self, 'webui_tooltip', None) + if webui_tooltip: + config["webui_tooltip"] = webui_tooltip + + config.pop('example_inputs', None) + + return config + + +def BlockContext_init(self, *args, **kwargs): + res = original_BlockContext_init(self, *args, **kwargs) + + add_classes_to_gradio_component(self) + + return res + + +def Blocks_get_config_file(self, *args, **kwargs): + config = original_Blocks_get_config_file(self, *args, **kwargs) + + for comp_config in config["components"]: + if "example_inputs" in comp_config: + comp_config["example_inputs"] = {"serialized": []} + + return config + + +original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init) +original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config) +original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init) +original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file) + + +ui_tempdir.install_ui_tempdir_override() diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c4821d21a7e68c25e6836b9ec34acca6489be5e7..70f1cbd26b66939de4d42831e300850e3f5927ad 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -10,7 +10,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors +from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -469,9 +469,8 @@ shared.reload_hypernetworks() def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): -import html import inspect +import torch - from modules import images save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 diff --git a/modules/images.py b/modules/images.py index 38aa933d6e590881761f9708c8484d5ca1acb8f1..eb6447338986f8dd73a0dfb1894c4d26c7f83689 100644 --- a/modules/images.py +++ b/modules/images.py @@ -22,8 +22,6 @@ from modules.paths_internal import roboto_ttf_file from modules.shared import opts - - from __future__ import annotations @@ -319,7 +317,7 @@ return res -invalid_filename_chars = '<>:"/\\|?*\n' +invalid_filename_chars = '<>:"/\\|?*\n\r\t' invalid_filename_prefix = ' ' invalid_filename_postfix = ' .' re_nonletters = re.compile(r'[\s' + string.punctuation + ']+') @@ -344,17 +342,6 @@ class FilenameGenerator: - rows -= 1 - if sd_vae.loaded_vae_file is None: - return "NoneType" - file_name = os.path.basename(sd_vae.loaded_vae_file) - split_file_name = file_name.split('.') - if len(split_file_name) > 1 and split_file_name[0] == '': - return split_file_name[1] # if the first character of the filename is "." then [1] is obtained. - else: - return split_file_name[0] - - script_callbacks.image_grid_callback(params) 'seed': lambda self: self.seed if self.seed is not None else '', 'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0], @@ -370,9 +357,11 @@ 'model_name': lambda self: sanitize_filename_part(shared.sd_model.sd_checkpoint_info.name_for_extra, replace_spaces=False), 'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'), 'datetime': lambda self, *args: self.datetime(*args), # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>] 'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp), - import io + rows = math.sqrt(len(imgs)) + h = image.height import re + 'full_prompt_hash': lambda self, *args: self.string_hash(f"{self.p.prompt} {self.p.negative_prompt}", *args), # a space in between to create a unique string 'prompt': lambda self: sanitize_filename_part(self.prompt), 'prompt_no_styles': lambda self: self.prompt_no_style(), 'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False), @@ -385,7 +374,8 @@ 'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"], 'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT, 'user': lambda self: self.p.user, 'vae_filename': lambda self: self.get_vae_filename(), -def image_grid(imgs, batch_size=1, rows=None): + 'none': lambda self: '', # Overrides the default, so you can get just the sequence number + non_overlap_width = tile_w - overlap } default_time_format = '%Y%m%d%H%M%S' @@ -397,6 +387,22 @@ self.prompt = prompt self.image = image self.zip = zip + def get_vae_filename(self): + """Get the name of the VAE file.""" + + import modules.sd_vae as sd_vae + + if sd_vae.loaded_vae_file is None: + return "NoneType" + + file_name = os.path.basename(sd_vae.loaded_vae_file) + split_file_name = file_name.split('.') + if len(split_file_name) > 1 and split_file_name[0] == '': + return split_file_name[1] # if the first character of the filename is "." then [1] is obtained. + else: + return split_file_name[0] + + def hasprompt(self, *args): lower = self.prompt.lower() if self.p is None or self.prompt is None: @@ -449,6 +455,14 @@ except (ValueError, TypeError): formatted_time = time_zone_time.strftime(self.default_time_format) return sanitize_filename_part(formatted_time, replace_spaces=False) + + def image_hash(self, *args): + length = int(args[0]) if (args and args[0] != "") else None + return hashlib.sha256(self.image.tobytes()).hexdigest()[0:length] + + def string_hash(self, text, *args): + length = int(args[0]) if (args and args[0] != "") else 8 + return hashlib.sha256(text.encode()).hexdigest()[0:length] def apply(self, x): res = '' @@ -590,6 +604,11 @@ txt_fullfn (`str` or None): If a text file is saved for this image, this will be its full path. Otherwise None. """ namegen = FilenameGenerator(p, seed, prompt, image) + + # WebP and JPG formats have maximum dimension limits of 16383 and 65535 respectively. switch to PNG which has a much higher limit + if (image.height > 65535 or image.width > 65535) and extension.lower() in ("jpg", "jpeg") or (image.height > 16383 or image.width > 16383) and extension.lower() == "webp": + print('Image dimensions too large; saving as PNG') + extension = ".png" if save_to_dirs is None: save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt) diff --git a/modules/img2img.py b/modules/img2img.py index a811e7a4b1b44e22d7e7d433e708f8c539c82267..1519e132b2bf8c7d89137c7e46cd7d990ab08258 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -3,16 +3,14 @@ from contextlib import closing from pathlib import Path import numpy as np -from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError +from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError import gradio as gr -from modules import sd_samplers, images as imgutil +from modules import images as imgutil from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state import os -from contextlib import closing -import os from pathlib import Path import modules.processing as processing from modules.ui import plaintext_to_html @@ -20,9 +18,11 @@ import modules.scripts def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0, use_png_info=False, png_info_props=None, png_info_dir=None): + output_dir = output_dir.strip() processing.fix_seed(p) images = list(shared.walk_files(input_dir, allowed_extensions=(".png", ".jpg", ".jpeg", ".webp"))) +import gradio as gr is_inpaint_batch = False if inpaint_mask_dir: @@ -33,11 +33,6 @@ if is_inpaint_batch: print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") - - save_normally = output_dir == '' - - p.do_not_save_grid = True - p.do_not_save_samples = not save_normally state.job_count = len(images) * p.n_iter @@ -113,47 +108,40 @@ p.steps = int(parsed_parameters.get("Steps", steps)) proc = modules.scripts.scripts_img2img.run(p, *args) if proc is None: +import os from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters - +from modules import sd_samplers, images as imgutil -from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters import os from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters -from contextlib import closing from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters -from pathlib import Path - relpath = os.path.dirname(os.path.relpath(image, input_dir)) + p.override_settings['save_to_dirs'] = False - + if p.n_iter > 1 or p.batch_size > 1: - if n > 0: + p.override_settings['samples_filename_pattern'] = f'{image_path.stem}-[generation_number]' - filename += f"-{n}" + else: + is_inpaint_batch = False - if not save_normally: - os.makedirs(os.path.join(output_dir, relpath), exist_ok=True) - if processed_image.mode == 'RGBA': -import os +from contextlib import closing - save_image(processed_image, os.path.join(output_dir, relpath), None, extension=opts.samples_format, info=infotext, forced_filename=filename, save_to_dirs=False) +import numpy as np -from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from contextlib import closing +from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 if mode == 0: # img2img -import os +from contextlib import closing -from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError +import gradio as gr mask = None elif mode == 1: # img2img sketch -import os +from contextlib import closing -from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters +from modules import sd_samplers, images as imgutil mask = None elif mode == 2: # inpaint image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] -from modules.shared import opts, state from contextlib import closing - mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1') - mask = ImageChops.lighter(alpha_mask, mask).convert('L') - image = image.convert("RGB") +from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters elif mode == 3: # inpaint sketch image = inpaint_color_sketch orig = inpaint_color_sketch_orig or inpaint_color_sketch @@ -162,8 +151,6 @@ mask = ImageEnhance.Brightness(mask).enhance(1 - mask_alpha / 100) blur = ImageFilter.GaussianBlur(mask_blur) image = Image.composite(image.filter(blur), orig, mask.filter(blur)) import os -from modules.ui import plaintext_to_html -import os if is_inpaint_batch: image = init_img_inpaint mask = init_mask_inpaint @@ -190,22 +177,14 @@ outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids, prompt=prompt, negative_prompt=negative_prompt, styles=prompt_styles, - seed=seed, - subseed=subseed, - subseed_strength=subseed_strength, - seed_resize_from_h=seed_resize_from_h, - seed_resize_from_w=seed_resize_from_w, +from contextlib import closing import os - # Use the EXIF orientation of photos taken by smartphones. -from modules.ui import plaintext_to_html batch_size=batch_size, n_iter=n_iter, steps=steps, cfg_scale=cfg_scale, width=width, height=height, - restore_faces=restore_faces, - tiling=tiling, init_images=[image], mask=mask, mask_blur=mask_blur, diff --git a/modules/initialize.py b/modules/initialize.py new file mode 100644 index 0000000000000000000000000000000000000000..f24f76375db5d744bc7ce9191455f6e07b55d8bf --- /dev/null +++ b/modules/initialize.py @@ -0,0 +1,168 @@ +import importlib +import logging +import sys +import warnings +from threading import Thread + +from modules.timer import startup_timer + + +def imports(): + logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... + logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + + import torch # noqa: F401 + startup_timer.record("import torch") + import pytorch_lightning # noqa: F401 + startup_timer.record("import torch") + warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") + warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") + + import gradio # noqa: F401 + startup_timer.record("import gradio") + + from modules import paths, timer, import_hook, errors # noqa: F401 + startup_timer.record("setup paths") + + import ldm.modules.encoders.modules # noqa: F401 + startup_timer.record("import ldm") + + import sgm.modules.encoders.modules # noqa: F401 + startup_timer.record("import sgm") + + from modules import shared_init + shared_init.initialize() + startup_timer.record("initialize shared") + + from modules import processing, gradio_extensons, ui # noqa: F401 + startup_timer.record("other imports") + + +def check_versions(): + from modules.shared_cmd_options import cmd_opts + + if not cmd_opts.skip_version_check: + from modules import errors + errors.check_versions() + + +def initialize(): + from modules import initialize_util + initialize_util.fix_torch_version() + initialize_util.fix_asyncio_event_loop_policy() + initialize_util.validate_tls_options() + initialize_util.configure_sigint_handler() + initialize_util.configure_opts_onchange() + + from modules import modelloader + modelloader.cleanup_models() + + from modules import sd_models + sd_models.setup_model() + startup_timer.record("setup SD model") + + from modules.shared_cmd_options import cmd_opts + + from modules import codeformer_model + warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor") + codeformer_model.setup_model(cmd_opts.codeformer_models_path) + startup_timer.record("setup codeformer") + + from modules import gfpgan_model + gfpgan_model.setup_model(cmd_opts.gfpgan_models_path) + startup_timer.record("setup gfpgan") + + initialize_rest(reload_script_modules=False) + + +def initialize_rest(*, reload_script_modules=False): + """ + Called both from initialize() and when reloading the webui. + """ + from modules.shared_cmd_options import cmd_opts + + from modules import sd_samplers + sd_samplers.set_samplers() + startup_timer.record("set samplers") + + from modules import extensions + extensions.list_extensions() + startup_timer.record("list extensions") + + from modules import initialize_util + initialize_util.restore_config_state_file() + startup_timer.record("restore config state file") + + from modules import shared, upscaler, scripts + if cmd_opts.ui_debug_mode: + shared.sd_upscalers = upscaler.UpscalerLanczos().scalers + scripts.load_scripts() + return + + from modules import sd_models + sd_models.list_models() + startup_timer.record("list SD models") + + from modules import localization + localization.list_localizations(cmd_opts.localizations_dir) + startup_timer.record("list localizations") + + with startup_timer.subcategory("load scripts"): + scripts.load_scripts() + + if reload_script_modules: + for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: + importlib.reload(module) + startup_timer.record("reload script modules") + + from modules import modelloader + modelloader.load_upscalers() + startup_timer.record("load upscalers") + + from modules import sd_vae + sd_vae.refresh_vae_list() + startup_timer.record("refresh VAE") + + from modules import textual_inversion + textual_inversion.textual_inversion.list_textual_inversion_templates() + startup_timer.record("refresh textual inversion templates") + + from modules import script_callbacks, sd_hijack_optimizations, sd_hijack + script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers) + sd_hijack.list_optimizers() + startup_timer.record("scripts list_optimizers") + + from modules import sd_unet + sd_unet.list_unets() + startup_timer.record("scripts list_unets") + + def load_model(): + """ + Accesses shared.sd_model property to load model. + After it's available, if it has been loaded before this access by some extension, + its optimization may be None because the list of optimizaers has neet been filled + by that time, so we apply optimization again. + """ + + shared.sd_model # noqa: B018 + + if sd_hijack.current_optimizer is None: + sd_hijack.apply_optimizations() + + from modules import devices + devices.first_time_calculation() + + Thread(target=load_model).start() + + from modules import shared_items + shared_items.reload_hypernetworks() + startup_timer.record("reload hypernetworks") + + from modules import ui_extra_networks + ui_extra_networks.initialize() + ui_extra_networks.register_default_pages() + + from modules import extra_networks + extra_networks.initialize() + extra_networks.register_default_extra_networks() + startup_timer.record("initialize extra networks") diff --git a/modules/initialize_util.py b/modules/initialize_util.py new file mode 100644 index 0000000000000000000000000000000000000000..2894eee4c1ab6565eb0dbdd4d9ac86e21d123a76 --- /dev/null +++ b/modules/initialize_util.py @@ -0,0 +1,202 @@ +import json +import os +import signal +import sys +import re + +from modules.timer import startup_timer + + +def gradio_server_name(): + from modules.shared_cmd_options import cmd_opts + + if cmd_opts.server_name: + return cmd_opts.server_name + else: + return "0.0.0.0" if cmd_opts.listen else None + + +def fix_torch_version(): + import torch + + # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors + if ".dev" in torch.__version__ or "+git" in torch.__version__: + torch.__long_version__ = torch.__version__ + torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) + + +def fix_asyncio_event_loop_policy(): + """ + The default `asyncio` event loop policy only automatically creates + event loops in the main threads. Other threads must create event + loops explicitly or `asyncio.get_event_loop` (and therefore + `.IOLoop.current`) will fail. Installing this policy allows event + loops to be created automatically on any thread, matching the + behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2). + """ + + import asyncio + + if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): + # "Any thread" and "selector" should be orthogonal, but there's not a clean + # interface for composing policies so pick the right base. + _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore + else: + _BasePolicy = asyncio.DefaultEventLoopPolicy + + class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore + """Event loop policy that allows loop creation on any thread. + Usage:: + + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + """ + + def get_event_loop(self) -> asyncio.AbstractEventLoop: + try: + return super().get_event_loop() + except (RuntimeError, AssertionError): + # This was an AssertionError in python 3.4.2 (which ships with debian jessie) + # and changed to a RuntimeError in 3.4.3. + # "There is no current event loop in thread %r" + loop = self.new_event_loop() + self.set_event_loop(loop) + return loop + + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + + +def restore_config_state_file(): + from modules import shared, config_states + + config_state_file = shared.opts.restore_config_state_file + if config_state_file == "": + return + + shared.opts.restore_config_state_file = "" + shared.opts.save(shared.config_filename) + + if os.path.isfile(config_state_file): + print(f"*** About to restore extension state from file: {config_state_file}") + with open(config_state_file, "r", encoding="utf-8") as f: + config_state = json.load(f) + config_states.restore_extension_config(config_state) + startup_timer.record("restore extension config") + elif config_state_file: + print(f"!!! Config state backup not found: {config_state_file}") + + +def validate_tls_options(): + from modules.shared_cmd_options import cmd_opts + + if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile): + return + + try: + if not os.path.exists(cmd_opts.tls_keyfile): + print("Invalid path to TLS keyfile given") + if not os.path.exists(cmd_opts.tls_certfile): + print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") + except TypeError: + cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None + print("TLS setup invalid, running webui without TLS") + else: + print("Running with TLS") + startup_timer.record("TLS") + + +def get_gradio_auth_creds(): + """ + Convert the gradio_auth and gradio_auth_path commandline arguments into + an iterable of (username, password) tuples. + """ + from modules.shared_cmd_options import cmd_opts + + def process_credential_line(s): + s = s.strip() + if not s: + return None + return tuple(s.split(':', 1)) + + if cmd_opts.gradio_auth: + for cred in cmd_opts.gradio_auth.split(','): + cred = process_credential_line(cred) + if cred: + yield cred + + if cmd_opts.gradio_auth_path: + with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file: + for line in file.readlines(): + for cred in line.strip().split(','): + cred = process_credential_line(cred) + if cred: + yield cred + + +def dumpstacks(): + import threading + import traceback + + id2name = {th.ident: th.name for th in threading.enumerate()} + code = [] + for threadId, stack in sys._current_frames().items(): + code.append(f"\n# Thread: {id2name.get(threadId, '')}({threadId})") + for filename, lineno, name, line in traceback.extract_stack(stack): + code.append(f"""File: "{filename}", line {lineno}, in {name}""") + if line: + code.append(" " + line.strip()) + + print("\n".join(code)) + + +def configure_sigint_handler(): + # make the program just exit at ctrl+c without waiting for anything + def sigint_handler(sig, frame): + print(f'Interrupted with signal {sig} in {frame}') + + dumpstacks() + + os._exit(0) + + if not os.environ.get("COVERAGE_RUN"): + # Don't install the immediate-quit handler when running under coverage, + # as then the coverage report won't be generated. + signal.signal(signal.SIGINT, sigint_handler) + + +def configure_opts_onchange(): + from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack + from modules.call_queue import wrap_queued_call + + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) + shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) + startup_timer.record("opts onchange") + + +def setup_middleware(app): + from starlette.middleware.gzip import GZipMiddleware + + app.middleware_stack = None # reset current middleware to allow modifying user provided list + app.add_middleware(GZipMiddleware, minimum_size=1000) + configure_cors_middleware(app) + app.build_middleware_stack() # rebuild middleware stack on-the-fly + + +def configure_cors_middleware(app): + from starlette.middleware.cors import CORSMiddleware + from modules.shared_cmd_options import cmd_opts + + cors_options = { + "allow_methods": ["*"], + "allow_headers": ["*"], + "allow_credentials": True, + } + if cmd_opts.cors_allow_origins: + cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',') + if cmd_opts.cors_allow_origins_regex: + cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex + app.add_middleware(CORSMiddleware, **cors_options) + diff --git a/modules/interrogate.py b/modules/interrogate.py index a3ae1dd5c4c0663fe3548ee85d2856866c04b1b8..3045560d0aeb4c2c77dc5ec9533d8407bb610172 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -187,12 +187,10 @@ res = "" shared.state.begin(job="interrogate") try: import os -from collections import namedtuple import os - running_on_cpu = None import os - def __init__(self, content_dir): + self.running_on_cpu = devices.device_interrogate == torch.device("cpu") self.load() diff --git a/modules/launch_utils.py b/modules/launch_utils.py index e1c9cfbec5b6bfda9848cc44465f2bfc8b8caf10..7e4d5a61392649ca301476acf5900a32d326c6a7 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -1,7 +1,10 @@ +# this scripts installs necessary requirements and launches main program in webui.py + minor = sys.version_info.minor # this scripts installs necessary requirements and launches main program in webui.py import re import subprocess import os +import shutil import sys import importlib.util import platform @@ -10,13 +13,14 @@ from functools import lru_cache from modules import cmd_args, errors from modules.paths_internal import script_path, extensions_dir -# this scripts installs necessary requirements and launches main program in webui.py + minor = sys.version_info.minor import subprocess - -# this scripts installs necessary requirements and launches main program in webui.py + minor = sys.version_info.minor import os # this scripts installs necessary requirements and launches main program in webui.py +import sys + minor = sys.version_info.minor import sys python = sys.executable @@ -144,6 +148,25 @@ result = subprocess.run([python, "-c", code], capture_output=True, shell=False) return result.returncode == 0 +def git_fix_workspace(dir, name): + run(f'"{git}" -C "{dir}" fetch --refetch --no-auto-gc', f"Fetching all contents for {name}", f"Couldn't fetch {name}", live=True) + run(f'"{git}" -C "{dir}" gc --aggressive --prune=now', f"Pruning {name}", f"Couldn't prune {name}", live=True) + return + + +def run_git(dir, name, command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live, autofix=True): + try: + return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live) + except RuntimeError: + if not autofix: + raise + + print(f"{errdesc}, attempting autofix...") + git_fix_workspace(dir, name) + + return run(f'"{git}" -C "{dir}" {command}', desc=desc, errdesc=errdesc, custom_env=custom_env, live=live) + + def git_clone(url, dir, name, commithash=None): # TODO clone into temporary dir and move if successful @@ -151,17 +174,26 @@ if os.path.exists(dir): if commithash is None: return - current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip() + current_hash = run_git(dir, name, 'rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}", live=False).strip() if current_hash == commithash: return + if run_git(dir, name, 'config --get remote.origin.url', None, f"Couldn't determine {name}'s origin URL", live=False).strip() != url: + run_git(dir, name, f'remote set-url origin "{url}"', None, f"Failed to set {name}'s origin URL", live=False) +import re from functools import lru_cache +# this scripts installs necessary requirements and launches main program in webui.py # this scripts installs necessary requirements and launches main program in webui.py + run_git(dir, name, f'checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}", live=True) return - run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True) + try: + run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}", live=True) + except RuntimeError: + shutil.rmtree(dir, ignore_errors=True) + raise if commithash is not None: run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}") @@ -221,7 +253,7 @@ disabled_extensions = set(settings.get('disabled_extensions', [])) disable_all_extensions = settings.get('disable_all_extensions', 'none') - if disable_all_extensions != 'none': + if disable_all_extensions != 'none' or args.disable_extra_extensions or args.disable_all_extensions or not os.path.isdir(extensions_dir): return [] return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions] @@ -231,10 +263,15 @@ def run_extensions_installers(settings_file): if not os.path.isdir(extensions_dir): return + with startup_timer.subcategory("run extensions installers"): + for dirname_extension in list_extensions(settings_file): + logging.debug(f"Installing {dirname_extension}") # this scripts installs necessary requirements and launches main program in webui.py - changelog_md = os.path.join(os.path.dirname(os.path.dirname(__file__)), "CHANGELOG.md") + path = os.path.join(extensions_dir, dirname_extension) # this scripts installs necessary requirements and launches main program in webui.py - with open(changelog_md, "r", encoding="utf-8") as file: + if os.path.isdir(path): + run_extension_installer(path) + startup_timer.record(dirname_extension) re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*") @@ -282,8 +319,6 @@ requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') index_url = os.environ.get('INDEX_URL', "") -import subprocess -index_url = os.environ.get('INDEX_URL', "") import os openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") @@ -294,14 +329,14 @@ codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") - stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "5c10deee76adad0032b412294130090932317a87") + stable_diffusion_xl_commit_hash = os.environ.get('STABLE_DIFFUSION_XL_COMMIT_HASH', "45c443b316737a4ab6e40413d7794a7f5657c19f") -dir_repos = "repositories" import subprocess +import sys codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") try: -dir_repos = "repositories" + supported_minors = [10] import importlib.util os.remove(os.path.join(script_path, "tmp", "restart")) os.environ.setdefault('SD_WEBUI_RESTARTING', '1') @@ -311,8 +346,11 @@ if not args.skip_python_version_check: check_python_version() + startup_timer.record("checks") + commit = commit_hash() tag = git_tag() + startup_timer.record("git version info") print(f"Python {sys.version}") print(f"Version: {tag}") @@ -320,21 +358,23 @@ print(f"Commit hash: {commit}") if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"): run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True) + startup_timer.record("install torch") if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"): raise RuntimeError( 'Torch is not able to use GPU; ' 'add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check' ) +import subprocess - if not is_installed("gfpgan"): - run_pip(f"install {gfpgan_package}", "gfpgan") if not is_installed("clip"): run_pip(f"install {clip_package}", "clip") + startup_timer.record("install clip") if not is_installed("open_clip"): run_pip(f"install {openclip_package}", "open_clip") + startup_timer.record("install open_clip") if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers: if platform.system() == "Windows": @@ -348,8 +388,11 @@ exit(0) elif platform.system() == "Linux": run_pip(f"install -U -I --no-deps {xformers_package}", "xformers") + startup_timer.record("install xformers") + if not is_installed("ngrok") and args.ngrok: run_pip("install ngrok", "ngrok") + startup_timer.record("install ngrok") os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) @@ -358,23 +401,29 @@ git_clone(stable_diffusion_xl_repo, repo_dir('generative-models'), "Stable Diffusion XL", stable_diffusion_xl_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) + + startup_timer.record("clone repositores") if not is_installed("lpips"): run_pip(f"install -r \"{os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}\"", "requirements for CodeFormer") + startup_timer.record("install CodeFormer requirements") if not os.path.isfile(requirements_file): requirements_file = os.path.join(script_path, requirements_file) if not requirements_met(requirements_file): run_pip(f"install -r \"{requirements_file}\"", "requirements") + startup_timer.record("install requirements") run_extensions_installers(settings_file=args.ui_settings_file) if args.update_check: version_check(commit) + startup_timer.record("check version") if args.update_all_extensions: git_pull_recursive(extensions_dir) + startup_timer.record("update extensions") if "--exit" in sys.argv: print("Exiting because of --exit argument") diff --git a/modules/localization.py b/modules/localization.py index e8f585dab2e285fd477222ce90173aaefa88a3d4..c132028856fc3a3d91779d479be56def9d0f764f 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -1,7 +1,7 @@ import json import os -from modules import errors +from modules import errors, scripts localizations = {} @@ -15,8 +15,6 @@ if ext.lower() != ".json": continue localizations[fn] = os.path.join(dirname, file) - -import json for file in scripts.list_scripts("localizations", ".json"): fn, ext = os.path.splitext(file.filename) diff --git a/modules/logging_config.py b/modules/logging_config.py new file mode 100644 index 0000000000000000000000000000000000000000..7db23d4b6e5b883edad12710d556f3cd1872c678 --- /dev/null +++ b/modules/logging_config.py @@ -0,0 +1,16 @@ +import os +import logging + + +def setup_logging(loglevel): + if loglevel is None: + loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL") + + if loglevel: + log_level = getattr(logging, loglevel.upper(), None) or logging.INFO + logging.basicConfig( + level=log_level, + format='%(asctime)s %(levelname)s [%(name)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + ) + diff --git a/modules/lowvram.py b/modules/lowvram.py index 3f83066437d9398f150479e54eb3a5d83ff0452e..45701046b546a452105c33671188ec13529b6d5c 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -1,5 +1,5 @@ import torch -from modules import devices +from modules import devices, shared module_in_gpu = None cpu = torch.device("cpu") @@ -14,8 +14,25 @@ module_in_gpu = None +def is_needed(sd_model): + return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner') + + +def apply(sd_model): + enable = is_needed(sd_model) import torch + import torch +import torch + setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram) + else: + sd_model.lowvram = False + + +def setup_for_low_vram(sd_model, use_medvram): + if getattr(sd_model, 'lowvram', False): + return + sd_model.lowvram = True parents = {} @@ -128,4 +145,4 @@ block.register_forward_pre_hook(send_me_to_gpu) def is_enabled(sd_model): - return getattr(sd_model, 'lowvram', False) + return sd_model.lowvram diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 9ceb43baec2ce253e43ef8516c69f1186aa2691e..89256c5b06073c38a903d71a10d3be31079085df 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -4,6 +4,7 @@ import torch import platform from modules.sd_hijack_utils import CondFunc from packaging import version +from modules import shared log = logging.getLogger(__name__) @@ -30,8 +31,7 @@ def torch_mps_gc() -> None: try: - from modules.shared import state - if state.current_latent is not None: + if shared.state.current_latent is not None: log.debug("`current_latent` is set, skipping MPS garbage collection") return from torch.mps import empty_cache @@ -52,9 +52,6 @@ return cumsum_func(input, *args, **kwargs) import platform - - # MPS fix for randn in torchsde - CondFunc('torchsde._brownian.brownian_interval._randn', lambda _, size, dtype, device, seed: torch.randn(size, dtype=dtype, device=torch.device("cpu"), generator=torch.Generator(torch.device("cpu")).manual_seed(int(seed))).to(device), lambda _, size, dtype, device, seed: device.type == 'mps') if platform.mac_ver()[0].startswith("13.2."): # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) diff --git a/modules/options.py b/modules/options.py new file mode 100644 index 0000000000000000000000000000000000000000..758b1ce5f2428bb5d11b9b45f68febeee4d60014 --- /dev/null +++ b/modules/options.py @@ -0,0 +1,245 @@ +import json +import sys + +import gradio as gr + +from modules import errors +from modules.shared_cmd_options import cmd_opts + + +class OptionInfo: + def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after='', infotext=None, restrict_api=False): + self.default = default + self.label = label + self.component = component + self.component_args = component_args + self.onchange = onchange + self.section = section + self.refresh = refresh + self.do_not_save = False + + self.comment_before = comment_before + """HTML text that will be added after label in UI""" + + self.comment_after = comment_after + """HTML text that will be added before label in UI""" + + self.infotext = infotext + + self.restrict_api = restrict_api + """If True, the setting will not be accessible via API""" + + def link(self, label, url): + self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]" + return self + + def js(self, label, js_func): + self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]" + return self + + def info(self, info): + self.comment_after += f"<span class='info'>({info})</span>" + return self + + def html(self, html): + self.comment_after += html + return self + + def needs_restart(self): + self.comment_after += " <span class='info'>(requires restart)</span>" + return self + + def needs_reload_ui(self): + self.comment_after += " <span class='info'>(requires Reload UI)</span>" + return self + + +class OptionHTML(OptionInfo): + def __init__(self, text): + super().__init__(str(text).strip(), label='', component=lambda **kwargs: gr.HTML(elem_classes="settings-info", **kwargs)) + + self.do_not_save = True + + +def options_section(section_identifier, options_dict): + for v in options_dict.values(): + v.section = section_identifier + + return options_dict + + +options_builtin_fields = {"data_labels", "data", "restricted_opts", "typemap"} + + +class Options: + typemap = {int: float} + + def __init__(self, data_labels: dict[str, OptionInfo], restricted_opts): + self.data_labels = data_labels + self.data = {k: v.default for k, v in self.data_labels.items()} + self.restricted_opts = restricted_opts + + def __setattr__(self, key, value): + if key in options_builtin_fields: + return super(Options, self).__setattr__(key, value) + + if self.data is not None: + if key in self.data or key in self.data_labels: + assert not cmd_opts.freeze_settings, "changing settings is disabled" + + info = self.data_labels.get(key, None) + if info.do_not_save: + return + + comp_args = info.component_args if info else None + if isinstance(comp_args, dict) and comp_args.get('visible', True) is False: + raise RuntimeError(f"not possible to set {key} because it is restricted") + + if cmd_opts.hide_ui_dir_config and key in self.restricted_opts: + raise RuntimeError(f"not possible to set {key} because it is restricted") + + self.data[key] = value + return + + return super(Options, self).__setattr__(key, value) + + def __getattr__(self, item): + if item in options_builtin_fields: + return super(Options, self).__getattribute__(item) + + if self.data is not None: + if item in self.data: + return self.data[item] + + if item in self.data_labels: + return self.data_labels[item].default + + return super(Options, self).__getattribute__(item) + + def set(self, key, value, is_api=False, run_callbacks=True): + """sets an option and calls its onchange callback, returning True if the option changed and False otherwise""" + + oldval = self.data.get(key, None) + if oldval == value: + return False + + option = self.data_labels[key] + if option.do_not_save: + return False + + if is_api and option.restrict_api: + return False + + try: + setattr(self, key, value) + except RuntimeError: + return False + + if run_callbacks and option.onchange is not None: + try: + option.onchange() + except Exception as e: + errors.display(e, f"changing setting {key} to {value}") + setattr(self, key, oldval) + return False + + return True + + def get_default(self, key): + """returns the default value for the key""" + + data_label = self.data_labels.get(key) + if data_label is None: + return None + + return data_label.default + + def save(self, filename): + assert not cmd_opts.freeze_settings, "saving settings is disabled" + + with open(filename, "w", encoding="utf8") as file: + json.dump(self.data, file, indent=4) + + def same_type(self, x, y): + if x is None or y is None: + return True + + type_x = self.typemap.get(type(x), type(x)) + type_y = self.typemap.get(type(y), type(y)) + + return type_x == type_y + + def load(self, filename): + with open(filename, "r", encoding="utf8") as file: + self.data = json.load(file) + + # 1.6.0 VAE defaults + if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None: + self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default') + + # 1.1.1 quicksettings list migration + if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None: + self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')] + + # 1.4.0 ui_reorder + if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data: + self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')] + + bad_settings = 0 + for k, v in self.data.items(): + info = self.data_labels.get(k, None) + if info is not None and not self.same_type(info.default, v): + print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr) + bad_settings += 1 + + if bad_settings > 0: + print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr) + + def onchange(self, key, func, call=True): + item = self.data_labels.get(key) + item.onchange = func + + if call: + func() + + def dumpjson(self): + d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()} + d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None} + d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None} + return json.dumps(d) + + def add_option(self, key, info): + self.data_labels[key] = info + + def reorder(self): + """reorder settings so that all items related to section always go together""" + + section_ids = {} + settings_items = self.data_labels.items() + for _, item in settings_items: + if item.section not in section_ids: + section_ids[item.section] = len(section_ids) + + self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section])) + + def cast_value(self, key, value): + """casts an arbitrary to the same type as this setting's value with key + Example: cast_value("eta_noise_seed_delta", "12") -> returns 12 (an int rather than str) + """ + + if value is None: + return None + + default_value = self.data_labels[key].default + if default_value is None: + default_value = getattr(self, key, None) + if default_value is None: + return None + + expected_type = type(default_value) + if expected_type == bool and value == "False": + value = False + else: + value = expected_type(value) + + return value diff --git a/modules/patches.py b/modules/patches.py new file mode 100644 index 0000000000000000000000000000000000000000..348235e7e32aa60f1c5cf295dc873dcfc648b70f --- /dev/null +++ b/modules/patches.py @@ -0,0 +1,64 @@ +from collections import defaultdict + + +def patch(key, obj, field, replacement): + """Replaces a function in a module or a class. + + Also stores the original function in this module, possible to be retrieved via original(key, obj, field). + If the function is already replaced by this caller (key), an exception is raised -- use undo() before that. + + Arguments: + key: identifying information for who is doing the replacement. You can use __name__. + obj: the module or the class + field: name of the function as a string + replacement: the new function + + Returns: + the original function + """ + + patch_key = (obj, field) + if patch_key in originals[key]: + raise RuntimeError(f"patch for {field} is already applied") + + original_func = getattr(obj, field) + originals[key][patch_key] = original_func + + setattr(obj, field, replacement) + + return original_func + + +def undo(key, obj, field): + """Undoes the peplacement by the patch(). + + If the function is not replaced, raises an exception. + + Arguments: + key: identifying information for who is doing the replacement. You can use __name__. + obj: the module or the class + field: name of the function as a string + + Returns: + Always None + """ + + patch_key = (obj, field) + + if patch_key not in originals[key]: + raise RuntimeError(f"there is no patch for {field} to undo") + + original_func = originals[key].pop(patch_key) + setattr(obj, field, original_func) + + return None + + +def original(key, obj, field): + """Returns the original function for the patch created by the patch() function""" + patch_key = (obj, field) + + return originals[key].get(patch_key, None) + + +originals = defaultdict(dict) diff --git a/modules/postprocessing.py b/modules/postprocessing.py index 136e9c88721d1a95056d7b66535a5c478234224c..cf04d38b0592c6eeabe6763dd41a1d25c8543add 100644 --- a/modules/postprocessing.py +++ b/modules/postprocessing.py @@ -11,55 +11,51 @@ devices.torch_gc() shared.state.begin(job="extras") - image_data = [] - image_names = [] outputs = [] + def get_images(extras_mode, image, image_folder, input_dir): + shared.state.begin(job="extras") import os + shared.state.begin(job="extras") -import os + shared.state.begin(job="extras") from PIL import Image -import os + shared.state.begin(job="extras") from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste -import os + shared.state.begin(job="extras") from modules.shared import opts -import os + shared.state.begin(job="extras") def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): -import os + shared.state.begin(job="extras") devices.torch_gc() -import os + shared.state.begin(job="extras") shared.state.begin(job="extras") -import os + shared.state.begin(job="extras") image_data = [] - + image_data = [] - + image_data = [] import os - + image_data = [] + image_data = [] from PIL import Image - + image_data = [] from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste - - + image_data = [] from modules.shared import opts - + image_data = [] def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True): - + image_data = [] devices.torch_gc() - + image_data = [] shared.state.begin(job="extras") - + image_data = [] image_data = [] - continue + else: - + image_names = [] -from PIL import Image import os - else: - assert image, 'image not selected' - - image_data.append(image) - image_names.append(None) +import os if extras_mode == 2 and output_dir != '': outpath = output_dir @@ -68,15 +64,18 @@ outpath = opts.outdir_samples or opts.outdir_extras_samples infotext = '' -from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste +import os + + image_data: Image.Image + shared.state.textinfo = name + image_names = [] from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste - if parameters: existing_pnginfo["parameters"] = parameters -from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste + image_names = [] from modules.shared import opts scripts.scripts_postproc.run(pp, args) @@ -97,6 +96,8 @@ images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=infotext, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None) if extras_mode != 2 or show_extras_results: outputs.append(pp.image) + + image_data.close() devices.torch_gc() diff --git a/modules/processing.py b/modules/processing.py index b0992ee15ab3426d4c754126c0ab480c7f752533..7dc931ba53e3afe01c37eed6e00982f6058dd4a6 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,9 +1,11 @@ +from __future__ import annotations import json import logging import math import os import sys import hashlib +from dataclasses import dataclass, field import torch import numpy as np @@ -11,11 +13,13 @@ from PIL import Image, ImageOps import random import cv2 from skimage import exposure -from typing import Any, Dict, List +from typing import Any import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng +from modules.rng import slerp # noqa: F401 from modules.sd_hijack import model_hijack +from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.paths as paths @@ -55,8 +59,9 @@ ), cv2.COLOR_LAB2RGB).astype("uint8")) image = blendLayers(image, original_image, BlendType.LUMINOSITY) -import os +import numpy as np import torch + def apply_overlay(image, paste_loc, index, overlays): @@ -78,13 +83,19 @@ image = image.convert('RGB') return image +def create_binary_mask(image): + if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255): + image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0) + else: + image = image.convert('L') + return image def txt2img_image_conditioning(sd_model, x, width, height): if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models # The "masked-image" in this case will just be all zeros since the entire image is masked. image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning)) + image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method)) # Add the fake full 1s mask to the first dimension. image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) @@ -103,158 +114,237 @@ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) -class StableDiffusionProcessing: - import numpy as np - The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing - import numpy as np - cached_uc = [None, None] - cached_c = [None, None] - -import torch import math - if sampler_index is not None: - print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) - self.outpath_samples: str = outpath_samples - self.outpath_grids: str = outpath_grids -import torch import torch -import torch import numpy as np import numpy as np - self.styles: list = styles or [] - self.seed: int = seed - self.subseed: int = subseed +import os import numpy as np -import os self.seed_resize_from_h: int = seed_resize_from_h import numpy as np +import numpy as np import hashlib +import numpy as np self.sampler_name: str = sampler_name import numpy as np +import numpy as np import torch +import numpy as np self.n_iter: int = n_iter self.steps: int = steps import json import json import json import logging from PIL import Image, ImageOps import math import json import os from PIL import Image, ImageOps import sys import json import hashlib from PIL import Image, ImageOps import json import torch from PIL import Image, ImageOps import numpy as np import json from PIL import Image, ImageOps import json import random import json import cv2 import json from skimage import exposure import json from typing import Any, Dict, List import json import modules.sd_hijack import json from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors import json from modules.sd_hijack import model_hijack import json from modules.shared import opts, cmd_opts, state import json import modules.shared as shared import json import modules.paths as paths import json import modules.face_restoration import json import modules.images as images import json import modules.styles import json import modules.sd_models as sd_models import json import modules.sd_vae as sd_vae import json from ldm.data.util import AddMiDaS import json import logging - import json from einops import repeat, rearrange import json from blendmodes.blend import blendLayers, BlendType import json # some of those options should not be changed at all because they would break the model, so I removed them from options. import json opt_C = 4 import json opt_f = 8 - import json def setup_color_correction(image): import json logging.info("Calibrating color correction.") import json correction_target = cv2.cvtColor(np.asarray(image.copy()), cv2.COLOR_RGB2LAB) import json return correction_target + import json def apply_color_correction(correction, original_image): import json logging.info("Applying color correction.") import json image = Image.fromarray(cv2.cvtColor(exposure.match_histograms( + +import torch import json + cached_c = [None, None] + +from PIL import Image, ImageOps cv2.cvtColor( import json np.asarray(original_image), - import json cv2.COLOR_RGB2LAB import json ), + import json correction, + import json channel_axis=2 import json ), cv2.COLOR_LAB2RGB).astype("uint8")) import json image = blendLayers(image, original_image, BlendType.LUMINOSITY) import json return image import json def apply_overlay(image, paste_loc, index, overlays): + import json if overlays is None or index >= len(overlays): import json return image + all_seeds: list = field(default=None, init=False) + all_subseeds: list = field(default=None, init=False) + iteration: int = field(default=0, init=False) + main_prompt: str = field(default=None, init=False) + main_negative_prompt: str = field(default=None, init=False) import json import sys + + negative_prompts: list = field(default=None, init=False) + seeds: list = field(default=None, init=False) + subseeds: list = field(default=None, init=False) + extra_network_data: dict = field(default=None, init=False) + + self.do_not_save_samples: bool = do_not_save_samples import logging import json + if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models + sd_model_hash: str = field(default=None, init=False) + sd_vae_name: str = field(default=None, init=False) + sd_vae_hash: str = field(default=None, init=False) + + is_api: bool = field(default=False, init=False) + + def __post_init__(self): + if self.sampler_index is not None: + print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) + + self.comments = {} + + if self.styles is None: + self.styles = [] + + self.sampler_noise_scheduler_override = None + self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond + self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn + self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin + self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf') + self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise + + self.extra_generation_params = self.extra_generation_params or {} + self.override_settings = self.override_settings or {} + self.script_args = self.script_args or {} + + self.refiner_checkpoint_info = None + + if not self.seed_enable_extras: + self.subseed = -1 + self.subseed_strength = 0 + self.seed_resize_from_h = 0 + self.seed_resize_from_w = 0 + + self.cached_uc = StableDiffusionProcessing.cached_uc + self.cached_c = StableDiffusionProcessing.cached_c + +import json if paste_loc is not None: def sd_model(self): return shared.sd_model import json + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): + def sd_model(self, value): + pass + + @property + def scripts(self): + return self.scripts_value + + @scripts.setter + def scripts(self, value): + self.scripts_value = value + + if self.scripts_value and self.script_args_value and not self.scripts_setup_complete: + self.setup_scripts() + + @property + def script_args(self): + return self.script_args_value + + @script_args.setter + def script_args(self, value): + self.script_args_value = value + + if self.scripts_value and self.script_args_value and not self.scripts_setup_complete: + self.setup_scripts() + + def setup_scripts(self): + self.scripts_setup_complete = True + + self.scripts.setup_scrips(self, is_ui=not self.is_api) + + def comment(self, text): + self.comments[text] = 1 + +import json image = images.resize_image(1, image, w, h) self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'} @@ -267,7 +306,7 @@ midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) import json - image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) + self.width: int = width conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), size=conditioning_image.shape[2:], @@ -281,7 +320,7 @@ return conditioning def edit_image_conditioning(self, source_image): import json - # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. + self.width: int = width return conditioning_image @@ -360,9 +399,8 @@ def close(self): self.sampler = None self.c = None self.uc = None -import logging + self.eta = eta import math -import logging StableDiffusionProcessing.cached_c = [None, None] StableDiffusionProcessing.cached_uc = [None, None] @@ -373,31 +411,63 @@ return self.token_merging_ratio or opts.token_merging_ratio def setup_prompts(self): -import logging + self.eta = eta import os self.all_prompts = self.prompt import json + self.tiling: bool = tiling + self.all_prompts = [self.prompt] * len(self.negative_prompt) +import json self.sampler_name: str = sampler_name self.all_prompts = self.batch_size * self.n_iter * [self.prompt] + if isinstance(self.negative_prompt, list): import modules.sd_models as sd_models -import math +import os + else: + self.all_negative_prompts = [self.negative_prompt] * len(self.all_prompts) + + if len(self.all_prompts) != len(self.all_negative_prompts): + raise RuntimeError(f"Received a different number of prompts ({len(self.all_prompts)}) and negative prompts ({len(self.all_negative_prompts)})") + import modules.sd_models as sd_models +import hashlib +import logging import os + -import modules.shared as shared + self.main_prompt = self.all_prompts[0] + self.do_not_reload_embeddings = do_not_reload_embeddings import logging + + def cached_params(self, required_prompts, steps, extra_network_data, hires_steps=None, use_old_scheduling=False): + self.do_not_reload_embeddings = do_not_reload_embeddings import os + + self.do_not_reload_embeddings = do_not_reload_embeddings import sys + required_prompts, + steps, + hires_steps, + self.do_not_reload_embeddings = do_not_reload_embeddings import logging + image.alpha_composite(overlay) + shared.sd_model.sd_checkpoint_info, + extra_network_data, + opts.sdxl_crop_left, +from ldm.data.util import AddMiDaS import os +import logging import hashlib +import sys import logging -import os + image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning)) +import json +import json -import modules.sd_models as sd_models + self.do_not_reload_embeddings = do_not_reload_embeddings import torch """ Returns the result of calling function(shared.sd_model, required_prompts, steps) @@ -410,27 +481,23 @@ caches is a list with items described above. """ - cached_params = ( - required_prompts, -import modules.sd_vae as sd_vae + self.do_not_reload_embeddings = do_not_reload_embeddings import numpy as np +import random import logging -import hashlib +import random import logging -import hashlib import json +import random import logging -import hashlib import logging +import random import logging -import hashlib import math - opts.sdxl_crop_top, - self.width, - self.height, + import json - import json +import modules.sd_models as sd_models for cache in caches: if cache[0] is not None and cached_params == cache[0]: @@ -439,8 +506,9 @@ cache = caches[0] with devices.autocast(): +import random import logging - return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) +import sys cache[0] = cached_params return cache[1] @@ -450,16 +518,28 @@ prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height) negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True) sampler_config = sd_samplers.find_sampler_config(self.sampler_name) + total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps +import random from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion -import numpy as np +import random from einops import repeat, rearrange + +import random import logging -import torch +import numpy as np import json + self.subseed_strength = 0 + + def get_conds(self): + return self.c, self.uc def parse_extra_network_prompts(self): self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) + def save_samples(self) -> bool: + """Returns whether generated images need to be written to disk""" + return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped) + class Processed: def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""): @@ -469,8 +550,9 @@ self.seed = seed self.subseed = subseed self.subseed_strength = p.subseed_strength self.info = info -from blendmodes.blend import blendLayers, BlendType +import random import math +import hashlib self.width = p.width self.height = p.height self.sampler_name = p.sampler_name @@ -480,8 +562,14 @@ self.steps = p.steps self.batch_size = p.batch_size self.restore_faces = p.restore_faces self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None +import random import math + +import random import math +import torch + self.sd_vae_name = p.sd_vae_name + self.sd_vae_hash = p.sd_vae_hash self.seed_resize_from_w = p.seed_resize_from_w self.seed_resize_from_h = p.seed_resize_from_h self.denoising_strength = getattr(p, 'denoising_strength', None) @@ -500,13 +589,14 @@ self.s_tmax = p.s_tmax self.s_noise = p.s_noise self.s_min_uncond = p.s_min_uncond self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override -opt_f = 8 + self.prompt = self.prompt if not isinstance(self.prompt, list) else self.prompt[0] + self.denoising_strength: float = denoising_strength import logging -opt_f = 8 + self.denoising_strength: float = denoising_strength import math -opt_f = 8 +import random import os - self.subseed = int(self.subseed if type(self.subseed) != list else self.subseed[0]) if self.subseed is not None else -1 +import os self.is_using_inpainting_conditioning = p.is_using_inpainting_conditioning self.all_prompts = all_prompts or p.all_prompts or [self.prompt] @@ -534,7 +624,10 @@ "steps": self.steps, "batch_size": self.batch_size, "restore_faces": self.restore_faces, "face_restoration_model": self.face_restoration_model, + "sd_model_name": self.sd_model_name, "sd_model_hash": self.sd_model_hash, + "sd_vae_name": self.sd_vae_name, + "sd_vae_hash": self.sd_vae_hash, "seed_resize_from_w": self.seed_resize_from_w, "seed_resize_from_h": self.seed_resize_from_h, "denoising_strength": self.denoising_strength, @@ -557,108 +650,31 @@ return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio import math -import hashlib - return correction_target import torch - low_norm = low/torch.norm(low, dim=1, keepdim=True) - high_norm = high/torch.norm(high, dim=1, keepdim=True) -def apply_color_correction(correction, original_image): import json - - if dot.mean() > 0.9995: - return low * val + high * (1 - val) - - omega = torch.acos(dot) - so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high - return res - - -def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None): - eta_noise_seed_delta = opts.eta_noise_seed_delta or 0 - xs = [] - - logging.info("Applying color correction.") import json - # enables the generation of additional tensors with noise that the sampler will use during its processing. - # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to - logging.info("Applying color correction.") import os -import math import torch -import sys - sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] - else: - sampler_noises = None - - for i, seed in enumerate(seeds): - noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8) - - subnoise = None - image = Image.fromarray(cv2.cvtColor(exposure.match_histograms( import json - subseed = 0 if i >= len(subseeds) else subseeds[i] - - subnoise = devices.randn(subseed, noise_shape) - - # randn results depend on device; gpu and cpu get different results for same seed; - # the way I see it, it's better to do this on CPU, so that everyone gets same result; - # but the original script had it like this, so I do not dare change it for now because - # it will break everyone's seeds. - noise = devices.randn(seed, noise_shape) - - if subnoise is not None: - noise = slerp(subseed_strength, noise, subnoise) - - cv2.cvtColor( import json - x = devices.randn(seed, shape) - dx = (shape[2] - noise_shape[2]) // 2 - dy = (shape[1] - noise_shape[1]) // 2 - w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx - h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy - tx = 0 if dx < 0 else dx - ty = 0 if dy < 0 else dy import os import numpy as np - dy = max(-dy, 0) - x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w] - noise = x -import os import json -import math - cnt = p.sampler.number_of_needed_noises(p) - -import os import modules.sd_hijack - torch.manual_seed(seed + eta_noise_seed_delta) - -import os import json - -import os import json -import torch - -import os +import sys import json -import numpy as np - if sampler_noises is not None: - p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises] cv2.COLOR_RGB2LAB -import logging import os -import modules.styles - - -def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): - cv2.COLOR_RGB2LAB +import random import sys +import logging for i in range(batch.shape[0]): sample = decode_first_stage(model, batch[i:i + 1])[0] @@ -672,9 +688,8 @@ errors.print_error_explanation( "A tensor with all NaNs was produced in VAE.\n" "Web UI will now convert VAE into 32-bit float and retry.\n" -import os + self.sampler_noise_scheduler_override = None import math - "To always start with 32-bit VAE, use --no-half-vae commandline flag." ) @@ -693,16 +708,20 @@ return samples import os - return image + if overlays is None or index >= len(overlays): + self.sampler_noise_scheduler_override = None import os -def apply_overlay(image, paste_loc, index, overlays): - + seed = -1 - return x + elif isinstance(seed, str): + self.sampler_noise_scheduler_override = None - + seed = int(seed) -import os +import random import sys +import numpy as np + self.ddim_discretize = ddim_discretize or opts.ddim_discretize - channel_axis=2 + + self.ddim_discretize = ddim_discretize or opts.ddim_discretize import json return int(random.randrange(4294967294)) @@ -746,11 +765,13 @@ "Sampler": p.sampler_name, "CFG scale": p.cfg_scale, "Image CFG scale": getattr(p, 'image_cfg_scale', None), "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index], - "Face restoration": (opts.face_restoration_model if p.restore_faces else None), + "Face restoration": opts.face_restoration_model if p.restore_faces else None, "Size": f"{p.width}x{p.height}", - "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), + "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None, + self.ddim_discretize = ddim_discretize or opts.ddim_discretize import os - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): + "VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None, + "VAE": p.sd_vae_name if opts.add_model_name_to_info else None, "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), @@ -761,8 +782,9 @@ "ENSD": opts.eta_noise_seed_delta if uses_ensd else None, "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio, "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr, "Init image hash": getattr(p, 'init_img_hash', None), - "RNG": opts.randn_source if opts.randn_source != "GPU" else None, + "RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None, "NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond, + "Tiling": "True" if p.tiling else None, **p.extra_generation_params, "Version": program_version() if opts.add_version_to_infotext else None, "User": p.user if opts.add_user_name_to_info else None, @@ -770,10 +792,10 @@ } generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) - if overlays is None or index >= len(overlays): +import json import json + image_conditioning = image_conditioning.to(x.dtype) -import sys + self.s_min_uncond = s_min_uncond or opts.s_min_uncond -import logging return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip() @@ -786,14 +808,15 @@ stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} try: # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint + # and if after running refiner, the refiner model is not unloaded - webui swaps back to main model here, if model over is present it will be reloaded afterwards if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None: p.override_settings.pop('sd_model_checkpoint', None) sd_models.reload_model_weights() for k, v in p.override_settings.items(): -import sys +import json import json -import os + return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) if k == 'sd_model_checkpoint': sd_models.reload_model_weights() @@ -822,8 +845,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" - if paste_loc is not None: +import json import json + else: assert(len(p.prompt) > 0) else: assert p.prompt is not None @@ -833,24 +857,42 @@ seed = get_fixed_seed(p.seed) subseed = get_fixed_seed(p.subseed) - if paste_loc is not None: +import random +import os + self.s_min_uncond = s_min_uncond or opts.s_min_uncond import sys + + if p.tiling is None: + p.tiling = opts.tiling + + if p.refiner_checkpoint not in (None, "", "None", "none"): + p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint) + if p.refiner_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}') + + p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra + self.s_churn = s_churn or opts.s_churn import math +import random import torch +import os + p.sd_vae_hash = sd_vae.get_loaded_vae_hash() + + if paste_loc is not None: if paste_loc is not None: -import numpy as np +import torch p.setup_prompts() - x, y, w, h = paste_loc import json + self.is_using_inpainting_conditioning = True p.all_seeds = seed else: p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))] - if type(subseed) == list: + if isinstance(subseed, list): p.all_subseeds = subseed else: p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] @@ -886,11 +928,15 @@ if state.interrupted: break + sd_models.reload_model_weights() # model can be changed for example by refiner + p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] + p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) + if p.scripts is not None: p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) @@ -912,13 +958,13 @@ # Example: a wildcard processed by process_batch sets an extra model # strength, which is saved as "Model Strength: 1.0" in the infotext if n == 0: with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: - processed = Processed(p, [], p.seed, "") + processed = Processed(p, []) file.write(processed.infotext(p, 0)) p.setup_conds() for comment in model_hijack.comments: - comments[comment] = 1 + p.comment(comment) p.extra_generation_params.update(model_hijack.extra_generation_params) @@ -928,8 +974,15 @@ with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) -import sys + if getattr(samples_ddim, 'already_decoded', False): + x_samples_ddim = samples_ddim + else: +import random self.subseed_strength: float = subseed_strength + p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method + + x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True) + x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) @@ -953,6 +1006,8 @@ def infotext(index=0, use_main_prompt=False): return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts) + save_samples = p.save_samples() + for i, x_sample in enumerate(x_samples_ddim): p.batch_index = i @@ -960,9 +1015,9 @@ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) x_sample = x_sample.astype(np.uint8) if p.restore_faces: -import hashlib +import json import json -import os + self.batch_size: int = batch_size images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration") devices.torch_gc() @@ -976,18 +1031,16 @@ if p.scripts is not None: pp = scripts.PostprocessImageArgs(image) p.scripts.postprocess_image(p, pp) image = pp.image - if p.color_corrections is not None and i < len(p.color_corrections): - if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction: + if save_samples and opts.save_images_before_color_correction: image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images) images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") image = apply_color_correction(p.color_corrections[i], image) image = apply_overlay(image, p.paste_to, i, p.overlay_images) -import hashlib +import json import logging -import numpy as np images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p) text = infotext(i) @@ -995,8 +1047,7 @@ infotexts.append(text) if opts.enable_pnginfo: image.info["parameters"] = text output_images.append(image) - - if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): + if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]): image_mask = p.mask_for_overlay.convert('RGB') image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') @@ -1032,7 +1083,6 @@ if opts.enable_pnginfo: grid.info["parameters"] = text output_images.insert(0, grid) index_of_first_image = 1 - if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True) @@ -1047,8 +1097,6 @@ images_list=output_images, seed=p.all_seeds[0], info=infotexts[0], import hashlib - image_conditioning = image_conditioning.to(x.dtype) -import hashlib return image_conditioning index_of_first_image=index_of_first_image, infotexts=infotexts, @@ -1072,197 +1120,229 @@ return width, height +@dataclass(repr=False) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + enable_hr: bool = False + denoising_strength: float = 0.75 + self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option import os - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option import sys - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option import hashlib + self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option -import hashlib + self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option import torch - - super().__init__(**kwargs) - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option import numpy as np - image_conditioning = image_conditioning.to(x.dtype) + self.s_noise = s_noise or opts.s_noise - image_conditioning = image_conditioning.to(x.dtype) + self.s_noise = s_noise or opts.s_noise import json - image_conditioning = image_conditioning.to(x.dtype) + self.s_noise = s_noise or opts.s_noise import logging - image_conditioning = image_conditioning.to(x.dtype) + self.s_noise = s_noise or opts.s_noise import math - image_conditioning = image_conditioning.to(x.dtype) + self.s_noise = s_noise or opts.s_noise import os + import hashlib -import numpy as np +import torch import sys import hashlib -import numpy as np +import torch import hashlib - image_conditioning = image_conditioning.to(x.dtype) - self.hr_sampler_name = hr_sampler_name + hr_checkpoint_info: dict = field(default=None, init=False) + self.s_noise = s_noise or opts.s_noise import hashlib - self.n_iter: int = n_iter + self.s_noise = s_noise or opts.s_noise - return image_conditioning +import cv2 import json +import torch - + truncate_y: int = field(default=0, init=False) + self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} +import json import logging +import modules.face_restoration - + hr_c: tuple | None = field(default=None, init=False) - return image_conditioning + self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} import math - return image_conditioning + self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} import os - return image_conditioning + self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} import sys - return image_conditioning + self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} import hashlib + self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} - + hr_extra_network_data: list = field(default=None, init=False) - +import json +import hashlib import torch - return image_conditioning + self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} import numpy as np import json +import modules.styles +import os +import sys import json +import modules.styles import json - import cv2 +opt_f = 8 + elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models -import math +import os +import modules.sd_hijack import json + StableDiffusionProcessing.cached_c = [None, None] + self.override_settings_restore_afterwards = override_settings_restore_afterwards import os - import json +import modules.styles import sys - import json +import modules.styles import hashlib -import json +import os +import sys +import json import logging -import logging +import math - import json +import modules.styles import torch import json +import modules.styles import numpy as np - +import json import logging +import os - - +import json import logging +import os import json - +import json import logging +import os import logging +import modules.shared as shared - +import json import logging +import os import math +import json import logging +import os import os - - +import json import logging +import os import sys - +import json import logging +import os import hashlib - +import json import logging +import os - +import json import logging +import os import torch - +import json import logging +import os import numpy as np - + else: - else: + self.disable_extra_networks = False - else: import json - - else: import logging - self.extra_generation_params["Hires upscale"] = self.hr_scale - self.hr_upscale_to_x = int(self.width * self.hr_scale) - else: import sys +import json import json - self.styles: list = styles or [] + cache is an array containing two elements. The first element is a tuple - + self.disable_extra_networks = False import math -import hashlib - if self.hr_resize_y == 0: + if src_ratio < dst_ratio: self.hr_upscale_to_x = self.hr_resize_x self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width # Dummy zero conditioning if we're not using inpainting or unclip models. +import math self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height self.hr_upscale_to_y = self.hr_resize_y - ), - target_w = self.hr_resize_x - # Dummy zero conditioning if we're not using inpainting or unclip models. + self.disable_extra_networks = False import sys - # Dummy zero conditioning if we're not using inpainting or unclip models. + self.disable_extra_networks = False import hashlib - image = blendLayers(image, original_image, BlendType.LUMINOSITY) +import modules.images as images -import os +import json import torch + self.disable_extra_networks = False -def apply_overlay(image, paste_loc, index, overlays): - +import cv2 import sys +import torch +import cv2 import sys +import numpy as np import json + opts.CLIP_stop_at_last_layers, -import sys +import json import logging + image = image.convert('RGB') - if paste_loc is not None: +import modules.shared as shared - x, y, w, h = paste_loc +import modules.paths as paths - base_image = Image.new('RGBA', (overlay.width, overlay.height)) +import modules.face_restoration - image = images.resize_image(1, image, w, h) +import modules.images as images -import sys +import modules.styles - image = base_image +import modules.sd_models as sd_models - image = image.convert('RGBA') - +import cv2 import hashlib +import logging + if self.enable_hr and self.latent_scale_mode is None: - +import cv2 import hashlib +import os import json + self.width, -import hashlib +import json import logging + image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning)) if not state.processing_has_refined_job_count: if state.job_count == -1: @@ -1277,25 +1357,55 @@ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) + x = self.rng.next() return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) -import os +import numpy as np +import numpy as np +import numpy as np + + if not self.enable_hr: + return samples + + if self.latent_scale_mode is None: + decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32) + else: + decoded_samples = None + +import cv2 +import json +import random import sys + +import cv2 +import logging +import cv2 -import hashlib +import math +import cv2 +import os +import cv2 +import sys +import cv2 +import hashlib +import cv2 return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) -import torch +import cv2 +import torch +import cv2 """ + devices.torch_gc() - +import cv2 The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing + if shared.state.interrupted: return samples self.is_hr_pass = True @@ -1306,9 +1416,9 @@ def save_intermediate(image, index): """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" - +import cv2 import torch - +import logging return if not isinstance(image, Image.Image): @@ -1317,13 +1427,18 @@ info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index) images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix") + img2img_sampler_name = self.hr_sampler_name or self.sampler_name + cached_uc = [None, None] import numpy as np + +import json import logging + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): for i in range(samples.shape[0]): save_intermediate(samples, i) - samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"]) + samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"]) # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. @@ -1333,8 +1448,6 @@ else: image_conditioning = self.txt2img_image_conditioning(samples) else: The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing -import json - The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing import logging batch_images = [] @@ -1351,40 +1464,37 @@ image = np.moveaxis(image, 2, 0) batch_images.append(image) decoded_samples = torch.from_numpy(np.array(batch_images)) -import torch import json -import json +import logging import torch -import cv2 +import sys +import cv2 import torch -from skimage import exposure +import hashlib + if not seed_enable_extras: -import torch import json -import os - +import logging import torch -import modules.sd_hijack - import torch -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors cached_uc = [None, None] - - img2img_sampler_name = 'DDIM' +import os cached_uc = [None, None] -import numpy as np +import sys samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] +import cv2 import torch +import numpy as np +import json import logging -import json + self.negative_prompt: str = (negative_prompt or "") # GC now before running the next img2img to prevent running out of memory - x = None devices.torch_gc() if not self.disable_extra_networks: @@ -1404,16 +1514,21 @@ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio()) from typing import Any, Dict, List +import json + devices.torch_gc() + + decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True) + +from typing import Any, Dict, List - return samples + return decoded_samples def close(self): super().close() self.hr_c = None self.hr_uc = None -import logging + self.eta = eta import math -import logging StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None] StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None] @@ -1429,14 +1544,13 @@ if self.hr_negative_prompt == '': self.hr_negative_prompt = self.negative_prompt - if type(self.hr_prompt) == list: + if isinstance(self.hr_prompt, list): self.all_hr_prompts = self.hr_prompt else: self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt] -import torch + self.subseed = -1 import os - self.all_hr_negative_prompts = self.hr_negative_prompt else: self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt] @@ -1448,21 +1562,32 @@ def calculate_hr_conds(self): if self.hr_c is not None: return -import torch + self.subseed = -1 import sys -import os + hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True) + + sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name) + self.subseed = -1 import torch - base_image = Image.new('RGBA', (overlay.width, overlay.height)) + total_steps = sampler_config.total_steps(steps) if sampler_config else steps + + self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data, total_steps) + self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data, total_steps) def setup_conds(self): + if self.is_hr_pass: + # if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model + self.hr_c = None + self.calculate_hr_conds() + return + super().setup_conds() self.hr_uc = None self.hr_c = None - import json -import torch + self.seed_resize_from_h = p.seed_resize_from_h if shared.opts.hires_fix_use_firstpass_conds: self.calculate_hr_conds() @@ -1475,6 +1600,12 @@ with devices.autocast(): extra_networks.activate(self, self.extra_network_data) + def get_conds(self): + if self.is_hr_pass: + return self.hr_c, self.hr_uc + + return super().get_conds() + def parse_extra_network_prompts(self): res = super().parse_extra_network_prompts() @@ -1487,63 +1618,81 @@ return res +@dataclass(repr=False) class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): -import hashlib + self.subseed_strength = 0 import torch -import os + resize_mode: int = 0 + denoising_strength: float = 0.75 + image_cfg_scale: float = None + mask: Any = None + mask_blur_x: int = 4 - + mask_blur_y: int = 4 - def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs): + mask_blur: int = None + inpainting_fill: int = 0 + self.seed_resize_from_h = 0 import hashlib - self.prompt: str = prompt + self.seed_resize_from_h = 0 + self.seed_resize_from_h = 0 import torch -import hashlib + self.seed_resize_from_h = 0 import numpy as np -import torch + latent_mask: Image = None import json +opt_f = 8 import json -import os -import torch import json + self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0] - self.outpath_grids: str = outpath_grids +from skimage import exposure import logging +import math - self.outpath_grids: str = outpath_grids +import json import math - self.outpath_grids: str = outpath_grids +import logging import os - self.outpath_grids: str = outpath_grids + self.seed_resize_from_w = 0 import sys - self.outpath_grids: str = outpath_grids + self.seed_resize_from_w = 0 import hashlib + + self.do_not_save_samples: bool = do_not_save_samples import torch + super().__post_init__() + self.seed_resize_from_w = 0 import torch - import torch +import hashlib + self.seed_resize_from_w = 0 import torch + @property + self.seed_resize_from_w = 0 import numpy as np - self.prompt: str = prompt + self.scripts = None - self.prompt: str = prompt + self.scripts = None import json - self.prompt: str = prompt + self.scripts = None import logging - self.prompt: str = prompt + + self.scripts = None import math - self.prompt: str = prompt + self.scripts = None import os - self.prompt: str = prompt + self.scripts = None import sys - self.prompt: str = prompt + self.scripts = None import hashlib - self.prompt: str = prompt + self.scripts = None - self.image_conditioning = None import modules.images as images + + self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) crop_region = None @@ -1550,8 +1698,10 @@ image_mask = self.image_mask if image_mask is not None: -import torch + self.scripts = None import numpy as np + # but we still want to support binary masks. + self.script_args = script_args import json if self.inpainting_mask_invert: @@ -1559,13 +1709,13 @@ image_mask = ImageOps.invert(image_mask) if self.mask_blur_x > 0: np_mask = np.array(image_mask) - kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1 + kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1 np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x) image_mask = Image.fromarray(np_mask) if self.mask_blur_y > 0: np_mask = np.array(image_mask) - kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1 + kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1 np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y) image_mask = Image.fromarray(np_mask) @@ -1644,12 +1794,16 @@ raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") image = torch.from_numpy(batch_images) self.seed_resize_from_h: int = seed_resize_from_h +import sys + + self.script_args = script_args import os - self.seed_resize_from_h: int = seed_resize_from_h + self.script_args = script_args import sys - self.seed_resize_from_h: int = seed_resize_from_h + self.script_args = script_args import hashlib + devices.torch_gc() if self.resize_mode == 3: self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") @@ -1671,13 +1825,12 @@ self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask -import numpy as np + self.script_args = script_args -import logging def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): + self.token_merging_ratio = 0 -class StableDiffusionProcessing: if self.initial_noise_multiplier != 1.0: self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py new file mode 100644 index 0000000000000000000000000000000000000000..29ccb78f90375339e197c7533a6779aa1848696e --- /dev/null +++ b/modules/processing_scripts/refiner.py @@ -0,0 +1,49 @@ +import gradio as gr + +from modules import scripts, sd_models +from modules.ui_common import create_refresh_button +from modules.ui_components import InputAccordion + + +class ScriptRefiner(scripts.ScriptBuiltinUI): + section = "accordions" + create_group = False + + def __init__(self): + pass + + def title(self): + return "Refiner" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner: + with gr.Row(): + refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation") + create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh")) + + refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation") + + def lookup_checkpoint(title): + info = sd_models.get_closet_checkpoint_match(title) + return None if info is None else info.title + + self.infotext_fields = [ + (enable_refiner, lambda d: 'Refiner' in d), + (refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))), + (refiner_switch_at, 'Refiner switch at'), + ] + + return enable_refiner, refiner_checkpoint, refiner_switch_at + + def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at): + # the actual implementation is in sd_samplers_common.py, apply_refiner + + if not enable_refiner or refiner_checkpoint in (None, "", "None"): + p.refiner_checkpoint = None + p.refiner_switch_at = None + else: + p.refiner_checkpoint = refiner_checkpoint + p.refiner_switch_at = refiner_switch_at diff --git a/modules/processing_scripts/seed.py b/modules/processing_scripts/seed.py new file mode 100644 index 0000000000000000000000000000000000000000..6b6ff987d2dbb248d9d2da56400e35d9e496048e --- /dev/null +++ b/modules/processing_scripts/seed.py @@ -0,0 +1,111 @@ +import json + +import gradio as gr + +from modules import scripts, ui, errors +from modules.shared import cmd_opts +from modules.ui_components import ToolButton + + +class ScriptSeed(scripts.ScriptBuiltinUI): + section = "seed" + create_group = False + + def __init__(self): + self.seed = None + self.reuse_seed = None + self.reuse_subseed = None + + def title(self): + return "Seed" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + with gr.Row(elem_id=self.elem_id("seed_row")): + if cmd_opts.use_textbox_seed: + self.seed = gr.Textbox(label='Seed', value="", elem_id=self.elem_id("seed"), min_width=100) + else: + self.seed = gr.Number(label='Seed', value=-1, elem_id=self.elem_id("seed"), min_width=100, precision=0) + + random_seed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_seed"), label='Random seed') + reuse_seed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_seed"), label='Reuse seed') + + seed_checkbox = gr.Checkbox(label='Extra', elem_id=self.elem_id("subseed_show"), value=False) + + with gr.Group(visible=False, elem_id=self.elem_id("seed_extras")) as seed_extras: + with gr.Row(elem_id=self.elem_id("subseed_row")): + subseed = gr.Number(label='Variation seed', value=-1, elem_id=self.elem_id("subseed"), precision=0) + random_subseed = ToolButton(ui.random_symbol, elem_id=self.elem_id("random_subseed")) + reuse_subseed = ToolButton(ui.reuse_symbol, elem_id=self.elem_id("reuse_subseed")) + subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=self.elem_id("subseed_strength")) + + with gr.Row(elem_id=self.elem_id("seed_resize_from_row")): + seed_resize_from_w = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from width", value=0, elem_id=self.elem_id("seed_resize_from_w")) + seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=self.elem_id("seed_resize_from_h")) + + random_seed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("seed") + "')}", show_progress=False, inputs=[], outputs=[]) + random_subseed.click(fn=None, _js="function(){setRandomSeed('" + self.elem_id("subseed") + "')}", show_progress=False, inputs=[], outputs=[]) + + seed_checkbox.change(lambda x: gr.update(visible=x), show_progress=False, inputs=[seed_checkbox], outputs=[seed_extras]) + + self.infotext_fields = [ + (self.seed, "Seed"), + (seed_checkbox, lambda d: "Variation seed" in d or "Seed resize from-1" in d), + (subseed, "Variation seed"), + (subseed_strength, "Variation seed strength"), + (seed_resize_from_w, "Seed resize from-1"), + (seed_resize_from_h, "Seed resize from-2"), + ] + + self.on_after_component(lambda x: connect_reuse_seed(self.seed, reuse_seed, x.component, False), elem_id=f'generation_info_{self.tabname}') + self.on_after_component(lambda x: connect_reuse_seed(subseed, reuse_subseed, x.component, True), elem_id=f'generation_info_{self.tabname}') + + return self.seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h + + def setup(self, p, seed, seed_checkbox, subseed, subseed_strength, seed_resize_from_w, seed_resize_from_h): + p.seed = seed + + if seed_checkbox and subseed_strength > 0: + p.subseed = subseed + p.subseed_strength = subseed_strength + + if seed_checkbox and seed_resize_from_w > 0 and seed_resize_from_h > 0: + p.seed_resize_from_w = seed_resize_from_w + p.seed_resize_from_h = seed_resize_from_h + + + +def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, is_subseed): + """ Connects a 'reuse (sub)seed' button's click event so that it copies last used + (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength + was 0, i.e. no variation seed was used, it copies the normal seed value instead.""" + + def copy_seed(gen_info_string: str, index): + res = -1 + + try: + gen_info = json.loads(gen_info_string) + index -= gen_info.get('index_of_first_image', 0) + + if is_subseed and gen_info.get('subseed_strength', 0) > 0: + all_subseeds = gen_info.get('all_subseeds', [-1]) + res = all_subseeds[index if 0 <= index < len(all_subseeds) else 0] + else: + all_seeds = gen_info.get('all_seeds', [-1]) + res = all_seeds[index if 0 <= index < len(all_seeds) else 0] + + except json.decoder.JSONDecodeError: + if gen_info_string: + errors.report(f"Error parsing JSON generation info: {gen_info_string}") + + return [res, gr.update()] + + reuse_seed.click( + fn=copy_seed, + _js="(x, y) => [x, selected_gallery_index()]", + show_progress=False, + inputs=[generation_info, seed], + outputs=[seed, seed] + ) diff --git a/modules/progress.py b/modules/progress.py index f405f07fed290142977d19b174ef62abfd4adb15..69921de728197d87df4f3bd45ab2a2b6885b36a3 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -48,6 +48,7 @@ class ProgressRequest(BaseModel): id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image") + live_preview: bool = Field(default=True, title="Include live preview", description="boolean flag indicating whether to include the live preview image") class ProgressResponse(BaseModel): @@ -71,8 +72,14 @@ queued = req.id_task in pending_tasks completed = req.id_task in finished_tasks if not active: + textinfo = "Waiting..." + if queued: +current_task = None current_task = None +import gradio as gr + textinfo = "In queue: {}/{}".format(queue_index + 1, len(sorted_queued)) + return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo=textinfo) progress = 0 @@ -90,44 +96,43 @@ elapsed_since_start = time.time() - shared.state.time_start predicted_duration = elapsed_since_start / progress if progress > 0 else None eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None - id_live_preview = req.id_live_preview - shared.state.set_current_image() - if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview: + live_preview = None from pydantic import BaseModel, Field - - if image is not None: - buffered = io.BytesIO() +import base64 - if opts.live_previews_image_format == "png": - # using optimize for large images takes an enormous amount of time -from pydantic import BaseModel, Field +current_task = None current_task = None -from modules.shared import opts +pending_tasks = {} -from modules.shared import opts +pending_tasks = {} import base64 -from modules.shared import opts +pending_tasks = {} import io - -from modules.shared import opts +pending_tasks = {} import time -from modules.shared import opts +pending_tasks = {} -from modules.shared import opts +pending_tasks = {} import gradio as gr -from modules.shared import opts +pending_tasks = {} from pydantic import BaseModel, Field -from modules.shared import opts +pending_tasks = {} from modules.shared import opts -from modules.shared import opts +pending_tasks = {} import modules.shared as shared -from modules.shared import opts +pending_tasks = {} current_task = None -import modules.shared as shared +finished_tasks = [] -import modules.shared as shared + + else: +finished_tasks = [] import base64 -import modules.shared as shared + +finished_tasks = [] import io + base64_image = base64.b64encode(buffered.getvalue()).decode('ascii') + live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}" + id_live_preview = shared.state.id_live_preview return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 203ae1acc14b9477f494daa7c3b402c3a7ed8d59..334efeef317cc5b3893e5fd38772e3b5d9677332 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -19,14 +19,15 @@ prompt: (emphasized | scheduled | alternate | plain | WHITESPACE)* !emphasized: "(" prompt ")" | "(" prompt ":" prompt ")" | "[" prompt "]" + >>> g("a [b: 3]") -alternate: "[" prompt ("|" prompt)+ "]" +alternate: "[" prompt ("|" [prompt])+ "]" WHITESPACE: /\s+/ plain: /([^\\\[\]():|]|\\.)+/ %import common.SIGNED_NUMBER -> NUMBER """) -def get_learned_conditioning_prompt_schedules(prompts, steps): +def get_learned_conditioning_prompt_schedules(prompts, base_steps, hires_steps=None, use_old_scheduling=False): """ >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10)[0] >>> g("test") @@ -52,19 +54,48 @@ >>> g("((a][:b:c [d:3]") [[3, '((a][:b:c '], [10, '((a][:b:c d']] >>> g("[a|(b:1.1)]") [[1, 'a'], [2, '(b:1.1)'], [3, 'a'], [4, '(b:1.1)'], [5, 'a'], [6, '(b:1.1)'], [7, 'a'], [8, '(b:1.1)'], [9, 'a'], [10, '(b:1.1)']] + >>> g("[fe|]male") + [[1, 'female'], [2, 'male'], [3, 'female'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'female'], [8, 'male'], [9, 'female'], [10, 'male']] + >>> g("[fe|||]male") + [[1, 'female'], [2, 'male'], [3, 'male'], [4, 'male'], [5, 'female'], [6, 'male'], [7, 'male'], [8, 'male'], [9, 'female'], [10, 'male']] + >>> g = lambda p: get_learned_conditioning_prompt_schedules([p], 10, 10)[0] + >>> g("a [b:.5] c") + [[10, 'a b c']] + >>> g("a [b:1.5] c") + [[5, 'a c'], [10, 'a b c']] """ + if hires_steps is None or use_old_scheduling: + int_offset = 0 + flt_offset = 0 + steps = base_steps + else: + int_offset = base_steps + flt_offset = 1.0 + steps = hires_steps + def collect_steps(steps, tree): res = [steps] class CollectSteps(lark.Visitor): def scheduled(self, tree): + s = tree.children[-2] +import re from typing import List + if use_old_scheduling: + v = v*steps if v<1 else v + else: + if "." in s: + v = (v - flt_offset) * steps +# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" import lark + v = (v - int_offset) +import re if tree.children[-1] < 1: +import re tree.children[-1] *= steps +import re tree.children[-1] = min(steps, int(tree.children[-1])) - res.append(tree.children[-1]) def alternate(self, tree): res.extend(range(1, steps+1)) @@ -75,15 +106,16 @@ def at_step(step, tree): class AtStep(lark.Transformer): def scheduled(self, args): +import re import lark -# will be represented with prompt_schedule like this (assuming steps=100): yield before or () if step <= when else after def alternate(self, args): -# a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" + >>> g("[(a:2):3]") from __future__ import annotations + yield args[(step - 1) % len(args)] def start(self, args): def flatten(x): - if type(x) == str: + if isinstance(x, str): yield x else: for gen in x: @@ -131,7 +162,7 @@ self.height = height or getattr(copy_from, 'height', None) -def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): +def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps, hires_steps=None, use_old_scheduling=False): """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), and the sampling step at which this condition is to be replaced by the next one. @@ -151,7 +182,7 @@ ] """ res = [] - prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps) + prompt_schedules = get_learned_conditioning_prompt_schedules(prompts, steps, hires_steps, use_old_scheduling) cache = {} for prompt, prompt_schedule in zip(prompts, prompt_schedules): @@ -226,8 +257,8 @@ self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS self.batch: List[List[ComposableScheduledPromptConditioning]] = batch -!emphasized: "(" prompt ")" import re + class AtStep(lark.Transformer): """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. For each prompt, the list is obtained by splitting the prompt using the AND separator. @@ -236,9 +267,8 @@ """ res_indexes, prompt_flat_list, prompt_indexes = get_multicond_prompt_list(prompts) -from __future__ import annotations + >>> g("[(a:2):3]") # a prompt like this: "fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]" -# will be represented with prompt_schedule like this (assuming steps=100): res = [] for indexes in res_indexes: @@ -338,7 +368,7 @@ \\\\| \\| \(| \[| -%import common.SIGNED_NUMBER -> NUMBER + >>> g("[(a:2):3]") # will be represented with prompt_schedule like this (assuming steps=100): \)| ]| diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 0700b85379cf15fb8a1fd28493cc54d71f15dbc0..02841c3028925afdb6926d64dfdc7ec847ae857d 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -55,6 +55,7 @@ model=info.model(), half=not cmd_opts.no_half and not cmd_opts.upcast_sampling, tile=opts.ESRGAN_tile, tile_pad=opts.ESRGAN_tile_overlap, + device=self.device, ) upsampled = upsampler.enhance(np.array(img), outscale=info.scale)[0] diff --git a/modules/rng.py b/modules/rng.py new file mode 100644 index 0000000000000000000000000000000000000000..9e8ba2ee9d79d78bcff450e70cf82fe2a5c4ad91 --- /dev/null +++ b/modules/rng.py @@ -0,0 +1,170 @@ +import torch + +from modules import devices, rng_philox, shared + + +def randn(seed, shape, generator=None): + """Generate a tensor with random numbers from a normal distribution using seed. + + Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed.""" + + manual_seed(seed) + + if shared.opts.randn_source == "NV": + return torch.asarray((generator or nv_rng).randn(shape), device=devices.device) + + if shared.opts.randn_source == "CPU" or devices.device.type == 'mps': + return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device) + + return torch.randn(shape, device=devices.device, generator=generator) + + +def randn_local(seed, shape): + """Generate a tensor with random numbers from a normal distribution using seed. + + Does not change the global random number generator. You can only generate the seed's first tensor using this function.""" + + if shared.opts.randn_source == "NV": + rng = rng_philox.Generator(seed) + return torch.asarray(rng.randn(shape), device=devices.device) + + local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device + local_generator = torch.Generator(local_device).manual_seed(int(seed)) + return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device) + + +def randn_like(x): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + + Use either randn() or manual_seed() to initialize the generator.""" + + if shared.opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype) + + if shared.opts.randn_source == "CPU" or x.device.type == 'mps': + return torch.randn_like(x, device=devices.cpu).to(x.device) + + return torch.randn_like(x) + + +def randn_without_seed(shape, generator=None): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + + Use either randn() or manual_seed() to initialize the generator.""" + + if shared.opts.randn_source == "NV": + return torch.asarray((generator or nv_rng).randn(shape), device=devices.device) + + if shared.opts.randn_source == "CPU" or devices.device.type == 'mps': + return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device) + + return torch.randn(shape, device=devices.device, generator=generator) + + +def manual_seed(seed): + """Set up a global random number generator using the specified seed.""" + + if shared.opts.randn_source == "NV": + global nv_rng + nv_rng = rng_philox.Generator(seed) + return + + torch.manual_seed(seed) + + +def create_generator(seed): + if shared.opts.randn_source == "NV": + return rng_philox.Generator(seed) + + device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device + generator = torch.Generator(device).manual_seed(int(seed)) + return generator + + +# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 +def slerp(val, low, high): + low_norm = low/torch.norm(low, dim=1, keepdim=True) + high_norm = high/torch.norm(high, dim=1, keepdim=True) + dot = (low_norm*high_norm).sum(1) + + if dot.mean() > 0.9995: + return low * val + high * (1 - val) + + omega = torch.acos(dot) + so = torch.sin(omega) + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + return res + + +class ImageRNG: + def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0): + self.shape = tuple(map(int, shape)) + self.seeds = seeds + self.subseeds = subseeds + self.subseed_strength = subseed_strength + self.seed_resize_from_h = seed_resize_from_h + self.seed_resize_from_w = seed_resize_from_w + + self.generators = [create_generator(seed) for seed in seeds] + + self.is_first = True + + def first(self): + noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8) + + xs = [] + + for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)): + subnoise = None + if self.subseeds is not None and self.subseed_strength != 0: + subseed = 0 if i >= len(self.subseeds) else self.subseeds[i] + subnoise = randn(subseed, noise_shape) + + if noise_shape != self.shape: + noise = randn(seed, noise_shape) + else: + noise = randn(seed, self.shape, generator=generator) + + if subnoise is not None: + noise = slerp(self.subseed_strength, noise, subnoise) + + if noise_shape != self.shape: + x = randn(seed, self.shape, generator=generator) + dx = (self.shape[2] - noise_shape[2]) // 2 + dy = (self.shape[1] - noise_shape[1]) // 2 + w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx + h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy + tx = 0 if dx < 0 else dx + ty = 0 if dy < 0 else dy + dx = max(-dx, 0) + dy = max(-dy, 0) + + x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w] + noise = x + + xs.append(noise) + + eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0 + if eta_noise_seed_delta: + self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds] + + return torch.stack(xs).to(shared.device) + + def next(self): + if self.is_first: + self.is_first = False + return self.first() + + xs = [] + for generator in self.generators: + x = randn_without_seed(self.shape, generator=generator) + xs.append(x) + + return torch.stack(xs).to(shared.device) + + +devices.randn = randn +devices.randn_local = randn_local +devices.randn_like = randn_like +devices.randn_without_seed = randn_without_seed +devices.manual_seed = manual_seed diff --git a/modules/rng_philox.py b/modules/rng_philox.py new file mode 100644 index 0000000000000000000000000000000000000000..5532cf9dd676012ae37b7fda3433b9b1d122f80a --- /dev/null +++ b/modules/rng_philox.py @@ -0,0 +1,102 @@ +"""RNG imitiating torch cuda randn on CPU. You are welcome. + +Usage: + +``` +g = Generator(seed=0) +print(g.randn(shape=(3, 4))) +``` + +Expected output: +``` +[[-0.92466259 -0.42534415 -2.6438457 0.14518388] + [-0.12086647 -0.57972564 -0.62285122 -0.32838709] + [-1.07454231 -0.36314407 -1.67105067 2.26550497]] +``` +""" + +import numpy as np + +philox_m = [0xD2511F53, 0xCD9E8D57] +philox_w = [0x9E3779B9, 0xBB67AE85] + +two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32) +two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32) + + +def uint32(x): + """Converts (N,) np.uint64 array into (2, N) np.unit32 array.""" + return x.view(np.uint32).reshape(-1, 2).transpose(1, 0) + + +def philox4_round(counter, key): + """A single round of the Philox 4x32 random number generator.""" + + v1 = uint32(counter[0].astype(np.uint64) * philox_m[0]) + v2 = uint32(counter[2].astype(np.uint64) * philox_m[1]) + + counter[0] = v2[1] ^ counter[1] ^ key[0] + counter[1] = v2[0] + counter[2] = v1[1] ^ counter[3] ^ key[1] + counter[3] = v1[0] + + +def philox4_32(counter, key, rounds=10): + """Generates 32-bit random numbers using the Philox 4x32 random number generator. + + Parameters: + counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation). + key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed). + rounds (int): The number of rounds to perform. + + Returns: + numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers. + """ + + for _ in range(rounds - 1): + philox4_round(counter, key) + + key[0] = key[0] + philox_w[0] + key[1] = key[1] + philox_w[1] + + philox4_round(counter, key) + return counter + + +def box_muller(x, y): + """Returns just the first out of two numbers generated by Box–Muller transform algorithm.""" + u = x * two_pow32_inv + two_pow32_inv / 2 + v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2 + + s = np.sqrt(-2.0 * np.log(u)) + + r1 = s * np.sin(v) + return r1.astype(np.float32) + + +class Generator: + """RNG that produces same outputs as torch.randn(..., device='cuda') on CPU""" + + def __init__(self, seed): + self.seed = seed + self.offset = 0 + + def randn(self, shape): + """Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform.""" + + n = 1 + for x in shape: + n *= x + + counter = np.zeros((4, n), dtype=np.uint32) + counter[0] = self.offset + counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3] + self.offset += 1 + + key = np.empty(n, dtype=np.uint64) + key.fill(self.seed) + key = uint32(key) + + g = philox4_32(counter, key) + + return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3] diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 77ee55ee3f41598e582ec783d4604d90d3d323cd..fab23551a68843f831e342fda2c970cc9ce0b9c9 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -29,6 +29,16 @@ """dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'""" import os +class UiTrainTabParams: + def __init__(self, noise, x): + self.noise = noise + """Random noise generated by the seed""" + + self.x = x + """Latent image representation of the image""" + + +import os import inspect def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, text_cond, text_uncond): self.x = x @@ -101,6 +111,7 @@ callbacks_ui_train_tabs=[], callbacks_ui_settings=[], callbacks_before_image_saved=[], callbacks_image_saved=[], + callbacks_extra_noise=[], callbacks_cfg_denoiser=[], callbacks_cfg_denoised=[], callbacks_cfg_after_cfg=[], @@ -188,6 +199,14 @@ try: c.callback(params) except Exception: report_exception(c, 'image_saved_callback') + + +def extra_noise_callback(params: ExtraNoiseParams): + for c in callback_map['callbacks_extra_noise']: + try: + c.callback(params) + except Exception: + report_exception(c, 'callbacks_extra_noise') def cfg_denoiser_callback(params: CFGDenoiserParams): @@ -366,6 +385,14 @@ The callback is called with one argument: - params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing. """ add_callback(callback_map['callbacks_image_saved'], callback) + + +def on_extra_noise(callback): + """register a function to be called before adding extra noise in img2img or hires fix; + The callback is called with one argument: + - params: ExtraNoiseParams - contains noise determined by seed and latent representation of image + """ + add_callback(callback_map['callbacks_extra_noise'], callback) def on_cfg_denoiser(callback): diff --git a/modules/scripts.py b/modules/scripts.py index 5b4edcac33ca3b739ab80583b20f40e25db424ea..e8518ad0fbab00a990e0fd1053d11bcd860ad27d 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -3,6 +3,7 @@ import re import sys import inspect from collections import namedtuple +from dataclasses import dataclass import gradio as gr @@ -21,6 +22,11 @@ def __init__(self, images): self.images = images +@dataclass +class OnComponent: + component: gr.blocks.Block + + class Script: name = None """script's internal name derived from title""" @@ -35,9 +41,13 @@ alwayson = False is_txt2img = False is_img2img = False + tabname = None group = None -import re + """A gr.Group component that has all script's UI inside it.""" + + create_group = True + def title(self): from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer infotext_fields = None @@ -52,6 +62,15 @@ """ api_info = None """Generated value of type modules.api.models.ScriptInfo with information about the script for API""" + + on_before_component_elem_id = None + """list of callbacks to be called before a component with an elem_id is created""" + + on_after_component_elem_id = None + """list of callbacks to be called after a component with an elem_id is created""" + + setup_for_ui_only = False + """If true, the script setup will only be run in Gradio UI, not in API""" def title(self): """this function should return the title of the script. This is what will be displayed in the dropdown menu.""" @@ -91,9 +110,16 @@ """ pass + def setup(self, p, *args): + """For AlwaysVisible scripts, this function is called when the processing object is set up, before any processing starts. + args contains all values returned by components from ui(). + """ + pass + + def before_process(self, p, *args): """ - This function is called very early before processing begins for AlwaysVisible scripts. + This function is called very early during processing begins for AlwaysVisible scripts. You can modify the processing object (p) here, inject hooks, etc. args contains all values returned by components from ui() """ @@ -213,6 +239,29 @@ """ pass + def on_before_component(self, callback, *, elem_id): + """ + Calls callback before a component is created. The callback function is called with a single argument of type OnComponent. + + May be called in show() or ui() - but it may be too late in latter as some components may already be created. + + This function is an alternative to before_component in that it also cllows to run before a component is created, but + it doesn't require to be called for every created component - just for the one you need. + """ + if self.on_before_component_elem_id is None: + self.on_before_component_elem_id = [] + + self.on_before_component_elem_id.append((elem_id, callback)) + + def on_after_component(self, callback, *, elem_id): + """ + Calls callback after a component is created. The callback function is called with a single argument of type OnComponent. + """ + if self.on_after_component_elem_id is None: + self.on_after_component_elem_id = [] + + self.on_after_component_elem_id.append((elem_id, callback)) + def describe(self): """unused""" return "" @@ -221,7 +270,7 @@ def elem_id(self, item_id): """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id""" need_tabname = self.show(True) == self.show(False) -import os +import inspect self.image = image tabname = f"{tabkind}_" if need_tabname else "" title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower())) @@ -234,6 +283,19 @@ This function is called before hires fix start. """ pass + +class ScriptBuiltinUI(Script): + setup_for_ui_only = True + + def elem_id(self, item_id): + """helper function to generate id for a HTML element, constructs final id out of tab and user-supplied item_id""" + + need_tabname = self.show(True) == self.show(False) + tabname = ('img2img' if self.is_img2img else 'txt2img') + "_" if need_tabname else "" + + return f'{tabname}{item_id}' + + current_basedir = paths.script_path @@ -252,8 +314,8 @@ postprocessing_scripts_data = [] ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) +import inspect import os -import re from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer scripts_list = [] @@ -262,11 +324,13 @@ if os.path.exists(basedir): for filename in sorted(os.listdir(basedir)): scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) +import inspect import os -import sys +AlwaysVisible = object() import inspect + """name of UI section that the script's controls will be placed into""" + """this function should create gradio UI elements. See https://gradio.app/docs/#components import os - various "Send to <X>" buttons when clicked scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] @@ -295,7 +359,7 @@ scripts_data.clear() postprocessing_scripts_data.clear() script_callbacks.clear_callbacks() - scripts_list = list_scripts("scripts", ".py") + scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False) syspath = sys.path @@ -356,10 +420,17 @@ self.scripts = [] self.selectable_scripts = [] self.alwayson_scripts = [] self.titles = [] + self.title_map = {} self.infotext_fields = [] self.paste_field_names = [] self.inputs = [None] + self.on_before_component_elem_id = {} + """dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks""" + + self.on_after_component_elem_id = {} + """dict of callbacks to be called after an element is created; key=elem_id, value=list of callbacks""" + def initialize_scripts(self, is_img2img): from modules import scripts_auto_postprocessing @@ -374,6 +445,7 @@ script = script_data.script_class() script.filename = script_data.path script.is_txt2img = not is_img2img script.is_img2img = is_img2img + script.tabname = "img2img" if is_img2img else "txt2img" visibility = script.show(script.is_img2img) @@ -386,6 +458,28 @@ elif visibility: self.scripts.append(script) self.selectable_scripts.append(script) + self.apply_on_before_component_callbacks() + + def apply_on_before_component_callbacks(self): + for script in self.scripts: + on_before = script.on_before_component_elem_id or [] + on_after = script.on_after_component_elem_id or [] + + for elem_id, callback in on_before: + if elem_id not in self.on_before_component_elem_id: + self.on_before_component_elem_id[elem_id] = [] + + self.on_before_component_elem_id[elem_id].append((callback, script)) + + for elem_id, callback in on_after: + if elem_id not in self.on_after_component_elem_id: + self.on_after_component_elem_id[elem_id] = [] + + self.on_after_component_elem_id[elem_id].append((callback, script)) + + on_before.clear() + on_after.clear() + def create_script_ui(self, script): import modules.api.models as api_models @@ -436,20 +530,27 @@ for script in scriptlist: if script.alwayson and script.section != section: continue - is_txt2img = False + Values of those returned components will be passed to run() and process() functions. import sys -import re + with gr.Group(visible=script.alwayson) as group: + Values of those returned components will be passed to run() and process() functions. from collections import namedtuple -import inspect + + Values of those returned components will be passed to run() and process() functions. import re +from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer from collections import namedtuple +import re from collections import namedtuple +import inspect def prepare_ui(self): self.inputs = [None] def setup_ui(self): + all_titles = [wrap_call(script.title, script.filename, "title") or script.filename for script in self.scripts] + self.title_map = {title.lower(): script for title, script in zip(all_titles, self.scripts)} self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] self.setup_ui_for_section(None) @@ -495,6 +596,8 @@ return gr.update(visible=False) self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None')))) self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts]) + + self.apply_on_before_component_callbacks() return self.inputs @@ -589,6 +692,12 @@ except Exception: errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) def before_component(self, component, **kwargs): + for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []): + try: + callback(OnComponent(component=component)) + except Exception: + errors.report(f"Error running on_before_component: {script.filename}", exc_info=True) + for script in self.scripts: try: script.before_component(component, **kwargs) @@ -596,11 +705,20 @@ except Exception: errors.report(f"Error running before_component: {script.filename}", exc_info=True) def after_component(self, component, **kwargs): + for callback, script in self.on_after_component_elem_id.get(component.elem_id, []): + try: + callback(OnComponent(component=component)) + except Exception: + errors.report(f"Error running on_after_component: {script.filename}", exc_info=True) + for script in self.scripts: try: script.after_component(component, **kwargs) except Exception: errors.report(f"Error running after_component: {script.filename}", exc_info=True) + + def script(self, title): + return self.title_map.get(title.lower()) def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): @@ -620,7 +738,6 @@ self.scripts[si].filename = filename self.scripts[si].args_from = args_from self.scripts[si].args_to = args_to - def before_hr(self, p): for script in self.alwayson_scripts: try: @@ -629,89 +746,50 @@ script.before_hr(p, *script_args) except Exception: errors.report(f"Error running before_hr: {script.filename}", exc_info=True) - -scripts_txt2img: ScriptRunner = None - various "Send to <X>" buttons when clicked import inspect -scripts_postproc: scripts_postprocessing.ScriptPostprocessingRunner = None -import sys This function is called if the script has been selected in the script dropdown. - + for script in self.alwayson_scripts: - -import sys +import inspect It must do all processing and return the Processed object with results, same as -import sys +import os one returned by processing.process_images. - scripts_txt2img.reload_sources(cache) - scripts_img2img.reload_sources(cache) - import sys - import os - - import sys - This function is called very early before processing begins for AlwaysVisible scripts. - """ - this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others -import sys import re - - api_info = None import inspect - - api_info = None from collections import namedtuple - comp.elem_classes.append('multiselect') - - - -def IOComponent_init(self, *args, **kwargs): - api_info = None from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer import sys - **kwargs will have those items: - - """Generated value of type modules.api.models.ScriptInfo with information about the script for API""" +import inspect - - res = original_IOComponent_init(self, *args, **kwargs) + errors.report(f"Error running setup: {script.filename}", exc_info=True) - add_classes_to_gradio_component(self) import sys - - subseeds - list of subseeds for current batch - +from collections import namedtuple import sys - new extra network keywords to the prompt with this callback. import sys -import gradio as gr +from collections import namedtuple import inspect - - return res - - import sys -import gradio as gr +from collections import namedtuple from collections import namedtuple import sys -import gradio as gr +from collections import namedtuple import sys -import gradio as gr +from collections import namedtuple import gradio as gr import sys -import gradio as gr +from collections import namedtuple from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer - + scripts_txt2img.reload_sources(cache) import sys - - seeds - list of seeds for current batch - return res import sys - def process_batch(self, p, *args, **kwargs): -gr.blocks.BlockContext.__init__ = BlockContext_init + def before_process(self, p, *args): diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 9fc89dc6a75985e68bb39f564dd55ce928ec6b29..8863107ae6f367f7f2569e847a8eee2dc53a34d3 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -3,8 +3,32 @@ import open_clip import torch import transformers.utils.hub +from modules import shared + +class ReplaceHelper: + def __init__(self): + self.replaced = [] + + def replace(self, obj, field, func): + original = getattr(obj, field, None) + if original is None: + return None + + self.replaced.append((obj, field, original)) + setattr(obj, field, func) + +import open_clip class DisableInitialization: + + def restore(self): + for obj, field, original in self.replaced: + setattr(obj, field, original) + + self.replaced.clear() + + +class DisableInitialization(ReplaceHelper): """ When an object of this class enters a `with` block, it starts: - preventing torch's layer initialization functions from working @@ -21,7 +45,7 @@ ``` """ def __init__(self, disable_clip=True): - self.replaced = [] + super().__init__() self.disable_clip = disable_clip def replace(self, obj, field, func): @@ -86,11 +110,131 @@ self.transformers_configuration_utils_cached_file = self.replace(transformers.configuration_utils, 'cached_file', transformers_configuration_utils_cached_file) self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) def __exit__(self, exc_type, exc_val, exc_tb): + """ class DisableInitialization: + + + """ """ + """ + Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device, + which results in those parameters having no values and taking no memory. model.to() will be broken and + will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict. + + Usage: + ``` + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) + ``` + """ + + def __enter__(self): + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + When an object of this class enters a `with` block, it starts: class DisableInitialization: When an object of this class enters a `with` block, it starts: + """ + return x + + linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs))) + conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs))) + mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs))) + self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None) class DisableInitialization: +class DisableInitialization: + self.restore() + + - preventing torch's layer initialization functions from working +import torch + """ + Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device. + - preventing torch's layer initialization functions from working + Meant to be used together with InitializeOnMeta above. + + Usage: + ``` + with sd_disable_initialization.LoadStateDictOnMeta(state_dict): + model.load_state_dict(state_dict, strict=False) + ``` + """ + + def __init__(self, state_dict, device, weight_dtype_conversion=None): + super().__init__() + self.state_dict = state_dict + self.device = device + self.weight_dtype_conversion = weight_dtype_conversion or {} + self.default_dtype = self.weight_dtype_conversion.get('') + + def get_weight_dtype(self, key): + key_first_term, _ = key.split('.', 1) + return self.weight_dtype_conversion.get(key_first_term, self.default_dtype) + + def __enter__(self): + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + sd = self.state_dict + device = self.device + + def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): + used_param_keys = [] + + for name, param in module._parameters.items(): + if param is None: + continue + + key = prefix + name + sd_param = sd.pop(key, None) + if sd_param is not None: + state_dict[key] = sd_param.to(dtype=self.get_weight_dtype(key)) + used_param_keys.append(key) + + if param.is_meta: + dtype = sd_param.dtype if sd_param is not None else param.dtype + module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) + + for name in module._buffers: + key = prefix + name + + sd_param = sd.pop(key, None) + if sd_param is not None: + state_dict[key] = sd_param + used_param_keys.append(key) + + original(module, state_dict, prefix, *args, **kwargs) + + for key in used_param_keys: + state_dict.pop(key, None) + + def load_state_dict(original, module, state_dict, strict=True): + """torch makes a lot of copies of the dictionary with weights, so just deleting entries from state_dict does not help + because the same values are stored in multiple copies of the dict. The trick used here is to give torch a dict with + all weights on meta device, i.e. deleted, and then it doesn't matter how many copies torch makes. + + In _load_from_state_dict, the correct weight will be obtained from a single dict with the right weights (sd). + + The dangerous thing about this is if _load_from_state_dict is not called, (if some exotic module overloads + the function and does not call the original) the state dict will just fail to load because weights + would be on the meta device. + """ + + if state_dict == sd: + state_dict = {k: v.to(device="meta", dtype=v.dtype) for k, v in state_dict.items()} + + original(module, state_dict, strict=strict) + + module_load_state_dict = self.replace(torch.nn.Module, 'load_state_dict', lambda *args, **kwargs: load_state_dict(module_load_state_dict, *args, **kwargs)) + module_load_from_state_dict = self.replace(torch.nn.Module, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(module_load_from_state_dict, *args, **kwargs)) + linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) + conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) + mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) + layer_norm_load_from_state_dict = self.replace(torch.nn.LayerNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(layer_norm_load_from_state_dict, *args, **kwargs)) + group_norm_load_from_state_dict = self.replace(torch.nn.GroupNorm, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(group_norm_load_from_state_dict, *args, **kwargs)) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.restore() diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c8fdd4f16ae1a00337f645bce12406da184a2b09..592f00551f1d0dd7f9a7754903f047b83856c404 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -2,7 +2,6 @@ import torch from torch.nn.functional import silu from types import MethodType -import modules.textual_inversion.textual_inversion from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts @@ -31,9 +30,11 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention # silence new console spam from SD2 from torch.nn.functional import silu -from modules.shared import cmd_opts + # a script can access the model very early, and optimizations would not be filled by then from torch.nn.functional import silu -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr + current_optimizer = None +ldm.util.print = shared.ldm_print +ldm.models.diffusion.ddpm.print = shared.ldm_print optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None @@ -167,12 +168,15 @@ clip = None optimization_method = None import ldm.modules.diffusionmodules.openaimodel -from torch.nn.functional import silu +from types import MethodType +from torch.nn.functional import silu +from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet - def __init__(self): + self.extra_generation_params = {} self.comments = [] + self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) def apply_optimizations(self, option=None): @@ -200,8 +204,9 @@ model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) text_cond_models.append(conditioner.embedders[i]) if typename == 'FrozenOpenCLIPEmbedder2': -import torch +from torch.nn.functional import silu +from modules.shared import cmd_opts conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) text_cond_models.append(conditioner.embedders[i]) @@ -247,7 +253,21 @@ ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward def undo_hijack(self, m): import torch -from modules.shared import cmd_opts + new_optimizers = script_callbacks.list_optimizers_callback() + if conditioner: + for i in range(len(conditioner.embedders)): + embedder = conditioner.embedders[i] + if isinstance(embedder, (sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords, sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords)): +ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.CrossAttention + conditioner.embedders[i] = embedder.wrapped + if isinstance(embedder, sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords): + embedder.wrapped.transformer.text_model.embeddings.token_embedding = embedder.wrapped.transformer.text_model.embeddings.token_embedding.wrapped + conditioner.embedders[i] = embedder.wrapped + + if hasattr(m, 'cond_stage_model'): + delattr(m, 'cond_stage_model') + + elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords: m.cond_stage_model = m.cond_stage_model.wrapped elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords: @@ -297,10 +317,11 @@ class EmbeddingsWithFixes(torch.nn.Module): from torch.nn.functional import silu -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr + matching_optimizer = None super().__init__() self.wrapped = wrapped self.embeddings = embeddings + self.textual_inversion_key = textual_inversion_key def forward(self, input_ids): batch_fixes = self.embeddings.fixes @@ -315,8 +336,9 @@ vecs = [] for fixes, tensor in zip(batch_fixes, inputs_embeds): for offset, embedding in fixes: from torch.nn.functional import silu -from torch.nn.functional import silu + elif matching_optimizer is None: from torch.nn.functional import silu + matching_optimizer = optimizers[0] emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0]) tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]]) diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 16a5500e31498960915179cadacd163d78c88572..8f29057a9cfcfafcf18e5c5c3eb095a8a1649198 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -161,7 +161,7 @@ chunk.multipliers.append(weight) position += 1 continue - emb_len = int(embedding.vec.shape[0]) + emb_len = int(embedding.vectors) if len(chunk.tokens) + emb_len > self.chunk_length: next_chunk() @@ -245,6 +245,8 @@ name = name.replace(":", "").replace(",", "") hashes.append(f"{name}: {shorthash}") if hashes: + if self.hijack.extra_generation_params.get("TI hashes"): + hashes.append(self.hijack.extra_generation_params.get("TI hashes")) self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) if getattr(self.wrapped, 'return_pooled', False): diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py deleted file mode 100644 index c1977b194190858fbb9726eb02256cdeafc654a0..0000000000000000000000000000000000000000 --- a/modules/sd_hijack_inpainting.py +++ /dev/null @@ -1,97 +0,0 @@ -import torch - -import ldm.models.diffusion.ddpm -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms - -from ldm.models.diffusion.ddim import noise_like -from ldm.models.diffusion.sampling_util import norm_thresholding - - [email protected]_grad() -def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None): - b, *_, device = *x.shape, x.device - - def get_model_output(x, t): - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - e_t = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - - if isinstance(c, dict): - assert isinstance(unconditional_conditioning, dict) - c_in = {} - for k in c: - if isinstance(c[k], list): - c_in[k] = [ - torch.cat([unconditional_conditioning[k][i], c[k][i]]) - for i in range(len(c[k])) - ] - else: - c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) - else: - c_in = torch.cat([unconditional_conditioning, c]) - - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - return e_t - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - - def get_x_prev_and_pred_x0(e_t, index): - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - if dynamic_threshold is not None: - pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - e_t = get_model_output(x, t) - if len(old_eps) == 0: - # Pseudo Improved Euler (2nd order) - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) - e_t_next = get_model_output(x_prev, t_next) - e_t_prime = (e_t + e_t_next) / 2 - elif len(old_eps) == 1: - # 2nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (3 * e_t - old_eps[-1]) / 2 - elif len(old_eps) == 2: - # 3nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 - elif len(old_eps) >= 3: - # 4nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 - - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) - - return x_prev, pred_x0, e_t - - -def do_inpainting_hijack(): - # p_sample_plms is needed because PLMS can't work with dicts as conditionings - - ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 0e810eec8a9a01f28ca96007595eb9d00e08eff9..7f9e328d05a8a20ef81841f8850af8dffa472136 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,6 +1,7 @@ from __future__ import annotations import math import psutil +import platform import torch from torch import einsum @@ -94,7 +95,10 @@ class SdOptimizationSubQuad(SdOptimization): name = "sub-quadratic" cmd_opt = "opt_sub_quad_attention" - priority = 10 + + @property + def priority(self): + return 1000 if shared.device.type == 'mps' else 10 def apply(self): ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward @@ -120,7 +124,7 @@ cmd_opt = "opt_split_attention_invokeai" @property def priority(self): - return 1000 if not torch.cuda.is_available() else 10 + return 1000 if shared.device.type != 'mps' and not torch.cuda.is_available() else 10 def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI @@ -427,8 +431,11 @@ _, k_tokens, _ = k.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens if chunk_threshold is None: - def is_available(self): + ldm.modules.attention.CrossAttention.forward = xformers_attention_forward from __future__ import annotations + chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token) + else: + chunk_threshold_bytes = int(get_available_vram() * 0.7) elif chunk_threshold == 0: chunk_threshold_bytes = None else: diff --git a/modules/sd_models.py b/modules/sd_models.py index fb31a7937b910ebde0e0743df96f82a8bd40a09a..547e93c442e12b7318eaa75af328c62adb85962e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,8 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl -import collections + self.shorthash = self.sha256[0:10] if self.sha256 else None from modules.timer import Timer import tomesd @@ -34,6 +33,8 @@ def __init__(self, filename): self.filename = filename abspath = os.path.abspath(filename) + self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" + if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): name = abspath.replace(shared.cmd_opts.ckpt_dir, '') elif abspath.startswith(model_path): @@ -44,36 +45,50 @@ if name.startswith("\\") or name.startswith("/"): name = name[1:] - self.name = name -import sys + self.shorthash = self.sha256[0:10] if self.sha256 else None import re -import sys + self.shorthash = self.sha256[0:10] if self.sha256 else None import safetensors.torch import gc +import sys import gc +import sys import collections + import gc +import threading + self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' import os.path + try: + self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata) + except Exception as e: + errors.display(e, f"reading metadata for {filename}") + self.name = name + self.name_for_extra = os.path.splitext(os.path.basename(filename))[0] + self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] import gc -import sys import gc +import collections import gc +import os.path import gc +import sys + self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' import threading import gc +import sys import gc +import sys import torch import gc +import sys import re - self.metadata = read_metadata_from_safetensors(filename) - except Exception as e: - errors.display(e, f"reading checkpoint metadata: {filename}") def register(self): checkpoints_list[self.title] = self @@ -85,14 +99,20 @@ self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}") if self.sha256 is None: return + shorthash = self.sha256[0:10] + if self.shorthash == self.sha256[0:10]: + return self.shorthash + self.shorthash = shorthash if self.shorthash not in self.ids: - self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] + self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]'] - + self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) import sys +import gc + self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) import gc self.register() @@ -113,16 +134,11 @@ enable_midas_autodownload() -def checkpoint_tiles(): - def convert(name): - return int(name) if name.isdigit() else name.lower() +def checkpoint_tiles(use_short=False): + self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) - def alphanumeric_key(key): - return [convert(c) for c in re.split('([0-9]+)', key)] - return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) - def list_models(): checkpoints_list.clear() checkpoint_aliases.clear() @@ -143,18 +159,29 @@ shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) -import safetensors.torch + self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) import torch checkpoint_info = CheckpointInfo(filename) checkpoint_info.register() +re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") + + def get_closet_checkpoint_match(search_string): + if not search_string: + return None + checkpoint_info = checkpoint_aliases.get(search_string, None) if checkpoint_info is not None: return checkpoint_info found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title)) + if found: + return found[0] + + search_string_without_checksum = re.sub(re_strip_checksum, '', search_string) + found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title)) if found: return found[0] @@ -293,11 +320,27 @@ return res +class SkipWritingToConfig: + """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight.""" + + skip = False + previous = None + + def __enter__(self): + self.previous = SkipWritingToConfig.skip + SkipWritingToConfig.skip = True + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + SkipWritingToConfig.skip = self.previous + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title + if not SkipWritingToConfig.skip: + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) @@ -311,21 +354,26 @@ sd_models_xl.extend_sdxl(model) model.load_state_dict(state_dict, strict=False) model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) -import re -model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) import safetensors.torch if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model + checkpoints_loaded[checkpoint_info] = state_dict + import os.path -from urllib import request +import re if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) timer.record("apply channels_last") -checkpoints_list = {} + if shared.cmd_opts.no_half: + model.float() +import gc +import re + timer.record("apply float()") + else: vae = model.first_stage_model depth_model = getattr(model, 'depth_model', None) @@ -341,9 +389,9 @@ model.first_stage_model = vae if depth_model: model.depth_model = depth_model + devices.dtype_unet = torch.float16 timer.record("apply half()") - devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 model.first_stage_model.to(devices.dtype_vae) @@ -363,9 +411,8 @@ model.logvar = model.logvar.to(devices.device) # fix for training sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() -import os.path import gc -import threading +except Exception: sd_vae.load_vae(model, vae_file, vae_source) timer.record("load VAE") @@ -442,6 +489,7 @@ class SdModelData: def __init__(self): self.sd_model = None + self.loaded_sd_models = [] self.was_loaded_at_least_once = False self.lock = threading.Lock() @@ -456,6 +504,7 @@ return self.sd_model try: load_model() + except Exception as e: errors.display(e, "loading stable diffusion model", full_traceback=True) print("", file=sys.stderr) @@ -464,16 +513,33 @@ self.sd_model = None return self.sd_model + def set_sd_model(self, v, already_loaded=False): elif abspath.startswith(model_path): +import torch + if already_loaded: + sd_vae.base_vae = getattr(v, "base_vae", None) + if ext.lower() == ".safetensors": + sd_vae.checkpoint_info = v.sd_checkpoint_info + + try: + self.loaded_sd_models.remove(v) + except ValueError: + pass + + if v is not None: + try: import sys -from modules.timer import Timer model_data = SdModelData() def get_empty_cond(sd_model): + + p = processing.StableDiffusionProcessingTxt2Img() + extra_networks.activate(p, {}) + if hasattr(sd_model, 'conditioner'): d = sd_model.get_learned_conditioning([""]) return d['crossattn'] @@ -481,27 +547,55 @@ else: return sd_model.cond_stage_model([""]) + try: + if m.lowvram: + lowvram.send_everything_to_cpu() + else: + m.to(devices.cpu) + + devices.torch_gc() + + +def model_target_device(m): + if lowvram.is_needed(m): + self.metadata = read_metadata_from_safetensors(filename) import sys -import os.path + else: + return devices.device + + +def send_model_to_device(m): + lowvram.apply(m) + + if not m.lowvram: + m.to(shared.device) + + +def send_model_to_trash(m): + m.to(device="meta") import gc + model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"]) + + name = abspath.replace(model_path, '') +import gc import threading +import collections checkpoint_info = checkpoint_info or select_checkpoint() import sys + else: + +import sys self.filename = filename -import sys + except Exception as e: import os.path -import re model_data.sd_model = None else: - else: import collections - do_inpainting_hijack() - - else: + except Exception as e: import sys if already_loaded_state_dict is not None: @@ -524,37 +618,48 @@ sd_model = None try: with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): -import sys + except Exception as e: import gc + except Exception as e: import threading - name = os.path.basename(filename) - name = os.path.basename(filename) + except Exception as e: + except Exception as e: import torch if sd_model is None: print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) + + with sd_disable_initialization.InitializeOnMeta(): import sys +import gc import threading sd_model.used_config = checkpoint_config timer.record("create model") -import sys + if shared.cmd_opts.no_half: + weight_dtype_conversion = None + else: + weight_dtype_conversion = { + 'first_stage_model': None, + '': torch.float16, import threading +import collections import sys -import sys import threading +import collections import gc import sys + if os.path.exists(cmd_ckpt): import threading +import collections import threading - else: + -import sys import threading - +import collections timer.record("move model to device") @@ -562,8 +668,9 @@ timer.record("hijack") sd_model.eval() - name = name[1:] +import threading import collections +import torch model_data.was_loaded_at_least_once = True sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model @@ -584,58 +691,122 @@ return sd_model +def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): + """ + Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models. + If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary). + If not, returns the model that can be used to load weights from checkpoint_info's file. + If no such model exists, returns None. + def register(self): import sys + """ + + already_loaded = None + for i in reversed(range(len(model_data.loaded_sd_models))): + loaded_model = model_data.loaded_sd_models[i] + def register(self): import torch +import threading import os.path +import re + continue + +import threading import sys -import torch +import threading import sys +import collections +import threading import sys -import torch +import os.path + send_model_to_trash(loaded_model) + checkpoints_list[self.title] = self import gc +import threading import sys -import torch import threading + send_model_to_cpu(sd_model) +import threading self.name = name + if already_loaded is not None: + send_model_to_device(already_loaded) + timer.record("send model to device") + model_data.set_sd_model(already_loaded, already_loaded=True) + + if not SkipWritingToConfig.skip: + for id in self.ids: import sys - def convert(name): + shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256 + + print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}") + sd_vae.reload_vae_weights(already_loaded) import sys + model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"]) + for id in self.ids: import torch + for id in self.ids: import re - else: + import sys -import torch + if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): + for id in self.ids: import safetensors.torch import sys -import re + model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"]) + checkpoint_aliases[id] = self import threading -import safetensors.torch + errors.display(e, f"reading checkpoint metadata: {filename}") + model_data.sd_model = sd_model + checkpoint_aliases[id] = self import sys + sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None) + sd_vae.checkpoint_info = sd_model.sd_checkpoint_info + + print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}") + return sd_model import re -import collections +import re + return None + import sys -import re +import torch import os.path import sys -import re + os.makedirs(model_path, exist_ok=True) + + else: import sys + import sys + enable_midas_autodownload() import sys +def checkpoint_tiles(): + import sys + def convert(name): + self.name = name import re -import gc - + else: + current_checkpoint_info = sd_model.sd_checkpoint_info self.name_for_extra = os.path.splitext(os.path.basename(filename))[0] import threading + if self.sha256 is None: -import sys + sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) + if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + return sd_model + + if sd_model is not None: import sys + return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) + send_model_to_cpu(sd_model) import sys + cmd_ckpt = shared.cmd_opts.ckpt state_dict = get_checkpoint_state_dict(checkpoint_info, timer) @@ -644,8 +813,10 @@ timer.record("find config") if sd_model is None or checkpoint_config != sd_model.used_config: + def calculate_shorthash(self): import sys - else: + send_model_to_trash(sd_model) + load_model(checkpoint_info, already_loaded_state_dict=state_dict) return model_data.sd_model @@ -662,17 +833,19 @@ script_callbacks.model_loaded_callback(sd_model) timer.record("script callbacks") - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + if not sd_model.lowvram: sd_model.to(devices.device) timer.record("move model to device") print(f"Weights loaded in {timer.summary()}.") + model_data.set_sd_model(sd_model) + sd_unet.apply_unet() + return sd_model def unload_model_weights(sd_model=None, info=None): - from modules import devices, sd_hijack timer = Timer() if model_data.sd_model: diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 8266fa39797b2044ddd0abd5e921bca5c6f87bea..08dd03f19c793b860832f73ebc534da520ed1813 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -2,7 +2,7 @@ import os import torch -from modules import shared, paths, sd_disable_initialization +from modules import shared, paths, sd_disable_initialization, devices sd_configs_path = shared.sd_configs_path sd_repo_configs_path = os.path.join(paths.paths['Stable Diffusion'], "configs", "stable-diffusion") @@ -29,7 +29,6 @@ Detects whether unet in state_dict is using v-parameterization. Returns True if it is. You're welcome. """ import ldm.modules.diffusionmodules.openaimodel - from modules import devices device = devices.cpu diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py new file mode 100644 index 0000000000000000000000000000000000000000..5ffd2f4f9fd164bccef1b37b1c459250ff360357 --- /dev/null +++ b/modules/sd_models_types.py @@ -0,0 +1,31 @@ +from ldm.models.diffusion.ddpm import LatentDiffusion +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from modules.sd_models import CheckpointInfo + + +class WebuiSdModel(LatentDiffusion): + """This class is not actually instantinated, but its fields are created and fieeld by webui""" + + lowvram: bool + """True if lowvram/medvram optimizations are enabled -- see modules.lowvram for more info""" + + sd_model_hash: str + """short hash, 10 first characters of SHA1 hash of the model file; may be None if --no-hashing flag is used""" + + sd_model_checkpoint: str + """path to the file on disk that model weights were obtained from""" + + sd_checkpoint_info: 'CheckpointInfo' + """structure with additional information about the file with model's weights""" + + is_sdxl: bool + """True if the model's architecture is SDXL""" + + is_sd2: bool + """True if the model's architecture is SD 2.x""" + + is_sd1: bool + """True if the model's architecture is SD 1.x""" diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 40559208bd7b8b15856bb39d06fb346a8874d3cb..0112332161fa21807e705cc1763681eefd7456e5 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -56,6 +56,14 @@ return torch.cat(res, dim=1) +def tokenize(self: sgm.modules.GeneralConditioner, texts): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]: + return embedder.tokenize(texts) + + raise AssertionError('no tokenizer available') + + + def process_texts(self, texts): for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: return embedder.process_texts(texts) @@ -68,6 +76,7 @@ # those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text +sgm.modules.GeneralConditioner.tokenize = tokenize sgm.modules.GeneralConditioner.process_texts = process_texts sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count @@ -89,12 +98,12 @@ model.conditioner.wrapped = torch.nn.Module() -sgm.modules.attention.print = lambda *args: None -sgm.modules.diffusionmodules.model.print = lambda *args: None -import sgm.modules.diffusionmodules.discretizer +from modules import devices, shared, prompt_parser -import sgm.modules.diffusionmodules.discretizer +from modules import devices, shared, prompt_parser import torch +sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print +sgm.modules.encoders.modules.print = shared.ldm_print # this gets the code to load the vanilla attention that we override sgm.modules.attention.SDP_IS_AVAILABLE = True diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index bea2684c4db6171075a38427bb9c01b34d9e688a..45faae62821cb55c5fdcb5b42086775071c8cc4e 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,17 +1,18 @@ -from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared +from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared # imports for functions that previously were here and are used by other modules from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 all_samplers = [ *sd_samplers_kdiffusion.samplers_data_k_diffusion, - *sd_samplers_compvis.samplers_data_compvis, + *sd_samplers_timesteps.samplers_data_timesteps, ] all_samplers_map = {x.name: x for x in all_samplers} samplers = [] samplers_for_img2img = [] samplers_map = {} +samplers_hidden = {} def find_sampler_config(name): @@ -38,20 +39,24 @@ return sampler def set_samplers(): +from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 -] - hidden = set(shared.opts.hide_samplers) +from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 # imports for functions that previously were here and are used by other modules - - samplers = [x for x in all_samplers if x.name not in hidden] + samplers = all_samplers - samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] + samplers_for_img2img = all_samplers samplers_map.clear() for sampler in all_samplers: samplers_map[sampler.name.lower()] = sampler.name for alias in sampler.aliases: # imports for functions that previously were here and are used by other modules + *sd_samplers_compvis.samplers_data_compvis, + + +def visible_sampler_names(): +from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 *sd_samplers_compvis.samplers_data_compvis, diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..b8101d38dc3be018f7725cbe8dfde00e664da7f8 --- /dev/null +++ b/modules/sd_samplers_cfg_denoiser.py @@ -0,0 +1,230 @@ +import torch +from modules import prompt_parser, devices, sd_samplers_common + +from modules.shared import opts, state +import modules.shared as shared +from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback +from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback +from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback + + +def catenate_conds(conds): + if not isinstance(conds[0], dict): + return torch.cat(conds) + + return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} + + +def subscript_cond(cond, a, b): + if not isinstance(cond, dict): + return cond[a:b] + + return {key: vec[a:b] for key, vec in cond.items()} + + +def pad_cond(tensor, repeats, empty): + if not isinstance(tensor, dict): + return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) + + tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) + return tensor + + +class CFGDenoiser(torch.nn.Module): + """ + Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) + that can take a noisy picture and produce a noise-free picture using two guidances (prompts) + instead of one. Originally, the second prompt is just an empty string, but we use non-empty + negative prompt. + """ + + def __init__(self, sampler): + super().__init__() + self.model_wrap = None + self.mask = None + self.nmask = None + self.init_latent = None + self.steps = None + """number of steps as specified by user in UI""" + + self.total_steps = None + """expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler""" + + self.step = 0 + self.image_cfg_scale = None + self.padded_cond_uncond = False + self.sampler = sampler + self.model_wrap = None + self.p = None + self.mask_before_denoising = False + + @property + def inner_model(self): + raise NotImplementedError() + + def combine_denoised(self, x_out, conds_list, uncond, cond_scale): + denoised_uncond = x_out[-uncond.shape[0]:] + denoised = torch.clone(denoised_uncond) + + for i, conds in enumerate(conds_list): + for cond_index, weight in conds: + denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) + + return denoised + + def combine_denoised_for_edit_model(self, x_out, cond_scale): + out_cond, out_img_cond, out_uncond = x_out.chunk(3) + denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) + + return denoised + + def get_pred_x0(self, x_in, x_out, sigma): + return x_out + + def update_inner_model(self): + self.model_wrap = None + + c, uc = self.p.get_conds() + self.sampler.sampler_extra_args['cond'] = c + self.sampler.sampler_extra_args['uncond'] = uc + + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): + if state.interrupted or state.skipped: + raise sd_samplers_common.InterruptedException + + if sd_samplers_common.apply_refiner(self): + cond = self.sampler.sampler_extra_args['cond'] + uncond = self.sampler.sampler_extra_args['uncond'] + + # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, + # so is_edit_model is set to False to support AND composition. + is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 + + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) + uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + + assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + + if self.mask_before_denoising and self.mask is not None: + x = self.init_latent * self.mask + self.nmask * x + + batch_size = len(conds_list) + repeats = [len(conds_list[i]) for i in range(batch_size)] + + if shared.sd_model.model.conditioning_key == "crossattn-adm": + image_uncond = torch.zeros_like(image_cond) + make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} + else: + image_uncond = image_cond + if isinstance(uncond, dict): + make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": [c_concat]} + else: + make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": [c_crossattn], "c_concat": [c_concat]} + + if not is_edit_model: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond]) + else: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)]) + + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps, tensor, uncond) + cfg_denoiser_callback(denoiser_params) + x_in = denoiser_params.x + image_cond_in = denoiser_params.image_cond + sigma_in = denoiser_params.sigma + tensor = denoiser_params.text_cond + uncond = denoiser_params.text_uncond + skip_uncond = False + + # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it + if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model: + skip_uncond = True + x_in = x_in[:-batch_size] + sigma_in = sigma_in[:-batch_size] + + self.padded_cond_uncond = False + if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]: + empty = shared.sd_model.cond_stage_model_empty_prompt + num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] + + if num_repeats < 0: + tensor = pad_cond(tensor, -num_repeats, empty) + self.padded_cond_uncond = True + elif num_repeats > 0: + uncond = pad_cond(uncond, num_repeats, empty) + self.padded_cond_uncond = True + + if tensor.shape[1] == uncond.shape[1] or skip_uncond: + if is_edit_model: + cond_in = catenate_conds([tensor, uncond, uncond]) + elif skip_uncond: + cond_in = tensor + else: + cond_in = catenate_conds([tensor, uncond]) + + if shared.opts.batch_cond_uncond: + x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) + else: + x_out = torch.zeros_like(x_in) + for batch_offset in range(0, x_out.shape[0], batch_size): + a = batch_offset + b = a + batch_size + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) + else: + x_out = torch.zeros_like(x_in) + batch_size = batch_size*2 if shared.opts.batch_cond_uncond else batch_size + for batch_offset in range(0, tensor.shape[0], batch_size): + a = batch_offset + b = min(a + batch_size, tensor.shape[0]) + + if not is_edit_model: + c_crossattn = subscript_cond(tensor, a, b) + else: + c_crossattn = torch.cat([tensor[a:b]], uncond) + + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) + + if not skip_uncond: + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) + + denoised_image_indexes = [x[0][0] for x in conds_list] + if skip_uncond: + fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) + x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be + + denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model) + cfg_denoised_callback(denoised_params) + + devices.test_for_nans(x_out, "unet") + + if is_edit_model: + denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) + elif skip_uncond: + denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0) + else: + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + + if not self.mask_before_denoising and self.mask is not None: + denoised = self.init_latent * self.mask + self.nmask * denoised + + self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) + + if opts.live_preview_content == "Prompt": + preview = self.sampler.last_latent + elif opts.live_preview_content == "Negative prompt": + preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma) + else: + preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma) + + sd_samplers_common.store_latent(preview) + + after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) + cfg_after_cfg_callback(after_cfg_callback_params) + denoised = after_cfg_callback_params.x + + self.step += 1 + return denoised + diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 763829f1ca34be1878cdcafd5d7e594488f6f2f8..60fa161cc7ecb60d046888ac89f3b807df541f80 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -1,13 +1,25 @@ +import inspect from collections import namedtuple import numpy as np import torch from PIL import Image +from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models +from modules.shared import opts, state + from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd + +SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) + + + from modules.shared import opts, state + import modules.shared as shared + if self.options.get("second_order", False): + steps = steps * 2 -SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) + return steps def setup_img2img_steps(p, steps=None): @@ -25,20 +37,36 @@ approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3} +from modules.shared import opts, state import numpy as np + """Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1].""" + - if approximation is None: + if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt): approximation = approximation_indexes.get(opts.show_progress_type, 0) + + from modules import lowvram + if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full: + approximation = 1 if approximation == 2: - x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5 + x_sample = sd_vae_approx.cheap_approximation(sample) elif approximation == 1: - x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5 + x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach() elif approximation == 3: -import numpy as np import modules.shared as shared - x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() + x_sample = x_sample * 2 - 1 else: + if model is None: +import modules.shared as shared import torch + with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32 + x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype)) + + return x_sample + + +import numpy as np + x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5 x_sample = torch.clamp(x_sample, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) @@ -46,6 +74,12 @@ return Image.fromarray(x_sample) +def decode_first_stage(model, x): + x = x.to(devices.dtype_vae) + approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0) + return samples_to_images_tensor(x, approx_index, model) + + def sample_to_image(samples, index=0, approximation=None): return single_sample_to_image(samples[index], approximation) @@ -54,6 +88,32 @@ def samples_to_image_grid(samples, approximation=None): return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples]) +def images_tensor_to_samples(image, approximation=None, model=None): + '''image[0, 1] -> latent''' + if approximation is None: + approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0) + + if approximation == 3: + image = image.to(devices.device, devices.dtype) + x_latent = sd_vae_taesd.encoder_model()(image) + else: + if model is None: + model = shared.sd_model + image = image.to(shared.device, dtype=devices.dtype_vae) + image = image * 2 - 1 + if len(image) > 1: + x_latent = torch.stack([ + model.get_first_stage_encoding( + model.encode_first_stage(torch.unsqueeze(img, 0)) + )[0] + for img in image + ]) + else: + x_latent = model.get_first_stage_encoding(model.encode_first_stage(image)) + + return x_latent + + def store_latent(decoded): state.current_latent = decoded @@ -85,16 +145,194 @@ class InterruptedException(BaseException): pass +def replace_torchsde_browinan(): from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd +import modules.shared as shared + + def torchsde_randn(size, dtype, device, seed): + return devices.randn_local(seed, size).to(device=device, dtype=dtype) + + torchsde._brownian.brownian_interval._randn = torchsde_randn + + +replace_torchsde_browinan() + + +def apply_refiner(cfg_denoiser): + completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps + refiner_switch_at = cfg_denoiser.p.refiner_switch_at + refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info + + if opts.img2img_fix_steps or steps is not None: from modules.shared import opts, state from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd +import torch + + if opts.img2img_fix_steps or steps is not None: import modules.shared as shared + return False + if getattr(cfg_denoiser.p, "enable_hr", False) and not cfg_denoiser.p.is_hr_pass: from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd +import torch + + cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title + cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at + + with sd_models.SkipWritingToConfig(): + sd_models.reload_model_weights(info=refiner_checkpoint_info) + + devices.torch_gc() + cfg_denoiser.p.setup_conds() + cfg_denoiser.update_inner_model() + + return True + + +class TorchHijack: + requested_steps = (steps or p.steps) SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) + steps = int(requested_steps / min(p.denoising_strength, 0.999)) if p.denoising_strength > 0 else 0 + this is needed to properly replace every use of torch.randn_like. from collections import namedtuple + x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) + def __init__(self, p): + self.rng = p.rng + def __getattr__(self, item): + if item == 'randn_like': + return self.randn_like + + if hasattr(torch, item): + return getattr(torch, item) + + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") + + def randn_like(self, x): + t_enc = requested_steps - 1 import numpy as np + + +class Sampler: + def __init__(self, funcname): + self.funcname = funcname + self.func = funcname + self.extra_params = [] + self.sampler_noises = None + self.stop_at = None + self.eta = None + self.config: SamplerData = None # set by the function calling the constructor + self.last_latent = None + self.s_min_uncond = None + self.s_churn = 0.0 + self.s_tmin = 0.0 + self.s_tmax = float('inf') + self.s_noise = 1.0 + + self.eta_option_field = 'eta_ancestral' + self.eta_infotext_field = 'Eta' + self.eta_default = 1.0 + + self.conditioning_key = shared.sd_model.model.conditioning_key + + self.p = None + self.model_wrap_cfg = None + self.sampler_extra_args = None + self.options = {} + + def callback_state(self, d): + step = d['i'] + + if self.stop_at is not None and step > self.stop_at: + raise InterruptedException + + state.sampling_step = step + shared.total_tqdm.update() + + def launch_sampling(self, steps, func): + self.model_wrap_cfg.steps = steps + self.model_wrap_cfg.total_steps = self.config.total_steps(steps) + state.sampling_steps = steps + state.sampling_step = 0 + + try: + return func() + except RecursionError: + print( + 'Encountered RecursionError during sampling, returning last latent. ' + 'rho >5 with a polyexponential scheduler may cause this error. ' + 'You should try to use a smaller rho value instead.' + ) + return self.last_latent + except InterruptedException: + return self.last_latent + + def number_of_needed_noises(self, p): + return p.steps + + def initialize(self, p) -> dict: + self.p = p + self.model_wrap_cfg.p = p + self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None + self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None + self.model_wrap_cfg.step = 0 + self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) + self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) + self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) + + k_diffusion.sampling.torch = TorchHijack(p) + + extra_params_kwargs = {} + for param_name in self.extra_params: + if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: + extra_params_kwargs[param_name] = getattr(p, param_name) + + if 'eta' in inspect.signature(self.func).parameters: + if self.eta != self.eta_default: + p.extra_generation_params[self.eta_infotext_field] = self.eta + + extra_params_kwargs['eta'] = self.eta + + if len(self.extra_params) > 0: + s_churn = getattr(opts, 's_churn', p.s_churn) + s_tmin = getattr(opts, 's_tmin', p.s_tmin) + s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf + s_noise = getattr(opts, 's_noise', p.s_noise) + + if 's_churn' in extra_params_kwargs and s_churn != self.s_churn: + extra_params_kwargs['s_churn'] = s_churn + p.s_churn = s_churn + p.extra_generation_params['Sigma churn'] = s_churn + if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin: + extra_params_kwargs['s_tmin'] = s_tmin + p.s_tmin = s_tmin + p.extra_generation_params['Sigma tmin'] = s_tmin + if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax: + extra_params_kwargs['s_tmax'] = s_tmax + p.s_tmax = s_tmax + p.extra_generation_params['Sigma tmax'] = s_tmax + if 's_noise' in extra_params_kwargs and s_noise != self.s_noise: + extra_params_kwargs['s_noise'] = s_noise + p.s_noise = s_noise + p.extra_generation_params['Sigma noise'] = s_noise + + return extra_params_kwargs + + def create_noise_sampler(self, x, sigmas, p): + """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" + if shared.opts.no_dpmpp_sde_batch_determinism: + return None + + from k_diffusion.sampling import BrownianTreeNoiseSampler + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] + return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) + + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + raise NotImplementedError() + + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + raise NotImplementedError() diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 4a8396f97ec3fd75b2eabdc9fbf97ac34af27e6a..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -1,224 +0,0 @@ -import math -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms - -import numpy as np -import torch - -from modules.shared import state -from modules import sd_samplers_common, prompt_parser, shared -import modules.models.diffusion.uni_pc - - -samplers_data_compvis = [ - sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}), - sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}), - sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}), -] - - -class VanillaStableDiffusionSampler: - def __init__(self, constructor, sd_model): - self.sampler = constructor(sd_model) - self.is_ddim = hasattr(self.sampler, 'p_sample_ddim') - self.is_plms = hasattr(self.sampler, 'p_sample_plms') - self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler) - self.orig_p_sample_ddim = None - if self.is_plms: - self.orig_p_sample_ddim = self.sampler.p_sample_plms - elif self.is_ddim: - self.orig_p_sample_ddim = self.sampler.p_sample_ddim - self.mask = None - self.nmask = None - self.init_latent = None - self.sampler_noises = None - self.step = 0 - self.stop_at = None - self.eta = None - self.config = None - self.last_latent = None - - self.conditioning_key = sd_model.model.conditioning_key - - def number_of_needed_noises(self, p): - return 0 - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except sd_samplers_common.InterruptedException: - return self.last_latent - - def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): - x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning) - - res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs) - - x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res) - - return res - - def before_sample(self, x, ts, cond, unconditional_conditioning): - if state.interrupted or state.skipped: - raise sd_samplers_common.InterruptedException - - if self.stop_at is not None and self.step > self.stop_at: - raise sd_samplers_common.InterruptedException - - # Have to unwrap the inpainting conditioning here to perform pre-processing - image_conditioning = None - uc_image_conditioning = None - if isinstance(cond, dict): - if self.conditioning_key == "crossattn-adm": - image_conditioning = cond["c_adm"] - uc_image_conditioning = unconditional_conditioning["c_adm"] - else: - image_conditioning = cond["c_concat"][0] - cond = cond["c_crossattn"][0] - unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) - - assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers' - cond = tensor - - # for DDIM, shapes must match, we can't just process cond and uncond independently; - # filling unconditional_conditioning with repeats of the last vector to match length is - # not 100% correct but should work well enough - if unconditional_conditioning.shape[1] < cond.shape[1]: - last_vector = unconditional_conditioning[:, -1:] - last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1]) - unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated]) - elif unconditional_conditioning.shape[1] > cond.shape[1]: - unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]] - - if self.mask is not None: - img_orig = self.sampler.model.q_sample(self.init_latent, ts) - x = img_orig * self.mask + self.nmask * x - - # Wrap the image conditioning back up since the DDIM code can accept the dict directly. - # Note that they need to be lists because it just concatenates them later. - if image_conditioning is not None: - if self.conditioning_key == "crossattn-adm": - cond = {"c_adm": image_conditioning, "c_crossattn": [cond]} - unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]} - else: - cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - return x, ts, cond, unconditional_conditioning - - def update_step(self, last_latent): - if self.mask is not None: - self.last_latent = self.init_latent * self.mask + self.nmask * last_latent - else: - self.last_latent = last_latent - - sd_samplers_common.store_latent(self.last_latent) - - self.step += 1 - state.sampling_step = self.step - shared.total_tqdm.update() - - def after_sample(self, x, ts, cond, uncond, res): - if not self.is_unipc: - self.update_step(res[1]) - - return x, ts, cond, uncond, res - - def unipc_after_update(self, x, model_x): - self.update_step(x) - - def initialize(self, p): - if self.is_ddim: - self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim - else: - self.eta = 0.0 - - if self.eta != 0.0: - p.extra_generation_params["Eta DDIM"] = self.eta - - if self.is_unipc: - keys = [ - ('UniPC variant', 'uni_pc_variant'), - ('UniPC skip type', 'uni_pc_skip_type'), - ('UniPC order', 'uni_pc_order'), - ('UniPC lower order final', 'uni_pc_lower_order_final'), - ] - - for name, key in keys: - v = getattr(shared.opts, key) - if v != shared.opts.get_default(key): - p.extra_generation_params[name] = v - - for fieldname in ['p_sample_ddim', 'p_sample_plms']: - if hasattr(self.sampler, fieldname): - setattr(self.sampler, fieldname, self.p_sample_ddim_hook) - if self.is_unipc: - self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx)) - - self.mask = p.mask if hasattr(p, 'mask') else None - self.nmask = p.nmask if hasattr(p, 'nmask') else None - - - def adjust_steps_if_invalid(self, p, num_steps): - if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'): - if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order: - num_steps = shared.opts.uni_pc_order - valid_step = 999 / (1000 // num_steps) - if valid_step == math.floor(valid_step): - return int(valid_step) + 1 - - return num_steps - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) - steps = self.adjust_steps_if_invalid(p, steps) - self.initialize(p) - - self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) - x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) - - self.init_latent = x - self.last_latent = x - self.step = 0 - - # Wrap the conditioning models with additional image conditioning for inpainting model - if image_conditioning is not None: - if self.conditioning_key == "crossattn-adm": - conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]} - unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]} - else: - conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - self.initialize(p) - - self.init_latent = None - self.last_latent = x - self.step = 0 - - steps = self.adjust_steps_if_invalid(p, steps or p.steps) - - # Wrap the conditioning models with additional image conditioning for inpainting model - # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape - if image_conditioning is not None: - if self.conditioning_key == "crossattn-adm": - conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning} - unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)} - else: - conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} - unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} - - samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) - - return samples_ddim diff --git a/modules/sd_samplers_extra.py b/modules/sd_samplers_extra.py new file mode 100644 index 0000000000000000000000000000000000000000..1b981ca80c355cbb6a92915d422cfa19e51b21c4 --- /dev/null +++ b/modules/sd_samplers_extra.py @@ -0,0 +1,74 @@ +import torch +import tqdm +import k_diffusion.sampling + + [email protected]_grad() +def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list=None): + """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023) + Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]} + If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list + """ + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + step_id = 0 + from k_diffusion.sampling import to_d, get_sigmas_karras + + def heun_step(x, old_sigma, new_sigma, second_order=True): + nonlocal step_id + denoised = model(x, old_sigma * s_in, **extra_args) + d = to_d(x, old_sigma, denoised) + if callback is not None: + callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised}) + dt = new_sigma - old_sigma + if new_sigma == 0 or not second_order: + # Euler method + x = x + d * dt + else: + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, new_sigma * s_in, **extra_args) + d_2 = to_d(x_2, new_sigma, denoised_2) + d_prime = (d + d_2) / 2 + x = x + d_prime * dt + step_id += 1 + return x + + steps = sigmas.shape[0] - 1 + if restart_list is None: + if steps >= 20: + restart_steps = 9 + restart_times = 1 + if steps >= 36: + restart_steps = steps // 4 + restart_times = 2 + sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device) + restart_list = {0.1: [restart_steps + 1, restart_times, 2]} + else: + restart_list = {} + + restart_list = {int(torch.argmin(abs(sigmas - key), dim=0)): value for key, value in restart_list.items()} + + step_list = [] + for i in range(len(sigmas) - 1): + step_list.append((sigmas[i], sigmas[i + 1])) + if i + 1 in restart_list: + restart_steps, restart_times, restart_max = restart_list[i + 1] + min_idx = i + 1 + max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0)) + if max_idx < min_idx: + sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] + while restart_times > 0: + restart_times -= 1 + step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])]) + + last_sigma = None + for old_sigma, new_sigma in tqdm.tqdm(step_list, disable=disable): + if last_sigma is None: + last_sigma = old_sigma + elif last_sigma < old_sigma: + x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (old_sigma ** 2 - last_sigma ** 2) ** 0.5 + x = heun_step(x, old_sigma, new_sigma) + last_sigma = new_sigma + + return x diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 5552a8dc7dc6c75e2dfe4f027b4bdeeb215a2299..b9e0d5776564866ce597924ef57e9b3cfebb8b16 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,437 +1,165 @@ -from collections import deque import torch import inspect import k_diffusion.sampling -from modules import prompt_parser, devices, sd_samplers_common +from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser - +from modules.sd_samplers_cfg_denoiser import CFGDenoiser # noqa: F401 -from modules.shared import opts, state -import modules.shared as shared -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback -from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback -from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback - -samplers_k_diffusion = [ -from collections import deque import inspect -from collections import deque import k_diffusion.sampling -from collections import deque from modules import prompt_parser, devices, sd_samplers_common -from collections import deque - ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}), - ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}), - ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), - ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), - ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), -import torch import inspect -import torch import k_diffusion.sampling + - ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), +import modules.shared as shared -import torch - ('DPM2 a Karras', 'sample_dpm_2_ancestral', ['k_dpm_2_a_ka'], {'scheduler': 'karras', 'discard_next_to_last_sigma': True, "uses_ensd": True, "second_order": True}), +from collections import deque import torch -import modules.shared as shared ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), import inspect -from collections import deque + 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential import inspect -import torch +from collections import deque - -import inspect +from collections import deque import inspect -import inspect +from collections import deque import k_diffusion.sampling -import inspect +from collections import deque from modules import prompt_parser, devices, sd_samplers_common -import inspect +from collections import deque import inspect -import torch - -sampler_extra_params = { +def catenate_conds(conds): import inspect -import modules.shared as shared + if not isinstance(conds[0], dict): -import inspect +from collections import deque from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback -import k_diffusion.sampling +import torch -import k_diffusion.sampling +import torch from collections import deque - -import k_diffusion.sampling +import torch import torch -import k_diffusion.sampling import inspect - 'Automatic': None, - 'karras': k_diffusion.sampling.get_sigmas_karras, - 'exponential': k_diffusion.sampling.get_sigmas_exponential, - 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential -} - - -def catenate_conds(conds): - if not isinstance(conds[0], dict): return torch.cat(conds) - +import inspect return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()} - - +import inspect def subscript_cond(cond, a, b): +import inspect if not isinstance(cond, dict): +import inspect return cond[a:b] - +import inspect return {key: vec[a:b] for key, vec in cond.items()} - - -def pad_cond(tensor, repeats, empty): - if not isinstance(tensor, dict): - return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1))], axis=1) - - tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty) - return tensor - - -class CFGDenoiser(torch.nn.Module): - import torch - import inspect - +import torch import k_diffusion.sampling - +import torch from modules import prompt_parser, devices, sd_samplers_common - negative prompt. - import torch - +import torch from modules.shared import opts, state - super().__init__() - self.inner_model = model - self.mask = None - self.nmask = None -from modules.shared import opts, state import torch +import modules.shared as shared -from modules.shared import opts, state import inspect - self.image_cfg_scale = None -from modules.shared import opts, state from modules import prompt_parser, devices, sd_samplers_common - - def combine_denoised(self, x_out, conds_list, uncond, cond_scale): - denoised_uncond = x_out[-uncond.shape[0]:] - denoised = torch.clone(denoised_uncond) - - for i, conds in enumerate(conds_list): - for cond_index, weight in conds: - denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale) - - return denoised -import modules.shared as shared import inspect - out_cond, out_img_cond, out_uncond = x_out.chunk(3) - denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) - -import modules.shared as shared import torch -import modules.shared as shared - if state.interrupted or state.skipped: - raise sd_samplers_common.InterruptedException - - # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, - # so is_edit_model is set to False to support AND composition. - is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 - +samplers_data_k_diffusion = [ - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback import inspect - -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback import k_diffusion.sampling - -from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback +import inspect from modules import prompt_parser, devices, sd_samplers_common - repeats = [len(conds_list[i]) for i in range(batch_size)] - - if shared.sd_model.model.conditioning_key == "crossattn-adm": - image_uncond = torch.zeros_like(image_cond) - make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": [c_crossattn], "c_adm": c_adm} - else: - image_uncond = image_cond - if isinstance(uncond, dict): -from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback import inspect - else: -from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules import prompt_parser, devices, sd_samplers_common - - if not is_edit_model: -from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.shared import opts, state - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond]) - else: - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) -from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback +import inspect import torch -from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback import inspect - cfg_denoiser_callback(denoiser_params) - x_in = denoiser_params.x - image_cond_in = denoiser_params.image_cond -from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback from modules.shared import opts, state -from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback +import inspect import modules.shared as shared -from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback +import inspect from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback -samplers_k_diffusion = [ +import k_diffusion.sampling - - # alternating uncond allows for higher thresholds without the quality loss normally expected from raising it - if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model: -samplers_k_diffusion = [ import inspect - x_in = x_in[:-batch_size] -samplers_k_diffusion = [ from modules import prompt_parser, devices, sd_samplers_common - +import modules.shared as shared -from modules.shared import opts, state +import inspect from modules import prompt_parser, devices, sd_samplers_common - if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]: - empty = shared.sd_model.cond_stage_model_empty_prompt - num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] - -samplers_k_diffusion = [ from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback -from collections import deque import inspect + -from collections import deque import inspect + from collections import deque -from collections import deque import inspect + import torch -from collections import deque import inspect -import inspect - self.padded_cond_uncond = True -from collections import deque import inspect -import k_diffusion.sampling - if is_edit_model: - cond_in = catenate_conds([tensor, uncond, uncond]) - elif skip_uncond: - cond_in = tensor -from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback import k_diffusion.sampling from collections import deque - 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=make_condition_dict(cond_in, image_cond_in)) - else: -from collections import deque k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion} -from collections import deque k_diffusion_scheduler = { -from collections import deque 'Automatic': None, -from collections import deque 'karras': k_diffusion.sampling.get_sigmas_karras, -from collections import deque 'exponential': k_diffusion.sampling.get_sigmas_exponential, - else: -from collections import deque 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential -from collections import deque import k_diffusion.sampling -import modules.shared as shared - for batch_offset in range(0, tensor.shape[0], batch_size): - a = batch_offset - ('LMS', 'sample_lms', ['k_lms'], {}), from collections import deque - if not is_edit_model: - c_crossattn = subscript_cond(tensor, a, b) - else: - c_crossattn = torch.cat([tensor[a:b]], uncond) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b])) - - if not skip_uncond: - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:])) - - denoised_image_indexes = [x[0][0] for x in conds_list] - if skip_uncond: - fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes]) - x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be - - ('Heun', 'sample_heun', ['k_heun'], {"second_order": True}), import inspect -from collections import deque that can take a noisy picture and produce a noise-free picture using two guidances (prompts) - -from collections import deque +import inspect instead of one. Originally, the second prompt is just an empty string, but we use non-empty - -from collections import deque +import inspect negative prompt. -from collections import deque +import inspect def __init__(self, model): -from collections import deque +import inspect super().__init__() -from collections import deque +import inspect self.inner_model = model - if is_edit_model: - denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) - elif skip_uncond: - ('DPM2', 'sample_dpm_2', ['k_dpm_2'], {'discard_next_to_last_sigma': True}), import inspect - else: -from collections import deque from modules.shared import opts, state -import k_diffusion.sampling - - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised - - after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) - cfg_after_cfg_callback(after_cfg_callback_params) - denoised = after_cfg_callback_params.x - - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}), - return denoised -class TorchHijack: - def __init__(self, sampler_noises): - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}), import inspect - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a'], {'discard_next_to_last_sigma': True, "uses_ensd": True}), from modules.shared import opts, state from collections import deque - raise sd_samplers_common.InterruptedException - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") - - def randn_like(self, x): - ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}), import inspect - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}), from modules.shared import opts, state - return torch.randn_like(x, device=devices.cpu).to(x.device) - else: - return torch.randn_like(x) - - import torch - def __init__(self, funcname, sd_model): - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - - ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), import inspect - self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) - self.extra_params = sampler_extra_params.get(funcname, []) - ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), from modules.shared import opts, state - self.sampler_noises = None - self.stop_at = None - self.eta = None - self.config = None # set by the function calling the constructor - self.last_latent = None - ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), import inspect import torch - ('Euler', 'sample_euler', ['k_euler'], {}), - - def callback_state(self, d): - step = d['i'] - latent = d["denoised"] - if opts.live_preview_content == "Combined": - sd_samplers_common.store_latent(latent) - ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), - if self.stop_at is not None and step > self.stop_at: - raise sd_samplers_common.InterruptedException - state.sampling_step = step - ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), import inspect - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), from modules.shared import opts, state - return func() - except RecursionError: - print( - 'Encountered RecursionError during sampling, returning last latent. ' - 'rho >5 with a polyexponential scheduler may cause this error. ' - 'You should try to use a smaller rho value instead.' - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), import k_diffusion.sampling -import torch import inspect +from modules.shared import opts, state from modules import prompt_parser, devices, sd_samplers_common - ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), -import torch import inspect -from modules import prompt_parser, devices, sd_samplers_common +from modules.shared import opts, state -import torch sampler_extra_params = { - return p.steps - - def initialize(self, p): - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None - self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.step = 0 - self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) - self.eta = p.eta if p.eta is not None else opts.eta_ancestral - self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) - - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) - - ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}), from modules.shared import opts, state - for param_name in self.extra_params: - if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: - extra_params_kwargs[param_name] = getattr(p, param_name) - - if 'eta' in inspect.signature(self.func).parameters: - if self.eta != 1.0: - p.extra_generation_params["Eta"] = self.eta - - extra_params_kwargs['eta'] = self.eta - - return extra_params_kwargs def get_sigmas(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) @@ -467,6 +200,9 @@ elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, device=shared.device) + elif self.config is not None and self.config.options.get('scheduler', None) == 'exponential': + m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) + sigmas = k_diffusion.sampling.get_sigmas_exponential(n=steps, sigma_min=m_sigma_min, sigma_max=m_sigma_max, device=shared.device) else: sigmas = self.model_wrap.get_sigmas(steps) @@ -476,30 +212,27 @@ return sigmas ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), -from collections import deque -import torch from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback -import torch - ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), import inspect - return None - from k_diffusion.sampling import BrownianTreeNoiseSampler + sigmas = self.get_sigmas(p, steps) - sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), import torch - if shared.sd_model.model.conditioning_key == "crossattn-adm": - return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), +import inspect import inspect +import modules.shared as shared -from collections import deque - + p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise import inspect +import modules.shared as shared import torch import inspect +import modules.shared as shared import inspect + noise = extra_noise_params.noise + xi += noise * opts.img2img_extra_noise extra_params_kwargs = self.initialize(p) parameters = inspect.signature(self.func).parameters @@ -518,10 +253,14 @@ noise_sampler = self.create_noise_sampler(x, sigmas, p) extra_params_kwargs['noise_sampler'] = noise_sampler import inspect + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): + extra_params_kwargs['solver_type'] = 'heun' + +import inspect ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), self.last_latent = x import inspect - ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), + raise sd_samplers_common.InterruptedException 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, @@ -530,7 +269,7 @@ 's_min_uncond': self.s_min_uncond } import inspect -import torch +import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback if self.model_wrap_cfg.padded_cond_uncond: @@ -549,13 +288,17 @@ extra_params_kwargs = self.initialize(p) parameters = inspect.signature(self.func).parameters import inspect +from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback + extra_params_kwargs['n'] = steps + +import inspect extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() - if 'n' in parameters: + import inspect - 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], from collections import deque +from modules import prompt_parser, devices, sd_samplers_common extra_params_kwargs['sigmas'] = sigmas if self.config.options.get('brownian_noise', False): @@ -562,20 +306,27 @@ noise_sampler = self.create_noise_sampler(x, sigmas, p) extra_params_kwargs['noise_sampler'] = noise_sampler import inspect + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): + extra_params_kwargs['solver_type'] = 'heun' + +import inspect ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), import inspect -} + raise sd_samplers_common.InterruptedException 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale, 's_min_uncond': self.s_min_uncond import inspect -import k_diffusion.sampling import torch +import modules.shared as shared + + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True return samples + diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py new file mode 100644 index 0000000000000000000000000000000000000000..7a6cbd46d9c5b3a6c090eac2c20dee2aaa7f4e59 --- /dev/null +++ b/modules/sd_samplers_timesteps.py @@ -0,0 +1,167 @@ +import torch +import inspect +import sys +from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl +from modules.sd_samplers_cfg_denoiser import CFGDenoiser +from modules.script_callbacks import ExtraNoiseParams, extra_noise_callback + +from modules.shared import opts +import modules.shared as shared + +samplers_timesteps = [ + ('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}), + ('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}), + ('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}), +] + + +samplers_data_timesteps = [ + sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options) + for label, funcname, aliases, options in samplers_timesteps +] + + +class CompVisTimestepsDenoiser(torch.nn.Module): + def __init__(self, model, *args, **kwargs): + super().__init__(*args, **kwargs) + self.inner_model = model + + def forward(self, input, timesteps, **kwargs): + return self.inner_model.apply_model(input, timesteps, **kwargs) + + +class CompVisTimestepsVDenoiser(torch.nn.Module): + def __init__(self, model, *args, **kwargs): + super().__init__(*args, **kwargs) + self.inner_model = model + + def predict_eps_from_z_and_v(self, x_t, t, v): + return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t + + def forward(self, input, timesteps, **kwargs): + model_output = self.inner_model.apply_model(input, timesteps, **kwargs) + e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output) + return e_t + + +class CFGDenoiserTimesteps(CFGDenoiser): + + def __init__(self, sampler): + super().__init__(sampler) + + self.alphas = shared.sd_model.alphas_cumprod + self.mask_before_denoising = True + + def get_pred_x0(self, x_in, x_out, sigma): + ts = sigma.to(dtype=int) + + a_t = self.alphas[ts][:, None, None, None] + sqrt_one_minus_at = (1 - a_t).sqrt() + + pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt() + + return pred_x0 + + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser + self.model_wrap = denoiser(shared.sd_model) + + return self.model_wrap + + +class CompVisSampler(sd_samplers_common.Sampler): + def __init__(self, funcname, sd_model): + super().__init__(funcname) + + self.eta_option_field = 'eta_ddim' + self.eta_infotext_field = 'Eta DDIM' + self.eta_default = 0.0 + + self.model_wrap_cfg = CFGDenoiserTimesteps(self) + + def get_timesteps(self, p, steps): + discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) + if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: + discard_next_to_last_sigma = True + p.extra_generation_params["Discard penultimate sigma"] = True + + steps += 1 if discard_next_to_last_sigma else 0 + + timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999) + + return timesteps + + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) + + timesteps = self.get_timesteps(p, steps) + timesteps_sched = timesteps[:t_enc] + + alphas_cumprod = shared.sd_model.alphas_cumprod + sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]]) + sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]]) + + xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod + + if opts.img2img_extra_noise > 0: + p.extra_generation_params["Extra noise"] = opts.img2img_extra_noise + extra_noise_params = ExtraNoiseParams(noise, x) + extra_noise_callback(extra_noise_params) + noise = extra_noise_params.noise + xi += noise * opts.img2img_extra_noise * sqrt_alpha_cumprod + + extra_params_kwargs = self.initialize(p) + parameters = inspect.signature(self.func).parameters + + if 'timesteps' in parameters: + extra_params_kwargs['timesteps'] = timesteps_sched + if 'is_img2img' in parameters: + extra_params_kwargs['is_img2img'] = True + + self.model_wrap_cfg.init_latent = x + self.last_latent = x + self.sampler_extra_args = { + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale, + 's_min_uncond': self.s_min_uncond + } + + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + + if self.model_wrap_cfg.padded_cond_uncond: + p.extra_generation_params["Pad conds"] = True + + return samples + + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps = steps or p.steps + timesteps = self.get_timesteps(p, steps) + + extra_params_kwargs = self.initialize(p) + parameters = inspect.signature(self.func).parameters + + if 'timesteps' in parameters: + extra_params_kwargs['timesteps'] = timesteps + + self.last_latent = x + self.sampler_extra_args = { + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale, + 's_min_uncond': self.s_min_uncond + } + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + + if self.model_wrap_cfg.padded_cond_uncond: + p.extra_generation_params["Pad conds"] = True + + return samples + + +sys.modules['modules.sd_samplers_compvis'] = sys.modules[__name__] +VanillaStableDiffusionSampler = CompVisSampler # temp. compatibility with older extensions diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..a72daafd47dedb0d9f000c1af4e40ec2767b4e39 --- /dev/null +++ b/modules/sd_samplers_timesteps_impl.py @@ -0,0 +1,137 @@ +import torch +import tqdm +import k_diffusion.sampling +import numpy as np + +from modules import shared +from modules.models.diffusion.uni_pc import uni_pc + + [email protected]_grad() +def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + alphas = alphas_cumprod[timesteps] + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) + sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones((x.shape[0])) + s_x = x.new_ones((x.shape[0], 1, 1, 1)) + for i in tqdm.trange(len(timesteps) - 1, disable=disable): + index = len(timesteps) - 1 - i + + e_t = model(x, timesteps[index].item() * s_in, **extra_args) + + a_t = alphas[index].item() * s_x + a_prev = alphas_prev[index].item() * s_x + sigma_t = sigmas[index].item() * s_x + sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x + + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t + noise = sigma_t * k_diffusion.sampling.torch.randn_like(x) + x = a_prev.sqrt() * pred_x0 + dir_xt + noise + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0}) + + return x + + [email protected]_grad() +def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + alphas = alphas_cumprod[timesteps] + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64 if x.device.type != 'mps' else torch.float32) + sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + s_x = x.new_ones((x.shape[0], 1, 1, 1)) + old_eps = [] + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = alphas[index].item() * s_x + a_prev = alphas_prev[index].item() * s_x + sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + + # direction pointing to x_t + dir_xt = (1. - a_prev).sqrt() * e_t + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + return x_prev, pred_x0 + + for i in tqdm.trange(len(timesteps) - 1, disable=disable): + index = len(timesteps) - 1 - i + ts = timesteps[index].item() * s_in + t_next = timesteps[max(index - 1, 0)].item() * s_in + + e_t = model(x, ts, **extra_args) + + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = model(x_prev, t_next, **extra_args) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + else: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + + x = x_prev + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0}) + + return x + + +class UniPCCFG(uni_pc.UniPC): + def __init__(self, cfg_model, extra_args, callback, *args, **kwargs): + super().__init__(None, *args, **kwargs) + + def after_update(x, model_x): + callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x}) + self.index += 1 + + self.cfg_model = cfg_model + self.extra_args = extra_args + self.callback = callback + self.index = 0 + self.after_update = after_update + + def get_model_input_time(self, t_continuous): + return (t_continuous - 1. / self.noise_schedule.total_N) * 1000. + + def model(self, x, t): + t_input = self.get_model_input_time(t) + + res = self.cfg_model(x, t_input, **self.extra_args) + + return res + + +def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False): + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + + ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means + unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant) + x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final) + + return x diff --git a/modules/sd_unet.py b/modules/sd_unet.py index 6d708ad296bbb5a46a4925db350fb6f319572ddb..5525cfbc3a03580ca884a43971232384c43888d2 100644 --- a/modules/sd_unet.py +++ b/modules/sd_unet.py @@ -47,7 +47,7 @@ current_unet_option = new_option if current_unet_option is None: current_unet = None - if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram): + if not shared.sd_model.lowvram: shared.sd_model.model.diffusion_model.to(devices.device) return diff --git a/modules/sd_vae.py b/modules/sd_vae.py index e4ff29946e8548e8dad447aa90f264d472acb4ed..669097daadef5b6052cb729158d9bfe3d0ad6408 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,6 +1,9 @@ import os import collections -from modules import paths, shared, devices, script_callbacks, sd_models +from dataclasses import dataclass + +from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes + import glob from copy import deepcopy @@ -15,6 +18,23 @@ loaded_vae_file = None checkpoint_info = None checkpoints_loaded = collections.OrderedDict() + + +def get_loaded_vae_name(): + if loaded_vae_file is None: + return None + + return os.path.basename(loaded_vae_file) + + +def get_loaded_vae_hash(): + if loaded_vae_file is None: + return None + + sha256 = hashes.sha256(loaded_vae_file, 'vae') + + return sha256[0:10] if sha256 else None + def get_base_vae(model): if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: @@ -83,8 +103,10 @@ for filepath in candidates: name = get_filename(filepath) vae_dict[name] = filepath + vae_dict.update(dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0])))) + import collections checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0] for vae_file in vae_dict.values(): @@ -94,42 +116,108 @@ return None + if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: + if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE")) - + if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} + source: str = None +import os -vae_dict = {} + def tuple(self): + return self.vae, self.source + + +def is_automatic(): + return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config + + +def resolve_vae_from_setting() -> VaeResolution: vae_path = os.path.abspath(os.path.join(paths.models_path, "VAE")) +import glob + return VaeResolution() + + vae_from_options = vae_dict.get(shared.opts.sd_vae, None) + if vae_from_options is not None: + return VaeResolution(vae_from_options, 'specified in settings') + + if not is_automatic(): +vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} + return VaeResolution(resolved=False) + + +def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution: +import os vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) +import os if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic): +import os return vae_near_checkpoint, 'found near the checkpoint' - +import os if shared.opts.sd_vae == "None": +import os return None, None +import os vae_from_options = vae_dict.get(shared.opts.sd_vae, None) +import os if vae_from_options is not None: +import os return vae_from_options, 'specified in settings' + return VaeResolution(resolved=False) + + +import os if not is_automatic: + vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) +import os print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") + return VaeResolution(vae_near_checkpoint, 'found near the checkpoint') -vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} import os + return shared.cmd_opts.vae_path, 'from commandline argument' +import os def load_vae_dict(filename, map_location): + if shared.cmd_opts.vae_path is not None: +import os vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location) + +import os vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} +import os return vae_dict_1 +def store_base_vae(model): + if res.resolved: +import os vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} +vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} + res = resolve_vae_near_checkpoint(checkpoint_file) +import os global vae_dict, loaded_vae_file + return res + + res = resolve_vae_from_setting() + + return res + + +def load_vae_dict(filename, map_location): + vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location) + vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} + return vae_dict_1 + + +def load_vae(model, vae_file=None, vae_source="from unknown source"): + global vae_dict, base_vae, loaded_vae_file # save_settings = False cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 @@ -167,6 +254,8 @@ elif loaded_vae_file: restore_base_vae(model) loaded_vae_file = vae_file + model.base_vae = base_vae + model.loaded_vae_file = loaded_vae_file # don't call this from outside @@ -185,9 +274,6 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): checkpoint_info = None -import glob - -checkpoint_info = None from copy import deepcopy sd_model = shared.sd_model @@ -196,7 +282,7 @@ checkpoint_file = checkpoint_info.filename if vae_file == unspecified: import os - print("Restoring base VAE") + _load_vae_dict(model, checkpoints_loaded[vae_file]) else: vae_source = "from function argument" @@ -204,7 +290,7 @@ if loaded_vae_file == vae_file: return import os - return os.path.basename(filepath) + else: lowvram.send_everything_to_cpu() else: sd_model.to(devices.cpu) @@ -217,7 +303,7 @@ sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) import os - os.path.join(vae_path, '**/*.ckpt'), + assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}" sd_model.to(devices.device) print("VAE weights loaded.") diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 86bd658ad32af5ba22d087ec668f450fedd693a0..3965e223e6fcfa182a8012911387d822d33e11be 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -82,6 +82,6 @@ coefs = torch.tensor(coeffs).to(sample.device) sd_vae_approx_models = {} -import os +import torch return x_sample diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 5bf7c76e1dd8ca9a7fc3b624861ed9962387222e..808eb3624fd40daa56bcbdb5f8ad771ae5557346 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -44,8 +44,18 @@ Block(64, 64), conv(64, 3), ) +def encoder(): + return nn.Sequential( + conv(3, 64), Block(64, 64), +import os (DNN for encoding / decoding SD's latent space) -""" + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 4), + ) + + +class TAESDDecoder(nn.Module): latent_magnitude = 3 latent_shift = 0.5 @@ -56,25 +66,31 @@ self.decoder = decoder() self.decoder.load_state_dict( torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) -""" +class TAESDEncoder(nn.Module): + latent_magnitude = 3 +(DNN for encoding / decoding SD's latent space) (DNN for encoding / decoding SD's latent space) -from modules import devices, paths_internal, shared + def __init__(self, encoder_path="taesd_encoder.pth"): - + """Initialize pretrained TAESD on the given device from the given checkpoints.""" """ +from modules import devices, paths_internal, shared + self.encoder = encoder() + self.encoder.load_state_dict( + torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) def download_model(model_path, model_url): if not os.path.exists(model_path): os.makedirs(os.path.dirname(model_path), exist_ok=True) - print(f'Downloading TAESD decoder to: {model_path}') + print(f'Downloading TAESD model to: {model_path}') torch.hub.download_url_to_file(model_url, model_path) - import torch +Tiny AutoEncoder for Stable Diffusion model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" loaded_model = sd_vae_taesd_models.get(model_name) @@ -82,12 +99,32 @@ model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name) download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) if os.path.exists(model_path): + loaded_model = TAESDDecoder(model_path) + loaded_model.eval() + loaded_model.to(devices.device, devices.dtype) + sd_vae_taesd_models[model_name] = loaded_model https://github.com/madebyollin/taesd +import torch.nn as nn + raise FileNotFoundError('TAESD model not found') + + return loaded_model.decoder + + +import torch + model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) + + if loaded_model is None: + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name) + download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) + + if os.path.exists(model_path): + loaded_model = TAESDEncoder(model_path) loaded_model.eval() loaded_model.to(devices.device, devices.dtype) sd_vae_taesd_models[model_name] = loaded_model else: raise FileNotFoundError('TAESD model not found') - return loaded_model.decoder + return loaded_model.encoder diff --git a/modules/shared.py b/modules/shared.py index aa72c9c87abeab06c9cdb7350320862b17315afa..636619391fce62974a7d26a9fc4fcd5599a32d82 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -1,1067 +1,179 @@ -import datetime -import json -import os -import re import sys -import threading -import time -import logging import gradio as gr -import torch -import tqdm -import launch -import modules.interrogate -import modules.memmon -import datetime import threading -import modules.devices as devices -from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args -from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 -from ldm.models.diffusion.ddpm import LatentDiffusion -from typing import Optional - -log = logging.getLogger(__name__) - -demo = None - -parser = cmd_args.parser - -import json import sys -script_loading.preload_extensions(extensions_builtin_dir, parser) - -if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None: - cmd_opts = parser.parse_args() -else: - cmd_opts, _ = parser.parse_known_args() - - -restricted_opts = { - "samples_filename_pattern", - "directories_filename_pattern", - "outdir_samples", -import os import sys - "outdir_img2img_samples", - "outdir_extras_samples", - "outdir_grids", - "outdir_txt2img_grids", - "outdir_save", -import re import datetime -} -# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json -gradio_hf_hub_themes = [ - "gradio/glass", -import re import threading - "gradio/seafoam", - "gradio/soft", - "freddyaboulton/dracula_revamped", - "gradio/dracula_test", - "abidlabs/dracula_test", - "abidlabs/pakistan", - "dawood/microsoft_windows", - "ysharma/steampunk" -] - - cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access -devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ - (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer']) - -devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16 -devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16 - -device = devices.device -weight_load_location = None if cmd_opts.lowram else "cpu" - -batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) -parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram xformers_available = False -config_filename = cmd_opts.ui_settings_file - -import threading import time import threading -import logging -loaded_hypernetworks = [] - - -def reload_hypernetworks(): - from modules.hypernetworks import hypernetwork - global hypernetworks - - hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) - - -class State: -import time import sys - interrupted = False - job = "" -import time import logging -import time - processing_has_refined_job_count = False - job_timestamp = '0' - sampling_step = 0 - sampling_steps = 0 - current_latent = None - current_image = None -import logging import threading - id_live_preview = 0 - textinfo = None - time_start = None - server_start = None - _server_command_signal = threading.Event() - _server_command: Optional[str] = None - - @property - def need_restart(self) -> bool: - import sys - return self.server_command == "restart" - - @need_restart.setter - def need_restart(self, value: bool) -> None: - # Compatibility setter for need_restart. - if value: - self.server_command = "restart" - @property - def server_command(self): - return self._server_command - - @server_command.setter - def server_command(self, value: Optional[str]) -> None: -import gradio as gr import threading - Set the server command to `value` and signal that it's been set. -import gradio as gr import threading - self._server_command = value - self._server_command_signal.set() - -import torch import datetime -import threading - Wait for server command to get set; return and clear the value and signal. - """ - if self._server_command_signal.wait(timeout): - self._server_command_signal.clear() - req = self._server_command - self._server_command = None - return req - return None - -import torch import logging - self.interrupt() - self.server_command = "restart" - log.info("Received restart request") - - def skip(self): - self.skipped = True -import tqdm import re - - def interrupt(self): -import tqdm import threading - log.info("Received interrupt request") - - def nextjob(self): - if opts.live_previews_enable and opts.show_progress_every_n_steps == -1: - self.do_set_current_image() - - self.job_no += 1 - self.sampling_step = 0 - self.current_image_sampling_step = 0 - - def dict(self): - obj = { -import launch import threading - "interrupted": self.interrupted, - "job": self.job, - "job_count": self.job_count, - "job_timestamp": self.job_timestamp, - "job_no": self.job_no, -import modules.interrogate import json - "sampling_steps": self.sampling_steps, - } - - return obj - - def begin(self, job: str = "(unknown)"): -import launch import json -import datetime import re -import time - self.processing_has_refined_job_count = False -import modules.interrogate - self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") - self.current_latent = None -import modules.memmon import json -import launch import os - self.id_live_preview = 0 - self.skipped = False - self.interrupted = False - self.textinfo = None - self.time_start = time.time() - self.job = job -import modules.memmon -import datetime import threading - -import datetime device = devices.device - duration = time.time() - self.time_start - log.info("Ending job %s (%.2f seconds)", self.job, duration) - self.job = "" - self.job_count = 0 - -import modules.memmon - -import datetime config_filename = cmd_opts.ui_settings_file - """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" - if not parallel_processing_allowed: - return - - if self.sampling_step - self.current_image_sampling_step >= opts.show_progress_every_n_steps and opts.live_previews_enable and opts.show_progress_every_n_steps != -1: - self.do_set_current_image() - - def do_set_current_image(self): -import modules.devices as devices import json - return - import modules.sd_samplers - if opts.show_progress_grid: - self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) -import modules.devices as devices import threading - self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) - - self.current_image_sampling_step = self.sampling_step - - def assign_current_image(self, image): - self.current_image = image - self.id_live_preview += 1 - - -state = State() -state.server_start = time.time() - -styles_filename = cmd_opts.styles_file -from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args import sys -from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args import threading - -import datetime import logging -import time - -class OptionInfo: - def __init__(self, default=None, label="", component=None, component_args=None, onchange=None, section=None, refresh=None, comment_before='', comment_after=''): - self.default = default - self.label = label - self.component = component - self.component_args = component_args - self.onchange = onchange - self.section = section -from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 import threading -import datetime -import time - """HTML text that will be added after label in UI""" - - self.comment_after = comment_after - """HTML text that will be added before label in UI""" - - def link(self, label, url): - self.comment_before += f"[<a href='{url}' target='_blank'>{label}</a>]" - return self - - def js(self, label, js_func): - self.comment_before += f"[<a onclick='{js_func}(); return false'>{label}</a>]" - return self - -from ldm.models.diffusion.ddpm import LatentDiffusion import threading - self.comment_after += f"<span class='info'>({info})</span>" - return self - - def html(self, html): - self.comment_after += html - return self - - def needs_restart(self): - self.comment_after += " <span class='info'>(requires restart)</span>" - return self - - - - -def options_section(section_identifier, options_dict): - for v in options_dict.values(): - v.section = section_identifier - - return options_dict - - -from typing import Optional import threading - import modules.sd_models - return modules.sd_models.checkpoint_tiles() - - -def refresh_checkpoints(): - import modules.sd_models - return modules.sd_models.list_models() - - -def list_samplers(): - import modules.sd_samplers -log = logging.getLogger(__name__) import os - -hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} -tab_names = [] - -log = logging.getLogger(__name__) import threading - -options_templates.update(options_section(('saving-images', "Saving images/grids"), { - "samples_save": OptionInfo(True, "Always save all generated images"), - "samples_format": OptionInfo('png', 'File format for images'), - "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), - "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs), - - "grid_save": OptionInfo(True, "Always save all generated image grids"), - "grid_format": OptionInfo('png', 'File format for grids'), - "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), - "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"), -demo = None import threading - "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), - "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), - "font": OptionInfo("", "Font for image grids that have text"), -import json import re - "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}), - "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}), - "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), - "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), - "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."), -parser = cmd_args.parser import threading - "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), - "save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"), - "save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"), - "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), - "webp_lossless": OptionInfo(False, "Use lossless compression for webp images"), - "export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"), - "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number), - "target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number), - "img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"), - -script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file)) import threading -import json import sys -import time - "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), -script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file)) - - "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"), -script_loading.preload_extensions(extensions_builtin_dir, parser) import datetime - -})) - -options_templates.update(options_section(('saving-paths', "Paths for saving"), { - "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs), - "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), - "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), -script_loading.preload_extensions(extensions_builtin_dir, parser) +import logging import time - "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs), -script_loading.preload_extensions(extensions_builtin_dir, parser) - "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs), - "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs), - "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs), -import json import threading -import json - -options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { - "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"), - "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"), -if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None: import threading - "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), - "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}), -import json import threading -import json - -options_templates.update(options_section(('upscaling', "Upscaling"), { - "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), - "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), - "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), - "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), -import json import threading -import json - -options_templates.update(options_section(('face-restoration', "Face restoration"), { - "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}), - cmd_opts = parser.parse_args() import threading - cmd_opts = parser.parse_args() import time -import json import threading -import json - -options_templates.update(options_section(('system', "System"), { - "show_warnings": OptionInfo(False, "Show warnings in console."), - "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"), - "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), - "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), - "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), - "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), - "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), -import json import threading -import json - -options_templates.update(options_section(('training', "Training"), { - "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), -else: import logging -else: - "save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."), - "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), - "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), - "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), - "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), - "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"), - cmd_opts, _ = parser.parse_known_args() import threading - "training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."), - "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."), -import json import threading -import json - cmd_opts, _ = parser.parse_known_args() - "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), - "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), - "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), -restricted_opts = { import sys - "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), -restricted_opts = { import time import os -from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args - "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"), - "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill image's transparent parts with this color.", ui_components.FormColorPicker, {}), - "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), - "enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"), - "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), - "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"), - "samples_filename_pattern", import sys - "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), - "samples_filename_pattern", import time - "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), -})) - -options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { - "sdxl_crop_top": OptionInfo(0, "crop top coordinate"), - "sdxl_crop_left": OptionInfo(0, "crop left coordinate"), - "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"), - "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), -})) - - "directories_filename_pattern", import re - "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), - "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), - "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), - "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), - "directories_filename_pattern", - "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"), - "experimental_persistent_cond_cache": OptionInfo(False, "persistent cond cache").info("Experimental, keep cond caches across jobs, reduce overhead."), import json -import threading import json - -options_templates.update(options_section(('compatibility', "Compatibility"), { - "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), - "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), - "outdir_samples", import sys - "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), - "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), - "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), -})) -options_templates.update(options_section(('interrogate', "Interrogate Options"), { - "interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"), - "interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"), - "interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), - "interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), - "interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), - "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"), - "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types), -import os devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ -import os import sys -import logging -import os import sys - - "deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"), - "deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"), -})) - -options_templates.update(options_section(('extra_networks', "Extra Networks"), { - "extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."), - "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'), - "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), - "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"), - "outdir_img2img_samples", import time -import os import threading -import logging - "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"), - "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), - "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(), - "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), - "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"), - "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks), -})) - - "outdir_extras_samples", import sys -import os import time -import threading -import os import time -import time - "img2img_editor_height": OptionInfo(720, "img2img: height of image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_restart(), - "return_grid": OptionInfo(True, "Show grid in results for web"), - "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"), - "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"), - "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), - "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), - "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), - "outdir_grids", import sys - "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), - "outdir_grids", import time -import os import logging -import logging - "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), - "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_restart(), - "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_restart(), - "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), - "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), - "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"), - "outdir_txt2img_grids", import sys - "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(), - "outdir_txt2img_grids", import time -import os -import logging - "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(), - "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(), - "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(), - "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(), -})) - -options_templates.update(options_section(('infotext', "Infotext"), { - "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), - "outdir_save", import sys - "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"), - "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"), - "outdir_save", import logging -import re - -<li>Ignore: keep prompt and styles dropdown as it is.</li> -<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li> -<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li> -<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li> -</ul>"""), - -})) - - "outdir_init_images" import sys - "show_progressbar": OptionInfo(True, "Show progressbar"), - "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), - "outdir_init_images" import logging -import re import datetime - - "show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"), - "show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"), - "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), - "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), -})) - -options_templates.update(options_section(('sampler-params', "Sampler parameters"), { -} import sys - "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; higher = more unperdictable results"), - "eta_ancestral": OptionInfo(1.0, "Eta for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}).info("noise multiplier; applies to Euler a and other samplers that have a in them"), -} import logging -import re import json - - 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), import re -restricted_opts = { -# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json import json - 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"), - 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"), - 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"), - 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"), - 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"), - 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}), -# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json - 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}).info("must be < sampling steps"), - 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final"), -})) - -options_templates.update(options_section(('postprocessing', "Postprocessing"), { - 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), - 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), -gradio_hf_hub_themes = [ import sys -})) - -options_templates.update(options_section((None, "Hidden options"), { - "disabled_extensions": OptionInfo([], "Disable these extensions"), -gradio_hf_hub_themes = [ import logging - "restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"), - "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), -})) - - -options_templates.update() - - -class Options: - "gradio/glass", import os - data_labels = options_templates - typemap = {int: float} - def __init__(self): - self.data = {k: v.default for k, v in self.data_labels.items()} - -import re (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer']) import re -devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16 - if key in self.data or key in self.data_labels: - assert not cmd_opts.freeze_settings, "changing settings is disabled" - info = opts.data_labels.get(key, None) - comp_args = info.component_args if info else None - if isinstance(comp_args, dict) and comp_args.get('visible', True) is False: - "gradio/monochrome", import sys - - if cmd_opts.hide_ui_dir_config and key in restricted_opts: - raise RuntimeError(f"not possible to set {key} because it is restricted") - - self.data[key] = value - "gradio/monochrome", import logging - - return super(Options, self).__setattr__(key, value) - - def __getattr__(self, item): -import re import sys - - if item in self.data: - return self.data[item] - if item in self.data_labels: - return self.data_labels[item].default - - "gradio/seafoam", import sys - - def set(self, key, value): - """sets an option and calls its onchange callback, returning True if the option changed and False otherwise""" - - oldval = self.data.get(key, None) - if oldval == value: - return False - - try: - setattr(self, key, value) - except RuntimeError: - return False - - if self.data_labels[key].onchange is not None: - try: -import re current_image_sampling_step = 0 - except Exception as e: - errors.display(e, f"changing setting {key} to {value}") - "gradio/soft", - return False - - return True - - def get_default(self, key): - """returns the default value for the key""" - - data_label = self.data_labels.get(key) - if data_label is None: - "freddyaboulton/dracula_revamped", import threading - - "freddyaboulton/dracula_revamped", import time - - def save(self, filename): - assert not cmd_opts.freeze_settings, "saving settings is disabled" - - "gradio/dracula_test", - json.dump(self.data, file, indent=4) - def same_type(self, x, y): - if x is None or y is None: - return True - - type_x = self.typemap.get(type(x), type(x)) - "gradio/dracula_test", import threading - - "gradio/dracula_test", import time - - def load(self, filename): - with open(filename, "r", encoding="utf8") as file: -import sys import datetime - # 1.1.1 quicksettings list migration - if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None: - self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')] - - # 1.4.0 ui_reorder - if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data: - "abidlabs/dracula_test", import threading - - "abidlabs/dracula_test", import time - for k, v in self.data.items(): - info = self.data_labels.get(k, None) -import sys import json - print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr) - bad_settings += 1 - - if bad_settings > 0: - print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr) - - def onchange(self, key, func, call=True): - "abidlabs/pakistan", import threading - "abidlabs/pakistan", import time - - if call: - func() - -import sys import os - d = {k: self.data.get(k, v.default) for k, v in self.data_labels.items()} - d["_comments_before"] = {k: v.comment_before for k, v in self.data_labels.items() if v.comment_before is not None} - d["_comments_after"] = {k: v.comment_after for k, v in self.data_labels.items() if v.comment_after is not None} - return json.dumps(d) - - def add_option(self, key, info): - "dawood/microsoft_windows", import threading - - "dawood/microsoft_windows", import time - """reorder settings so that all items related to section always go together""" - - section_ids = {} -import sys import re - for _, item in settings_items: - if item.section not in section_ids: - section_ids[item.section] = len(section_ids) - self.data_labels = dict(sorted(settings_items, key=lambda x: section_ids[x[1].section])) - - def cast_value(self, key, value): - "ysharma/steampunk" import threading - "ysharma/steampunk" import time - """ - import sys - "gradio/soft", - "freddyaboulton/dracula_revamped", import threading - - default_value = self.data_labels[key].default - if default_value is None: - default_value = getattr(self, key, None) - if default_value is None: - return None - - expected_type = type(default_value) - if expected_type == bool and value == "False": - value = False -import datetime interrupted = False - value = expected_type(value) - -] import threading - - -] import time -if os.path.exists(config_filename): - opts.load(config_filename) - - -class Shared(sys.modules[__name__].__class__): - """ - this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than - at program startup. - """ - - sd_model_val = None - - @property - def sd_model(self): - import modules.sd_models - -cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access import time - -import sys import threading -import logging - def sd_model(self, value): - import modules.sd_models - - modules.sd_models.model_data.set_sd_model(value) - - -sd_model: LatentDiffusion = None # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead -sys.modules[__name__].__class__ = Shared - -settings_components = None -"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings""" - -latent_upscale_default_mode = "Latent" -latent_upscale_modes = { - "Latent": {"mode": "bilinear", "antialias": False}, -import sys job_no = 0 - "Latent (bicubic)": {"mode": "bicubic", "antialias": False}, - "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True}, - "Latent (nearest)": {"mode": "nearest", "antialias": False}, - "Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False}, -} - -sd_upscalers = [] - -clip_model = None - -progress_print_out = sys.stdout - - (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer']) import threading - - - (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer']) import time - global gradio_theme - (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer']) - theme_name = opts.gradio_theme - - default_theme_args = dict( - font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'], - font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'], - ) - - if theme_name == "Default": -devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16 import threading - else: -import re import logging -import datetime - gradio_theme = gr.themes.ThemeClass.from_hub(theme_name) - except Exception as e: -import threading - gradio_theme = gr.themes.Default(**default_theme_args) - - import threading -import json - def __init__(self): - self._tqdm = None - - def reset(self): - self._tqdm = tqdm.tqdm( - desc="Total progress", - total=state.job_count * state.sampling_steps, -devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16 import logging - file=progress_print_out -import threading import datetime import threading -import torch - if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: - return - if self._tqdm is None: - self.reset() - self._tqdm.update() - - def updateTotal(self, new_total): - if not opts.multiple_tqdm or cmd_opts.disable_console_progressbars: - return - if self._tqdm is None: - self.reset() - self._tqdm.total = new_total - -device = devices.device import logging - if self._tqdm is not None: -import threading import json import threading -from typing import Optional - self._tqdm = None - - -total_tqdm = TotalTQDM() - -mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts) -mem_mon.start() - - -def natural_sort_key(s, regex=re.compile('([0-9]+)')): - return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)] - - -weight_load_location = None if cmd_opts.lowram else "cpu" import logging - filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")] -import threading import os - - import threading -restricted_opts = { - return os.path.join(script_path, "html", filename) - - -def html(filename): - path = html_path(filename) - - if os.path.exists(path): - with open(path, encoding="utf8") as file: - return file.read() - -batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram) import logging - - -def walk_files(path, allowed_extensions=None): -import threading import re import threading - "outdir_init_images" - - if allowed_extensions is not None: - allowed_extensions = set(allowed_extensions) - - items = list(os.walk(path, followlinks=True)) - items = sorted(items, key=lambda x: natural_sort_key(x[0])) - - for root, _, files in items: - for filename in sorted(files, key=natural_sort_key): -parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram import logging - _, ext = os.path.splitext(filename) -import threading import sys - continue - - if not opts.list_hidden_files and ("/." in root or "\\." in root): - continue - - yield os.path.join(root, filename) diff --git a/modules/shared_cmd_options.py b/modules/shared_cmd_options.py new file mode 100644 index 0000000000000000000000000000000000000000..af24938b05f34c8baaa02f6429df208a3b37c896 --- /dev/null +++ b/modules/shared_cmd_options.py @@ -0,0 +1,18 @@ +import os + +import launch +from modules import cmd_args, script_loading +from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 + +parser = cmd_args.parser + +script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file)) +script_loading.preload_extensions(extensions_builtin_dir, parser) + +if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None: + cmd_opts = parser.parse_args() +else: + cmd_opts, _ = parser.parse_known_args() + + +cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access diff --git a/modules/shared_gradio_themes.py b/modules/shared_gradio_themes.py new file mode 100644 index 0000000000000000000000000000000000000000..822db0a951d866b629074a8c1d580765d3f35cfe --- /dev/null +++ b/modules/shared_gradio_themes.py @@ -0,0 +1,67 @@ +import os + +import gradio as gr + +from modules import errors, shared +from modules.paths_internal import script_path + + +# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json +gradio_hf_hub_themes = [ + "gradio/base", + "gradio/glass", + "gradio/monochrome", + "gradio/seafoam", + "gradio/soft", + "gradio/dracula_test", + "abidlabs/dracula_test", + "abidlabs/Lime", + "abidlabs/pakistan", + "Ama434/neutral-barlow", + "dawood/microsoft_windows", + "finlaymacklon/smooth_slate", + "Franklisi/darkmode", + "freddyaboulton/dracula_revamped", + "freddyaboulton/test-blue", + "gstaff/xkcd", + "Insuz/Mocha", + "Insuz/SimpleIndigo", + "JohnSmith9982/small_and_pretty", + "nota-ai/theme", + "nuttea/Softblue", + "ParityError/Anime", + "reilnuud/polite", + "remilia/Ghostly", + "rottenlittlecreature/Moon_Goblin", + "step-3-profit/Midnight-Deep", + "Taithrah/Minimal", + "ysharma/huggingface", + "ysharma/steampunk", + "NoCrypt/miku" +] + + +def reload_gradio_theme(theme_name=None): + if not theme_name: + theme_name = shared.opts.gradio_theme + + default_theme_args = dict( + font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'], + font_mono=['IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace'], + ) + + if theme_name == "Default": + shared.gradio_theme = gr.themes.Default(**default_theme_args) + else: + try: + theme_cache_dir = os.path.join(script_path, 'tmp', 'gradio_themes') + theme_cache_path = os.path.join(theme_cache_dir, f'{theme_name.replace("/", "_")}.json') + if shared.opts.gradio_themes_cache and os.path.exists(theme_cache_path): + shared.gradio_theme = gr.themes.ThemeClass.load(theme_cache_path) + else: + os.makedirs(theme_cache_dir, exist_ok=True) + shared.gradio_theme = gr.themes.ThemeClass.from_hub(theme_name) + shared.gradio_theme.dump(theme_cache_path) + except Exception as e: + errors.display(e, "changing gradio theme") + shared.gradio_theme = gr.themes.Default(**default_theme_args) diff --git a/modules/shared_init.py b/modules/shared_init.py new file mode 100644 index 0000000000000000000000000000000000000000..d3fb687e0cdab552e143607791c7dc3a4725ccc8 --- /dev/null +++ b/modules/shared_init.py @@ -0,0 +1,49 @@ +import os + +import torch + +from modules import shared +from modules.shared import cmd_opts + + +def initialize(): + """Initializes fields inside the shared module in a controlled manner. + + Should be called early because some other modules you can import mingt need these fields to be already set. + """ + + os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) + + from modules import options, shared_options + shared.options_templates = shared_options.options_templates + shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts) + shared.restricted_opts = shared_options.restricted_opts + if os.path.exists(shared.config_filename): + shared.opts.load(shared.config_filename) + + from modules import devices + devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ + (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer']) + + devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16 + + shared.device = devices.device + shared.weight_load_location = None if cmd_opts.lowram else "cpu" + + from modules import shared_state + shared.state = shared_state.State() + + from modules import styles + shared.prompt_styles = styles.StyleDatabase(shared.styles_filename) + + from modules import interrogate + shared.interrogator = interrogate.InterrogateModels("interrogate") + + from modules import shared_total_tqdm + shared.total_tqdm = shared_total_tqdm.TotalTQDM() + + from modules import memmon, devices + shared.mem_mon = memmon.MemUsageMonitor("MemMon", devices.device, shared.opts) + shared.mem_mon.start() + diff --git a/modules/shared_items.py b/modules/shared_items.py index 89792e88aec9a0f66ce44e9983cbae4e41ffe897..84d69c8df43a3a638bbea08194ad0940e011a712 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -1,3 +1,6 @@ +import sys + +from modules.shared_cmd_options import cmd_opts def realesrgan_models_names(): @@ -41,15 +44,37 @@ modules.sd_unet.list_unets() +def list_checkpoint_tiles(): + import modules.sd_models + return modules.sd_models.checkpoint_tiles() + + +def refresh_checkpoints(): + import modules.sd_models + return modules.sd_models.list_models() + + +def list_samplers(): + import modules.sd_samplers + return modules.sd_samplers.all_samplers + + +def reload_hypernetworks(): + from modules.hypernetworks import hypernetwork + from modules import shared + + shared.hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) + + ui_reorder_categories_builtin_items = [ "inpaint", "sampler", + "accordions", "checkboxes", def realesrgan_models_names(): -def postprocessing_scripts(): -def realesrgan_models_names(): import modules.scripts "cfg", + "denoising", "seed", "batch", "override_settings", @@ -63,10 +88,34 @@ yield from ui_reorder_categories_builtin_items sections = {} for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts: - import modules.realesrgan_model +def postprocessing_scripts(): return modules.scripts.scripts_postproc.scripts sections[script.section] = 1 yield from sections yield "scripts" + + +class Shared(sys.modules[__name__].__class__): + """ + this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than + at program startup. + """ + + sd_model_val = None + + @property + def sd_model(self): + import modules.sd_models + + return modules.sd_models.model_data.get_sd_model() + + @sd_model.setter + def sd_model(self, value): + import modules.sd_models + + modules.sd_models.model_data.set_sd_model(value) + + +sys.modules['modules.shared'].__class__ = Shared diff --git a/modules/shared_options.py b/modules/shared_options.py new file mode 100644 index 0000000000000000000000000000000000000000..83f56314900806d5393d22117f8d6485dd7e5b26 --- /dev/null +++ b/modules/shared_options.py @@ -0,0 +1,330 @@ +import gradio as gr + +from modules import localization, ui_components, shared_items, shared, interrogate, shared_gradio_themes +from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 +from modules.shared_cmd_options import cmd_opts +from modules.options import options_section, OptionInfo, OptionHTML + +options_templates = {} +hide_dirs = shared.hide_dirs + +restricted_opts = { + "samples_filename_pattern", + "directories_filename_pattern", + "outdir_samples", + "outdir_txt2img_samples", + "outdir_img2img_samples", + "outdir_extras_samples", + "outdir_grids", + "outdir_txt2img_grids", + "outdir_save", + "outdir_init_images" +} + +options_templates.update(options_section(('saving-images', "Saving images/grids"), { + "samples_save": OptionInfo(True, "Always save all generated images"), + "samples_format": OptionInfo('png', 'File format for images'), + "samples_filename_pattern": OptionInfo("", "Images filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), + "save_images_add_number": OptionInfo(True, "Add number to filename when saving", component_args=hide_dirs), + + "grid_save": OptionInfo(True, "Always save all generated image grids"), + "grid_format": OptionInfo('png', 'File format for grids'), + "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), + "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"), + "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"), + "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), + "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), + "font": OptionInfo("", "Font for image grids that have text"), + "grid_text_active_color": OptionInfo("#000000", "Text color for image grids", ui_components.FormColorPicker, {}), + "grid_text_inactive_color": OptionInfo("#999999", "Inactive text color for image grids", ui_components.FormColorPicker, {}), + "grid_background_color": OptionInfo("#ffffff", "Background color for image grids", ui_components.FormColorPicker, {}), + + "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), + "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."), + "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."), + "save_images_before_highres_fix": OptionInfo(False, "Save a copy of image before applying highres fix."), + "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), + "save_mask": OptionInfo(False, "For inpainting, save a copy of the greyscale mask"), + "save_mask_composite": OptionInfo(False, "For inpainting, save a masked composite"), + "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}), + "webp_lossless": OptionInfo(False, "Use lossless compression for webp images"), + "export_for_4chan": OptionInfo(True, "Save copy of large images as JPG").info("if the file size is above the limit, or either width or height are above the limit"), + "img_downscale_threshold": OptionInfo(4.0, "File size limit for the above option, MB", gr.Number), + "target_side_length": OptionInfo(4000, "Width/height limit for the above option, in pixels", gr.Number), + "img_max_size_mp": OptionInfo(200, "Maximum image size", gr.Number).info("in megapixels"), + + "use_original_name_batch": OptionInfo(True, "Use original name for output filename during batch process in extras tab"), + "use_upscaler_name_as_suffix": OptionInfo(False, "Use upscaler name as filename suffix in the extras tab"), + "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"), + "save_init_img": OptionInfo(False, "Save init images when using img2img"), + + "temp_dir": OptionInfo("", "Directory for temporary images; leave empty for default"), + "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"), + + "save_incomplete_images": OptionInfo(False, "Save incomplete images").info("save images that has been interrupted in mid-generation; even if not saved, they will still show up in webui output."), +})) + +options_templates.update(options_section(('saving-paths', "Paths for saving"), { + "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs), + "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs), + "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs), + "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs), + "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs), + "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs), + "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs), + "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs), + "outdir_init_images": OptionInfo("outputs/init-images", "Directory for saving init images when using img2img", component_args=hide_dirs), +})) + +options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), { + "save_to_dirs": OptionInfo(True, "Save images to a subdirectory"), + "grid_save_to_dirs": OptionInfo(True, "Save grids to a subdirectory"), + "use_save_to_dirs_for_ui": OptionInfo(False, "When using \"Save\" button, save images to a subdirectory"), + "directories_filename_pattern": OptionInfo("[date]", "Directory name pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), + "directories_max_prompt_words": OptionInfo(8, "Max prompt words for [prompt_words] pattern", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1, **hide_dirs}), +})) + +options_templates.update(options_section(('upscaling', "Upscaling"), { + "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).info("0 = no tiling"), + "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap for ESRGAN upscalers.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}).info("Low values = visible seam"), + "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI.", gr.CheckboxGroup, lambda: {"choices": shared_items.realesrgan_models_names()}), + "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in shared.sd_upscalers]}), +})) + +options_templates.update(options_section(('face-restoration', "Face restoration"), { + "face_restoration": OptionInfo(False, "Restore faces", infotext='Face restoration').info("will use a third-party model on generation result to reconstruct faces"), + "face_restoration_model": OptionInfo("CodeFormer", "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in shared.face_restorers]}), + "code_former_weight": OptionInfo(0.5, "CodeFormer weight", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}).info("0 = maximum effect; 1 = minimum effect"), + "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"), +})) + +options_templates.update(options_section(('system', "System"), { + "auto_launch_browser": OptionInfo("Local", "Automatically open webui in browser on startup", gr.Radio, lambda: {"choices": ["Disable", "Local", "Remote"]}), + "show_warnings": OptionInfo(False, "Show warnings in console.").needs_reload_ui(), + "show_gradio_deprecation_warnings": OptionInfo(True, "Show gradio deprecation warnings in console.").needs_reload_ui(), + "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}).info("0 = disable"), + "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"), + "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."), + "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), + "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), + "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), + "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), +})) + +options_templates.update(options_section(('API', "API"), { + "api_enable_requests": OptionInfo(True, "Allow http:// and https:// URLs for input images in API", restrict_api=True), + "api_forbid_local_requests": OptionInfo(True, "Forbid URLs to local resources", restrict_api=True), + "api_useragent": OptionInfo("", "User agent for requests", restrict_api=True), +})) + +options_templates.update(options_section(('training', "Training"), { + "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."), + "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."), + "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."), + "save_training_settings_to_txt": OptionInfo(True, "Save textual inversion and hypernet settings to a text file whenever training starts."), + "dataset_filename_word_regex": OptionInfo("", "Filename word regex"), + "dataset_filename_join_string": OptionInfo(" ", "Filename join string"), + "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}), + "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"), + "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"), + "training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."), + "training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."), + "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."), +})) + +options_templates.update(options_section(('sd', "Stable Diffusion"), { + "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext='Model hash'), + "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), + "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"), + "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"), + "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"), + "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds").needs_reload_ui(), + "enable_emphasis": OptionInfo(True, "Enable emphasis").info("use (text) to make model pay more attention to text and [text] to make it pay less attention"), + "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), + "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"), + "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}, infotext="Clip skip").link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), + "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), + "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), + "tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"), +})) + +options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { + "sdxl_crop_top": OptionInfo(0, "crop top coordinate"), + "sdxl_crop_left": OptionInfo(0, "crop left coordinate"), + "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"), + "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), +})) + +options_templates.update(options_section(('vae', "VAE"), { + "sd_vae_explanation": OptionHTML(""" +<abbr title='Variational autoencoder'>VAE</abbr> is a neural network that transforms a standard <abbr title='red/green/blue'>RGB</abbr> +image into latent space representation and back. Latent space representation is what stable diffusion is working on during sampling +(i.e. when the progress bar is between empty and full). For txt2img, VAE is used to create a resulting image after the sampling is finished. +For img2img, VAE is used to process user's input image before the sampling, and to create an image after sampling. +"""), + "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list, infotext='VAE').info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), + "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"), + "auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), + "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"), + "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"), +})) + +options_templates.update(options_section(('img2img', "img2img"), { + "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Conditional mask weight'), + "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.0, "maximum": 1.5, "step": 0.001}, infotext='Noise multiplier'), + "img2img_extra_noise": OptionInfo(0.0, "Extra noise multiplier for img2img and hires fix", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Extra noise').info("0 = disabled (default); should be lower than denoising strength"), + "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), + "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies.").info("normally you'd do less with less denoising"), + "img2img_background_color": OptionInfo("#ffffff", "With img2img, fill transparent parts of the input image with this color.", ui_components.FormColorPicker, {}), + "img2img_editor_height": OptionInfo(720, "Height of the image editor", gr.Slider, {"minimum": 80, "maximum": 1600, "step": 1}).info("in pixels").needs_reload_ui(), + "img2img_sketch_default_brush_color": OptionInfo("#ffffff", "Sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img sketch").needs_reload_ui(), + "img2img_inpaint_mask_brush_color": OptionInfo("#ffffff", "Inpaint mask brush color", ui_components.FormColorPicker, {}).info("brush color of inpaint mask").needs_reload_ui(), + "img2img_inpaint_sketch_default_brush_color": OptionInfo("#ffffff", "Inpaint sketch initial brush color", ui_components.FormColorPicker, {}).info("default brush color of img2img inpaint sketch").needs_reload_ui(), + "return_mask": OptionInfo(False, "For inpainting, include the greyscale mask in results for web"), + "return_mask_composite": OptionInfo(False, "For inpainting, include masked composite in results for web"), +})) + +options_templates.update(options_section(('optimizations', "Optimizations"), { + "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), + "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 15.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), + "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), + "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), + "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}, infotext='Token merging ratio hr').info("only applies if non-zero and overrides above"), + "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), + "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), + "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), +})) + +options_templates.update(options_section(('compatibility', "Compatibility"), { + "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), + "use_old_karras_scheduler_sigmas": OptionInfo(False, "Use old karras scheduler sigmas (0.1 to 10)."), + "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."), + "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), + "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), + "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), + "use_old_scheduling": OptionInfo(False, "Use old prompt editing timelines.", infotext="Old prompt editing timelines").info("For [red:green:N]; old: If N < 1, it's a fraction of steps (and hires fix uses range from 0 to 1), if N >= 1, it's an absolute number of steps; new: If N has a decimal point in it, it's a fraction of steps (and hires fix uses range from 1 to 2), othewrwise it's an absolute number of steps"), +})) + +options_templates.update(options_section(('interrogate', "Interrogate"), { + "interrogate_keep_models_in_memory": OptionInfo(False, "Keep models in VRAM"), + "interrogate_return_ranks": OptionInfo(False, "Include ranks of model tags matches in results.").info("booru only"), + "interrogate_clip_num_beams": OptionInfo(1, "BLIP: num_beams", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}), + "interrogate_clip_min_length": OptionInfo(24, "BLIP: minimum description length", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}), + "interrogate_clip_max_length": OptionInfo(48, "BLIP: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}), + "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file").info("0 = No limit"), + "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": interrogate.category_types()}, refresh=interrogate.category_types), + "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "deepbooru: score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}), + "deepbooru_sort_alpha": OptionInfo(True, "deepbooru: sort tags alphabetically").info("if not: sort by score"), + "deepbooru_use_spaces": OptionInfo(True, "deepbooru: use spaces in tags").info("if not: use underscores"), + "deepbooru_escape": OptionInfo(True, "deepbooru: escape (\\) brackets").info("so they are used as literal brackets and not for emphasis"), + "deepbooru_filter_tags": OptionInfo("", "deepbooru: filter out those tags").info("separate by comma"), +})) + +options_templates.update(options_section(('extra_networks', "Extra Networks"), { + "extra_networks_show_hidden_directories": OptionInfo(True, "Show hidden directories").info("directory is hidden if its name starts with \".\"."), + "extra_networks_hidden_models": OptionInfo("When searched", "Show cards for models in hidden directories", gr.Radio, {"choices": ["Always", "When searched", "Never"]}).info('"When searched" option will only show the item when the search string has 4 characters or more'), + "extra_networks_default_multiplier": OptionInfo(1.0, "Default multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}), + "extra_networks_card_width": OptionInfo(0, "Card width for Extra Networks").info("in pixels"), + "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), + "extra_networks_card_text_scale": OptionInfo(1.0, "Card text scale", gr.Slider, {"minimum": 0.0, "maximum": 2.0, "step": 0.01}).info("1 = original size"), + "extra_networks_card_show_desc": OptionInfo(True, "Show description on card"), + "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), + "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_reload_ui(), + "textual_inversion_print_at_load": OptionInfo(False, "Print a list of Textual Inversion embeddings when loading model"), + "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"), + "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *shared.hypernetworks]}, refresh=shared_items.reload_hypernetworks), +})) + +options_templates.update(options_section(('ui', "User interface"), { + "localization": OptionInfo("None", "Localization", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)).needs_reload_ui(), + "gradio_theme": OptionInfo("Default", "Gradio theme", ui_components.DropdownEditable, lambda: {"choices": ["Default"] + shared_gradio_themes.gradio_hf_hub_themes}).info("you can also manually enter any of themes from the <a href='https://huggingface.co/spaces/gradio/theme-gallery'>gallery</a>.").needs_reload_ui(), + "gradio_themes_cache": OptionInfo(True, "Cache gradio themes locally").info("disable to update the selected Gradio theme"), + "gallery_height": OptionInfo("", "Gallery height", gr.Textbox).info("an be any valid CSS value").needs_reload_ui(), + "return_grid": OptionInfo(True, "Show grid in results for web"), + "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), + "send_seed": OptionInfo(True, "Send seed when sending prompt or image to other interface"), + "send_size": OptionInfo(True, "Send size when sending prompt or image to another interface"), + "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"), + "js_modal_lightbox_initially_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"), + "js_modal_lightbox_gamepad": OptionInfo(False, "Navigate image viewer with gamepad"), + "js_modal_lightbox_gamepad_repeat": OptionInfo(250, "Gamepad repeat period, in milliseconds"), + "show_progress_in_title": OptionInfo(True, "Show generation progress in window title."), + "samplers_in_dropdown": OptionInfo(True, "Use dropdown for sampler selection instead of radio group").needs_reload_ui(), + "dimensions_and_batch_together": OptionInfo(True, "Show Width/Height and Batch sliders in same row").needs_reload_ui(), + "keyedit_precision_attention": OptionInfo(0.1, "Ctrl+up/down precision when editing (attention:1.1)", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "keyedit_precision_extra": OptionInfo(0.05, "Ctrl+up/down precision when editing <extra networks:0.9>", gr.Slider, {"minimum": 0.01, "maximum": 0.2, "step": 0.001}), + "keyedit_delimiters": OptionInfo(".,\\/!?%^*;:{}=`~()", "Ctrl+up/down word delimiters"), + "keyedit_move": OptionInfo(True, "Alt+left/right moves prompt elements"), + "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_reload_ui(), + "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), + "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(shared.tab_names)}).needs_reload_ui(), + "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_reload_ui(), + "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_reload_ui(), + "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), + "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_reload_ui(), +})) + + +options_templates.update(options_section(('infotext', "Infotext"), { + "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), + "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), + "add_user_name_to_info": OptionInfo(False, "Add user name to generation information when authenticated"), + "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"), + "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"), + "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""<ul style='margin-left: 1.5em'> +<li>Ignore: keep prompt and styles dropdown as it is.</li> +<li>Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).</li> +<li>Discard: remove style text from prompt, keep styles dropdown as it is.</li> +<li>Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.</li> +</ul>"""), + +})) + +options_templates.update(options_section(('ui', "Live previews"), { + "show_progressbar": OptionInfo(True, "Show progressbar"), + "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), + "live_previews_image_format": OptionInfo("png", "Live preview file format", gr.Radio, {"choices": ["jpeg", "png", "webp"]}), + "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), + "show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"), + "show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"), + "live_preview_allow_lowvram_full": OptionInfo(False, "Allow Full live preview method with lowvram/medvram").info("If not, Approx NN will be used instead; Full live preview method is very detrimental to speed if lowvram/medvram optimizations are enabled"), + "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), + "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), + "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), +})) + +options_templates.update(options_section(('sampler-params', "Sampler parameters"), { + "hide_samplers": OptionInfo([], "Hide samplers in user interface", gr.CheckboxGroup, lambda: {"choices": [x.name for x in shared_items.list_samplers()]}).needs_reload_ui(), + "eta_ddim": OptionInfo(0.0, "Eta for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta DDIM').info("noise multiplier; higher = more unpredictable results"), + "eta_ancestral": OptionInfo(1.0, "Eta for k-diffusion samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}, infotext='Eta').info("noise multiplier; currently only applies to ancestral samplers (i.e. Euler a) and SDE samplers"), + "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}), + 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 100.0, "step": 0.01}, infotext='Sigma churn').info('amount of stochasticity; only applies to Euler, Heun, and DPM2'), + 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 10.0, "step": 0.01}, infotext='Sigma tmin').info('enable stochasticity; start value of the sigma range; only applies to Euler, Heun, and DPM2'), + 's_tmax': OptionInfo(0.0, "sigma tmax", gr.Slider, {"minimum": 0.0, "maximum": 999.0, "step": 0.01}, infotext='Sigma tmax').info("0 = inf; end value of the sigma range; only applies to Euler, Heun, and DPM2"), + 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.1, "step": 0.001}, infotext='Sigma noise').info('amount of additional noise to counteract loss of detail during sampling'), + 'k_sched_type': OptionInfo("Automatic", "Scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}, infotext='Schedule type').info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"), + 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number, infotext='Schedule max sigma').info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"), + 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number, infotext='Schedule min sigma').info("0 = default (~14.6); maximum noise strength for k-diffusion noise scheduler"), + 'rho': OptionInfo(0.0, "rho", gr.Number, infotext='Schedule rho').info("0 = default (7 for karras, 1 for polyexponential); higher values result in a steeper noise schedule (decreases faster)"), + 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}, infotext='ENSD').info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"), + 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma", infotext='Discard penultimate sigma').link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"), + 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}, infotext='UniPC variant'), + 'uni_pc_skip_type': OptionInfo("time_uniform", "UniPC skip type", gr.Radio, {"choices": ["time_uniform", "time_quadratic", "logSNR"]}, infotext='UniPC skip type'), + 'uni_pc_order': OptionInfo(3, "UniPC order", gr.Slider, {"minimum": 1, "maximum": 50, "step": 1}, infotext='UniPC order').info("must be < sampling steps"), + 'uni_pc_lower_order_final': OptionInfo(True, "UniPC lower order final", infotext='UniPC lower order final'), +})) + +options_templates.update(options_section(('postprocessing', "Postprocessing"), { + 'postprocessing_enable_in_main_ui': OptionInfo([], "Enable postprocessing operations in txt2img and img2img tabs", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), + 'postprocessing_operation_order': OptionInfo([], "Postprocessing operation order", ui_components.DropdownMulti, lambda: {"choices": [x.name for x in shared_items.postprocessing_scripts()]}), + 'upscaling_max_images_in_cache': OptionInfo(5, "Maximum number of images in upscaling cache", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), +})) + +options_templates.update(options_section((None, "Hidden options"), { + "disabled_extensions": OptionInfo([], "Disable these extensions"), + "disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}), + "restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"), + "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), +})) + diff --git a/modules/shared_state.py b/modules/shared_state.py new file mode 100644 index 0000000000000000000000000000000000000000..d272ee5bc2c046554fc8f9237b3f31957e9f5bf7 --- /dev/null +++ b/modules/shared_state.py @@ -0,0 +1,159 @@ +import datetime +import logging +import threading +import time + +from modules import errors, shared, devices +from typing import Optional + +log = logging.getLogger(__name__) + + +class State: + skipped = False + interrupted = False + job = "" + job_no = 0 + job_count = 0 + processing_has_refined_job_count = False + job_timestamp = '0' + sampling_step = 0 + sampling_steps = 0 + current_latent = None + current_image = None + current_image_sampling_step = 0 + id_live_preview = 0 + textinfo = None + time_start = None + server_start = None + _server_command_signal = threading.Event() + _server_command: Optional[str] = None + + def __init__(self): + self.server_start = time.time() + + @property + def need_restart(self) -> bool: + # Compatibility getter for need_restart. + return self.server_command == "restart" + + @need_restart.setter + def need_restart(self, value: bool) -> None: + # Compatibility setter for need_restart. + if value: + self.server_command = "restart" + + @property + def server_command(self): + return self._server_command + + @server_command.setter + def server_command(self, value: Optional[str]) -> None: + """ + Set the server command to `value` and signal that it's been set. + """ + self._server_command = value + self._server_command_signal.set() + + def wait_for_server_command(self, timeout: Optional[float] = None) -> Optional[str]: + """ + Wait for server command to get set; return and clear the value and signal. + """ + if self._server_command_signal.wait(timeout): + self._server_command_signal.clear() + req = self._server_command + self._server_command = None + return req + return None + + def request_restart(self) -> None: + self.interrupt() + self.server_command = "restart" + log.info("Received restart request") + + def skip(self): + self.skipped = True + log.info("Received skip request") + + def interrupt(self): + self.interrupted = True + log.info("Received interrupt request") + + def nextjob(self): + if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1: + self.do_set_current_image() + + self.job_no += 1 + self.sampling_step = 0 + self.current_image_sampling_step = 0 + + def dict(self): + obj = { + "skipped": self.skipped, + "interrupted": self.interrupted, + "job": self.job, + "job_count": self.job_count, + "job_timestamp": self.job_timestamp, + "job_no": self.job_no, + "sampling_step": self.sampling_step, + "sampling_steps": self.sampling_steps, + } + + return obj + + def begin(self, job: str = "(unknown)"): + self.sampling_step = 0 + self.job_count = -1 + self.processing_has_refined_job_count = False + self.job_no = 0 + self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + self.current_latent = None + self.current_image = None + self.current_image_sampling_step = 0 + self.id_live_preview = 0 + self.skipped = False + self.interrupted = False + self.textinfo = None + self.time_start = time.time() + self.job = job + devices.torch_gc() + log.info("Starting job %s", job) + + def end(self): + duration = time.time() - self.time_start + log.info("Ending job %s (%.2f seconds)", self.job, duration) + self.job = "" + self.job_count = 0 + + devices.torch_gc() + + def set_current_image(self): + """if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly""" + if not shared.parallel_processing_allowed: + return + + if self.sampling_step - self.current_image_sampling_step >= shared.opts.show_progress_every_n_steps and shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps != -1: + self.do_set_current_image() + + def do_set_current_image(self): + if self.current_latent is None: + return + + import modules.sd_samplers + + try: + if shared.opts.show_progress_grid: + self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) + else: + self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) + + self.current_image_sampling_step = self.sampling_step + + except Exception: + # when switching models during genration, VAE would be on CPU, so creating an image will fail. + # we silently ignore this error + errors.record_exception() + + def assign_current_image(self, image): + self.current_image = image + self.id_live_preview += 1 diff --git a/modules/shared_total_tqdm.py b/modules/shared_total_tqdm.py new file mode 100644 index 0000000000000000000000000000000000000000..cf82e10478f853d4fccacd7b75962006aa6b4293 --- /dev/null +++ b/modules/shared_total_tqdm.py @@ -0,0 +1,37 @@ +import tqdm + +from modules import shared + + +class TotalTQDM: + def __init__(self): + self._tqdm = None + + def reset(self): + self._tqdm = tqdm.tqdm( + desc="Total progress", + total=shared.state.job_count * shared.state.sampling_steps, + position=1, + file=shared.progress_print_out + ) + + def update(self): + if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars: + return + if self._tqdm is None: + self.reset() + self._tqdm.update() + + def updateTotal(self, new_total): + if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars: + return + if self._tqdm is None: + self.reset() + self._tqdm.total = new_total + + def clear(self): + if self._tqdm is not None: + self._tqdm.refresh() + self._tqdm.close() + self._tqdm = None + diff --git a/modules/styles.py b/modules/styles.py index ec0e1bc51dbfb6415da3c953c96fd524a554ac0d..0740fe1b1c0ed15a5d44f25af144f3f2257fa0d9 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -106,10 +106,7 @@ # Always keep a backup file around if os.path.exists(path): shutil.copy(path, f"{path}.bak") - fd = os.open(path, os.O_RDWR | os.O_CREAT) - with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: - # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, - # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() + with open(path, "w", encoding="utf-8-sig", newline='') as file: writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) writer.writeheader() writer.writerows(style._asdict() for k, style in self.styles.items()) diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 497568eb51b6de54d2f0a642ab289eaba8c55f82..ae4ee4bbec061b72cf20bfc369f3e14ca4188c7a 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -58,8 +58,9 @@ value: Tensor, scale: float, ) -> AttnChunk: attn_weights = torch.baddbmm( +# original source: # MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license) -# Amin Rezaei (original author) +# implementation of: query, key.transpose(1,2), alpha=scale, @@ -122,8 +123,9 @@ value: Tensor, scale: float, ) -> Tensor: attn_scores = torch.baddbmm( +# original source: # MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license) -# Amin Rezaei (original author) +# implementation of: query, key.transpose(1,2), alpha=scale, diff --git a/modules/sysinfo.py b/modules/sysinfo.py index cf24c6dd4a4effb272311b5d83df7673444fa108..058e66ce4e9b5396c0eb800ccba197089b993f10 100644 --- a/modules/sysinfo.py +++ b/modules/sysinfo.py @@ -11,7 +11,7 @@ import re import launch import json -import os + "WEBUI_LAUNCH_LIVE_OUTPUT", checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY" environment_whitelist = { @@ -24,7 +24,6 @@ "TORCH_INDEX_URL", "TORCH_COMMAND", "REQS_FILE", "XFORMERS_PACKAGE", - "GFPGAN_PACKAGE", "CLIP_PACKAGE", "OPENCLIP_PACKAGE", "STABLE_DIFFUSION_REPO", @@ -117,9 +116,6 @@ def get_exceptions(): try: import psutil -import os - -import psutil import sys except Exception as e: return str(e) @@ -146,9 +142,6 @@ def get_extensions(*, enabled): try: import re -import traceback - -import re return { "name": x.name, @@ -165,7 +158,6 @@ def get_config(): try: - from modules import shared return shared.opts.data except Exception as e: return str(e) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 6166c76f6537b757ebea4ec40b5dc59eecfc2ff5..aa79dc09843ae6575af352bd61a7c05ba16376c5 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -13,7 +13,7 @@ import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -import os + if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None: import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -182,36 +182,44 @@ data = safetensors.torch.load_file(path, device="cpu") else: return + # textual inversion embeddings if 'string_to_param' in data: param_dict = data['string_to_param'] param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 assert len(param_dict) == 1, 'embedding file has multiple terms in it' emb = next(iter(param_dict.items()))[1] -import os + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding import torch -import os + def __init__(self, vec, name, step=None): + shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] + optimizer_saved_dict = { import os import torch +from contextlib import closing from collections import namedtuple assert len(data.keys()) == 1, 'embedding file has multiple terms in it' emb = next(iter(data.values())) if len(emb.shape) == 1: emb = emb.unsqueeze(0) + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] else: raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") import modules.textual_inversion.dataset -import datetime -import modules.textual_inversion.dataset import csv embedding.step = data.get('step', None) embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) -from modules.textual_inversion.learn_schedule import LearnRateScheduler + optimizer_saved_dict = { from contextlib import closing -from modules.textual_inversion.learn_schedule import LearnRateScheduler + optimizer_saved_dict = { embedding.filename = path embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '') @@ -387,6 +394,8 @@ assert log_directory, "Log directory is empty" def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + from modules import processing + save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 template_file = textual_inversion_templates.get(template_filename, None) diff --git a/modules/timer.py b/modules/timer.py index da99e49f82df4260676cf8205cf6bcc93b9d3f01..1d38595c7fff24a9047a619f440811007f6cde44 100644 --- a/modules/timer.py +++ b/modules/timer.py @@ -1,4 +1,5 @@ import time +import argparse class TimerSubcategory: @@ -11,20 +12,27 @@ def __enter__(self): self.start = time.time() self.timer.base_category = self.original_base_category + self.category + "/" + self.timer.subcategory_level += 1 + + if self.timer.print_log: + print(f"{' ' * self.timer.subcategory_level}{self.category}:") def __exit__(self, exc_type, exc_val, exc_tb): elapsed_for_subcategroy = time.time() - self.start self.timer.base_category = self.original_base_category self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy) - self.timer.record(self.category) + self.timer.subcategory_level -= 1 + self.timer.record(self.category, disable_log=True) class Timer: - def __init__(self): + def __init__(self, print_log=False): self.start = time.time() self.records = {} self.total = 0 self.base_category = '' + self.print_log = print_log + self.subcategory_level = 0 def elapsed(self): end = time.time() @@ -38,12 +46,16 @@ self.records[category] = 0 self.records[category] += amount - def record(self, category, extra_time=0): + def record(self, category, extra_time=0, disable_log=False): e = self.elapsed() self.add_time_to_record(self.base_category + category, e + extra_time) class TimerSubcategory: + def __init__(self, timer, category): + + if self.print_log and not disable_log: + self.category = category def __init__(self, timer, category): def subcategory(self, name): @@ -72,7 +84,11 @@ def reset(self): self.__init__() + self.category = category self.timer = timer -import time +parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup") +args = parser.parse_known_args()[0] + +startup_timer = Timer(print_log=args.log_startup) startup_record = None diff --git a/modules/txt2img.py b/modules/txt2img.py index 29d94e8cb2c0fa8aca36a3f08120705bf36eefea..1ee592ad9446d204693db92ba95ffc517aa2153a 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,7 +1,7 @@ from contextlib import closing import modules.scripts -from modules import sd_samplers, processing +from modules import processing from modules.generation_parameters_copypaste import create_override_settings_dict from modules.shared import opts, cmd_opts import modules.shared as shared @@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html import gradio as gr -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, n_iter: int, batch_size: int, cfg_scale: float, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) p = processing.StableDiffusionProcessingTxt2Img( @@ -19,14 +19,8 @@ outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids, prompt=prompt, styles=prompt_styles, negative_prompt=negative_prompt, - seed=seed, - subseed=subseed, - subseed_strength=subseed_strength, - seed_resize_from_h=seed_resize_from_h, - seed_resize_from_w=seed_resize_from_w, - +from modules.shared import opts, cmd_opts from modules import sd_samplers, processing - sampler_name=sd_samplers.samplers[sampler_index].name, batch_size=batch_size, n_iter=n_iter, steps=steps, @@ -34,9 +28,6 @@ cfg_scale=cfg_scale, width=width, height=height, import modules.scripts - - tiling=tiling, -import modules.scripts from modules import sd_samplers, processing denoising_strength=denoising_strength if enable_hr else None, hr_scale=hr_scale, @@ -44,7 +35,8 @@ hr_upscaler=hr_upscaler, hr_second_pass_steps=hr_second_pass_steps, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y, - hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None, + hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name, + hr_sampler_name=None if hr_sampler_name == 'Use same sampler' else hr_sampler_name, hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, override_settings=override_settings, diff --git a/modules/ui.py b/modules/ui.py index 07ecee7b680fdc4e1931ef26043f0c11fd66c611..2b6a13cbb6c53b7facb6b671594d8611d29a59f2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1,5 +1,4 @@ import datetime -import json import mimetypes import os import sys @@ -13,48 +12,49 @@ from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call import datetime -import os +reuse_symbol = '\u267b\ufe0f' # ♻️ import datetime import sys +import gradio as gr +from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML, InputAccordion, ResizeHandleRow from modules.paths import script_path from modules.ui_common import create_refresh_button import datetime - from modules.shared import opts, cmd_opts import json +import datetime +import datetime +save_style_symbol = '\U0001f4be' # 💾 -import json import datetime -import json +from functools import reduce import json -import json + elif mode == 5: import mimetypes import json -import os -import json import sys -import json +import gradio.utils from functools import reduce -import modules.textual_inversion.ui +import os from modules import prompt_parser from modules.sd_hijack import model_hijack import mimetypes -from modules.textual_inversion import textual_inversion import json -import mimetypes -from modules.generation_parameters_copypaste import image_from_url_text -import modules.extras create_setting_component = ui_settings.create_setting_component warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) +warnings.filterwarnings("default" if opts.show_gradio_deprecation_warnings else "ignore", category=gr.deprecation.GradioDeprecationWarning) # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI mimetypes.init() mimetypes.add_type('application/javascript', '.js') +# Likewise, add explicit content-type header for certain missing image types +mimetypes.add_type('image/webp', '.webp') + if not cmd_opts.share and not cmd_opts.listen: # fix gradio phoning home gradio.utils.version_check = lambda: None @@ -90,7 +89,6 @@ extra_networks_symbol = '\U0001F3B4' # 🎴 switch_values_symbol = '\U000021C5' # ⇅ restore_progress_symbol = '\U0001F300' # 🌀 detect_image_size_symbol = '\U0001F4D0' # 📐 -up_down_symbol = '\u2195\ufe0f' # ↕️ plaintext_to_html = ui_common.plaintext_to_html @@ -102,33 +100,17 @@ return None return image_from_url_text(x[0]) -def add_style(name: str, prompt: str, negative_prompt: str): - if name is None: - return [gr_show() for x in range(4)] - - style = modules.styles.PromptStyle(name, prompt, negative_prompt) - shared.prompt_styles.styles[style.name] = style - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we - # reserialize all styles every time we save them - shared.prompt_styles.save_styles(shared.styles_filename) - - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)] - - def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): -import os - - import sys return "" p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) + elif mode == 5: - with devices.autocast(): - p.init([""], [0], [0]) -import gradio as gr +import datetime +plaintext_to_html = ui_common.plaintext_to_html def resize_from_to_html(width, height, scale_by): @@ -139,13 +121,6 @@ if not target_width or not target_height: return "no image selected" return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>" - - -def apply_styles(prompt, prompt_neg, styles): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles) - - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])] def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles): @@ -182,302 +157,225 @@ return gr.update() if prompt is None else prompt import datetime -import json import sys import datetime -import modules.styles import datetime +import sys import json -import warnings import datetime -from modules import prompt_parser - random_seed = ToolButton(random_symbol, elem_id=f"{target_interface}_random_seed", label='Random seed') -import datetime +import sys import mimetypes - import datetime -from modules.textual_inversion import textual_inversion - - # Components to show/hide based on the 'Extra' checkbox - seed_extras = [] - -from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call +import sys import os import datetime -import mimetypes +import sys import sys import datetime -import mimetypes +import sys from functools import reduce import datetime -import mimetypes +import sys import warnings import datetime -import mimetypes +import sys - reuse_subseed = ToolButton(reuse_symbol, elem_id=f"{target_interface}_reuse_subseed") - subseed_strength = gr.Slider(label='Variation strength', value=0.0, minimum=0, maximum=1, step=0.01, elem_id=f"{target_interface}_subseed_strength") - with FormRow(visible=False) as seed_extra_row_2: + import datetime + import os -import json import datetime -if cmd_opts.ngrok is not None: - seed_resize_from_h = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize seed from height", value=0, elem_id=f"{target_interface}_seed_resize_from_h") -from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo import sys import datetime -import os + from functools import reduce import datetime -import os + import warnings import datetime -import os + import datetime -import os + import gradio as gr - import datetime -import sys +import gradio as gr - - - import datetime -import sys +import gradio as gr import datetime import datetime -import sys +import gradio as gr import json + import datetime -import sys +import gradio as gr import mimetypes import datetime -import sys +import gradio as gr import os import datetime -import sys +import gradio as gr import sys import datetime -import sys +import gradio as gr from functools import reduce - outputs=[], -from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML - -def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: gr.Textbox, dummy_component, is_subseed): import datetime -from functools import reduce +import warnings import datetime -from functools import reduce +import warnings import datetime + import datetime -from functools import reduce +import warnings import json import datetime -from functools import reduce +import warnings import mimetypes import datetime -from functools import reduce +import warnings import os import datetime -from functools import reduce +import warnings import sys import datetime -restore_progress_symbol = '\U0001F300' # 🌀 -from modules.paths import script_path import warnings - -import datetime from functools import reduce - +import warnings import datetime -plaintext_to_html = ui_common.plaintext_to_html +import mimetypes import datetime -import warnings -import datetime import warnings -import datetime -import datetime import warnings -import json - res = all_seeds[index if 0 <= index < len(all_seeds) else 0] -from modules.ui_common import create_refresh_button import os import datetime import warnings -import sys + import datetime import warnings -from functools import reduce +import gradio as gr + with gr.Row(): import datetime style = modules.styles.PromptStyle(name, prompt, negative_prompt) +import os import datetime -import warnings - fn=copy_seed, + import datetime - -import datetime shared.prompt_styles.save_styles(shared.styles_filename) import datetime return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)] import datetime def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): import datetime -reuse_symbol = '\u267b\ufe0f' # ♻️ - - -import datetime from modules import processing, devices import datetime if not enable: import datetime return "" import datetime p = processing.StableDiffusionProcessingTxt2Img(width=width, height=height, enable_hr=True, hr_scale=hr_scale, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y) import datetime - import datetime p.init([""], [0], [0]) import datetime return f"resize: from <span class='resolution'>{p.width}x{p.height}</span> to <span class='resolution'>{p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}</span>" import datetime def resize_from_to_html(width, height, scale_by): + import datetime target_width = int(width * scale_by) - import datetime target_height = int(height * scale_by) - prompts = [prompt_text for step, prompt_text in flat_prompts] + inputs=[], -from modules.shared import opts, cmd_opts import sys +import sys import datetime - return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>" - + ) import datetime -def apply_styles(prompt, prompt_neg, styles): -import datetime import gradio as gr - - +import os import datetime -import gradio as gr import gradio as gr +import sys -import modules.codeformer_model +# Using constants for these since the variation selector isn't visible. -import modules.codeformer_model +# Using constants for these since the variation selector isn't visible. import datetime - with gr.Column(scale=80): - with gr.Row(): -import modules.codeformer_model +import sys import os +import warnings -import modules.codeformer_model import datetime -import json -import json - with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) - -import modules.codeformer_model +import gradio as gr from functools import reduce -import json +import datetime +import gradio as gr import warnings -import json +import datetime +import gradio as gr -import json +import datetime import gradio as gr +import gradio as gr + -import json +import datetime import datetime -import json import numpy as np - +import datetime -import json import datetime +import gradio.utils import json -import json +import datetime import datetime import mimetypes -import json + +import datetime import datetime import os -import json import datetime +import gradio.utils import sys -import json +import datetime import datetime from functools import reduce - -import json import datetime -import warnings -import json import datetime - - inputs=[], -import modules.gfpgan_model +import warnings -import modules.gfpgan_model import datetime - - interrupt.click( - fn=lambda: shared.state.interrupt(), -import json import datetime -import gradio as gr -import modules.gfpgan_model - ) - with gr.Row(elem_id=f"{id_part}_tools"): -import modules.gfpgan_model import sys - clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") -import modules.gfpgan_model +import os import warnings -import modules.gfpgan_model -import modules.gfpgan_model + if ii_output_dir != "": import gradio as gr - restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False) -import modules.hypernetworks.ui +import numpy as np import datetime - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + fn=modules.images.image_data, + os.makedirs(ii_output_dir, exist_ok=True) import json -import modules.extras -import json + os.makedirs(ii_output_dir, exist_ok=True) import mimetypes -import os - - clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[prompt, negative_prompt], - outputs=[prompt, negative_prompt], -import modules.gfpgan_model +import datetime import datetime - - with gr.Row(elem_id=f"{id_part}_styles_row"): -import modules.scripts import datetime -import json import os -import json - - return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button + ) def setup_progressbar(*args, **kwargs): @@ -521,15 +396,15 @@ def create_sampler_and_steps_selection(choices, tabname): if opts.samplers_in_dropdown: with FormRow(elem_id=f"sampler_selection_{tabname}"): - sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0]) steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) else: with FormGroup(elem_id=f"sampler_selection_{tabname}"): steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) -import modules.textual_inversion.ui + os.makedirs(ii_output_dir, exist_ok=True) from functools import reduce -import modules.textual_inversion.ui + os.makedirs(ii_output_dir, exist_ok=True) import warnings @@ -560,32 +435,29 @@ reload_javascript() parameters_copypaste.reset() - modules.scripts.scripts_current = modules.scripts.scripts_txt2img + scripts.scripts_current = scripts.scripts_txt2img -import json + os.makedirs(ii_output_dir, exist_ok=True) import gradio as gr -import sys with gr.Blocks(analytics_enabled=False) as txt2img_interface: +import numpy as np import json -def apply_styles(prompt, prompt_neg, styles): dummy_component = gr.Label(visible=False) - txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) - with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks: -from modules.sd_samplers import samplers, samplers_for_img2img + else: import datetime -from modules.sd_samplers import samplers, samplers_for_img2img + else: import json -from modules.sd_samplers import samplers, samplers_for_img2img + else: import mimetypes with gr.Column(variant='compact', elem_id="txt2img_settings"): - modules.scripts.scripts_txt2img.prepare_ui() + scripts.scripts_txt2img.prepare_ui() for category in ordered_ui_categories(): if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") + steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img") elif category == "dimensions": with FormRow(): @@ -602,66 +473,74 @@ batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="txt2img_batch_count") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="txt2img_batch_size") elif category == "cfg": - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="txt2img_cfg_scale") +import os -import mimetypes +import numpy as np import json -import json - seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs('txt2img') +from functools import reduce elif category == "checkboxes": with FormRow(elem_classes="checkboxes-row", variant="compact"): - restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") -import mimetypes +import numpy as np import modules.textual_inversion.ui -import mimetypes + +import numpy as np from modules import prompt_parser -import mimetypes +import numpy as np from modules.sd_hijack import model_hijack - -import mimetypes +import numpy as np from modules.sd_samplers import samplers, samplers_for_img2img -import mimetypes +import numpy as np from modules.textual_inversion import textual_inversion -import mimetypes +import numpy as np from modules.generation_parameters_copypaste import image_from_url_text -import mimetypes + +import numpy as np import modules.extras -import mimetypes +import numpy as np create_setting_component = ui_settings.create_setting_component -import mimetypes +import numpy as np warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) - -import mimetypes +import numpy as np # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI -import mimetypes + +import numpy as np mimetypes.init() -import mimetypes +import numpy as np mimetypes.add_type('application/javascript', '.js') -import mimetypes +import numpy as np if not cmd_opts.share and not cmd_opts.listen: - -import mimetypes +import numpy as np # fix gradio phoning home -import mimetypes + +import numpy as np gradio.utils.version_check = lambda: None -import mimetypes +import numpy as np gradio.utils.get_local_ip_address = lambda: '127.0.0.1' -import mimetypes +import numpy as np if cmd_opts.ngrok is not None: -import mimetypes + +import numpy as np import modules.ngrok as ngrok -import mimetypes + +import numpy as np print('ngrok authtoken detected, trying to connect...') -import mimetypes +import numpy as np import os -import mimetypes +from functools import reduce -import mimetypes +import numpy as np import os +import warnings +import numpy as np import os + -import mimetypes +import numpy as np ngrok.connect( + with gr.Row(): + hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"]) + + scripts.scripts_txt2img.setup_ui_for_section(category) elif category == "batch": if not opts.dimensions_and_batch_together: @@ -675,16 +554,16 @@ override_settings = create_override_settings_dropdown('txt2img', row) elif category == "scripts": with FormGroup(elem_id="txt2img_script_container"): -import mimetypes +import numpy as np import sys -import warnings +import datetime -import mimetypes +import numpy as np import sys - +import json -import mimetypes +import numpy as np import sys -import gradio as gr +import mimetypes hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] @@ -706,27 +585,20 @@ show_progress=False, ) txt2img_gallery, generation_info, html_info, html_log = create_output_panel("txt2img", opts.outdir_txt2img_samples) - - connect_reuse_seed(seed, reuse_seed, generation_info, dummy_component, is_subseed=False) - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) txt2img_args = dict( fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), _js="submit", inputs=[ dummy_component, - txt2img_prompt, + toprow.prompt, - txt2img_negative_prompt, + toprow.negative_prompt, - txt2img_prompt_styles, + toprow.ui_styles.dropdown, steps, - sampler_index, - restore_faces, - tiling, + sampler_name, batch_count, batch_size, cfg_scale, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, height, width, enable_hr, @@ -736,7 +608,8 @@ hr_upscaler, hr_second_pass_steps, hr_resize_x, hr_resize_y, -if not cmd_opts.share and not cmd_opts.listen: + hr_checkpoint_name, + img = Image.open(image) import gradio as gr hr_prompt, hr_negative_prompt, @@ -753,18 +626,17 @@ ], show_progress=False, ) -import os +import datetime import datetime -import json +refresh_symbol = '\U0001f504' # 🔄 -import os import datetime -import mimetypes + (sub)seed value from generation info the to the seed field. If copying subseed and subseed strength res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False) -import os +import datetime import datetime -import sys +apply_style_symbol = '\U0001f4cb' # 📋 fn=progress.restore_progress, _js="restoreProgressTxt2img", inputs=[dummy_component], @@ -778,118 +650,94 @@ show_progress=False, ) import os -from modules.shared import opts, cmd_opts -import os import json -import mimetypes import warnings -from functools import reduce - gradio.utils.get_local_ip_address = lambda: '127.0.0.1' +import datetime import datetime - ], - outputs=[ +from functools import reduce import mimetypes - shared.prompt_styles.styles[style.name] = style - txt_prompt_img - ], - show_progress=False, -import os import numpy as np - - enable_hr.change( - fn=lambda x: gr_show(x), - inputs=[enable_hr], - outputs=[hr_options], - gradio.utils.get_local_ip_address = lambda: '127.0.0.1' from functools import reduce import os -import numpy as np - - txt2img_paste_fields = [ - (txt2img_prompt, "Prompt"), - (txt2img_negative_prompt, "Negative prompt"), (steps, "Steps"), -if cmd_opts.ngrok is not None: import datetime - (restore_faces, "Face restoration"), + try: (cfg_scale, "CFG scale"), - (seed, "Seed"), (width, "Size-1"), (height, "Size-2"), (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - import modules.ngrok as ngrok import datetime - (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), + gen_info = json.loads(gen_info_string) (denoising_strength, "Denoising strength"), - (enable_hr, lambda d: "Denoising strength" in d), - (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), + (enable_hr, lambda d: "Denoising strength" in d and ("Hires upscale" in d or "Hires upscaler" in d or "Hires resize-1" in d)), (hr_scale, "Hires upscale"), (hr_upscaler, "Hires upscaler"), (hr_second_pass_steps, "Hires steps"), (hr_resize_x, "Hires resize-1"), (hr_resize_y, "Hires resize-2"), - print('ngrok authtoken detected, trying to connect...') +import datetime import datetime +up_down_symbol = '\u2195\ufe0f' # ↕️ - (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()), + (hr_sampler_name, "Hires sampler"), + (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), (hr_prompt, "Hires prompt"), (hr_negative_prompt, "Hires negative prompt"), (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), - *modules.scripts.scripts_txt2img.infotext_fields + *scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None, + paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None, )) txt2img_preview_params = [ - ngrok.connect( + left, _ = os.path.splitext(filename) import mimetypes - ngrok.connect( + left, _ = os.path.splitext(filename) import os steps, - sampler_index, + sampler_name, cfg_scale, -import os + left, _ = os.path.splitext(filename) from functools import reduce - width, height, ] - cmd_opts.ngrok, import datetime -import os +from modules.ui_common import create_refresh_button import warnings -import json - -import os +import numpy as np import warnings -import mimetypes + -import os +import numpy as np import warnings -import os +import gradio as gr - modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) + ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) - with gr.Blocks(analytics_enabled=False) as img2img_interface: + extra_tabs.__exit__() - img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True) + +import numpy as np +import json - cmd_opts.ngrok, +import numpy as np +import mimetypes cmd_opts.ngrok, -import gradio as gr +from functools import reduce -from modules.sd_samplers import samplers, samplers_for_img2img import datetime +from modules.ui_gradio_extensions import reload_javascript import os +import numpy as np +import sys -import os + extra_tabs.__enter__() import datetime + text, _ = extra_networks.parse_prompt(text) with gr.Column(variant='compact', elem_id="img2img_settings"): copy_image_buttons = [] copy_image_destinations = {} @@ -908,22 +757,23 @@ with gr.Tabs(elem_id="mode_img2img"): img2img_selected_tab = gr.State(0) with gr.TabItem('img2img', id='img2img', elem_id="img2img_img2img_tab") as tab_img2img: - cmd_opts.ngrok_options +import numpy as np +import warnings add_copy_image_controls('img2img', init_img) with gr.TabItem('Sketch', id='img2img_sketch', elem_id="img2img_img2img_sketch_tab") as tab_sketch: - ) +import datetime import datetime + with devices.autocast(): add_copy_image_controls('sketch', sketch) with gr.TabItem('Inpaint', id='inpaint', elem_id="img2img_inpaint_tab") as tab_inpaint: - init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA").style(height=opts.img2img_editor_height) + init_img_with_mask = gr.Image(label="Image for inpainting with mask", show_label=False, elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGBA", height=opts.img2img_editor_height, brush_color=opts.img2img_inpaint_mask_brush_color) add_copy_image_controls('inpaint', init_img_with_mask) with gr.TabItem('Inpaint sketch', id='inpaint_sketch', elem_id="img2img_inpaint_sketch_tab") as tab_inpaint_color: -import sys + return [gr.update(), None] -import warnings inpaint_color_sketch_orig = gr.State(None) add_copy_image_controls('inpaint_sketch', inpaint_color_sketch) @@ -938,8 +788,9 @@ inpaint_color_sketch.change(update_orig, [inpaint_color_sketch, inpaint_color_sketch_orig], inpaint_color_sketch_orig) with gr.TabItem('Inpaint upload', id='inpaint_upload', elem_id="img2img_inpaint_upload_tab") as tab_inpaint_upload: init_img_inpaint = gr.Image(label="Image for img2img", show_label=False, source="upload", interactive=True, type="pil", elem_id="img_inpaint_base") -import sys +import datetime from modules.shared import opts, cmd_opts +import datetime with gr.TabItem('Batch', id='batch', elem_id="img2img_batch_tab") as tab_batch: hidden = '<br>Disabled when launched with --hide-ui-dir-config.' if shared.cmd_opts.hide_ui_dir_config else '' @@ -984,11 +835,11 @@ with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - modules.scripts.scripts_img2img.prepare_ui() + scripts.scripts_img2img.prepare_ui() for category in ordered_ui_categories(): if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") + steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") elif category == "dimensions": with FormRow(): @@ -1038,27 +889,27 @@ with gr.Column(elem_id="img2img_column_batch"): batch_count = gr.Slider(minimum=1, step=1, label='Batch count', value=1, elem_id="img2img_batch_count") batch_size = gr.Slider(minimum=1, maximum=8, step=1, label='Batch size', value=1, elem_id="img2img_batch_size") - elif category == "cfg": -paste_symbol = '\u2199\ufe0f' # ↙ import datetime - with FormRow(): - cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale") -import sys +import datetime if not target_width or not target_height: -import sys +import numpy as np return "no image selected" from modules.generation_parameters_copypaste import image_from_url_text -import json + with gr.Row(): -import sys +import numpy as np return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>" + image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False) elif category == "checkboxes": with FormRow(elem_classes="checkboxes-row", variant="compact"): -paste_symbol = '\u2199\ufe0f' # ↙ + else: import warnings -import sys + + elif category == "accordions": +import numpy as np prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) + scripts.scripts_img2img.setup_ui_for_section(category) elif category == "batch": if not opts.dimensions_and_batch_together: @@ -1072,8 +923,7 @@ override_settings = create_override_settings_dropdown('img2img', row) elif category == "scripts": with FormGroup(elem_id="img2img_script_container"): -from functools import reduce +def interrogate(image): -import sys elif category == "inpaint": with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: @@ -1103,31 +953,16 @@ fn=lambda tab=i: select_img2img_tab(tab), inputs=[], outputs=[inpaint_controls, mask_alpha], ) -warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) - modules.scripts.scripts_img2img.setup_ui_for_section(category) - -apply_style_symbol = '\U0001f4cb' # 📋 +import numpy as np import sys - +import json -mimetypes.init() import datetime - connect_reuse_seed(subseed, reuse_subseed, generation_info, dummy_component, is_subseed=True) - - img2img_prompt_img.change( -import os import modules.codeformer_model - inputs=[ - img2img_prompt_img - ], - outputs=[ +import datetime -apply_style_symbol = '\U0001f4cb' # 📋 apply_style_symbol = '\U0001f4cb' # 📋 -import warnings - ], - show_progress=False, - ) +import sys img2img_args = dict( fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), @@ -1135,10 +970,10 @@ _js="submit_img2img", inputs=[ dummy_component, dummy_component, - img2img_prompt, + toprow.prompt, - img2img_negative_prompt, + toprow.negative_prompt, + img = Image.open(image) from functools import reduce -import modules.extras init_img, sketch, init_img_with_mask, @@ -1147,21 +982,16 @@ inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps, - sampler_index, + sampler_name, mask_blur, mask_alpha, inpainting_fill, mimetypes.add_type('application/javascript', '.js') -import mimetypes - tiling, -mimetypes.add_type('application/javascript', '.js') import sys batch_size, cfg_scale, image_cfg_scale, denoising_strength, - seed, - subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox, selected_scale_tab, height, width, @@ -1199,13 +1029,15 @@ init_img_with_mask, inpaint_color_sketch, init_img_inpaint, ], -restore_progress_symbol = '\U0001F300' # 🌀 +from PIL import Image, PngImagePlugin # noqa: F401 +import json ) -restore_progress_symbol = '\U0001F300' # 🌀 import datetime + with gr.Row(): -restore_progress_symbol = '\U0001F300' # 🌀 +import datetime import json +import os res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False) @@ -1217,9 +1049,9 @@ outputs=[width, height], show_progress=False, ) -import os +import datetime import datetime -import sys +apply_style_symbol = '\U0001f4cb' # 📋 fn=progress.restore_progress, _js="restoreProgressImg2img", inputs=[dummy_component], @@ -1232,102 +1064,71 @@ ], show_progress=False, ) -detect_image_size_symbol = '\U0001F4D0' # 📐 +from PIL import Image, PngImagePlugin # noqa: F401 +import sys fn=lambda *args: process_interrogate(interrogate, *args), **interrogate_args, ) +def interrogate(image): from functools import reduce - return image_from_url_text(x[0]) fn=lambda *args: process_interrogate(interrogate_deepbooru, *args), **interrogate_args, ) - prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles] - style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] - -from functools import reduce +def interrogate(image): import warnings - - button.click( - fn=add_style, -up_down_symbol = '\u2195\ufe0f' # ↕️ import datetime - # Have to pass empty dummy component here, because the JavaScript and Python function have to accept - # the same number of parameters, but we only know the style-name after the JavaScript prompt - inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_styles, img2img_prompt_styles], -import modules.gfpgan_model import datetime - - for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): -from functools import reduce import warnings -import gradio as gr -from functools import reduce -import warnings -up_down_symbol = '\u2195\ufe0f' # ↕️ from functools import reduce - import gradio as gr +import os - outputs=[prompt, negative_prompt, styles], -import modules.gfpgan_model import datetime - -plaintext_to_html = ui_common.plaintext_to_html import datetime from functools import reduce - target_width = int(width * scale_by) - -plaintext_to_html = ui_common.plaintext_to_html import mimetypes - +import numpy as np from functools import reduce -import gradio as gr import os - (img2img_prompt, "Prompt"), - (img2img_negative_prompt, "Negative prompt"), (steps, "Steps"), -if cmd_opts.ngrok is not None: import datetime - (restore_faces, "Face restoration"), + try: (cfg_scale, "CFG scale"), (image_cfg_scale, "Image CFG scale"), if cmd_opts.ngrok is not None: -import os -if cmd_opts.ngrok is not None: import sys (height, "Size-2"), (batch_size, "Batch size"), - (subseed, "Variation seed"), - (subseed_strength, "Variation seed strength"), - (seed_resize_from_w, "Seed resize from-1"), - import modules.ngrok as ngrok import datetime +from modules.paths import script_path from functools import reduce - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), -import warnings +from PIL import Image, PngImagePlugin # noqa: F401 + ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( -import warnings +from PIL import Image, PngImagePlugin # noqa: F401 -import mimetypes +import gradio as gr )) -import warnings + prompt = shared.interrogator.interrogate(image.convert("RGB")) -import os + ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + + extra_tabs.__exit__() + + scripts.scripts_current = None with gr.Blocks(analytics_enabled=False) as extras_interface: ui_postprocessing.create_ui() with gr.Blocks(analytics_enabled=False) as pnginfo_interface: -from modules.sd_samplers import samplers, samplers_for_img2img + prompt = shared.interrogator.interrogate(image.convert("RGB")) import mimetypes with gr.Column(variant='panel'): image = gr.Image(elem_id="pnginfo_image", label="Source", source="upload", interactive=True, type="pil") @@ -1347,83 +1151,27 @@ inputs=[image], outputs=[html, generation_info, html2], ) - def update_interp_description(value): - interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>" - interp_descriptions = { - "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."), - "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"), - "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M") - } - return interp_descriptions[value] - - return image_from_url_text(x[0]) import datetime - with gr.Row().style(equal_height=False): - return image_from_url_text(x[0]) import json - interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description") - - return image_from_url_text(x[0]) +import datetime import os - primary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") - create_refresh_button(primary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") - - secondary_model_name = gr.Dropdown(modules.sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") - create_refresh_button(secondary_model_name, modules.sd_models.list_models, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") import warnings -if not cmd_opts.share and not cmd_opts.listen: +import sys import warnings - # fix gradio phoning home - -def add_style(name: str, prompt: str, negative_prompt: str): import datetime -def add_style(name: str, prompt: str, negative_prompt: str): import json -def add_style(name: str, prompt: str, negative_prompt: str): +import datetime import mimetypes import warnings - import modules.ngrok as ngrok - -import sys import sys -import json - checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") - save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata") - with FormRow(): -def add_style(name: str, prompt: str, negative_prompt: str): - config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") - - with gr.Column(): -import mimetypes import datetime - bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae") - create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae") - -# Using constants for these since the variation selector isn't visible. import json - discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights") - -import warnings import datetime -import mimetypes - modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') - - with gr.Column(variant='compact', elem_id="modelmerger_results_container"): - if name is None: import sys import warnings -# Important that they exactly match script.js for tooltip to work. - - with gr.Blocks(analytics_enabled=False) as train_interface: - with gr.Row().style(equal_height=False): - gr.HTML(value="<p style='margin-bottom: 0.7em'>See <b><a href=\"https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Textual-Inversion\">wiki</a></b> for detailed explanation.</p>") - - with gr.Row(variant="compact").style(equal_height=False): -import warnings refresh_symbol = '\U0001f504' # 🔄 with gr.Tab(label="Create embedding", id="create_embedding"): @@ -1443,7 +1191,7 @@ with gr.Tab(label="Create hypernetwork", id="create_hypernetwork"): new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func") new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") @@ -1583,14 +1331,15 @@ script_callbacks.ui_train_tabs_callback(params) with gr.Column(elem_id='ti_gallery_container'): ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - return "" +import datetime import json +from modules.ui_common import create_refresh_button gr.HTML(elem_id="ti_progress", value="") ti_outcome = gr.HTML(elem_id="ti_error", value="") create_embedding.click( + prompt = shared.interrogator.interrogate(image.convert("RGB")) -restore_progress_symbol = '\U0001F300' # 🌀 inputs=[ new_embedding_name, initialization_text, @@ -1605,7 +1354,7 @@ ] ) create_hypernetwork.click( - fn=modules.hypernetworks.ui.create_hypernetwork, + fn=hypernetworks_ui.create_hypernetwork, inputs=[ new_hypernetwork_name, new_hypernetwork_sizes, @@ -1625,7 +1374,7 @@ ] ) run_preprocess.click( - p.init([""], [0], [0]) + return gr.update() if prompt is None else prompt _js="start_training_textual_inversion", inputs=[ dummy_component, @@ -1661,9 +1410,8 @@ ], ) train_embedding.click( -import gradio as gr + return gr.update() if prompt is None else prompt import datetime - _js="start_training_textual_inversion", inputs=[ dummy_component, @@ -1697,7 +1445,7 @@ ] ) train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), + fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ dummy_component, @@ -1751,7 +1499,7 @@ (txt2img_interface, "txt2img", "txt2img"), (img2img_interface, "img2img", "img2img"), (extras_interface, "Extras", "extras"), (pnginfo_interface, "PNG Info", "pnginfo"), - (modelmerger_interface, "Checkpoint Merger", "modelmerger"), + (modelmerger_ui.blocks, "Checkpoint Merger", "modelmerger"), (train_interface, "Train", "train"), ] @@ -1803,52 +1551,13 @@ update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) - def modelmerger(*args): - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles) import datetime - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles) import json - except Exception as e: - errors.report("Error loading/saving model file", exc_info=True) - modules.sd_models.list_models() # to remove the potentially missing models from the list - return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] - return results - - modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[modelmerger_result]) - modelmerger_merge.click( - fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]), - _js='modelmerger', - inputs=[ - p.init([""], [0], [0]) import json - primary_model_name, - secondary_model_name, - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])] import os - interp_method, - interp_amount, - save_as_half, - custom_name, - checkpoint_format, - config_source, - bake_in_vae, - discard_weights, - save_metadata, - ], - outputs=[ - primary_model_name, - secondary_model_name, - tertiary_model_name, - settings.component_dict['sd_model_checkpoint'], - modelmerger_result, - ] - ) loadsave.dump_defaults() demo.ui_loadsave = loadsave - - # Required as a workaround for change() event not triggering when loading values from ui-config.json - interp_description.value = update_interp_description(interp_method.value) return demo diff --git a/modules/ui_checkpoint_merger.py b/modules/ui_checkpoint_merger.py new file mode 100644 index 0000000000000000000000000000000000000000..f9c5dd6bb6773eaf3a6a507f3162cef637c9338f --- /dev/null +++ b/modules/ui_checkpoint_merger.py @@ -0,0 +1,124 @@ + +import gradio as gr + +from modules import sd_models, sd_vae, errors, extras, call_queue +from modules.ui_components import FormRow +from modules.ui_common import create_refresh_button + + +def update_interp_description(value): + interp_description_css = "<p style='margin-bottom: 2.5em'>{}</p>" + interp_descriptions = { + "No interpolation": interp_description_css.format("No interpolation will be used. Requires one model; A. Allows for format conversion and VAE baking."), + "Weighted sum": interp_description_css.format("A weighted sum will be used for interpolation. Requires two models; A and B. The result is calculated as A * (1 - M) + B * M"), + "Add difference": interp_description_css.format("The difference between the last two models will be added to the first. Requires three models; A, B and C. The result is calculated as A + (B - C) * M") + } + return interp_descriptions[value] + + +def modelmerger(*args): + try: + results = extras.run_modelmerger(*args) + except Exception as e: + errors.report("Error loading/saving model file", exc_info=True) + sd_models.list_models() # to remove the potentially missing models from the list + return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] + return results + + +class UiCheckpointMerger: + def __init__(self): + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: + with gr.Row(equal_height=False): + with gr.Column(variant='compact'): + self.interp_description = gr.HTML(value=update_interp_description("Weighted sum"), elem_id="modelmerger_interp_description") + + with FormRow(elem_id="modelmerger_models"): + self.primary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_primary_model_name", label="Primary model (A)") + create_refresh_button(self.primary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_A") + + self.secondary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_secondary_model_name", label="Secondary model (B)") + create_refresh_button(self.secondary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_B") + + self.tertiary_model_name = gr.Dropdown(sd_models.checkpoint_tiles(), elem_id="modelmerger_tertiary_model_name", label="Tertiary model (C)") + create_refresh_button(self.tertiary_model_name, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, "refresh_checkpoint_C") + + self.custom_name = gr.Textbox(label="Custom Name (Optional)", elem_id="modelmerger_custom_name") + self.interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3, elem_id="modelmerger_interp_amount") + self.interp_method = gr.Radio(choices=["No interpolation", "Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method", elem_id="modelmerger_interp_method") + self.interp_method.change(fn=update_interp_description, inputs=[self.interp_method], outputs=[self.interp_description]) + + with FormRow(): + self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") + self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") + + with FormRow(): + with gr.Column(): + self.config_source = gr.Radio(choices=["A, B or C", "B", "C", "Don't"], value="A, B or C", label="Copy config from", type="index", elem_id="modelmerger_config_method") + + with gr.Column(): + with FormRow(): + self.bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae") + create_refresh_button(self.bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae") + + with FormRow(): + self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights") + + with gr.Accordion("Metadata", open=False) as metadata_editor: + with FormRow(): + self.save_metadata = gr.Checkbox(value=True, label="Save metadata", elem_id="modelmerger_save_metadata") + self.add_merge_recipe = gr.Checkbox(value=True, label="Add merge recipe metadata", elem_id="modelmerger_add_recipe") + self.copy_metadata_fields = gr.Checkbox(value=True, label="Copy metadata from merged models", elem_id="modelmerger_copy_metadata") + + self.metadata_json = gr.TextArea('{}', label="Metadata in JSON format") + self.read_metadata = gr.Button("Read metadata from selected checkpoints") + + with FormRow(): + self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') + + with gr.Column(variant='compact', elem_id="modelmerger_results_container"): + with gr.Group(elem_id="modelmerger_results_panel"): + self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False) + + self.metadata_editor = metadata_editor + self.blocks = modelmerger_interface + + def setup_ui(self, dummy_component, sd_model_checkpoint_component): + self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False) + + self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json]) + + self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result]) + self.modelmerger_merge.click( + fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]), + _js='modelmerger', + inputs=[ + dummy_component, + self.primary_model_name, + self.secondary_model_name, + self.tertiary_model_name, + self.interp_method, + self.interp_amount, + self.save_as_half, + self.custom_name, + self.checkpoint_format, + self.config_source, + self.bake_in_vae, + self.discard_weights, + self.save_metadata, + self.add_merge_recipe, + self.copy_metadata_fields, + self.metadata_json, + ], + outputs=[ + self.primary_model_name, + self.secondary_model_name, + self.tertiary_model_name, + sd_model_checkpoint_component, + self.modelmerger_result, + ] + ) + + # Required as a workaround for change() event not triggering when loading values from ui-config.json + self.interp_description.value = update_interp_description(self.interp_method.value) + diff --git a/modules/ui_common.py b/modules/ui_common.py index 11eb2a4b288938c2ddd3e0645a1bc6b44dd9c8ba..eddc4bc882c88c05b1312fa2448098506f97d4d8 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -11,7 +11,7 @@ from modules import call_queue, shared from modules.generation_parameters_copypaste import image_from_url_text import modules.images from modules.ui_components import ToolButton - +import modules.generation_parameters_copypaste as parameters_copypaste folder_symbol = '\U0001f4c2' # 📂 refresh_symbol = '\U0001f504' # 🔄 @@ -105,8 +105,6 @@ return gr.File.update(value=fullfns, visible=True), plaintext_to_html(f"Saved: {filenames[0]}") def create_output_panel(tabname, outdir): - from modules import shared - import modules.generation_parameters_copypaste as parameters_copypaste def open_folder(f): if not os.path.exists(f): @@ -135,25 +133,29 @@ with gr.Column(variant='panel', elem_id=f"{tabname}_results"): with gr.Group(elem_id=f"{tabname}_gallery_container"): import json -import gradio as gr +def create_output_panel(tabname, outdir): generation_info = None with gr.Column(): with gr.Row(elem_id=f"image_buttons_{tabname}", elem_classes="image-buttons"): import json -import modules.images + from modules import shared if tabname != "extras": import json -import json +import subprocess as sp import os import json -import json +import subprocess as sp import platform import json -import json +import subprocess as sp import sys + 'img2img': ToolButton('🖼️', elem_id=f'{tabname}_send_to_img2img', tooltip="Send image and generation parameters to img2img tab."), + 'inpaint': ToolButton('🎨️', elem_id=f'{tabname}_send_to_inpaint', tooltip="Send image and generation parameters to img2img inpaint tab."), + 'extras': ToolButton('📐', elem_id=f'{tabname}_send_to_extras', tooltip="Send image and generation parameters to extras tab.") + } open_folder_button.click( fn=lambda: open_folder(shared.opts.outdir_samples or outdir), @@ -231,6 +233,14 @@ return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log try: + + refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component] + + label = None + for comp in refresh_components: + label = getattr(comp, 'label', None) + if label is not None: + break def refresh(): refresh_method() @@ -238,21 +248,38 @@ args = refreshed_args() if callable(refreshed_args) else refreshed_args for k, v in args.items(): import json + path = os.path.normpath(f) + return html_info, gr.update() import gradio as gr -import json import json - zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]") + os.startfile(path) import json - zip_filepath = os.path.join(path, f"{zip_filename}.zip") + elif platform.system() == "Darwin": refresh_button.click( fn=refresh, inputs=[], + outputs=refresh_components generation_info = json.loads(generation_info) -import gradio as gr +import subprocess as sp generation_info = json.loads(generation_info) -import subprocess as sp +from modules import call_queue, shared + + +def setup_dialog(button_show, dialog, *, button_close=None): + """Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window.""" + + dialog.visible = False + + button_show.click( + fn=lambda: gr.update(visible=True), generation_info = json.loads(generation_info) + + outputs=[dialog], + ).then(fn=None, _js="function(){ popup(gradioApp().getElementById('" + dialog.elem_id + "')); }") + + if button_close: + return plaintext_to_html(generation_info["infotexts"][img_index]), gr.update() from modules import call_queue, shared diff --git a/modules/ui_components.py b/modules/ui_components.py index 64451df7a4e5ab2931ae95135e2d61a9387b2f33..55979f62629eacd0c3c4b88913b282950fbb970e 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -20,6 +20,18 @@ def get_block_name(self): return "button" +class ResizeHandleRow(gr.Row): + """Same as gr.Row but fits inside gradio forms""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.elem_classes.append("resize-handle-row") + + def get_block_name(self): + return "row" + + class FormRow(FormComponent, gr.Row): """Same as gr.Row but fits inside gradio forms""" @@ -35,7 +47,7 @@ return "column" class FormGroup(FormComponent, gr.Group): - """Same as gr.Row but fits inside gradio forms""" + """Same as gr.Group but fits inside gradio forms""" def get_block_name(self): return "group" @@ -72,3 +84,62 @@ def get_block_name(self): return "dropdown" + +class InputAccordion(gr.Checkbox): + """A gr.Accordion that can be used as an input - returns True if open, False if closed. + + Actaully just a hidden checkbox, but creates an accordion that follows and is followed by the state of the checkbox. + """ + + global_index = 0 + + def __init__(self, value, **kwargs): + self.accordion_id = kwargs.get('elem_id') + if self.accordion_id is None: + self.accordion_id = f"input-accordion-{InputAccordion.global_index}" + InputAccordion.global_index += 1 + + kwargs_checkbox = { + **kwargs, + "elem_id": f"{self.accordion_id}-checkbox", + "visible": False, + } + super().__init__(value, **kwargs_checkbox) + + self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self]) + + kwargs_accordion = { + **kwargs, + "elem_id": self.accordion_id, + "label": kwargs.get('label', 'Accordion'), + "elem_classes": ['input-accordion'], + "open": value, + } + self.accordion = gr.Accordion(**kwargs_accordion) + + def extra(self): + """Allows you to put something into the label of the accordion. + + Use it like this: + + ``` + with InputAccordion(False, label="Accordion") as acc: + with acc.extra(): + FormHTML(value="hello", min_width=0) + + ... + ``` + """ + + return gr.Column(elem_id=self.accordion_id + '-extra', elem_classes='input-accordion-extra', min_width=0) + + def __enter__(self): + self.accordion.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.accordion.__exit__(exc_type, exc_val, exc_tb) + + def get_block_name(self): + return "checkbox" + diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index f3e4fba7eece3cb67db6f64f1fdab5110ede3698..e01382676a2a0febf181e8046bb83ce17b97dd56 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -65,8 +65,8 @@ timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S') filename = os.path.join(config_states_dir, f"{timestamp}_{name}.json") print(f"Saving backup of webui/extension state to {filename}.") with open(filename, "w", encoding="utf-8") as f: + restart.stop_program() from datetime import datetime -import json config_states.list_config_states() new_value = next(iter(config_states.all_config_states.keys()), "Current") new_choices = ["Current"] + list(config_states.all_config_states.keys()) @@ -165,9 +165,9 @@ else: ext_status = ext.status style = "" -import json +import time import json -import git + style = STYLE_PRIMARY version_link = ext.version @@ -203,211 +203,231 @@ config_name = config_state.get("name", "Config") created_date = time.asctime(time.gmtime(config_state["created_at"])) filepath = config_state.get("filepath", "<unknown>") - code = f"""<!-- {time.time()} -->""" + assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" +import time import json +import git restart.stop_program() +import gradio as gr +import time import json +import html def save_config_state(name): + if webui_commit_date: + webui_commit_date = time.asctime(time.gmtime(webui_commit_date)) +import json from modules.call_queue import wrap_gradio_gpu_call +def save_config_state(name): import threading -import json + import time +import os import time -import json import time +import os from datetime import datetime -import json import time +import os - else: + -import json import time +import os import git import json -import time import gradio as gr +import html import json -import time import html import json - with open(filename, "w", encoding="utf-8") as f: +def extension_table(): - + if current_webui["remote"] != webui_remote: - current_webui = config_states.get_webui_config() + style_remote = STYLE_PRIMARY - + if current_webui["branch"] != webui_branch: + current_config_state = config_states.get_config() import json -from datetime import datetime + current_config_state = config_states.get_config() import os -available_extensions = {"extensions": []} + current_config_state = config_states.get_config() import threading -available_extensions = {"extensions": []} + + current_config_state = config_states.get_config() import time -available_extensions = {"extensions": []} + current_config_state = config_states.get_config() from datetime import datetime -available_extensions = {"extensions": []} + current_config_state = config_states.get_config() -available_extensions = {"extensions": []} + current_config_state = config_states.get_config() import git -available_extensions = {"extensions": []} + current_config_state = config_states.get_config() import gradio as gr -available_extensions = {"extensions": []} + current_config_state = config_states.get_config() import html -STYLE_PRIMARY = ' style="color: var(--primary-400)"' + if not name: - -STYLE_PRIMARY = ' style="color: var(--primary-400)"' + if not name: import json -STYLE_PRIMARY = ' style="color: var(--primary-400)"' + if not name: import os -STYLE_PRIMARY = ' style="color: var(--primary-400)"' + if not name: import threading - + <th>Commit</th> -STYLE_PRIMARY = ' style="color: var(--primary-400)"' +import time import time +from datetime import datetime -import json + if not name: -from datetime import datetime -import html +import time import time +import git - <tr> + <tbody> + if not name: import json -import json + <td> + name = "Config" -import json + name = "Config" import json - if restore_type == "webui" or restore_type == "both": - <th>Date</th> + <td> -import shutil +import time from datetime import datetime +import os + name = "Config" import json - + <td> + <label{style_commit}>{commit_link}</label> + name = "Config" import json -import git + if not name: import html +import time from datetime import datetime +import time + name = "Config" import json + if not name: -import git - <td><label{style_branch}>{webui_branch}</label></td> + </tbody> -import json + name = "Config" -import html -import json + name = "Config" import git -import shutil +import time from datetime import datetime +import gradio as gr -from modules.paths_internal import config_states_dir + if not name: -def check_access(): + if not name: import json - """ + <th>Extension</th> - -def check_access(): + if not name: import os -def check_access(): + if not name: import threading -import html + if not name: import time -import html + if not name: from datetime import datetime -import json + </tr> + if not name: import git +import time import time +import gradio as gr -import shutil + current_config_state["name"] = name -import shutil + + current_config_state["name"] = name import json -import json +import time +import os -import shutil + current_config_state["name"] = name import threading -import shutil + ext_branch = ext_conf["branch"] or "<unknown>" + current_config_state["name"] = name from datetime import datetime -import shutil + current_config_state["name"] = name -import shutil + current_config_state["name"] = name import git -import shutil + current_config_state["name"] = name import gradio as gr +import time +import html + assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}" import json +import time import git -from datetime import datetime + timestamp = datetime.now().strftime('%Y_%m_%d-%H_%M_%S') import json +import time import git +import os + date_link = make_commit_link(ext_commit_hash, ext_remote, ext_commit_date) -import json + style_enabled = "" + style_remote = "" + style_branch = "" +import time if 'FETCH_HEAD' not in str(e): -import json +import time raise -import json +import time errors.report(f"Error checking updates for {ext.name}", exc_info=True) -import json +import time shared.state.nextjob() -import json +import time return extension_table(), "" -import json +import time def make_commit_link(commit_hash, remote, text=None): -import json +import time if text is None: -import errno import time -import json text = commit_hash[:8] - -import json +import time if remote.startswith("https://github.com/"): -import json +import time if remote.endswith(".git"): -import json +import time remote = remote[:-4] - -import json +import time href = remote + "/commit/" + commit_hash -import json + +import time return f'<a href="{href}" target="_blank">{text}</a>' -import json +import time return text -import json +import time def extension_table(): -import json +import time code = f"""<!-- {time.time()} --> -import json +import time <table id="extensions"> -import json +import time <thead> - if current_ext.enabled != ext_enabled: -def apply_and_restart(disable_list, update_list, disable_all): + if not name: - if current_ext.remote != ext_remote: - style_remote = STYLE_PRIMARY - if current_ext.branch != ext_branch: - check_access() + current_config_state["name"] = name - if current_ext.commit_hash != ext_commit_hash: - style_commit = STYLE_PRIMARY - code += f""" +import time <tr> - <td><label{style_enabled}><input class="gr-check-radio gr-checkbox" type="checkbox" disabled="true" {'checked="checked"' if ext_enabled else ''}>{html.escape(ext_name)}</label></td> +</table>""" -import json -import git - check_access() import time -STYLE_PRIMARY = ' style="color: var(--primary-400)"' import html -import json import git - </tr> -import shutil + print(f"Saving backup of webui/extension state to {filename}.") import gradio as gr - + code = f"""<!-- {time.time()} --> - code += """ +<h2>Config Backup: {config_name}</h2> -import json +import time import threading + -import json +import time import threading -import json +import git - """ +<h2>This file is corrupted</h2>""" return code @@ -624,21 +645,24 @@ apply_label = ("Apply and restart UI" if restart.is_restartable() else "Apply and quit") apply = gr.Button(value=apply_label, variant="primary") check = gr.Button(value="Check for updates") extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all") - extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False) + extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False, container=False) - extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False) + extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False, container=False) html = "" - except Exception: -import threading + if shared.cmd_opts.disable_all_extensions or shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions != "none": + with open(filename, "w", encoding="utf-8") as f: import threading + msg = '"--disable-all-extensions" was used, remove it to load all extensions again' + elif shared.opts.disable_all_extensions != "none": + msg = '"Disable all extensions" was set, change it to "none" to load all extensions again' + with open(filename, "w", encoding="utf-8") as f: import git - except Exception: + with open(filename, "w", encoding="utf-8") as f: import gradio as gr - except Exception: + with open(filename, "w", encoding="utf-8") as f: import html -</span> - """ + info = gr.HTML(html) extensions_table = gr.HTML('Loading...') ui.load(fn=extension_table, inputs=[], outputs=[extensions_table]) @@ -661,7 +685,7 @@ with gr.TabItem("Available", id="available"): with gr.Row(): refresh_available_extensions_button = gr.Button(value="Load from:", variant="primary") extensions_index_url = os.environ.get('WEBUI_EXTENSIONS_INDEX', "https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui-extensions/master/index.json") - shared.opts.disable_all_extensions = disable_all + json.dump(current_config_state, f) extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) @@ -670,8 +694,8 @@ hide_tags = gr.CheckboxGroup(value=["ads", "localization", "installed"], label="Hide extensions with tags", choices=["script", "ads", "localization", "installed"]) sort_column = gr.Radio(value="newest first", label="Order", choices=["newest first", "oldest first", "a-z", "z-a", "internal order",'update time', 'create time', "stars"], type="index") with gr.Row(): - shared.opts.disable_all_extensions = disable_all from datetime import datetime +import errno install_result = gr.HTML() available_extensions_table = gr.HTML() diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index f2752f107805235360b17ae276e44ecc261eac97..063bd7b80e66345121a209b7c0483955c022c7c1 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -2,9 +2,8 @@ import os.path import urllib.parse from pathlib import Path -from modules import shared, ui_extra_networks_user_metadata, errors +from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks from modules.images import read_info_from_image, save_image_with_geninfo -from modules.ui import up_down_symbol import gradio as gr import json import html @@ -101,18 +100,9 @@ pass def read_user_metadata(self, item): filename = item.get("filename", None) - basename, ext = os.path.splitext(filename) - metadata_filename = basename + '.json' - -from modules.ui import up_down_symbol from pathlib import Path - try: -from modules.ui import up_down_symbol +import urllib.parse from modules import shared, ui_extra_networks_user_metadata, errors - with open(metadata_filename, "r", encoding="utf8") as file: - metadata = json.load(file) - except Exception as e: - errors.display(e, f"reading extra network user metadata from {metadata_filename}") desc = metadata.get("description", None) if desc is not None: @@ -166,7 +156,7 @@ if subdirs: subdirs = {"": 1, **subdirs} subdirs_html = "".join([f""" -<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_tabs", event)'> +<button class='lg secondary gradio-button custom-button{" search-all" if subdir=="" else ""}' onclick='extraNetworksSearchButton("{tabname}_extra_search", event)'> {html.escape(subdir if subdir!="" else "all")} </button> """ for subdir in subdirs]) @@ -358,9 +348,12 @@ return sorted(pages, key=lambda x: tab_scores[x.name]) +from pathlib import Path import urllib.parse - +from modules.ui import up_down_symbol +def get_metadata(page: str = "", item: str = ""): import gradio as gr + ui = ExtraNetworksUi() ui.pages = [] ui.pages_contents = [] @@ -368,70 +361,72 @@ ui.user_metadata_editors = [] ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy()) ui.tabname = tabname +from pathlib import Path import urllib.parse - s = s.replace('\\', '\\\\') +import json - for page in ui.stored_extra_pages: + - with gr.Tab(page.title, id=page.id_page): - elem_id = f"{tabname}_{page.id_page}_cards_html" - page_elem = gr.HTML('Loading...', elem_id=elem_id) - raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + from starlette.responses import JSONResponse - - raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + from starlette.responses import JSONResponse import os.path - - raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + from starlette.responses import JSONResponse import urllib.parse - raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + from starlette.responses import JSONResponse from pathlib import Path - raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + from starlette.responses import JSONResponse - raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + from starlette.responses import JSONResponse from modules import shared, ui_extra_networks_user_metadata, errors - raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + + from starlette.responses import JSONResponse from modules.images import read_info_from_image, save_image_with_geninfo - raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + from starlette.responses import JSONResponse from modules.ui import up_down_symbol - raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + from starlette.responses import JSONResponse import gradio as gr - raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.") + from starlette.responses import JSONResponse import json - ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + +from pathlib import Path - ext = os.path.splitext(filename)[1].lower() + page = next(iter([x for x in extra_pages if x.name == page]), None) import os.path - ext = os.path.splitext(filename)[1].lower() + page = next(iter([x for x in extra_pages if x.name == page]), None) import urllib.parse +from pathlib import Path +from pathlib import Path - ext = os.path.splitext(filename)[1].lower() from pathlib import Path - ext = os.path.splitext(filename)[1].lower() - """Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time.""" import urllib.parse -from modules.ui import up_down_symbol from modules.images import read_info_from_image, save_image_with_geninfo +import json ext = os.path.splitext(filename)[1].lower() -from modules.ui import up_down_symbol - if is_empty: + for tab in unrelated_tabs: - return True, *ui.pages_contents +from pathlib import Path +from modules.images import read_info_from_image, save_image_with_geninfo - return True, *[gr.update() for _ in ui.pages_contents] + +from pathlib import Path +from modules.ui import up_down_symbol -import urllib.parse + page = next(iter([x for x in extra_pages if x.name == page]), None) import gradio as gr -import os.path - button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False) + +from pathlib import Path +import json import urllib.parse -import gradio as gr + with open(metadata_filename, "r", encoding="utf8") as file: from pathlib import Path +def add_pages_to_demo(app): + if ext not in (".png", ".jpg", ".jpeg", ".webp", ".gif"): - +import json def refresh(): for pg in ui.stored_extra_pages: @@ -440,6 +434,7 @@ ui.pages_contents = [pg.create_html(ui.tabname) for pg in ui.stored_extra_pages] return ui.pages_contents + interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages]) button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) return ui diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 76780cfd0af55701a6bdec25e05b804c07f873fe..ca6c26076f9b8b7e4fd49062e5614c9fe1b1b544 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -3,6 +3,7 @@ import os from modules import shared, ui_extra_networks, sd_models from modules.ui_extra_networks import quote_js +from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): @@ -12,24 +13,29 @@ def refresh(self): shared.refresh_checkpoints() - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) path, ext = os.path.splitext(checkpoint.filename) return { "name": checkpoint.name_for_extra, "filename": checkpoint.filename, + "shorthash": checkpoint.shorthash, "preview": self.find_preview(path), "description": self.find_description(path), "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', "local_preview": f"{path}.{shared.opts.samples_format}", + "metadata": checkpoint.metadata, "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, } def list_items(self): - for index, name in enumerate(sd_models.checkpoints_list): + names = list(sd_models.checkpoints_list) + for index, name in enumerate(names): yield self.create_item(name, index) def allowed_directories_for_previews(self): return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] + def create_user_metadata_editor(self, ui, tabname): + return CheckpointUserMetadataEditor(ui, tabname, self) diff --git a/modules/ui_extra_networks_checkpoints_user_metadata.py b/modules/ui_extra_networks_checkpoints_user_metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..25df0a8079ba82d606c5246256c13427fd8f13e0 --- /dev/null +++ b/modules/ui_extra_networks_checkpoints_user_metadata.py @@ -0,0 +1,66 @@ +import gradio as gr + +from modules import ui_extra_networks_user_metadata, sd_vae, shared +from modules.ui_common import create_refresh_button + + +class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor): + def __init__(self, ui, tabname, page): + super().__init__(ui, tabname, page) + + self.select_vae = None + + def save_user_metadata(self, name, desc, notes, vae): + user_metadata = self.get_user_metadata(name) + user_metadata["description"] = desc + user_metadata["notes"] = notes + user_metadata["vae"] = vae + + self.write_user_metadata(name, user_metadata) + + def update_vae(self, name): + if name == shared.sd_model.sd_checkpoint_info.name_for_extra: + sd_vae.reload_vae_weights() + + def put_values_into_components(self, name): + user_metadata = self.get_user_metadata(name) + values = super().put_values_into_components(name) + + return [ + *values[0:5], + user_metadata.get('vae', ''), + ] + + def create_editor(self): + self.create_default_editor_elems() + + with gr.Row(): + self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae") + create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae") + + self.edit_notes = gr.TextArea(label='Notes', lines=4) + + self.create_default_buttons() + + viewed_components = [ + self.edit_name, + self.edit_description, + self.html_filedata, + self.html_preview, + self.edit_notes, + self.select_vae, + ] + + self.button_edit\ + .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\ + .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) + + edited_components = [ + self.edit_description, + self.edit_notes, + self.select_vae, + ] + + self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components) + self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input]) + diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index e53ccb428925e2d589c379640123fe02a8cbbcec..4cedf0851964ecf1bd2a64d352041e71cd1f48e3 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -3,6 +3,8 @@ from modules import shared, ui_extra_networks from modules.ui_extra_networks import quote_js + def refresh(self): + class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): def __init__(self): @@ -11,16 +13,19 @@ def refresh(self): shared.reload_hypernetworks() - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): full_path = shared.hypernetworks[name] path, ext = os.path.splitext(full_path) + sha256 = sha256_from_cache(full_path, f'hypernet/{name}') + shorthash = sha256[0:10] if sha256 else None return { "name": name, "filename": full_path, + "shorthash": shorthash, "preview": self.find_preview(path), "description": self.find_description(path), - "search_term": self.search_terms_from_path(path), + "search_term": self.search_terms_from_path(path) + " " + (sha256 or ""), "prompt": quote_js(f"<hypernet:{name}:") + " + opts.extra_networks_default_multiplier + " + quote_js(">"), "local_preview": f"{path}.preview.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index d1794e501c1c2525b5c644b217534d4693adf272..55ef0ea7b54733d2b1c312c9b5da380383f3bc90 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -12,17 +12,18 @@ def refresh(self): sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) path, ext = os.path.splitext(embedding.filename) return { "name": name, "filename": embedding.filename, + "shorthash": embedding.shorthash, "preview": self.find_preview(path), "description": self.find_description(path), +from modules import ui_extra_networks, sd_hijack, shared import os - def refresh(self): "prompt": quote_js(embedding.name), "local_preview": f"{path}.preview.{shared.opts.samples_format}", "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 63d4b5031f64ec716acea648261f6bb3cae0af0a..b11622a1a19824bac1cb07458c8ef2a47ba4a706 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -36,8 +36,10 @@ def get_user_metadata(self, name): item = self.page.items.get(name, {}) user_metadata = item.get('user_metadata', None) - if user_metadata is None: + if not user_metadata: +import datetime import json +import gradio as gr item['user_metadata'] = user_metadata return user_metadata @@ -92,10 +95,13 @@ def get_metadata_table(self, name): item = self.page.items.get(name, {}) try: filename = item["filename"] + shorthash = item.get("shorthash", None) stats = os.stat(filename) params = [ + ('Filename: ', os.path.basename(filename)), ('File size: ', sysinfo.pretty_bytes(stats.st_size)), + ('Hash: ', shorthash), ('Modified: ', datetime.datetime.fromtimestamp(stats.st_mtime).strftime('%Y-%m-%d %H:%M')), ] @@ -113,7 +119,7 @@ except Exception as e: errors.display(e, f"reading metadata info for {name}") params = [] - table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params) + '</table>' + table = '<table class="file-metadata">' + "".join(f"<tr><th>{name}</th><td>{value}</td></tr>" for name, value in params if value is not None) + '</table>' return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '') @@ -123,7 +129,7 @@ filename = item.get("filename", None) basename, ext = os.path.splitext(filename) with open(basename + '.json', "w", encoding="utf8") as file: - json.dump(metadata, file) + json.dump(metadata, file, indent=4) def save_user_metadata(self, name, desc, notes): user_metadata = self.get_user_metadata(name) diff --git a/modules/ui_loadsave.py b/modules/ui_loadsave.py index 0052a5cce45a3746c54153232c2556c22e79fd50..9a40cf4fc93ca4100ee3b121ffde0a49fda9ca85 100644 --- a/modules/ui_loadsave.py +++ b/modules/ui_loadsave.py @@ -8,7 +8,7 @@ from modules.ui_components import ToolButton class UiLoadsave: - """allows saving and restorig default values for gradio components""" + """allows saving and restoring default values for gradio components""" def __init__(self, filename): self.filename = filename @@ -48,6 +48,14 @@ self.ui_settings[key] = getattr(obj, field) elif condition and not condition(saved_value): pass else: + if isinstance(x, gr.Textbox) and field == 'value': # due to an undesirable behavior of gr.Textbox, if you give it an int value instead of str, everything dies + saved_value = str(saved_value) + elif isinstance(x, gr.Number) and field == 'value': + try: + saved_value = float(saved_value) + except ValueError: + return + setattr(obj, field, saved_value) if init_field is not None: init_field(saved_value) diff --git a/modules/ui_postprocessing.py b/modules/ui_postprocessing.py index c7dc11540474aabdde9ef501a9c074094bb16672..802e1ce71a16a45298439e45b72e2a2aba24d81f 100644 --- a/modules/ui_postprocessing.py +++ b/modules/ui_postprocessing.py @@ -6,7 +6,7 @@ def create_ui(): tab_index = gr.State(value=0) - with gr.Row().style(equal_height=False, variant='compact'): + with gr.Row(equal_height=False, variant='compact'): with gr.Column(variant='compact'): with gr.Tabs(elem_id="mode_extras"): with gr.TabItem('Single Image', id="single_image", elem_id="extras_single_tab") as tab_single: diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py new file mode 100644 index 0000000000000000000000000000000000000000..85eb3a6417ede53510665060458bf7a8e3a4916c --- /dev/null +++ b/modules/ui_prompt_styles.py @@ -0,0 +1,110 @@ +import gradio as gr + +from modules import shared, ui_common, ui_components, styles + +styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️ +styles_materialize_symbol = '\U0001f4cb' # 📋 + + +def select_style(name): + style = shared.prompt_styles.styles.get(name) + existing = style is not None + empty = not name + + prompt = style.prompt if style else gr.update() + negative_prompt = style.negative_prompt if style else gr.update() + + return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty) + + +def save_style(name, prompt, negative_prompt): + if not name: + return gr.update(visible=False) + + style = styles.PromptStyle(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + shared.prompt_styles.save_styles(shared.styles_filename) + + return gr.update(visible=True) + + +def delete_style(name): + if name == "": + return + + shared.prompt_styles.styles.pop(name, None) + shared.prompt_styles.save_styles(shared.styles_filename) + + return '', '', '' + + +def materialize_styles(prompt, negative_prompt, styles): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) + negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles) + + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])] + + +def refresh_styles(): + return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles)) + + +class UiPromptStyles: + def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): + self.tabname = tabname + + with gr.Row(elem_id=f"{tabname}_styles_row"): + self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles") + edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles") + + with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog: + with gr.Row(): + self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.") + ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles") + self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.") + + with gr.Row(): + self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3) + + with gr.Row(): + self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3) + + with gr.Row(): + self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False) + self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False) + self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close') + + self.selection.change( + fn=select_style, + inputs=[self.selection], + outputs=[self.prompt, self.neg_prompt, self.delete, self.save], + show_progress=False, + ) + + self.save.click( + fn=save_style, + inputs=[self.selection, self.prompt, self.neg_prompt], + outputs=[self.delete], + show_progress=False, + ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False) + + self.delete.click( + fn=delete_style, + _js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }', + inputs=[self.selection], + outputs=[self.selection, self.prompt, self.neg_prompt], + show_progress=False, + ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False) + + self.materialize.click( + fn=materialize_styles, + inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], + outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], + show_progress=False, + ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False) + + ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close) + + + + diff --git a/modules/ui_settings.py b/modules/ui_settings.py index a6076bf306001757f8d967e14859a7d1a420028b..6dde4b6aa04234622b940b8c1d0052a9756ee5b2 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -158,7 +158,7 @@ with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"): loadsave.create_ui() with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"): - args = info.component_args() if callable(info.component_args) else info.component_args or {} + comp = info.component def get_value_for_setting(key): with gr.Row(): diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py index fb75137e6596107e899fd28e93bca7339306697a..85015db56b54ff88aa50e0a34c525477eeb32358 100644 --- a/modules/ui_tempdir.py +++ b/modules/ui_tempdir.py @@ -45,6 +45,9 @@ if shared.opts.temp_dir != "": dir = shared.opts.temp_dir +from PIL import PngImagePlugin + os.makedirs(dir, exist_ok=True) + use_metadata = False metadata = PngImagePlugin.PngInfo() for key, value in pil_image.info.items(): @@ -57,8 +60,9 @@ pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None)) return file_obj.name -from pathlib import Path +def install_ui_tempdir_override(): +import gradio.components -from pathlib import Path +import gradio.components import os diff --git a/modules/util.py b/modules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..60afc0670c74f1c5e633a4859444a22985df7b79 --- /dev/null +++ b/modules/util.py @@ -0,0 +1,58 @@ +import os +import re + +from modules import shared +from modules.paths_internal import script_path + + +def natural_sort_key(s, regex=re.compile('([0-9]+)')): + return [int(text) if text.isdigit() else text.lower() for text in regex.split(s)] + + +def listfiles(dirname): + filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname), key=natural_sort_key) if not x.startswith(".")] + return [file for file in filenames if os.path.isfile(file)] + + +def html_path(filename): + return os.path.join(script_path, "html", filename) + + +def html(filename): + path = html_path(filename) + + if os.path.exists(path): + with open(path, encoding="utf8") as file: + return file.read() + + return "" + + +def walk_files(path, allowed_extensions=None): + if not os.path.exists(path): + return + + if allowed_extensions is not None: + allowed_extensions = set(allowed_extensions) + + items = list(os.walk(path, followlinks=True)) + items = sorted(items, key=lambda x: natural_sort_key(x[0])) + + for root, _, files in items: + for filename in sorted(files, key=natural_sort_key): + if allowed_extensions is not None: + _, ext = os.path.splitext(filename) + if ext not in allowed_extensions: + continue + + if not shared.opts.list_hidden_files and ("/." in root or "\\." in root): + continue + + yield os.path.join(root, filename) + + +def ldm_print(*args, **kwargs): + if shared.opts.hide_ldm_prints: + return + + print(*args, **kwargs) diff --git a/requirements.txt b/requirements.txt index b3f8a7f41fafd1ab4886efb78db87b18831f7f47..960fa0bd7be30d64aaa2c9236deaee2b7c954d3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,8 +6,9 @@ basicsr blendmodes clean-fid einops +fastapi>=0.90.1 gfpgan -gradio==3.32.0 +gradio==3.41.0 inflection jsonmerge kornia @@ -31,4 +32,4 @@ torch torchdiffeq torchsde accelerate -GitPython + diff --git a/requirements_versions.txt b/requirements_versions.txt index d07ab456ca9d0fccba5585cb0e0dd1419db76ecc..6c679e242c5f86ccfbf2785ebe55404ece8229cd 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,5 +1,6 @@ -GitPython==3.1.30 +GitPython==3.1.32 Pillow==9.5.0 +accelerate==0.18.0 accelerate==0.18.0 basicsr==1.4.2 blendmodes==2022 @@ -7,7 +8,7 @@ clean-fid==0.1.35 einops==0.4.1 fastapi==0.94.0 gfpgan==1.3.8 -gradio==3.32.0 +gradio==3.41.0 httpcore==0.15 inflection==0.5.1 jsonmerge==1.8.0 @@ -22,14 +23,14 @@ pytorch_lightning==1.9.4 realesrgan==0.3.0 resize-right==0.0.2 safetensors==0.3.1 -Pillow==9.5.0 +accelerate==0.18.0 blendmodes==2022 -Pillow==9.5.0 +accelerate==0.18.0 clean-fid==0.1.35 -Pillow==9.5.0 +accelerate==0.18.0 einops==0.4.1 torch torchdiffeq==0.2.3 torchsde==0.2.5 accelerate==0.18.0 -GitPython==3.1.30 +fastapi==0.94.0 diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 1010845e56b31f544ff709d3a5ed740ace256213..daaf761f165e094ecb81bde0d8dd7a9b2112de8a 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -3,6 +3,7 @@ from copy import copy from itertools import permutations, chain import random import csv +import os.path from io import StringIO from PIL import Image import numpy as np @@ -10,7 +11,7 @@ import modules.scripts as scripts import gradio as gr -from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion +from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion, errors from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.shared import opts, state import modules.shared as shared @@ -67,14 +68,6 @@ p.prompt = prompt_tmp + p.prompt import csv - sampler_name = sd_samplers.samplers_map.get(x.lower(), None) - if sampler_name is None: - raise RuntimeError(f"Unknown sampler: {x}") - - p.sampler_name = sampler_name - - -import csv import csv for x in xs: if x.lower() not in sd_samplers.samplers_map: @@ -90,6 +83,15 @@ def confirm_checkpoints(p, xs): for x in xs: + if modules.sd_models.get_closet_checkpoint_match(x) is None: + raise RuntimeError(f"Unknown checkpoint: {x}") + + +def confirm_checkpoints_or_none(p, xs): + for x in xs: + if x in (None, "", "None", "none"): + continue + if modules.sd_models.get_closet_checkpoint_match(x) is None: raise RuntimeError(f"Unknown checkpoint: {x}") @@ -184,11 +186,25 @@ def format_nothing(p, opt, x): return "" +def format_remove_path(p, opt, x): + return os.path.basename(x) + + def str_permutations(x): """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" return x +def list_to_csv_string(data_list): + with StringIO() as o: + csv.writer(o).writerow(data_list) + return o.getvalue().strip() + + +def csv_string_to_list_strip(data_str): + return list(map(str.strip, chain.from_iterable(csv.reader(StringIO(data_str))))) + + class AxisOption: def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None): self.label = label @@ -205,6 +221,7 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.is_img2img = True + class AxisOptionTxt2Img(AxisOption): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -222,14 +239,15 @@ AxisOption("CFG Scale", float, apply_field("cfg_scale")), AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")), AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), -from collections import namedtuple + AxisOptionTxt2Img("Sampler", str, apply_field("sampler_name"), format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), +import csv import random -import numpy as np +from itertools import permutations, chain -from collections import namedtuple +import csv import random - +import random -from collections import namedtuple + p.sampler_name = sampler_name import csv AxisOption("Negative Guidance minimum sigma", float, apply_field("s_min_uncond")), AxisOption("Sigma Churn", float, apply_field("s_churn")), AxisOption("Sigma min", float, apply_field("s_tmin")), @@ -241,6 +260,8 @@ AxisOption("Schedule rho", float, apply_override("rho")), AxisOption("Eta", float, apply_field("eta")), AxisOption("Clip skip", int, apply_clip_skip), AxisOption("Denoising", float, apply_field("denoising_strength")), + AxisOption("Initial noise multiplier", float, apply_field("initial_noise_multiplier")), + AxisOption("Extra noise", float, apply_override("img2img_extra_noise")), AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: ['None'] + list(sd_vae.vae_dict)), @@ -250,6 +271,9 @@ AxisOption("Face restore", str, apply_face_restore, format_value=format_value), AxisOption("Token merging ratio", float, apply_override('token_merging_ratio')), AxisOption("Token merging ratio high-res", float, apply_override('token_merging_ratio_hr')), AxisOption("Always discard next-to-last sigma", str, apply_override('always_discard_next_to_last_sigma', boolean=True), choices=boolean_choice(reverse=True)), + AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)), + AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')), + AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]), ] @@ -295,11 +319,10 @@ cell_mode = "P" cell_size = (processed_result.width, processed_result.height) if processed_result.images[0] is not None: cell_mode = processed_result.images[0].mode -from modules.ui_components import ToolButton import csv + sampler_name = sd_samplers.samplers_map.get(x.lower(), None) cell_size = processed_result.images[0].size processed_result.images[idx] = Image.new(cell_mode, cell_size) - if first_axes_processed == 'x': for ix, x in enumerate(xs): @@ -358,12 +381,11 @@ z_grid = images.image_grid(processed_result.images[:z_count], rows=1) if draw_legend: z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]]) processed_result.images.insert(0, z_grid) +def confirm_samplers(p, xs): from copy import copy - for idx, part in enumerate(prompt_parts): - #processed_result.all_prompts.insert(0, processed_result.all_prompts[0]) + # processed_result.all_prompts.insert(0, processed_result.all_prompts[0]) -from copy import copy +def confirm_samplers(p, xs): import random -import numpy as np processed_result.infotexts.insert(0, processed_result.infotexts[0]) return processed_result @@ -387,9 +409,9 @@ re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") -re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") +re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*])?\s*") - return fun import csv + for x in xs: class Script(scripts.Script): @@ -404,22 +426,21 @@ with gr.Column(scale=19): with gr.Row(): x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) -from copy import copy +def confirm_samplers(p, xs): from PIL import Image -from io import StringIO fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False) with gr.Row(): y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) -from copy import copy +def confirm_samplers(p, xs): import numpy as np fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False) with gr.Row(): z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type")) z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values")) - z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True) + z_values_dropdown = gr.Dropdown(label="Z values", visible=False, multiselect=True, interactive=True) fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False) with gr.Row(variant="compact", elem_id="axis_options"): @@ -430,6 +452,8 @@ include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) with gr.Column(): margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) + with gr.Column(): + csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode")) with gr.Row(variant="compact", elem_id="swap_axes"): swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") @@ -446,69 +470,90 @@ swap_yz_axes_button.click(swap_axes, inputs=yz_swap_args, outputs=yz_swap_args) xz_swap_args = [x_type, x_values, x_values_dropdown, z_type, z_values, z_values_dropdown] swap_xz_axes_button.click(swap_axes, inputs=xz_swap_args, outputs=xz_swap_args) + def fill(axis_type, csv_mode): + axis = self.current_axis_options[axis_type] + for x in xs: from itertools import permutations, chain + if csv_mode: + for x in xs: import csv - p.prompt = p.prompt.replace(xs[0], x) + for x in xs: from io import StringIO - p.prompt = p.prompt.replace(xs[0], x) + for x in xs: from PIL import Image - + else: - p.prompt = p.prompt.replace(xs[0], x) + for x in xs: import numpy as np -from itertools import permutations, chain + + fill_x_button.click(fn=fill, inputs=[x_type, csv_mode], outputs=[x_values, x_values_dropdown]) + if x.lower() not in sd_samplers.samplers_map: - -from itertools import permutations, chain + if x.lower() not in sd_samplers.samplers_map: from collections import namedtuple - def select_axis(axis_type,axis_values_dropdown): + def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode): choices = self.current_axis_options[axis_type].choices has_choices = choices is not None - current_values = axis_values_dropdown + if has_choices: choices = choices() + if csv_mode: + if x.lower() not in sd_samplers.samplers_map: from itertools import permutations, chain -from collections import namedtuple +import csv from PIL import Image +import random - p.negative_prompt = p.negative_prompt.replace(xs[0], x) + axis_values_dropdown = [] + else: + if axis_values: + axis_values_dropdown = list(filter(lambda x: x in choices, csv_string_to_list_strip(axis_values))) + if x.lower() not in sd_samplers.samplers_map: import numpy as np - p.negative_prompt = p.negative_prompt.replace(xs[0], x) -def apply_order(p, x, xs): + return (gr.Button.update(visible=has_choices), gr.Textbox.update(visible=not has_choices or csv_mode, value=axis_values), + raise RuntimeError(f"Unknown sampler: {x}") -def apply_order(p, x, xs): + raise RuntimeError(f"Unknown sampler: {x}") from collections import namedtuple -def apply_order(p, x, xs): + raise RuntimeError(f"Unknown sampler: {x}") from copy import copy -def apply_order(p, x, xs): + raise RuntimeError(f"Unknown sampler: {x}") from itertools import permutations, chain -def apply_order(p, x, xs): + raise RuntimeError(f"Unknown sampler: {x}") import random -def apply_order(p, x, xs): + raise RuntimeError(f"Unknown sampler: {x}") import csv -def apply_order(p, x, xs): + raise RuntimeError(f"Unknown sampler: {x}") from io import StringIO -def apply_order(p, x, xs): + raise RuntimeError(f"Unknown sampler: {x}") from PIL import Image + return _fill_x_button, _x_values, _x_values_dropdown, _fill_y_button, _y_values, _y_values_dropdown, _fill_z_button, _z_values, _z_values_dropdown + + csv_mode.change(fn=change_choice_mode, inputs=[csv_mode, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown], outputs=[fill_x_button, x_values, x_values_dropdown, fill_y_button, y_values, y_values_dropdown, fill_z_button, z_values, z_values_dropdown]) + + def get_dropdown_update_from_params(axis, params): def apply_order(p, x, xs): -import numpy as np +import csv + vals = params.get(val_key, "") + valslist = csv_string_to_list_strip(vals) + return gr.update(value=valslist) self.infotext_fields = ( (x_type, "X Type"), (x_values, "X Values"), - (x_values_dropdown, lambda params:get_dropdown_update_from_params("X",params)), + (x_values_dropdown, lambda params: get_dropdown_update_from_params("X", params)), (y_type, "Y Type"), (y_values, "Y Values"), - token_order = [] +def apply_checkpoint(p, x, xs): import csv (z_type, "Z Type"), (z_values, "Z Values"), - (z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)), + (z_values_dropdown, lambda params: get_dropdown_update_from_params("Z", params)), ) - return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size] + return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode] - def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size): + def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode): if not no_fixed_seeds: modules.processing.fix_seed(p) @@ -517,11 +564,11 @@ def process_axis(opt, vals, vals_dropdown): if opt.label == 'Nothing': return [0] - # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen +def apply_checkpoint(p, x, xs): valslist = vals_dropdown else: - valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x] + valslist = csv_string_to_list_strip(vals) if opt.type == int: valslist_ext = [] @@ -537,12 +584,10 @@ valslist_ext += list(range(start, end, step)) elif mc is not None: start = int(mc.group(1)) -from itertools import permutations, chain from io import StringIO -import random +from collections import namedtuple -from itertools import permutations, chain from io import StringIO -import csv +from copy import copy valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()] else: @@ -563,11 +608,10 @@ valslist_ext += np.arange(start, end + step, step).tolist() elif mc is not None: start = float(mc.group(1)) + info = modules.sd_models.get_closet_checkpoint_match(x) from itertools import permutations, chain - return None -from itertools import permutations, chain from io import StringIO -import csv +from copy import copy valslist_ext += np.linspace(start=start, stop=end, num=num).tolist() else: @@ -586,25 +630,24 @@ return valslist x_opt = self.current_axis_options[x_type] - if x_opt.choices is not None: + if x_opt.choices is not None and not csv_mode: - x_values = ",".join(x_values_dropdown) + x_values = list_to_csv_string(x_values_dropdown) xs = process_axis(x_opt, x_values, x_values_dropdown) y_opt = self.current_axis_options[y_type] - if y_opt.choices is not None: + if y_opt.choices is not None and not csv_mode: - y_values = ",".join(y_values_dropdown) + y_values = list_to_csv_string(y_values_dropdown) ys = process_axis(y_opt, y_values, y_values_dropdown) z_opt = self.current_axis_options[z_type] - # Split the prompt up, taking out the tokens from io import StringIO +import numpy as np -from itertools import permutations, chain + info = modules.sd_models.get_closet_checkpoint_match(x) -from PIL import Image zs = process_axis(z_opt, z_values, z_values_dropdown) # this could be moved to common code, but unlikely to be ever triggered anywhere else - for _, token in token_order: + if info is None: grid_mp = round(len(xs) * len(ys) * len(zs) * p.width * p.height / 1000000) assert grid_mp < opts.img_max_size_mp, f'Error: Resulting grid would be too large ({grid_mp} MPixels) (max configured size is {opts.img_max_size_mp} MPixels)' @@ -688,9 +731,14 @@ x_opt.apply(pc, x, xs) y_opt.apply(pc, y, ys) z_opt.apply(pc, z, zs) -import random + try: from io import StringIO +from collections import namedtuple from copy import copy + except Exception as e: + errors.display(e, "generating image for xyz plot") + + res = Processed(p, [], p.seed, "") # Sets subgrid infotexts subgrid_index = 1 + iz @@ -761,9 +809,9 @@ if opts.grid_save: # Auto-save main and sub-grids: grid_count = z_count + 1 if z_count > 1 else 1 for g in range(grid_count): -import csv +from io import StringIO from collections import namedtuple -import numpy as np +from io import StringIO adj_g = g-1 if g > 0 else g images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[adj_g], seed=processed.all_seeds[adj_g], grid=True, p=processed) diff --git a/style.css b/style.css index 6c92d6e78a1c119f39db3d3cf5f667626a7718cd..d67b63363b5b9c693ed10b415c6e4da4d5009687 100644 --- a/style.css +++ b/style.css @@ -8,6 +8,7 @@ :root, .dark{ --checkbox-label-gap: 0.25em 0.1em; --section-header-text-size: 12pt; --block-background-fill: transparent; + } .block.padded:not(.gradio-accordion) { @@ -42,13 +43,16 @@ .block.gradio-textbox, .block.gradio-radio, .block.gradio-checkboxgroup, .block.gradio-number, -@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap'); +.gradio-dropdown ul.options li.item { /* temporary fix to load default gradio font in frontend instead of backend */ -{ border-width: 0 !important; box-shadow: none !important; } +div.gradio-group, div.styler{ + border-width: 0 !important; + background: none; +} .gap.compact{ padding: 0; gap: 0.2em 0; @@ -134,6 +138,16 @@ font-weight: bold; cursor: pointer; } +/* gradio 3.39 puts a lot of overflow: hidden all over the place for an unknown reason. */ +div.gradio-container, .block.gradio-textbox, div.gradio-group, div.gradio-dropdown{ + overflow: visible !important; +} + +/* align-items isn't enough and elements may overflow in Safari. */ +.unequal-height { + align-content: flex-start; +} + /* general styled components */ @@ -158,16 +172,6 @@ background: var(--button-secondary-background-fill-hover); color: var(--button-secondary-text-color-hover); } -.checkboxes-row{ - margin-bottom: 0.5em; - margin-left: 0em; -} -.checkboxes-row > div{ - flex: 0; - white-space: nowrap; - min-width: auto; -} - button.custom-button{ border-radius: var(--button-large-radius); padding: var(--button-large-padding); @@ -182,6 +186,14 @@ align-items: center; transition: var(--button-transition); box-shadow: var(--button-shadow); padding: 0 !important; +@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap'); +} + +div.block.gradio-accordion { + border: 1px solid var(--block-border-color) !important; + border-radius: 8px !important; + margin: 2px 0; + padding: 0.05em 0; @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap'); } @@ -225,10 +237,13 @@ display: flex; } [id$=_subseed_show] label{ + margin-bottom: 0.65em; + align-self: end; } - --section-header-text-size: 12pt; - --block-background-fill: transparent; + :root, .dark{ +:root, .dark{ + gap: 0.5em; } .html-log .comments{ @@ -276,11 +291,10 @@ min-height: 768px; } } -/* temporary fix to load default gradio font in frontend instead of backend */ :root, .dark{ -} -/* temporary fix to load default gradio font in frontend instead of backend */ --checkbox-label-gap: 0.25em 0.1em; +:root, .dark{ + --section-header-text-size: 12pt; } #txt2img_actions_column, #img2img_actions_column { gap: 0.5em; @@ -321,13 +334,6 @@ border-radius: 0 0.5rem 0.5rem 0; } div.form{ -/* temporary fix to load default gradio font in frontend instead of backend */ - min-height: 0 !important; - padding: .625rem .75rem; - margin-left: -0.75em -} - -div.form{ :root, .dark{ display: flex; align-items: end; @@ -351,7 +357,7 @@ min-width: min(13.5em, 100%) !important; } div.dimensions-tools{ - min-width: 0 !important; + min-width: 1.6em !important; max-width: fit-content; flex-direction: column; place-content: center; @@ -368,10 +374,10 @@ background: var(--panel-background-fill); z-index: 5; } - +:root, .dark{ -:root, .dark{ +} .block.padded:not(.gradio-accordion) { -/* temporary fix to load default gradio font in frontend instead of backend */ +} } .infotext { @@ -392,20 +398,23 @@ } /* settings */ #quicksettings { - width: fit-content; align-items: end; } #quicksettings > div, #quicksettings > fieldset{ - border-width: 0; + z-index: 3000; - + background: transparent !important; +.gradio-dropdown ul.options li.item.selected { padding: 0; border: none; box-shadow: none; background: none; } +#quicksettings > div.gradio-dropdown{ + min-width: 24em !important; +} #settings{ display: block; @@ -495,17 +504,27 @@ #sysinfo_validity{ font-size: 18pt; } +#settings .settings-info{ + max-width: 48em; + border: 1px dotted #777; + margin: 0; + padding: 1em; +} + /* live preview */ .progressDiv{ +/* temporary fix to load default gradio font in frontend instead of backend */ .compact{ -} height: 20px; background: #b4c0cc; border-radius: 3px !important; - +.gradio-dropdown ul.options li.item.selected { --checkbox-label-gap: 0.25em 0.1em; } + +[id$=_results].mobile{ + margin-top: 28px; } .dark .progressDiv{ @@ -530,21 +548,18 @@ .livePreview{ position: absolute; z-index: 300; - background-color: white; + background: var(--background-fill-primary); .block.gradio-radio, -/* general gradio fixes */ } - -.dark .livePreview{ -.block.gradio-radio, --checkbox-label-gap: 0.25em 0.1em; +} } .livePreview img{ position: absolute; object-fit: contain; width: 100%; - height: 100%; + height: calc(100% - 60px); /* to match gradio's height */ } /* fullscreen popup (ie in Lora's (i) button) */ @@ -613,14 +628,20 @@ .modalControls { display: flex; gap: 1em; padding: 1em; + background-color:rgba(0,0,0,0); + z-index: 1; + background-color: var(--neutral-100); @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap'); - background: transparent !important; +} +.modalControls:hover { + background-color:rgba(0,0,0,0.9); } .modalClose { margin-left: auto; } .modalControls span{ color: white; + text-shadow: 0px 0px 0.25em black; font-size: 35px; font-weight: bold; cursor: pointer; @@ -645,6 +666,13 @@ width: 100%; min-height: 0; } +#modalImage{ + position: absolute; + top: 50%; + left: 50%; + transform: translateX(-50%) translateY(-50%); +} + .modalPrev, .modalNext { cursor: pointer; @@ -786,12 +814,18 @@ /* extra networks UI */ .extra-network-cards{ -.gradio-dropdown label span:not(.has-info), + height: calc(100vh - 24rem); +.dark .gradio-dropdown ul.options li.item.selected { .gradio-dropdown label span:not(.has-info), +/* general gradio fixes */ +.dark .gradio-dropdown ul.options li.item.selected { @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap'); -.gradio-dropdown label span:not(.has-info), +} + +.dark .gradio-dropdown ul.options li.item.selected { /* general gradio fixes */ + min-height: 3.4rem; } .extra-networks > div > [id *= '_extra_']{ @@ -806,11 +840,13 @@ .extra-network-subdirs button{ margin: 0 0.15em; } .extra-networks .tab-nav .search, -.extra-networks .tab-nav .sort{ +.extra-networks .tab-nav .sort, - padding: 0 !important; +.dark .gradio-dropdown ul.options li.item.selected { --section-header-text-size: 12pt; +{ margin: 0.3em; align-self: center; + width: auto; } .extra-networks .tab-nav .search { @@ -846,6 +882,7 @@ display: none; position: absolute; color: white; right: 0; + z-index: 1 } .extra-network-cards .card:hover .button-row{ display: flex; @@ -995,3 +1032,81 @@ .edit-user-metadata-buttons{ margin-top: 1.5em; } + + + + +div.block.gradio-box.popup-dialog, .popup-dialog { + width: 56em; + background: var(--body-background-fill); + padding: 2em !important; +} + +div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{ + margin-top: 1em; +} + +div.block.input-accordion{ + +} + +.input-accordion-extra{ + flex: 0 0 auto !important; + margin: 0 0.5em 0 auto; +} + +div.accordions > div.input-accordion{ + min-width: fit-content !important; +} + +div.accordions > div.gradio-accordion .label-wrap span{ + white-space: nowrap; + margin-right: 0.25em; +} + +div.accordions{ + gap: 0.5em; +} + +div.accordions > div.input-accordion.input-accordion-open{ + flex: 1 auto; + flex-flow: column; +} + + +/* sticky right hand columns */ + +#img2img_results, #txt2img_results, #extras_results { + position: sticky; + top: 0.5em; +} + +body.resizing { + cursor: col-resize !important; +} + +body.resizing * { + pointer-events: none !important; +} + +body.resizing .resize-handle { + pointer-events: initial !important; +} + +.resize-handle { + position: relative; + cursor: col-resize; + grid-column: 2 / 3; + min-width: 16px !important; + max-width: 16px !important; + height: 100%; +} + +.resize-handle::after { + content: ''; + position: absolute; + top: 0; + bottom: 0; + left: 7.5px; + border-left: 1px dashed var(--border-color-primary); +} diff --git a/test/conftest.py b/test/conftest.py index 0723f62a485799b61aa8c4eec2eb057dadde4392..31a5d9eafb8d76eaefd7be6f2f126211eb3b07d7 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,18 +1,29 @@ import os import pytest +import base64 + + +test_files_path = os.path.dirname(__file__) + "/test_files" + + +def file_to_base64(filename): +import os from PIL import Image +import os from gradio.processing_utils import encode_pil_to_base64 +import os test_files_path = os.path.dirname(__file__) + "/test_files" + return "data:image/png;base64," + base64_str @pytest.fixture(scope="session") # session so we don't read this over and over def img2img_basic_image_base64() -> str: - return encode_pil_to_base64(Image.open(os.path.join(test_files_path, "img2img_basic.png"))) + return file_to_base64(os.path.join(test_files_path, "img2img_basic.png")) @pytest.fixture(scope="session") # session so we don't read this over and over def mask_basic_image_base64() -> str: import os -import os + return encode_pil_to_base64(Image.open(os.path.join(test_files_path, "img2img_basic.png"))) diff --git a/webui-macos-env.sh b/webui-macos-env.sh index 6354e73ba720efa47584854e5122e73fbac9743c..24bc5c42615477b2cc6c16470f6f796bbde77ae7 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -13,9 +13,6 @@ export install_dir="$HOME" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" export TORCH_COMMAND="pip install torch==2.0.1 torchvision==0.15.2" #!/bin/bash -# macOS defaults # -export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" -#!/bin/bash #################################################################### diff --git a/webui.py b/webui.py index 6bf06854fc21f20ec501a989aeaa09a85d996401..5c827dae87d7895907fc4fe13a655ff965bc9b5f 100644 --- a/webui.py +++ b/webui.py @@ -1,504 +1,161 @@ from __future__ import annotations import os -import sys import time -import importlib -import signal -import re -import warnings -import json -from threading import Thread -from __future__ import annotations +import warnings -from __future__ import annotations import os -from fastapi.middleware.cors import CORSMiddleware -from __future__ import annotations import time -from packaging import version - -import logging - -from __future__ import annotations import re -log_level = os.environ.get("SD_WEBUI_LOG_LEVEL") -if log_level: - log_level = getattr(logging, log_level.upper(), None) or logging.INFO - logging.basicConfig( - import os - format='%(asctime)s %(levelname)s [%(name)s] %(message)s', - import time - ) - -logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... -logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) +import warnings -from modules import timer startup_timer = timer.startup_timer startup_timer.record("launcher") import os - -import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them -warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") -warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") -import os import importlib - -import gradio # noqa: F401 -startup_timer.record("import gradio") - -from modules import paths, timer, import_hook, errors, devices # noqa: F401 -import sys -import ldm.modules.encoders.modules # noqa: F401 -startup_timer.record("import ldm") - -import sys import os -from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401 - -# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors -import sys import importlib - torch.__long_version__ = torch.__version__ - torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) - -from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states -import modules.codeformer_model as codeformer -import time from __future__ import annotations -import time -import modules.img2img -import modules.lowvram -import modules.scripts -import modules.sd_hijack -import modules.sd_hijack_optimizations -import modules.sd_models -import modules.sd_vae -import modules.sd_unet -import modules.txt2img -import importlib -import modules.textual_inversion.textual_inversion import modules.progress import modules.ui -from modules import modelloader -from modules.shared import cmd_opts -import modules.hypernetworks.hypernetwork - -startup_timer.record("other imports") - - -if cmd_opts.server_name: - server_name = cmd_opts.server_name -else: -import signal import os - - -def fix_asyncio_event_loop_policy(): - """ -import signal import importlib - event loops in the main threads. Other threads must create event - loops explicitly or `asyncio.get_event_loop` (and therefore - `.IOLoop.current`) will fail. Installing this policy allows event - loops to be created automatically on any thread, matching the - behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2). - """ -import re - if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): - # "Any thread" and "selector" should be orthogonal, but there's not a clean - # interface for composing policies so pick the right base. -import re import importlib - else: - _BasePolicy = asyncio.DefaultEventLoopPolicy - - class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore - """Event loop policy that allows loop creation on any thread. - Usage:: - - asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) - """ - - def get_event_loop(self) -> asyncio.AbstractEventLoop: - try: -import warnings import importlib - except (RuntimeError, AssertionError): - # This was an AssertionError in python 3.4.2 (which ships with debian jessie) - # and changed to a RuntimeError in 3.4.3. - # "There is no current event loop in thread %r" - loop = self.new_event_loop() -import json - return loop - - asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) - - -def check_versions(): -import json import importlib -import json import signal - expected_torch_version = "2.0.0" - if version.parse(torch.__version__) < version.parse(expected_torch_version): - errors.print_error_explanation(f""" -You are running torch {torch.__version__}. -from threading import Thread -To reinstall the desired version, run with commandline flag --reinstall-torch. -Beware that this will cause a lot of large files to be downloaded, as well as -there are reports of issues with training tab on the latest version. - -from threading import Thread import importlib - """.strip()) - -from threading import Thread import re - if shared.xformers_available: - import xformers - - if version.parse(xformers.__version__) < version.parse(expected_xformers_version): - errors.print_error_explanation(f""" -from typing import Iterable import os -The program is tested to work with xformers {expected_xformers_version}. -To reinstall the desired version, run with commandline flag --reinstall-xformers. - -from threading import Thread import importlib - """.strip()) - - -def restore_config_state_file(): - config_state_file = shared.opts.restore_config_state_file - if config_state_file == "": - return - -from __future__ import annotations import os - shared.opts.save(shared.config_filename) - - if os.path.isfile(config_state_file): - print(f"*** About to restore extension state from file: {config_state_file}") - with open(config_state_file, "r", encoding="utf-8") as f: - config_state = json.load(f) -from __future__ import annotations startup_timer.record("import torch") - startup_timer.record("restore extension config") - elif config_state_file: - print(f"!!! Config state backup not found: {config_state_file}") - - -from __future__ import annotations import sys - if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile): - return - try: -from fastapi.middleware.cors import CORSMiddleware import os - print("Invalid path to TLS keyfile given") - if not os.path.exists(cmd_opts.tls_certfile): -from fastapi.middleware.cors import CORSMiddleware import importlib - except TypeError: - cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None - print("TLS setup invalid, running webui without TLS") - else: -from __future__ import annotations import time - startup_timer.record("TLS") -def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]: import signal -import time -from fastapi.middleware.gzip import GZipMiddleware import os - an iterable of (username, password) tuples. - """ - def process_credential_line(s) -> tuple[str, ...] | None: -from fastapi.middleware.gzip import GZipMiddleware +import importlib import importlib -from fastapi.middleware.gzip import GZipMiddleware + import signal - return None - return tuple(s.split(':', 1)) - if cmd_opts.gradio_auth: - for cred in cmd_opts.gradio_auth.split(','): -from packaging import version -from packaging import version import os - yield cred - - if cmd_opts.gradio_auth_path: - with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file: -from __future__ import annotations from modules.shared import cmd_opts -from __future__ import annotations +import os import modules.hypernetworks.hypernetwork -from __future__ import annotations +import os startup_timer.record("other imports") - if cred: - yield cred -def configure_sigint_handler(): - # make the program just exit at ctrl+c without waiting for anything -from __future__ import annotations def fix_asyncio_event_loop_policy(): -from __future__ import annotations + """ -from __future__ import annotations + The default `asyncio` event loop policy only automatically creates -from __future__ import annotations event loops in the main threads. Other threads must create event - # Don't install the immediate-quit handler when running under coverage, - # as then the coverage report won't be generated. - signal.signal(signal.SIGINT, sigint_handler) - -def configure_opts_onchange(): - shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False) - shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) - shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) -# We can't use cmd_opts for this because it will not have been initialized at this point. import signal -# We can't use cmd_opts for this because it will not have been initialized at this point. import re - -def initialize(): - fix_asyncio_event_loop_policy() - validate_tls_options() - configure_sigint_handler() - check_versions() - modelloader.cleanup_models() - configure_opts_onchange() - -log_level = os.environ.get("SD_WEBUI_LOG_LEVEL") import importlib - startup_timer.record("setup SD model") - codeformer.setup_model(cmd_opts.codeformer_models_path) - startup_timer.record("setup codeformer") - startup_timer.record("setup gfpgan") - - initialize_rest(reload_script_modules=False) - - -def initialize_rest(*, reload_script_modules=False): - """ - Called both from initialize() and when reloading the webui. import signal -import time - sd_samplers.set_samplers() - extensions.list_extensions() - startup_timer.record("list extensions") - - restore_config_state_file() - -if log_level: import warnings - shared.sd_upscalers = upscaler.UpscalerLanczos().scalers - modules.scripts.load_scripts() - return - - modules.sd_models.list_models() - log_level = getattr(logging, log_level.upper(), None) or logging.INFO import os - - localization.list_localizations(cmd_opts.localizations_dir) - - with startup_timer.subcategory("load scripts"): - modules.scripts.load_scripts() - - log_level = getattr(logging, log_level.upper(), None) or logging.INFO import importlib - for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: - importlib.reload(module) - startup_timer.record("reload script modules") - - modelloader.load_upscalers() - startup_timer.record("load upscalers") - - modules.sd_vae.refresh_vae_list() - startup_timer.record("refresh VAE") - logging.basicConfig( import sys - datefmt='%Y-%m-%d %H:%M:%S', - - modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers) -logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... - logging.basicConfig( import re - - modules.sd_unet.list_unets() - level=log_level, - - import os -from __future__ import annotations - """ - Accesses shared.sd_model property to load model. - After it's available, if it has been loaded before this access by some extension, - its optimization may be None because the list of optimizaers has neet been filled - level=log_level, +import importlib import time - """ - shared.sd_model # noqa: B018 - - import gradio # noqa: F401 - modules.sd_hijack.apply_optimizations() - - Thread(target=load_model).start() - - format='%(asctime)s %(levelname)s [%(name)s] %(message)s', -import sys +import re from __future__ import annotations -import sys +import re - format='%(asctime)s %(levelname)s [%(name)s] %(message)s', +import re import os -from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401 - - extra_networks.initialize() - extra_networks.register_default_extra_networks() - startup_timer.record("initialize extra networks") - - - format='%(asctime)s %(levelname)s [%(name)s] %(message)s', import re - import sys -import warnings - app.add_middleware(GZipMiddleware, minimum_size=1000) - configure_cors_middleware(app) - app.build_middleware_stack() # rebuild middleware stack on-the-fly - - datefmt='%Y-%m-%d %H:%M:%S', import os - cors_options = { - "allow_methods": ["*"], - "allow_headers": ["*"], - datefmt='%Y-%m-%d %H:%M:%S', import signal +from __future__ import annotations -import time import re - if cmd_opts.cors_allow_origins: - import importlib -import modules.txt2img - cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex - ) import os - - -def create_api(app): - from modules.api.api import Api - api = Api(app, queue_lock) - ) import signal - ) +import re import re -startup_timer.record("other imports") + class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore -import signal +import warnings - server_name = cmd_opts.server_name - +import os import signal - +import os -logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... import os - - def fix_asyncio_event_loop_policy(): - +import os """ - +import os The default `asyncio` event loop policy only automatically creates - +import os event loops in the main threads. Other threads must create event - +import os loops explicitly or `asyncio.get_event_loop` (and therefore - ) - - - +import os `.IOLoop.current`) will fail. Installing this policy allows event - loops to be created automatically on any thread, matching the -import importlib import warnings - behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2). - if shared.opts.clean_temp_dir_at_start: -logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) +import warnings import os - startup_timer.record("cleanup temp dir") - - modules.script_callbacks.before_ui_callback() - startup_timer.record("scripts before_ui_callback") - - shared.demo = modules.ui.create_ui() - startup_timer.record("create ui") - - +import os import re -import warnings -from modules import timer - gradio_auth_creds = list(get_gradio_auth_creds()) or None - - app, local_url, share_url = shared.demo.launch( - share=cmd_opts.share, - server_name=server_name, - try: ssl_keyfile=cmd_opts.tls_keyfile, ssl_certfile=cmd_opts.tls_certfile, @@ -500,6 +162,7 @@ ssl_verify=cmd_opts.disable_tls_verify, debug=cmd_opts.gradio_debug, auth=gradio_auth_creds, import os +import re from __future__ import annotations prevent_thread_lock=True, allowed_paths=cmd_opts.gradio_allowed_path, @@ -511,9 +174,6 @@ root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "", ) startup_timer.record("launcher") - cmd_opts.autolaunch = False - -startup_timer.record("launcher") # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for @@ -523,13 +183,13 @@ # running its code. We disable this here. Suggested by RyotaK. app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] import os -from __future__ import annotations import re + import os -log_level = os.environ.get("SD_WEBUI_LOG_LEVEL") + if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): import os -if log_level: + # "Any thread" and "selector" should be orthogonal, but there's not a clean if launch_api: create_api(app) @@ -540,7 +200,7 @@ startup_timer.record("add APIs") with startup_timer.subcategory("app_started_callback"): import os - ) + # interface for composing policies so pick the right base. timer.startup_record = startup_timer.dump() print(f"Startup time: {startup_timer.summary()}.") @@ -564,21 +224,27 @@ shared.demo.close() break import os + _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore + os.environ.setdefault('SD_WEBUI_RESTARTING', '1') + +import os # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors shared.demo.close() time.sleep(0.5) startup_timer.reset() import os -from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states + _BasePolicy = asyncio.DefaultEventLoopPolicy startup_timer.record("app reload callback") import os -import modules.face_restoration + class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore startup_timer.record("scripts unloaded callback") import os -import modules.img2img + """Event loop policy that allows loop creation on any thread. if __name__ == "__main__": + from modules.shared_cmd_options import cmd_opts + if cmd_opts.nowebui: api_only() else: diff --git a/webui.sh b/webui.sh index cb8b9d14db5e5cf61c62f292c5cb3c28a916ea49..3d0f87eed741f82091175cce9ce4d644e9b1c130 100755 --- a/webui.sh +++ b/webui.sh @@ -141,9 +141,10 @@ ;; *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 ;; *"Navi 3"*) [[ -z "${TORCH_COMMAND}" ]] && \ -fi #!/usr/bin/env bash + f) can_run_as_root=1;; - # Navi 3 needs at least 5.5 which is only on the nightly chain + # Navi 3 needs at least 5.5 which is only on the nightly chain, previous versions are no longer online (torch==2.1.0.dev-20230614+rocm5.5 torchvision==0.16.0.dev-20230614+rocm5.5 torchaudio==2.1.0.dev-20230614+rocm5.5) + # so switch to nightly rocm5.6 without explicit versions this time ;; *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 printf "\n%s\n" "${delimiter}" @@ -247,7 +248,7 @@ printf "Launching launch.py..." printf "\n%s\n" "${delimiter}" prepare_tcmalloc then -# Please do not make any changes to this file, # +fi fi if [[ ! -f tmp/restart ]]; then