AI Часть 3: Diffusion Transformer (DiT) — Stable Diffusion 3 как она есть

AI

Редактор
Регистрация
23 Август 2023
Сообщения
2 819
Лучшие ответы
0
Реакции
0
Баллы
51
Offline
#1
Обо мне


Привет, меня зовут Василий Техин. В первой статье мы разобрали ResNet, во второй — ViT. Теперь погрузимся в мир генерации изображений с Diffusion Transformer (DiT) — сердцем Stable Diffusion 3.

Пролог: От распознавания к созданию


Представьте нейросеть как художника. Раньше она только анализировала картины ("Это Ван Гог!"). Теперь она создаёт шедевры в стиле Ван Гога и не только!


Изображения из статьи

Ключевые этапы работы DiT:


  1. Обучение:

    • Сжимаем изображение в латентное пространство через VAE (256х256х3 → 32х32х4)


    • Добавляем шум за 1000 шагов (чтобы модель училась удалять шум постепенно)


    • DiT учится предсказывать шум на каждом шаге

  2. Генерация (инференс):

    • Начинаем с чистого шума


    • Постепенно удаляем шум за 1000 шагов


    • Декодируем результат через VAE
Пайплайн обучения и генерации

1. Подготовка данных (VAE)


VAE (Variational Autoencoder) сжимает изображение:

# Для изображения 256x256:
original = (3, 256, 256) → latent = (4, 32, 32) # Сжатие в 64 раза


Зачем? DiT работает с 32×32×4 латентными векторами — экономия вычислений!

2. Прямой процесс (добавление шума)


Процесс зашумления

1000 шагов постепенного зашумления по формуле:

def forward_diffusion(z0, t, T=1000):
alpha_t = cos((t/T + 0.008) / 1.008 * π/2)**2
noise = torch.randn_like(z0) # Случайный шум
z_t = sqrt(alpha_t) * z0 + sqrt(1-alpha_t) * noise # Зашумленная версия
return z_t, noise


Где:


  • z0 — исходный латентный вектор изображения


  • t — текущий шаг (1-1000)


  • noise — добавленный шум
3. Обратный процесс (обучение DiT)


Ключевые шаги обучения:


  1. Выбираем случайный шаг t (1-1000)


  2. Зашумляем латентный вектор: z_t, real_noise = forward_diffusion(z0, t)


  3. Подаем в DiT: pred_noise = DiT(z_t, t, text_embed) и получаем предсказанный шум


  4. Считаем MSE-лосс: loss = (real_noise - pred_noise).square().mean()


  5. Обновляем веса через backpropagation

Обратите внимание: DiT учится предсказывать оригинальный шум, а не изображение!

4. Генерация изображений (инференс)


Пошаговый процесс для Stable Diffusion 3(отличается от DiT из оригинальной статьи тем, что подается эмбеддинг текста вместо метки класса):

def generate(prompt, steps=1000):
# 1. Текстовый эмбеддинг
text_embed = text_encoder(prompt) # [1, 768]

# 2. Начальный шум
z = torch.randn(1, 4, 32, 32) # z_T

# 3. Итеративное удаление шума
for t in range(steps, 0, -1):
# a) Предсказание шума DiT
pred_noise = DiT(z, t, text_embed)

# b) Classifier-Free Guidance (CFG) - усиление текстового влияния
if cfg_scale > 1.0:
uncond_embed = text_encoder("") # Пустой промпт
uncond_noise = DiT(z, t, uncond_embed)
pred_noise = uncond_noise + cfg_scale * (pred_noise - uncond_noise)

# c) Формула обратного шага (DDIM)
alpha_t = cos((t/steps + 0.008)/1.008 * π/2)**2
alpha_prev = cos(((t-1)/steps + 0.008)/1.008 * π/2)**2
z = (z - (1 - alpha_t)/sqrt(1 - alpha_t) * pred_noise) / sqrt(alpha_t)
z += sqrt(1 - alpha_prev) * torch.randn_like(z) # Стохастичность

# 4. Декодирование через VAE
return VAE.decode(z) # [1, 3, 256, 256]

DiT в деталях: Отличия от ViT

1. Patchify: Работа с латентами


Мы нарезаем на патчи не оригинальное изображение, а латентный вектор

# Для латента 32x32x4 с патчами 2x2:
self.patch_embed = nn.Conv2d(4, dim, kernel_size=2, stride=2)
# → [batch, 256, dim] (16*16=256 патчей)


Сравнение с ViT: ViT работает с пикселями, DiT — с латентными векторами.

2. Classifier-Free Guidance (CFG)


Механизм усиления текста мы хотим, чтобы изображение из шума соответсвовало тексту, который мы передали:

pred_noise = uncond_noise + guidance_scale * (text_noise - uncond_noise)


