Qter 发表于 2023-12-24 12:49:19

深入浅出完整解析Stable Diffusion(SD)核心基础知识

https://zhuanlan.zhihu.com/p/632809634Rocky持续在撰写Stable Diffusion XL全方位解析文章,希望大家能多多点赞,让Rocky有更多坚持的动力:深入浅出完整解析Stable Diffusion XL(SDXL)核心基础知识627 赞同 · 103 评论文章https://pic1.zhimg.com/v2-687b4e4e6ffd58c770154894bec34720_180x120.jpg
2023.11.16最新消息,本文已经发布Stable Diffusion中VAE,U-Net和CLIP三大模型的可视化网络结构图,大家可以下载用于学习!关注Rocky的公众号WeThinkIn,后台回复:SD网络结构,即可获得SD网络结构图资源链接。2023.10.06最新消息,本文已经发布Stable Diffusion V1-2系列的完整模型微调训练教程和对应的LoRA模型训练教程,并增加对Stable Diffusion微调训练与LoRA训练的解读与思考。同时Rocky也在持续完善补充本文,后续会将Stable Diffusion各个模块的网络结构图,Stable Diffusion的从0到1训练流程,从0到1搭建Stable Diffusion进行AI绘画流程,以及Stable Diffusion各个模块的融合优化流程完整呈现在本文中,大家敬请期待,希望能多多点赞!2023.09.05最新消息:本文已经完成Stable Diffusion核心基础原理部分内容。
大家好,我是Rocky。2022年,Stable Diffusion模型横空出世,其成为AI行业从传统深度学习时代走向AIGC时代的标志性模型之一,并为工业界,投资界,学术界以及竞赛界都注入了新的AI想象空间,让AI再次性感。Stable Diffusion是计算机视觉领域的一个生成式大模型,能够进行文生图(txt2img)和图生图(img2img)等图像生成任务。与Midjourney不同的是,Stable Diffusion是一个完全开源的项目(模型,代码,训练数据,论文等),这使得其快速构建了强大繁荣的上下游生态(AI绘画社区,基于SD的自训练模型,丰富的辅助AI绘画工具与插件等),并且吸引了越来越多的AI绘画爱好者也加入其中,与AI行业从业者一起不断推动AIGC行业的发展与普惠。也正是Stable Diffusion的开源属性,繁荣的上下游生态以及各行各业AI绘画爱好者的参与,使得AI绘画火爆出圈,让大部分人都能非常容易地进行AI绘画。可以说,本次AI科技浪潮的ToC普惠在AIGC时代的早期就已经显现,这是之前的传统深度学习时代从未有过的。而这也是最让Rocky振奋的AIGC属性,让Rocky相信未来的十年会是像移动互联网时代那样,充满科技变革与机会的时代。Rocky从传统深度学习时代走来,与图像分类领域的ResNet系列,图像分割领域的U-Net系列以及目标检测领域的YOLO系列模型打过交道,Rocky相信Stable Diffusion会是图像生成领域的“YOLO”。https://pic4.zhimg.com/80/v2-e58e32d08e58f1c8f113ae10129f172b_720w.webp
Stable Diffusion生成图片因此本文中,Rocky将以社区中最为火爆的SDv1.5为例,对Stable Diffusion的各个细节做一个深入浅出的分析总结(模型结构,应用场景,性能优化,从0到1训练教程,从0到1搭建推理流程,最新SD资源汇总,相关插件工具使用等),希望我们能更好的入门Stable Diffusion及其背后的AIGC领域,在AIGC时代更好地融合和从容。1. Stable Diffusion资源分享【1】Stable Diffusion V1.5模型资源
[*]官方项目:https://huggingface.co/runwayml/stable-diffusion-v1-5
[*]SD V1.5模型权重百度云网盘:关注Rocky的公众号WeThinkIn,后台回复:SDV1.5模型,即可获得资源链接,包含SD V1.5的diffusers格式(FP32和FP16精度)模型权重,SD V1.5的safetensors格式模型权重和ckpt格式模型权重。
【2】Stable Diffusion V2.1-base(512x512)模型资源
[*]官方项目:https://huggingface.co/stabilityai/stable-diffusion-2-1-base
[*]SD V2.1-base模型权重百度云网盘:关注Rocky的公众号WeThinkIn,后台回复:SDV2.1-base模型,即可获得资源链接,包含SD V2.1-base的diffusers格式(FP32和FP16精度)模型权重,SD V2.1的safetensors格式模型权重和ckpt格式模型权重。
【3】Stable Diffusion V2.1(768x768)模型资源
[*]官方项目:https://huggingface.co/stabilityai/stable-diffusion-2-1
[*]SD V2.1模型权重百度云网盘:关注Rocky的公众号WeThinkIn,后台回复:SDV2.1模型,即可获得资源链接,包含SD V2.1的diffusers格式(FP32和FP16精度)模型权重,SD V2.1的safetensors格式模型权重和ckpt格式模型权重。
【4】Stable Diffusion Inpainting模型资源
[*]官方项目:https://huggingface.co/runwayml/stable-diffusion-inpainting
[*]Stable Diffusion Inpainting模型权重百度云网盘:关注Rocky的公众号WeThinkIn,后台回复:SDInpainting模型,即可获得资源链接,包含SDInpainting的diffusers格式(FP32和FP16精度)模型权重,SDInpainting的safetensors格式模型权重和ckpt格式模型权重。
【5】Stable Diffusion x4-Upscaler(超分)模型资源
[*]官方项目:stabilityai/stable-diffusion-x4-upscaler
[*]Stable Diffusion x4-Upscaler模型权重百度云网盘:关注Rocky的公众号WeThinkIn,后台回复:SDx4-Upscaler模型,即可获得资源链接,包含SDx4-Upscaler的diffusers格式(FP32和FP16精度)模型权重,SDx4-Upscaler的safetensors格式模型权重和ckpt格式模型权重。
【6】Stable Diffusion热门社区&&第三方模型资源
[*]https://huggingface.co/models(huggingface模型网站)
[*]https://civitai.com/(全球最全的SD模型资源库)
[*]https://www.reddit.com/r/StableDiffusion/(全球讨论最激烈的SD资讯论坛)
【7】Dall·E3同款解码器:consistency-decoderOpenAI开源的一致性解码器(consistency-decoder)。它既能用于Dall·E 3模型,同时它也支持作为Stable Diffusion V1.4/1.5的VAE模型。它能让图像生成质量更高、更稳定,比如多人脸、带文字图像以及线条控制方面。
[*]官方项目:openai/consistency-decoder
[*]consistency-decoder模型权重百度云网盘:关注Rocky的公众号WeThinkIn,后台回复:一致性解码器,即可获得资源链接,包含consistency-decoder的safetensors格式(FP32和FP16精度)模型权重。
【8】SD Turbo模型资源SD Turbo模型权重百度云网盘:关注Rocky的公众号WeThinkIn,后台回复:SDTurbo模型,即可获得资源链接,包含SD Turbo模型权重。2. 零基础深入理解Stable Diffusion核心基础原理2.1 通俗讲解Stable Diffusion模型工作流程Stable Diffusion(SD)模型是由Stability AI和LAION等公司共同开发的生成式模型,总共有1B左右的参数量,可以用于文生图,图生图,图像inpainting,ControlNet控制生成,图像超分等丰富的任务,本节中我们以文生图(txt2img)和图生图(img2img)任务展开对Stable Diffusion模型的工作流程进行通俗的讲解。文生图任务是指将一段文本输入到SD模型中,经过一定的迭代次数,SD模型输出一张符合输入文本描述的图片。比如下图中输入了“天堂,巨大的,海滩”,于是SD模型生成了一个美丽沙滩的图片。https://pic3.zhimg.com/80/v2-20425eeeeed2f7d69d51a1182255c33e_720w.webp
SD模型的文生图(txt2img)过程而图生图任务在输入本文的基础上,再输入一张图片,SD模型将根据文本的提示,将输入图片进行重绘以更加符合文本的描述。比如下图中,SD模型将“海盗船”添加在之前生成的那个美丽的沙滩图片上。https://pic4.zhimg.com/80/v2-0bab3b3c51305d9d2b9856d66f6a9807_720w.webp
SD模型的图生图(img2img)过程那么输入的人类文本信息如何成为SD模型能够理解的机器数学信息呢?很简单,我们需要给SD模型一个文本信息与机器数据信息之间互相转换的“桥梁”——CLIP Text Encoder模型。如下图所示,我们使用CLIP Text Encoder模型作为SD模型的前置模块,将输入的人类文本信息进行编码,输出特征矩阵,这个特征矩阵与文本信息相匹配,并且能够使得SD模型理解:https://pic3.zhimg.com/80/v2-6d5793d623b4a241c40774e8d8bc76d6_720w.webp
蓝色框就是CLIP Text Encoder模型,能够将输入文本信息进行编码,输出SD能够理解的特征矩阵完成对文本信息的编码后,就会输入到SD模型的“图像优化模块”中对图像的优化进行“控制”。如果是图生图任务,我们在输入文本信息的同时,还需要将原图片通过图像编码器(VAE Encoder)生成Latent Feature(隐空间特征)作为输入。如果是文生图任务,我们只需要输入文本信息,再用random函数生成一个高斯噪声矩阵作为Latent Feature的“替代”输入到SD模型的“图像优化模块”中。“图像优化模块”作为SD模型中最为重要的模块,其工作流程是什么样的呢?首先,“图像优化模块”是由一个U-Net网络和一个Schedule算法共同组成,U-Net网络负责预测噪声,不断优化生成过程,在预测噪声的同时不断注入文本语义信息。而schedule算法对每次U-Net预测的噪声进行优化处理(动态调整预测的噪声,控制U-Net预测噪声的强度),从而统筹生成过程的进度。在SD中,U-Net的迭代优化步数大概是50或者100次,在这个过程中Latent Feature的质量不断的变好(纯噪声减少,图像语义信息增加,文本语义信息增加)。整个过程如下图所示:https://pic3.zhimg.com/80/v2-6267e80bfe5730f52aa20f8f4f248672_720w.webp
U-Net网络+Schedule算法的迭代去噪过程U-Net网络和Schedule算法的工作完成以后,SD模型会将优化迭代后的Latent Feature输入到图像解码器(VAE Decoder)中,将Latent Feature重建成像素级图像。我们对比一下文生图任务中,初始Latent Feature和经过SD的“图像优化模块”处理后,再用图像解码器重建出来的图片之间的区别:https://pic2.zhimg.com/80/v2-15b2711566c14063237e1f7ec5fdc055_720w.webp
初始Latent Feature和经过SD的“图像优化模块”处理后的图像内容区别可以看到,上图左侧是初始Latent Feature经过图像解码器重建后的图片,显然是一个纯噪声图片;上图右侧是经过SD的“图像优化模块”处理后,再用图像解码器重建出来的图片,可以看到是一个张包含丰富内容信息的有效图片。我们再将U-Net网络+Schedule算法的迭代去噪过程的每一步结果都用图像解码器进行重建,我们可以直观的感受到从纯噪声到有效图片的全过程:https://pic4.zhimg.com/80/v2-f71e47876a8dccf514167c52d247980b_720w.webp
U-Net网络+Schedule算法的迭代去噪过程的每一步结果以上就是SD模型工作的完整流程,下面Rocky再将其进行总结归纳制作成完整的Stable Diffusion前向推理流程图,方便大家更好的理解SD模型的前向推理过程:https://pic1.zhimg.com/80/v2-cbc067b9d12ad2c25aff103be299bf94_720w.webp
SD模型文生图和图生图的前向推理流程图2.2 从0到1读懂Stable Diffusion模型核心基础原理在传统深度学习时代,凭借生成器与判别器对抗训练的开创性哲学思想,GAN(Generative adversarial networks)可谓是在生成式模型中一枝独秀。同样的,在AIGC时代,以SD模型为代表的扩散模型接过GAN的衣钵,在图像生成领域一路“狂飙”。与GAN等生成式模型一致的是,SD模型同样学习拟合训练集分布,并能够生成与训练集分布相似的输出结果,但与GAN相比,SD模型训练过程更稳定,而且具备更强的泛化性能。这些都归功于扩散模型中核心的前向扩散过程(forward diffusion process)和反向生成过程(reverse generation process)。在前向扩散过程中,SD模型持续对一张图像添加高斯噪声直至变成随机噪声矩阵。而在反向生成过程中,SD模型进行去噪声过程,将一个随机噪声矩阵逐渐去噪声直至生成一张图像。【1】扩散模型的基本原理在Stable Diffusion这个扩散模型中,无论是前向扩散过程还是反向生成过程都是一个参数化的马尔可夫链(Markov chain),如下图所示:https://pic3.zhimg.com/80/v2-4a4d117454ba571fa8af830b1bf4d572_720w.webp
扩散模型的前向扩散过程和反向生成过程看到这里,大家是不是感觉概念有点复杂了,don‘t worry,Rocky在本文不会讲太多复杂难懂的公式,大家只要知道Stable Diffusion模型的整个流程遵循参数化的马尔可夫链,前向扩散过程是对图像增加噪声,反向生成过程是去噪过程即可,对于面试,工业界应用,竞赛界厮杀来说,已经足够了。如果有想要深入理解扩散模型数学原理的读者,Rocky这里推荐阅读原论文:Denoising Diffusion Probabilistic ModelsRocky再从AI绘画应用角度解释一下扩散模型的基本原理,让大家能够对扩散模型有更多通俗易懂的认识:“如果从艺术和美学的角度来理解扩散模型,我们可以将其视为一种创作过程。想象这种情况,艺术家在画布的一角开始创作,颜色和形状逐渐扩散到整个画布。每一次画笔的触碰都可能对画布的其他部分产生影响,形成新的颜色和形状的组合。这就像是信息或行为在网络中的传播,每一个节点的改变都可能影响到其他节点。此外,扩散过程也可以看作是一种艺术表达。例如,抽象派艺术家可能会利用颜色和形状的扩散来表达他们的想法和感情。这种扩散过程可以看作是一种元素间的动态交互,就像在社会网络中,人们通过交流和互动来传播信息和影响他人的行为。美学角度则可以从扩散模型展现的和谐、平衡、动态等特性来解读。扩散过程中的动态平衡,反映了美学中的对称和平衡的原则。同时,扩散过程的不确定性和随机性,也反映了现代美学中对创新和突破的追求。总的来说,从艺术和美学的角度来看,扩散模型可以被理解为一种创作和表达过程,其中的元素通过互动和影响,形成一种动态的、有机的整体结构。”【2】前向扩散过程详解接下来,我们再详细分析一下前向扩散过程,其是一个不断加噪声的过程。我们举个例子,如下图所示,我们在猫的图片中多次增加高斯噪声直至图片变成随机噪音矩阵。可以看到,对于初始数据,我们设置K步的扩散步数,每一步增加一定的噪声,如果我们设置的K足够大,那么我们就能够将初始数据转化成随机噪音矩阵。https://pic3.zhimg.com/80/v2-e8ec5f55d7ea46c506f709219e9c6eb6_720w.webp
扩散模型的前向扩散过程一般来说,扩散过程是固定的,由上节中提到的Schedule算法进行统筹控制。同时扩散过程也有一个重要的性质:我们可以基于初始数据 <span class="MathJax_SVG" id="MathJax-Element-2-Frame" tabindex="0" data-mathml="X0" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">�0 和任意的扩散步数 <span class="MathJax_SVG" id="MathJax-Element-1-Frame" tabindex="0" data-mathml="Ki" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">�� ,采样得到对应的数据 <span class="MathJax_SVG" id="MathJax-Element-3-Frame" tabindex="0" data-mathml="Xi" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">�� 。【3】反向生成过程详解反向生成过程和前向扩散过程正好相反,是一个不断去噪的过程。下面是一个直观的例子,将随机高斯噪声矩阵通过扩散模型的Inference过程,预测噪声并逐步去噪,最后生成一个小别墅的有效图片。https://pic1.zhimg.com/80/v2-0066d71357c57323b870759c45b385b0_720w.webp
扩散模型的反向生成过程其中每一步预测并去除的噪声分布,都需要扩散模型在训练中学习。讲好了扩散模型的前向扩散过程和反向生成过程,他们的目的都是服务于扩散模型的训练,训练目标也非常简单:扩散模型每次预测出的噪声和每次实际加入的噪声做回归,让扩散模型能够准确的预测出每次实际加入的真实噪声。关于SD模型具体的训练过程,大家可以阅读本文2.3节中了解。【4】Latent让SD模型彻底“破圈”如果说前面讲到的扩散模型相关基础知识是为SD模型打下地基的话,引入Latent思想则让SD模型“一遇风雨便化龙”,成为了AIGC时代的图像生成式模型的佼佼者。那么Latent又是什么呢?为了Latent有如此魔力呢?首先,我们已经知道了扩散模型会设置一个迭代次数,并不会像GAN网络那样一次输入一次输出,虽然这样输出效果会更好更稳定,但是会导致生成过程非常耗时。再者,不管是训练还是前向推理,常规的扩散模型在实际像素空间进行前向扩散过程和反向生成过程,而基于Latent的扩散模型可以将这些过程压缩在低维的Latent隐空间,这样一来大大降低了显存占用和计算复杂性,这是常规扩散模型和基于Latent的扩散模型之间的主要区别,也是SD模型火爆出圈的关键一招。我们举个形象的例子理解一下,如果SD模型将输入数据压缩的倍数设为8,那么原本尺寸为<span class="MathJax_SVG" id="MathJax-Element-4-Frame" tabindex="0" data-mathml="" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">的数据就会进入<span class="MathJax_SVG" id="MathJax-Element-5-Frame" tabindex="0" data-mathml="" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">的Latent隐空间中,显存和计算量直接缩小64倍,整体效率大大提升。也正是因为这样,SD模型能够在2080Ti级别的显卡上进行AI绘画,大大推动了SD模型的普惠与生态的繁荣。到这里,大家应该对SD模型的核心基础原理有一个清晰的认识了,Rocky这里再帮大家总结一下:
[*]SD模型是生成式模型,输入可以是图片,文本以及两者的结合,输出是生成的图片。
[*]SD模型属于扩散模型,扩散模型的整体逻辑的特点是过程分步化与可迭代,这给整个生成过程引入更多约束与优化提供了可能。
[*]SD模型是基于Latent的扩散模型,将输入数据压缩到Latent隐空间中,比起常规扩散模型,大幅提高计算效率的同时,降低了显存占用,成为了SD模型破圈的关键一招。
[*]站在CTO视角,将维度拉到最高维,Rocky认为SD模型是一个优化噪声的AI艺术工具。
2.3 零基础读懂Stable Diffusion训练全过程Stable Diffusion的整个训练过程在最高维度上可以看成是如何加噪声和如何去噪声的过程,并在针对噪声的“对抗与攻防”中学习到生成图片的能力。Stable Diffusion整体的训练逻辑也非常清晰:
[*]从数据集中随机选择一个训练样本
[*]从K个噪声量级随机抽样一个timestep
[*]产生随机噪声
[*]计算当前所产生的噪声数据
[*]将噪声输入U-Net预测噪声
[*]计算产生的噪声和预测的噪声的L2损失
[*]计算梯度并更新SD模型参数
下面Rocky再对SD模型训练过程中的一些关键环节进行详细的讲解。【1】SD训练集加入噪声SD模型训练时,我们需要输入加噪的数据集,每一次迭代我们用random函数生成从强到弱各个强度的噪声,通常来说会生成0-1000一共1001种不同的噪声强度,通过Time Embedding嵌入到SD的训练过程中。下图是一个简单的加噪声流程,可以帮助大家更好地理解SD训练时数据是如何加噪声的。首先从数据集中选择一张干净样本,然后再用random函数生成0-3一共4种强度的噪声,然后每次迭代中随机一种强度的噪声,增加到干净图片上,完成图片的加噪流程。https://pic1.zhimg.com/80/v2-8cf1a63bd06f0fa7b4b05d331f663550_720w.webp
SD训练集的加噪声流程【2】SD训练中加噪与去噪具体地,在训练过程中,我们首先对干净样本进行加噪处理,采用多次逐步增加噪声的方式,直至干净样本转变成为纯噪声。https://pic4.zhimg.com/80/v2-4632cb6be013a8b0aa27e812d620cd5f_720w.webp
SD训练时的加噪过程接着,让SD模型学习去噪过程,最后抽象出一个高维函数,这个函数能在纯噪声中“优化”噪声,得到一个干净样本。其中,将去噪过程具像化,就得到使用U-Net预测噪声,并结合Schedule算法逐步去噪的过程。https://pic4.zhimg.com/80/v2-709e01c5cbf610cd85d52bd63ce0df4f_720w.webp
SD训练时的去噪过程我们可以看到,加噪和去噪过程都是逐步进行的,我们假设进行<span class="MathJax_SVG" id="MathJax-Element-6-Frame" tabindex="0" data-mathml="K" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">�步,那么每一步,SD都要去预测噪声,从而形成“小步快跑的稳定去噪”,类似于移动互联网时代的产品逻辑,这是足够伟大的关键一招。与此同时,在加噪过程中,每次增加的噪声量级可以不同,假设有5种噪声量级,那么每次都可以取一种量级的噪声,增加噪声的多样性。https://pic2.zhimg.com/80/v2-2ebe90e61b3f05839db96efa637b5b75_720w.webp
多量级噪声那么怎么让网络知道目前处于<span class="MathJax_SVG" id="MathJax-Element-7-Frame" tabindex="0" data-mathml="K" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">�的哪一步呢?SD模型其实需要K个噪声预测模型,我们可以增加一个Time Embedding(类似Positional embeddings)进行处理,通过将timestep编码进网络中,从而只需要训练一个共享的U-Net模型,就让网络知道现在处于哪一步。我们了解了训练中的加噪和去噪过程,SD训练的具体过程就是对每个加噪和去噪过程进行计算,从而优化SD模型参数,如下图所示分为四个步骤:从训练集中选取一张加噪过的图片和噪声强度(timestep),然后将其输入到U-Net中,让U-Net预测噪声(下图中的Unet Prediction),接着再计算预测噪声与真实噪声的误差(loss),最后通过反向传播更新U-Net的参数。https://pic4.zhimg.com/80/v2-086fc4db6ac0ec2a76f0e5d0c5c37b4f_720w.webp完成U-Net的训练,我们就可以用U-Net对噪声图片进行去噪,逐步重建出有效图像的Latent Feature了!在噪声图上逐步减去被U-Net预测出来的噪声,从而得到一个我们想要的高质量的图像隐特征,去噪流程如下图所示:https://pic1.zhimg.com/80/v2-ec1bdb95dd416b574bfaa9360a6a54f0_720w.webp【3】语义信息对图片生成的控制SD模型在生成图片时,需要输入prompt,那么这些语义信息是如何影响图片的生成呢?答案非常简单:注意力机制。在SD模型的训练中,每个训练样本都会对应一个标签,我们将对应标签通过CLIP Text Encoder输出Text Embeddings,并将Text Embeddings以Cross Attention的形式与U-Net结构耦合,使得每次输入的图片信息与文字信息进行融合训练,如下图所示:https://pic3.zhimg.com/80/v2-ca66e9375558b5e04cd9c76fa0f6d122_720w.webp
Noise与Text Embeddings通过CrossAttention与U-Net结构耦合上图中的token是NLP领域的一个基础概念,可以理解为最小语意单元。与之对应的分词操作为tokenization。Rocky举一个简单的例子来帮助大家理解:“WeThinkIn是伟大的自媒体”是一个句子,我们需要将其切分成一个token序列,这个操作就是tokenization。经过tokenization操作后,我们获得["WeThinkIn", "是", "伟大的", "自媒体"]这个句子的token序列,从而完成对文本信息的预处理。【4】SD模型训练时的输入有了上面的介绍,我们在这里可以小结一下SD模型训练时的输入,一共有三个部分组成:图片,文本,噪声强度。其中图片和文本是固定的,而噪声强度在每一次训练参数更新时都会随机选择一个进行叠加。https://pic3.zhimg.com/80/v2-065321f6d060da5503ff8311de1e6b5a_720w.webp
SD模型训练时需要的数据配置2.4 其他主流生成式模型介绍在AIGC时代中,虽然SD模型已经成为核心的生成式模型之一,但是曾在传统深度学习时代火爆的GAN,VAE,Flow-based model等模型也跨过周期在SD模型身边作为辅助,发挥了巨大的作用。下面是主流生成式模型各自的生成逻辑:https://pic2.zhimg.com/80/v2-5d8ffc165938db059017b96397627a65_720w.webp
生成式模型的主流架构GAN网络在AIGC时代依然发挥了巨大的作用,配合SD模型完成了很多算法工作流,比如:图像超分,脸部修复,风格迁移,图像编辑,图像fix,图像定权等。所以Rocky在这里简单讲解一下GAN的基本原理,让大家做个了解。GAN由生成器<span class="MathJax_SVG" id="MathJax-Element-8-Frame" tabindex="0" data-mathml="G" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">�和判别器<span class="MathJax_SVG" id="MathJax-Element-9-Frame" tabindex="0" data-mathml="D" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">�组成。其中,生成器主要负责生成相应的样本数据,输入一般是由高斯分布随机采样得到的噪声<span class="MathJax_SVG" id="MathJax-Element-10-Frame" tabindex="0" data-mathml="Z" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">�。而判别器的主要职责是区分生成器生成的样本与<span class="MathJax_SVG" id="MathJax-Element-12-Frame" tabindex="0" data-mathml="gt&#xFF08;GroundTruth&#xFF09;" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">()��(����������ℎ)样本,输入一般是<span class="MathJax_SVG" id="MathJax-Element-11-Frame" tabindex="0" data-mathml="gt" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">��样本与相应的生成样本,我们想要的是对<span class="MathJax_SVG" id="MathJax-Element-13-Frame" tabindex="0" data-mathml="gt" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">��样本输出的置信度越接近<span class="MathJax_SVG" id="MathJax-Element-14-Frame" tabindex="0" data-mathml="1" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">1越好,而对生成样本输出的置信度越接近<span class="MathJax_SVG" id="MathJax-Element-15-Frame" tabindex="0" data-mathml="0" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">0越好。与一般神经网络不同的是,GAN在训练时要同时训练生成器与判别器,所以其训练难度是比较大的。我们可以将GAN中的生成器比喻为印假钞票的犯罪分子,判别器则被当作警察。犯罪分子努力让印出的假钞看起来逼真,警察则不断提升对于假钞的辨识能力。二者互相博弈,随着时间的进行,都会越来越强。在图像生成任务中也是如此,生成器不断生成尽可能逼真的假图像。判别器则判断图像是<span class="MathJax_SVG" id="MathJax-Element-16-Frame" tabindex="0" data-mathml="gt" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">��图像,还是生成的图像。二者不断博弈优化,最终生成器生成的图像使得判别器完全无法判别真假。关于Flow-based models,其在AIGC时代的作用还未显现,可以持续关注。最后,VAE将在本文后面的章节详细讲解,因为正是VAE将输入数据压缩至Latent隐空间中,故其成为了SD模型的核心结构之一。3. Stable Diffusion核心网络结构解析(全网最详细)3.1 SD模型整体架构初识Stable Diffusion模型整体上是一个End-to-End模型,主要由VAE(变分自编码器,Variational Auto-Encoder),U-Net以及CLIP Text Encoder三个核心组件构成。在FP16精度下Stable Diffusion模型大小2G(FP32:4G),其中U-Net大小1.6G,VAE模型大小160M以及CLIP Text Encoder模型大小235M。其中U-Net结构包含约860M参数,以FP32精度下大小为3.4G左右。https://pic1.zhimg.com/80/v2-a643ee39e80807d6b7236d15f1c289a8_720w.webp
Stable Diffusion整体架构图3.2 VAE模型在Stable Diffusion中,VAE(变分自编码器,Variational Auto-Encoder)的Encoder(编码器)结构能将输入图像转换为低维Latent特征,并作为U-Net的输入。VAE的Decoder(解码器)结构能将低维Latent特征重建还原成像素级图像。https://pic2.zhimg.com/80/v2-ca6cb91a11a4f0694a3672f45b82b5fd_720w.webp
VAE在Stable Diffusion中的主要功能为什么VAE可以将图像压缩到一个非常小的Latent space(潜空间)后能再次对图像还原呢?因为虽然整个过程可以看作是一个有损压缩,但自然图像并不是随机的,它们具有很高的规律性:比如说一张脸上的眼睛、鼻子、脸颊和嘴巴之间遵循特定的空间关系,又比如说一只猫有四条腿,并且是一个特定的生物结构。所以如果我们生成的图像尺寸在<span class="MathJax_SVG" id="MathJax-Element-17-Frame" tabindex="0" data-mathml="512&#x00D7;512" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">512×512之上时,其实特征损失带来的影响非常小。同时如果我们在SD模型中切换不同的VAE结构,能够发现生成图片的细节与整体颜色也会随之改变。下图是Rocky梳理的Stable Diffusion VAE的完整结构图,大家可以感受一下其魅力,看着这个完整结构图学习Stable Diffusion VAE部分,相信大家脑海中的思路也会更加清晰:https://pic1.zhimg.com/80/v2-a390d53cc59c0e76b0bbc86864f226ac_720w.webp
Stable Diffusion VAE完整结构图3.3 U-Net模型在Stable Diffusion中,U-Net模型是一个关键核心,作为扩散模型其主要是预测噪声残差,并结合Sampling method(调度算法:PNDM,DDIM,K-LMS等)对输入的特征矩阵进行重构,逐步将其从随机高斯噪声转化成图片的Latent Feature。具体来说,在前向推理过程中,SD模型通过反复调用 U-Net,将预测出的噪声残差从原噪声矩阵中去除,得到逐步去噪后的图像Latent Feature,再通过VAE的Decoder结构将Latent Feature重建成像素级图像,如下图所示:
https://pica.zhimg.com/v2-189013e6a420a0f7c350ab7140bb93eb.jpg?source=25ab7b06重新播放






