Примечание: Этот проект НЕ требует постоянных обновлений кода.
Статус Проекта: 🟩 Актуален.
Целью данной работы является реализация взаимодействия с моделью генеративно-состязательной нейросети (GAN2) через веб-сайт. Данная модель (её Generator) уже предварительно обучена.
Работа состоит из двух простых страниц. Подробнее о них ниже.
На главной странице есть форма, состоящая из одного поля и кнопки submit. Пользователь вводит целое число — количество изображений, которое требуется сгенерировать. Если отправить пустую форму, то сгенерируется 1 изображение.
Для этого поля есть следующие ограничения:
- число изображений не может быть меньше 1;
- число изображений не может быть больше 100.
Результат генерации изображений:
Здесь находится небольшая информация о том, что из себя представляет эта работа.
Внешний вид страницы:
Генеративно-состязательная нейросеть (Generative adversarial network, GAN) — архитектура, состоящая из генератора и дискриминатора, настроенных на работу друг против друга. Отсюда GAN и получила название генеративно-созтязательная.
Как работает генеративно-состязательная нейросеть?
- Одна нейронная сеть, называемая генератором, генерирует новые экземпляры данных, а другая — дискриминатор, оценивает их на подлинность; т.е. дискриминатор решает, относится ли каждый экземпляр данных, который он рассматривает, к набору тренировочных данных или нет.
Дискриминатор: Дискриминационные алгоритмы пытаются классифицировать входные данные. Учитывая особенности полученных данных, они стараются определить категорию, к которой они относятся.
Генератор: Генеративные алгоритмы заняты обратным. Вместо того, чтобы предсказывать категорию по имеющимся образам, они пытаются подобрать образы к данной категории.
Шаги, которые проходит GAN:
- Генератор получает рандомное число и возвращает изображение.
- Это сгенерированное изображение подается в дискриминатор наряду с потоком изображений, взятых из фактического набора данных.
- Дискриминатор принимает как реальные, так и поддельные изображения и возвращает вероятности, числа от 0 до 1, причем 1 представляет собой подлинное изображение и 0 представляет фальшивое.
🤢 GAN требуют много времени на тренировку. На одном GPU тренировка может занимать часы, а на одном CPU — более одного дня.
Примечание: данная модель не является достаточно хорошей. На данный момент нет более хорошей альтернативы.
Генерация изображений происходит с помощью файла "generator.pth", который представляет собой обученную модель генеративно-состязательной нейросети (GAN2) — это нейросеть для генератора. Скачать этот файл можно здесь: https://drive.google.com/drive/folders/1Uuw-Kl-AzuB2i6lJpQUnIBrjapzghwiY?usp=sharing
Архив с данными для обучения я скачал с Kaggle (https://www.kaggle.com/datasets/kimbosoek/cosmos-images). Его нужно скачать себе на компьютер. Также этот архив можно скачать, используя ссылку на Google Drive выше.
❗ Вот код, который я использовал для обучения GAN2 (сделано в Google Colab):
Для Colab:
!pip install torch torchvision numpy matplotlib scipy pytorch-fid scipy
Импорты:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import torch.nn.functional as F
from google.colab import files
import zipfile
import os
from PIL import Image
from pytorch_fid import fid_score
import matplotlib.pyplot as plt
from scipy.linalg import sqrtm
Загрузка архива (НУЖНО ВЫБРАТЬ СКАЧЕННЫЙ ФАЙЛ НА КОМПЬЮТЕРЕ):
uploaded = files.upload()
archive_name = list(uploaded.keys())[0]
with zipfile.ZipFile(archive_name, 'r') as zip_ref:
zip_ref.extractall('/content/dataset')
print(os.listdir('/content/dataset'))
Подготовка dataset и dataloader:
class CustomDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_paths = [os.path.join(root_dir, fname) for fname in os.listdir(root_dir) if fname.endswith('.jpg')]
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
try:
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image
except Exception as e:
print(f"Error loading image {img_path}: {e}")
return None
# Data transformations
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load dataset
dataset = CustomDataset(root_dir='/content/dataset/data/img_align_celeba', transform=transform)
filtered_dataset = [data for data in dataset if data is not None]
dataloader = DataLoader(filtered_dataset, batch_size=32, shuffle=True)
Архитектура моделей для GAN2 (и другие важные для дальнейшего обучения вещи):
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(100, 256),
nn.BatchNorm1d(256),
nn.ReLU(True),
nn.Linear(256, 512),
nn.BatchNorm1d(512),
nn.ReLU(True),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(True),
nn.Linear(1024, 128*128*3),
nn.Tanh()
)
def forward(self, input):
return self.main(input).view(-1, 3, 128, 128)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(128*128*3, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input.view(-1, 128*128*3))
# Создание директории для хранения сгенерированных изображений
os.makedirs('/content/generated_images', exist_ok=True)
# Создание директории для хранения losses
os.makedirs('/content/G_D_losses', exist_ok=True)
# Initialize models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# Loss and optimizers
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Learning rate schedulers
scheduler_g = optim.lr_scheduler.StepLR(optimizer_g, step_size=30, gamma=0.1)
scheduler_d = optim.lr_scheduler.StepLR(optimizer_d, step_size=30, gamma=0.1)
Функция для сохранения графиков изменения loss:
def save_losses(g_losses, d_losses, epoch):
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(g_losses, label="G")
plt.plot(d_losses, label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig(f"/content/G_D_losses/losses_epoch_{epoch}.png")
plt.close()
Обучение генератора и дискриминатора:
g_losses = []
d_losses = []
DL_size = len(dataloader);
num_epochs = 50
latent_dim = 100
for epoch in range(num_epochs):
for i, imgs in enumerate(dataloader):
batch_size = imgs.size(0)
real_labels = torch.full((batch_size, 1), 0.9, device=device) # Label smoothing для реальных данных
fake_labels = torch.full((batch_size, 1), 0.1, device=device) # Label smoothing для фейковых данных
real_imgs = imgs.to(device)
# Train discriminator
optimizer_d.zero_grad()
z = torch.randn(batch_size, latent_dim).to(device)
generated_imgs = generator(z)
real_loss = criterion(discriminator(real_imgs), real_labels)
fake_loss = criterion(discriminator(generated_imgs.detach()), fake_labels)
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_d.step()
# Train generator
optimizer_g.zero_grad()
fake_labels.fill_(0.9) # Fake labels are real for generator cost
g_loss = criterion(discriminator(generated_imgs), fake_labels)
g_loss.backward()
optimizer_g.step()
g_losses.append(g_loss.item())
d_losses.append(d_loss.item())
if i == 0 or i % (DL_size-1) == 0:
print(f"Epoch [{epoch}/{num_epochs}] Batch {i}/{len(dataloader)-1} \
Loss D: {d_loss.item()}, Loss G: {g_loss.item()}")
# Step learning rate schedulers
scheduler_g.step()
scheduler_d.step()
save_image(generated_imgs.data[:25], f"/content/generated_images/epoch_{epoch}.png", nrow=5, normalize=True)
save_losses(g_losses, d_losses, epoch)
Генерация изображений (проверка обученной модели генератора):
num_images = 500
latent_dim = 100
desired_height = 128
desired_width = 128
resize_transform = transforms.Resize((desired_width, desired_height))
generator.eval()
os.makedirs('/content/generated_images_testing', exist_ok=True)
with torch.no_grad():
for i in range(num_images // 32): # Пакеты по 32 изображения
z = torch.randn(32, latent_dim).to(device)
generated_imgs = generator(z).cpu()
for j, img in enumerate(generated_imgs):
resized_img = resize_transform(img)
save_image(resized_img, f'/content/generated_images_testing/img_{i*32 + j}.png', normalize=True)
Вычисление метрики "Frechet Inception Distance" (FID) для оценки качества сгенерированных изображений (НЕ РАБОТАЕТ!!!):
# Пути к директориям с реальными и сгенерированными изображениями
path_real = '/content/dataset/data/img_align_celeba'
path_generated = '/content/generated_images_testing'
# Проверка и удаление NaN и бесконечных значений
def clean_data(path_real):
for fname in os.listdir(path_real):
try:
img = Image.open(os.path.join(path_real, fname))
img.verify()
except (IOError, SyntaxError) as e:
print(f'Bad file {fname}: {e}')
os.remove(os.path.join(path_real, fname))
clean_data(path_real)
clean_data(path_generated)
# Рассчет FID
fid_value = fid_score.calculate_fid_given_paths([path_real, path_generated], batch_size=32, device='cuda', dims=2048)
print(f'FID: {fid_value}')
В проекте был использован следующий стек технологий:
Клонируем удалённый репозиторий на локальную машину:
git clone https://github.com/Nico-kun123/vue-python-gan2
Устанавливаем все необходимые компоненты, создаём виртуальную среду для Python:
cd vue-python-gan2
npm install
python3 -m venv venv
venv/Scripts/activate
pip install -r requirements.txt
Затем нужно запустить Flask-приложение с моделью нейросети:
npm run startServer
Потом нужно запустить сайт:
npm run dev
После этого нужно открыть страницу http://localhost:5173/ в браузере.
Содержание package.json
:
{
"name": "vue-python-ai",
"version": "0.0.0",
"private": true,
"author": {
"name": "Nicolay Kudryavtsev"
},
"type": "module",
"scripts": {
"dev": "vite",
"startServer": "python app.py",
"format": "prettier --write src/",
"build": "vite build",
"preview": "npm run build && vite preview"
},
"dependencies": {
"vue": "^3.4.21",
"vue-router": "^4.3.0"
},
"devDependencies": {
"@vitejs/plugin-vue": "^5.0.4",
"axios": "^1.7.2",
"prettier": "^3.2.5",
"sass": "^1.77.4",
"vite": "^5.2.8"
}
}
В проекте есть следующие разделы:
- dependencies: Это пакеты, необходимые для работы приложения.
- devDependencies: Это пакеты, которые нужны только для разработки и тестирования приложения. Они не будут включены в окончательную сборку приложения.
В проекте есть следующие скрипты (только первые 3 важны):
- "dev". Этот скрипт запускает сервер разработки Vite на локальной машине (http://localhost:5173/);
- "startServer". Этот скрипт запускает приложение Flask на локальной машине (http://127.0.0.1:5000);
- "format". Автоматически форматирует код, используя "Prettier";
- "build". Этот скрипт используется для сборки проекта для production. Он минимизирует и оптимизирует код для лучшей производительности в production;
- "preview". Этот скрипт предназначен для предварительного просмотра собранного проекта. Он запускает сервер, который позволяет увидеть, как он будет выглядеть и работать в production.
Кудрявцев Николай (Электронная почта: nicolay.kudryavtsev@gmail.com)