Где:


  • uncond_noise — предсказание для пустого промпта


  • text_noise — предсказание для целевого промпта


  • guidance_scale (7-10) — сила влияния текста
3. Cross-Attention Block


В SD3 (не в оригинальном DiT):

class CrossAttentionBlock(nn.Module):
def forward(self, x, text_emb):
# Проекция текста
q = self.wq(x) # [batch, tokens, dim]
k = self.wk(text_emb) # [batch, text_tokens, dim]
v = self.wv(text_emb)

# Attention
attn = softmax(q @ k.transpose(-2,-1) / sqrt(dim))
return attn @ v # Текст-условные признаки


Зачем? Точнее связывает текст и визуальные патчи.

4. In-Context Conditioning


Механизм в DiT-XL:


  • Ввод текстовых токенов как патчей


  • Пример: [IMAGE_PATCH1, TEXT_TOKEN1, IMAGE_PATCH2, ...]


  • Позволяет смешивать текст и изображение на входе
5. AdaLN-Zero


Улучшение в DiT-2:


  • Инициализация параметров γ в AdaLN нулями


  • Первые шаги обучения: AdaLN = Identity Function


  • Стабилизирует раннее обучение
Разберём на простом примере

Оригинальный DiT (класс-условный)


Uncurated 512 × 512 DiT-XL/2 samples. Classifier-free guidance scale = 2.0 Class label = “panda” (388)

# Генерация "собаки" (класс 207)
class_label = 207
z = torch.randn(1, 4, 32, 32) # Начальный шум

for t in range(1000, 0, -1):
pred_noise = DiT(z, t, class_label) # Прямой вызов
z = update_step(z, pred_noise, t) # Обновление латента


Особенности:


  • Простой ввод класса вместо текста


  • Нет CFG и Cross-Attention
Stable Diffusion 3 (текст-условный)


prompt = "пудель в розовой шапочке"
text_embed = text_encoder(prompt) # [1, 768]

for t in range(1000, 0, -1):
# 1. Основное предсказание
pred_noise = DiT(z, t, text_embed)

# 2. Classifier-Free Guidance (усиление текста)
uncond_noise = DiT(z, t, text_encoder(""))
pred_noise = uncond_noise + 7.5 * (pred_noise - uncond_noise)

# 3. Cross-Attention в некоторых блоках
# (см. архитектуру ниже)


Нововведения SD3(относительно DiT):


  • Текст через T5 вместо классов


  • CFG с масштабом 7.5 для точного следования промпту
Оценка качества: Метрики

1. FID (Fréchet Inception Distance)


Как работает:


  1. Берем 50k реальных и 50k сгенерированных изображений


  2. Пропускаем через Inception-v3 (получаем признаки)


  3. Считаем "расстояние" между распределениями:

FID = ||μ_real - μ_gen||^2 + Tr(Σ_real + Σ_gen - 2(Σ_real Σ_gen)^{1/2})


Интерпретация:


  • FID = 0 — идеальное совпадение


  • FID < 5 — фотореалистичные изображения


  • DiT-XL: FID = 2.27 (ImageNet 256x256)
2. IS (Inception Score)


IS = exp(E_x[KL(p(y|x) || p(y))])


Где:


  • p(y|x) — распределение классов для изображения


  • p(y) — общее распределение классов


  • Высокий IS = разнообразные и узнаваемые изображения
Почему DiT — это будущее Stable Diffusion?

✅ Преимущества перед U-Net:

Параметр

U-Net (SD 2.1)

DiT (SD 3)

Качество (FID)​

3.85​

2.27

Масштабируемость​

Ограничена​

Линейный рост​

Разрешение​

768x768​

1024x1024

Текстовая привязка​

Средняя​

Точная


❌ Ограничения:


  1. Ресурсы: Обучение DiT-XL требует 500,000 GPU-hours


  2. Память: Генерация 1024px требует 48GB VRAM
Философский итог


DiT объединяет три революции ИИ:


  1. Сжатие данных (VAE)


  2. Трансформеры (ViT)


  3. Диффузионные процессы
Проверь себя


  1. Почему DiT работает с 32x32, а не 256x256?


  2. Как Classifier-Free Guidance улучшает генерацию?
Резюме


Diffusion Transformer (DiT):


  • Работает в латентном пространстве VAE (32x32x4)


  • Заменяет U-Net на трансформер с AdaLN


  • Оригинал: класс-условная генерация

Stable Diffusion 3:


  • Текст через текст энкодер и Cross-Attention


  • Classifier-Free Guidance для точности


  • Поддержка 1024px изображений


  • FID 2.27 — новый стандарт качества

Ссылки:


  1. Оригинальная статья DiT


  2. Stable Diffusion 3


  3. CFG в диффузионных моделях
 
Сверху Снизу