" class="_1rs4xzm1" style="display: block; visibility: hidden; width: auto; height: 386.4px;">https://pic1.zhimg.com/v2-afa6eca8b0c435cb2c1d2f9cf6592cd0.png?source=382ee89ahttps://pic3.zhimg.com/80/v2-8aeffb8be728db6723830640ad5dc3dd_1440w.png













Rocky再从AI绘画应用视角解释一下SD中U-Net的原理与作用。其实大家在使用Stable Diffusion WebUI时,点击Generate按钮后,页面右下角图片生成框中展示的从噪声到图片的生成过程,其中就是U-Net在不断的为大家去除噪声的过程。到这里大家应该都能比较清楚的理解U-Net的作用了。好了,我们再回到AIGC算法工程师视角。Stable Diffusion中的U-Net,在传统深度学习时代的Encoder-Decoder结构的基础上,增加了ResNetBlock(包含Time Embedding)模块,Spatial Transformer(SelfAttention + CrossAttention + FeedForward)模块以及CrossAttnDownBlock,CrossAttnUpBlock和CrossAttnMidBlock模块。那么各个模块都有什么作用呢?不着急,咱们先看看SD U-Net的整体架构(AIGC算法工程师面试核心考点)。下图是Rocky梳理的Stable Diffusion U-Net的完整结构图,大家可以感受一下其魅力,看着这个完整结构图学习Stable Diffusion U-Net部分,相信大家脑海中的思路也会更加清晰:https://pic2.zhimg.com/80/v2-8fafb5695089ea1d9fa8a5217877bd65_720w.webp
Stable Diffusion U-Net完整结构图上图中包含Stable Diffusion U-Net的十四个基本模块:
[*]GSC模块:Stable Diffusion U-Net中的最小组件之一,由GroupNorm+SiLU+Conv三者组成。
[*]DownSample模块:Stable Diffusion U-Net中的下采样组件,使用了Conv(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))进行采下采样。
[*]UpSample模块:Stable Diffusion U-Net中的上采样组件,由插值算法(nearest)+Conv组成。
[*]ResNetBlock模块:借鉴ResNet模型的“残差结构”,让网络能够构建的更深的同时,将Time Embedding信息嵌入模型。
[*]CrossAttention模块:将文本的语义信息与图像的语义信息进行Attention机制,增强输入文本Prompt对生成图片的控制。
[*]SelfAttention模块:SelfAttention模块的整体结构与CrossAttention模块相同,这是输入全部都是图像信息,不再输入文本信息。
[*]FeedForward模块:Attention机制中的经典模块,由GeGlU+Dropout+Linear组成。
[*]BasicTransformer Block模块:由LayerNorm+SelfAttention+CrossAttention+FeedForward组成,是多重Attention机制的级联,并且也借鉴ResNet模型的“残差结构”。通过加深网络和多Attention机制,大幅增强模型的学习能力与图文的匹配能力。
[*]Spatial Transformer模块:由GroupNorm+Conv+BasicTransformer Block+Conv构成,ResNet模型的“残差结构”依旧没有缺席。
[*]DownBlock模块:由两个ResNetBlock模块组成。
[*]UpBlock_X模块:由X个ResNetBlock模块和一个UpSample模块组成。
[*]CrossAttnDownBlock_X模块:是Stable Diffusion U-Net中Encoder部分的主要模块,由X个(ResNetBlock模块+Spatial Transformer模块)+DownSample模块组成。
[*]CrossAttnUpBlock_X模块:是Stable Diffusion U-Net中Decoder部分的主要模块,由X个(ResNetBlock模块+Spatial Transformer模块)+UpSample模块组成。
[*]CrossAttnMidBlock模块:是Stable Diffusion U-Net中Encoder和ecoder连接的部分,由ResNetBlock+Spatial Transformer+ResNetBlock组成。

