Diffusion扩散模型学习3——Stable Diffusion结构解析-以图像生成图像(图

[复制链接]
查看1005 | 回复0 | 2023-8-10 17:01:29 来自手机 | 显示全部楼层 |阅读模式
学习前言

用了很久的Stable Diffusion,但从来没有好好解析过它内部的结构,写个博客记录一下,嘿嘿。

源码下载地址

https://github.com/bubbliiiing/stable-diffusion
喜欢的可以点个star噢。
网络构建

一、什么是Stable Diffusion(SD)

Stable Diffusion是比较新的一个扩散模型,翻译过来是稳定扩散,虽然名字叫稳定扩散,但实际上换个seed生成的结果就完全不一样,非常不稳定哈。
Stable Diffusion最开始的应用应该是文本生成图像,即文生图,随着技术的发展Stable Diffusion不仅支持image2image图生图的生成,还支持ControlNet等各种控制方法来定制生成的图像。
Stable Diffusion基于扩散模型,所以不免包含不断去噪的过程,如果是图生图的话,还有不断加噪的过程,此时离不开DDPM那张老图,如下:

Stable Diffusion相比于DDPM,使用了DDIM采样器,使用了隐空间的扩散,另外使用了非常大的LAION-5B数据集进行预训练。
直接Finetune Stable Diffusion大多数同学应该是无法cover住成本的,不过Stable Diffusion有很多轻量Finetune的方案,比如Lora、Textual Inversion等,但这是后话。
本文主要是解析一下整个SD模型的结构组成,一次扩散,多次扩散的流程。
大模型、AIGC是当前行业的趋势,不会的话容易被淘汰,hh。
txt2img的原理如博文
Diffusion扩散模型学习2——Stable Diffusion结构解析-以文本生成图像(txt2img)为例
所示。
二、Stable Diffusion的组成

Stable Diffusion由四大部分组成。
1、Sampler采样器。
2、Variational Autoencoder (VAE) 变分自编码器。
3、UNet 主网络,噪声预测器。
4、CLIPEmbedder文本编码器。
每一部分都很重要,我们以图像生成图像为例进行解析。既然是图像生成图像,那么我们的输入有两个,一个是文本,另外一个是图片。
三、img2img生成流程


生成流程分为四个部分:
1、对图片进行VAE编码,根据denoise数值进行加噪声。
2、Prompt文本编码。
3、根据denoise数值进行若干次采样。
4、使用VAE进行解码。
相比于文生图,图生图的输入发生了变化,不再以Gaussian noise作为初始化,而是以加噪后的图像特征为初始化这样便以图像的方式为模型注入了信息。
详细来讲,如上图所示:


  • 第一步为对输入的图像利用VAE编码,获得输入图像的Latent特征;然后使用该Latent特征基于DDIM Sampler进行加噪,此时获得输入图片加噪后的特征。假设我们设置denoise数值为0.8,总步数为20步,那么第一步中,我们会对输入图片进行0.8x20次的加噪声,剩下4步不加,可理解为打乱了80%的特征,保留20%的特征。
  • 第二步是对输入的文本进行编码,获得文本特征;
  • 第三步是根据denoise数值对 第一步中获得的 加噪后的特征 进行若干次采样。还是以第一步中denoise数值为0.8为例,我们只加了0.8x20次噪声那么我们也只需要进行0.8x20次采样就可以恢复出图片了
  • 第四步是将采样后的图片利用VAE的Decoder进行恢复。
  1. with torch.no_grad():
  2.     if seed == -1:
  3.         seed = random.randint(0, 65535)
  4.     seed_everything(seed)
  5.     # ----------------------- #
  6.     #   对输入图片进行编码并加噪
  7.     # ----------------------- #
  8.     if image_path is not None:
  9.         img = HWC3(np.array(img, np.uint8))
  10.         img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
  11.         img = torch.stack([img for _ in range(num_samples)], dim=0)
  12.         img = einops.rearrange(img, 'b h w c -> b c h w').clone()
  13.         ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
  14.         t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
  15.         z = model.get_first_stage_encoding(model.encode_first_stage(img))
  16.         z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
  17.     # ----------------------- #
  18.     #   获得编码后的prompt
  19.     # ----------------------- #
  20.     cond    = {"c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)]}
  21.     un_cond = {"c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)]}
  22.     H, W    = input_shape
  23.     shape   = (4, H // 8, W // 8)
  24.     if image_path is not None:
  25.         samples = ddim_sampler.decode(z_enc, cond, t_enc, unconditional_guidance_scale=scale, unconditional_conditioning=un_cond)
  26.     else:
  27.         # ----------------------- #
  28.         #   进行采样
  29.         # ----------------------- #
  30.         samples, intermediates = ddim_sampler.sample(ddim_steps, num_samples,
  31.                                                         shape, cond, verbose=False, eta=eta,
  32.                                                         unconditional_guidance_scale=scale,
  33.                                                         unconditional_conditioning=un_cond)
  34.     # ----------------------- #
  35.     #   进行解码
  36.     # ----------------------- #
  37.     x_samples = model.decode_first_stage(samples)
  38.     x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
复制代码
1、输入图片编码


在图生图中,我们首先要指定一张参考的图像,然后在这个参考图像上开始工作:
1、利用VAE编码器对这张参考图像进行编码,使其进入隐空间,只有进入了隐空间,网络才知道这个图像是什么
2、然后使用该Latent特征基于DDIM Sampler进行加噪,此时获得输入图片加噪后的特征。加噪的逻辑如下:


  • denoise可认为是重建的比例,1代表全部重建,0代表不重建;
  • 假设我们设置denoise数值为0.8,总步数为20步;我们会对输入图片进行0.8x20次的加噪声,剩下4步不加,可理解为80%的特征,保留20%的特征;不过就算加完20步噪声原始输入图片的信息还是有一点保留的,不是完全不保留。
此时我们便获得在隐空间加噪后的图像,后续会在这个 隐空间加噪后的图像 的基础上进行采样。
  1. with torch.no_grad():
  2.     if seed == -1:
  3.         seed = random.randint(0, 65535)
  4.     seed_everything(seed)
  5.     # ----------------------- #
  6.     #   对输入图片进行编码并加噪
  7.     # ----------------------- #
  8.     if image_path is not None:
  9.         img = HWC3(np.array(img, np.uint8))
  10.         img = torch.from_numpy(img.copy()).float().cuda() / 127.0 - 1.0
  11.         img = torch.stack([img for _ in range(num_samples)], dim=0)
  12.         img = einops.rearrange(img, 'b h w c -> b c h w').clone()
  13.         ddim_sampler.make_schedule(ddim_steps, ddim_eta=eta, verbose=True)
  14.         t_enc = min(int(denoise_strength * ddim_steps), ddim_steps - 1)
  15.         z = model.get_first_stage_encoding(model.encode_first_stage(img))
  16.         z_enc = ddim_sampler.stochastic_encode(z, torch.tensor([t_enc] * num_samples).to(model.device))
复制代码
2、文本编码


文本编码的思路比较简单,直接使用CLIP的文本编码器进行编码就可以了,在代码中定义了一个FrozenCLIPEmbedder类别,使用了transformers库的CLIPTokenizer和CLIPTextModel。
在前传过程中,我们对输入进来的文本首先利用CLIPTokenizer进行编码,然后使用CLIPTextModel进行特征提取,通过FrozenCLIPEmbedder,我们可以获得一个[batch_size, 77, 768]的特征向量。
[code]class FrozenCLIPEmbedder(AbstractEncoder):    """Uses the CLIP transformer encoder for text (from huggingface)"""    LAYERS = [        "last",        "pooled",        "hidden"    ]    def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77,                 freeze=True, layer="last", layer_idx=None):  # clip-vit-base-patch32        super().__init__()        assert layer in self.LAYERS        # 定义文本的tokenizer和transformer        self.tokenizer      = CLIPTokenizer.from_pretrained(version)        self.transformer    = CLIPTextModel.from_pretrained(version)        self.device         = device        self.max_length     = max_length        # 冻结模型参数        if freeze:            self.freeze()        self.layer = layer        self.layer_idx = layer_idx        if layer == "hidden":            assert layer_idx is not None            assert 0

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则