接下来,Rocky将为大家全面分析SD模型中U-Net结构的核心知识,码字实在不易,希望大家能多多点赞,谢谢!(1)ResNetBlock模块在传统深度学习时代,ResNet的残差结构在图像分类,图像分割,目标检测等主流方向中几乎是不可或缺,其简洁稳定有效的“残差思想”终于在AIGC时代跨过周期,在SD模型的U-Net结构中继续繁荣。值得注意的是,Time Embedding正是输入到ResNetBlock模块中,为U-Net引入了时间信息(时间步长T,T的大小代表了噪声扰动的强度),模拟一个随时间变化不断增加不同强度噪声扰动的过程,让SD模型能够更好地理解时间相关性。同时,在SD模型调用U-Net重复迭代去噪的过程中,我们希望在迭代的早期,能够先生成整幅图片的轮廓与边缘特征,随着迭代的深入,再补充生成图片的高频和细节特征信息。由于在每个ResNetBlock模块中都有Time Embedding,就能告诉U-Net现在是整个迭代过程的哪一步,并及时控制U-Net够根据不同的输入特征和迭代阶段而预测不同的噪声残差。Rocky再从AI绘画应用视角解释一下Time Embedding的作用。Time Embedding能够让SD模型在生成图片时考虑时间的影响,使得生成的图片更具有故事性、情感和沉浸感等艺术效果。并且Time Embedding可以帮助SD模型在不同的时间点将生成的图片添加完善不同情感和主题的内容,从而增加了AI绘画的多样性和表现力。定义Time Embedding的代码如下所示,可以看到Time Embedding的生成方式,主要通过sin和cos函数再经过Linear层进行变换。:def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000):    half = self.channels // 2    frequencies = torch.exp(            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half      ).to(device=time_steps.device)    args = time_steps[:, None].float() * frequencies[None]    return torch.cat(, dim=-1)
讲完Time Embedding的相关核心知识,我们再学习一下ResNetBlock模块的整体知识。在上面的Stable Diffusion U-Net完整结构图中展示了完整的ResNetBlock模块,其输入包括Latent Feature和 Time Embedding。首先Latent Feature经过GSC(GroupNorm+SiLU激活函数+卷积)模块后和Time Embedding(经过SiLU激活函数+全连接层处理)做加和操作,之后再经过GSC模块和Skip Connection而来的输入Latent Feature做加和操作,进行两次特征融合后最终得到ResNetBlock模块的Latent Feature输出,增强SD模型的特征学习能力。同时,和传统深度学习时代的U-Net结构一样,Decoder结构中的ResNetBlock模块不单单要接受来自上一层的Latent Feature,还要与Encoder结构中对应层的ResNetBlock模块的输出Latent Feature进行concat操作。举个例子,如果Decoder结构中ResNetBlock Structure上一层的输出结果的尺寸为 ,Encoder结构对应 ResNetBlock Structure的输出结果的尺寸为 ,那么这个Decoder结构中ResNeBlock Structure得到的Latent Feature的尺寸为 。(2)CrossAttention模块CrossAttention模块是我们使用输入文本Prompt控制SD模型图片内容生成的关键一招。上面的Stable Diffusion U-Net完整结构图中展示了Spatial Transformer(Cross Attention)模块的结构。Spatial Transformer模块和ResNetBlock模块一样接受两个输入:一个是ResNetBlock模块的输出,另外一个是输入文本Prompt经过CLIP Text Encoder模型编码后的Context Embedding。两个输入首先经过Attention机制(将Context Embedding对应的语义信息与图片中对应的语义信息相耦合),输出新的Latent Feature,再将新输出的Latent Feature与输入的Context Embedding再做一次Attention机制,从而使得SD模型学习到了文本与图片之间的特征对应关系。Spatial Transformer模块不改变输入输出的尺寸,只在图片对应的位置上融合了语义信息,所以不管是在传统深度学习时代,还是AIGC时代,Spatial Transformer都是将本文与图像结合的一个“万金油”模块。看CrossAttention模块的结构图,大家可能会疑惑为什么Context Embedding用来生成K和V,Latent Feature用来生成Q呢?原因也非常简单:因为在Stable Diffusion中,主要的目的是想把文本信息注入到图像信息中里,所以用图片token对文本信息做 Attention实现逐步的文本特征提取和耦合。Rocky再从AI绘画应用视角解释一下CrossAttention模块的作用。CrossAttention模块在AI绘画应用中可以被视为一种连接和表达的工具,它有助于在输入文本和生成图片之间建立联系,创造更具深度和多样性的艺术作品,引发观众的思考和情感共鸣。CrossAttention模块可以将图像和文本信息关联起来,就像艺术家可以将不同的元素融合到一幅作品中,这有助于在创作中实现不同信息之间的协同和互动,产生更具创意性的艺术作品。再者CrossAttention模块可以用于将文本中的情感元素传递到生成图片中,这种情感的交互可以增强艺术作品的表现力和观众的情感共鸣。(3)BasicTransformer Block模块BasicTransformer Block模块是在CrossAttention子模块的基础上,增加了SelfAttention子模块和Feedforward子模块共同组成的,并且每个子模块都是一个残差结构,这样除了能让文本的语义信息与图像的语义信息更好的融合之外,还能通过SelfAttention机制让模型更好的学习图像数据的特征。写到这里,可能还有读者会问,Stable Diffusion U-Net中的SelfAttention到底起了什么作用呀?首先,在Stable Diffusion U-Net的SelfAttention模块中,输入只有图像信息,所以SelfAttention主要是为了让SD模型更好的学习图像数据的整体特征。再者,SelfAttention可以将输入图像的不同部分(像素或图像Patch)进行交互,从而实现特征的整合和全局上下文的引入,能够让模型建立捕捉图像全局关系的能力,有助于模型理解不同位置的像素之间的依赖关系,以更好地理解图像的语义。在此基础上,SelfAttention还能减少平移不变性问题,SelfAttention模块可以在不考虑位置的情况下捕捉特征之间的关系,因此具有一定的平移不变性。Rocky再从AI绘画应用视角解释一下SelfAttention的作用。SelfAttention模块可以让SD模型在图片生成过程中捕捉内在关系、创造性表达情感和思想、突出重要元素,并创造出丰富多彩、具有深度和层次感的艺术作品。(4)Spatial Transformer模块更进一步的,在BasicTransformer Block模块基础上,加入GroupNorm和两个卷积层就组成Spatial Transformer模块。Spatial Transformer模块是SD U-Net中的核心Base结构,Encoder中的CrossAttnDownBlock模块,Decoder中的CrossAttnUpBlock模块以及CrossAttnMidBlock模块都包含了大量的Spatial Transformer子模块。在生成式模型中,GroupNorm的效果一般会比BatchNorm更好,生成式模型通常比较复杂,因此需要更稳定和适应性强的归一化方法。而GroupNorm主要有以下一些优势,让其能够成为生成式模型的标配:1. 对训练中不同Batch-Size的适应性:在生成式模型中,通常需要使用不同的Batch-Size进行训练和微调。这会导致 BatchNorm在训练期间的不稳定性,而GroupNorm不受Batch-Size的影响,因此更适合生成式模型。2. 能适应通道数变化:GroupNorm 是一种基于通道分组的归一化方法,更适应通道数的变化,而不需要大量调整。3. 更稳定的训练:生成式模型的训练通常更具挑战性,存在训练不稳定性的问题。GroupNorm可以减轻训练过程中的梯度问题,有助于更稳定的收敛。4. 能适应不同数据分布:生成式模型通常需要处理多模态数据分布,GroupNorm 能够更好地适应不同的数据分布,因为它不像 Batch Normalization那样依赖于整个批量的统计信息。(5)CrossAttnDownBlock/CrossAttnUpBlock/CrossAttnMidBlock模块在Stable Diffusion U-Net的Encoder部分中,使用了三个CrossAttnDownBlock模块,其由ResNetBlock Structure+BasicTransformer Block+Downsample构成。Downsample通过使用一个卷积(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))来实现。在Decoder部分中,使用了三个CrossAttnUpBlock模块,其由ResNetBlock Structure+BasicTransformer Block+Upsample构成。Upsample使用插值算法+卷积来实现,插值算法将输入的Latent Feature尺寸扩大一倍,同时通过一个卷积(kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))改变Latent Feature的通道数,以便于输入后续的模块中。在CrossAttnMidBlock模块中,包含ResNetBlock Structure+BasicTransformer Block+ResNetBlock Structure,作为U-Net的Encoder与Decoder之间的媒介。(6)Stable Diffusion U-Net整体宏观角度小结从整体上看,不管是在训练过程还是前向推理过程,Stable Diffusion中的U-Net在每次循环迭代中Content Embedding部分始终保持不变,而Time Embedding每次都会发生变化。和传统深度学习时代的U-Net一样,Stable Diffusion中的U-Net也是不限制输入图片的尺寸,因为这是个基于Transformer和卷积的模型结构。到这里,Stable Diffusion U-Net的完整核心基础知识就介绍好了,欢迎大家在评论区发表自己的观点,也希望大家能多多点赞,Rocky会持续完善本文的全部内容,大家敬请期待!3.4 CLIP Text Encoder模型作为文生图模型,Stable Diffusion中的文本编码模块直接决定了语义信息的优良程度,从而影响到最后图片生成的多样性和可控性。在这里,多模态领域的神器——CLIP(Contrastive Language-Image Pre-training),跨过了周期,从传统深度学习时代进入AIGC时代,成为了SD系列模型中文本和图像之间的连接通道。并且从某种程度上讲,正是因为CLIP模型的前置出现,更加快速地推动了AI绘画领域的繁荣。那么,什么是CLIP呢?CLIP有哪些优良的性质呢?为什么是CLIP呢?首先,CLIP模型是一个基于对比学习的多模态模型,主要包含Text Encoder和Image Encoder两个模型。其中Text Encoder用来提取文本的特征,可以使用NLP中常用的text transformer模型作为Text Encoder;而Image Encoder主要用来提取图像的特征,可以使用CNN/vision transformer模型(ResNet和ViT)作为Image Encoder。与此同时,他直接使用4亿个图片与标签文本对数据集进行训练,来学习图片与本文内容的对应关系。与U-Net的Encoder和Decoder一样,CLIP的Text Encoder和Image Encoder也能非常灵活的切换;其庞大图片与标签文本数据的预训练赋予了CLIP强大的zero-shot分类能力。灵活的结构,简洁的思想,让CLIP不仅仅是个模型,也给我们一个很好的借鉴,往往伟大的产品都是大道至简的。更重要的是,CLIP把自然语言领域的抽象概念带到了计算机视觉领域。https://pic4.zhimg.com/80/v2-c876c26f91e7ed3df060c0bd2116b357_720w.webp
CLIP模型训练使用的图片-文本对数据CLIP在训练时,从训练集中随机取出一张图片和标签文本。CLIP模型的任务主要是通过Text Encoder和Image Encoder分别将标签文本和图片提取embedding向量,然后用余弦相似度(cosine similarity)来比较两个embedding向量的相似性,以判断随机抽取的标签文本和图片是否匹配,并进行梯度反向传播,不断进行优化训练。https://pic3.zhimg.com/80/v2-17b7c75d9f4a693f3711d602d8e971ca_720w.webp
CLIP模型训练示意图上面讲了Batch为1时的情况,当我们把训练的Batch提高到 <span class="MathJax_SVG" id="MathJax-Element-19-Frame" tabindex="0" data-mathml="N" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">� 时,其实整体的训练流程是不变的。只是现在CLIP模型需要将<span class="MathJax_SVG" id="MathJax-Element-18-Frame" tabindex="0" data-mathml="N" role="presentation" style="display: inline-block; font-weight: normal; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">�个标签文本和<span class="MathJax_SVG" id="MathJax-Element-22-Frame" tabindex="0" data-mathml="N" role="presentation" style="display: inline-block; font-weight: normal; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">�个图片的两两组合预测出<span class="MathJax_SVG" id="MathJax-Element-20-Frame" tabindex="0" data-mathml="N2" role="presentation" style="display: inline-block; font-weight: normal; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">�2个可能的文本-图片对的余弦相似性,即下图所示的矩阵。这里共有<span class="MathJax_SVG" id="MathJax-Element-21-Frame" tabindex="0" data-mathml="N" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">�个正样本,即真正匹配的文本和图片(矩阵中的对角线元素),而剩余的<span class="MathJax_SVG" id="MathJax-Element-24-Frame" tabindex="0" data-mathml="N2&#x2212;N" role="presentation" style="display: inline-block; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">�2−�个文本-图片对为负样本,这时CLIP模型的训练目标就是最大<span class="MathJax_SVG" id="MathJax-Element-23-Frame" tabindex="0" data-mathml="N" role="presentation" style="display: inline-block; font-weight: normal; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">�个正样本的余弦相似性,同时最小化<span class="MathJax_SVG" id="MathJax-Element-25-Frame" tabindex="0" data-mathml="N2&#x2212;N" role="presentation" style="display: inline-block; font-weight: normal; line-height: normal; font-size: 16px; word-spacing: normal; overflow-wrap: normal; text-wrap: nowrap; float: none; direction: ltr; max-width: none; max-height: none; min-width: 0px; min-height: 0px; border: 0px; position: relative;">�2−�个负样本的余弦相似性。https://pic4.zhimg.com/80/v2-6fcd9e16204fd7a457b61adada425883_720w.webp
Batch为N时的CLIP训练示意图完成CLIP的训练后,输入配对的图片和标签文本,则Text Encoder和Image Encoder可以输出相似的embedding向量,计算余弦相似度就可以得到接近1的结果。同时对于不匹配的图片和标签文本,输出的embedding向量计算余弦相似度则会接近0。就这样,CLIP成为了计算机视觉和自然语言处理这两大AI方向的“桥梁”,AI领域的多模态应用有了经典的基石模型。上面我们讲到CLIP模型主要包含Text Encoder和Image Encoder两个模型,在Stable Diffusion中主要使用了Text Encoder模型。CLIP Text Encoder模型将输入的文本Prompt进行编码,转换成Text Embeddings(文本的语义信息),通过前面一章节提到的U-Net网络中的CrossAttention模块嵌入Stable Diffusion中作为Condition,对生成图像的内容进行一定程度上的控制与引导,目前SD模型使用的的是CLIP ViT-L/14中的Text Encoder模型。CLIP ViT-L/14 中的Text Encoder是只包含Transformer结构的模型,一共由12个CLIPEncoderLayer模块组成,模型参数大小是123M,具体CLIP Text Encoder模型结构如下图所示。其中特征维度为768,token数量是77,所以输出的Text Embeddings的维度为77x768。CLIPEncoderLayer(    (self_attn): CLIPAttention(      (k_proj): Linear(in_features=768, out_features=768, bias=True)      (v_proj): Linear(in_features=768, out_features=768, bias=True)      (q_proj): Linear(in_features=768, out_features=768, bias=True)      (out_proj): Linear(in_features=768, out_features=768, bias=True)      )    (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)    (mlp): CLIPMLP(      (activation_fn): QuickGELUActivation()      (fc1): Linear(in_features=768, out_features=3072, bias=True)            (fc2): Linear(in_features=3072, out_features=768, bias=True)          )          (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)      )
下图是Rocky梳理的Stable Diffusion CLIP Encoder的完整结构图,大家可以感受一下其魅力,看着这个完整结构图学习Stable Diffusion CLIP Encoder部分,相信大家脑海中的思路也会更加清晰:https://pic3.zhimg.com/80/v2-46fcafb5a14d108cd29d2751e453a142_720w.webp下面Rocky将使用transofmers库演示调用CLIP Text Encoder,给大家一个更加直观的SD模型的文本编码全过



页: [1]
查看完整版本: 深入浅出完整解析Stable Diffusion(SD)核心基础知识