/Autoencoder_Latent_Space_GUI

Решение вступительной задачи для стажировки Generative Neural Networks at Computational Arts в JetBrains

Primary LanguagePython

Autoencoder_Latent_Space_GUI

Решение вступительной задачи для стажировки Generative Neural Networks at Computational Arts в JetBrains.

Ссылка на видео с примером работы:

https://youtu.be/bkHdb-MVesg

Зависимости

  • matplotlib
  • plotly
  • numpy
  • tensorflow
  • keras
  • PyQt5

Запуск

  • Windows: файл setup.bat (должен быть установлен Python 3.8)
  • Linux: файл setup.sh (должен быть установлен python3)
sudo chmod +x setup.sh
sudo ./setup.sh

Возможности

  • Кнопка Latent Space (Plotly) открывает в браузере визуализацию Latent Space с возможностью просмотра ее на каждой эпохе (с помощью слайдера)
  • Кнопка Latent Space (Pyplot) открывает окно с визуализацией Latent Space и возможностью просмотра исходного изображения и декодированного (требуется навести мышкой на желаемую точку и нажать на любую клавишу на клавиатуре). Если мышка наведена не на точку, программа автоматически находит ближайшую.
  • Кнопка Latent Space Freeroam открывает окно с визуализацией Latent Space и возможностью просмотреть декодированное изображение в любой точке пространства (действия аналогичны действиям в предыдущем окне).
  • Кнопка Make model запускает заново обучение модели, в результате чего обновляются файлы моделей .h5 и массив векторов encoder_res.npy

История создания

Был выбран датасет изображений для автоэнкодера: fashion_mnist. Причины, по которым был выбран именно этот датасет:

  • Более интересный набор изображений, чем в MNIST
  • Изображения маленького размера и монохромные - на моем компьютере GPU от AMD, поэтому скорость обучения нейронной сети невысока, так что это ключевой фактор

И энкодер, и декодер были созданы как обычные нейронные сети со слоями Input, Dense.

Энкодер:

Input(784)->Dense(256, relu)->Dense(128, relu)->Dense(64, relu)->Dense(3, linear)

Декодер:

Input(3)->Dense(64, relu)->Dense(128, relu)->Dense(256, relu)->Dense(784, sigmoid)

Оптимизатор был выбран - adam, функция потерь - binary_crossentropy, число эпох - 20, batch_size - 64.

Энкодер-декодер

Изначально была идея кодировать в двумерное пространство, однако практика показала, что при таком кодировании нейронная сеть не может корректно различить все 10 классов. Поэтому изображение кодируется в трехмерное пространство, соответственно, графики Latent space будут также трехмерными.

Сохранение модели

Стало понятно, что необходимо сохранять модель, чтобы при каждом запуске не обучать ее заново. Сделал сохранение в h5. Также, чтобы в будущем реализовать анимацию изменения Latent space, после каждой эпохи прогоняю энкодер по данным для проверки и сохраняю то, что возвращает нейронка, в файл.

Визуализация

Latent space после обучения

Сначала я попробовал визуализировать Latent Space после обучения. Также хотелось сделать так, чтобы можно было посмотреть на декодированное изображение, выбрав на графике любую точку с координатами закодированного вектора (не только те точки, что были во входных данных).

Было достаточно понятно, как это сделать с помощью Matplotlib. Вот скриншот того, как это выглядит:

Достаточно навести мышкой на желаемую чать графика и нажать на любую клавишу клавиатуры - применяется декодер на выбранный вектор и выводится результат в виде изображения.

Сравнение входных и выходных изображений

Немного модифицируем предыдущий код таким образом, чтобы при наведении мышкой на график и нажатия клавиши происходил поиск ближайшей точки среди результатов кодирования тестовых данных. Таким образом появляется возможность выводить не только результат декодера, но и то, как выглядело исходное изображение. Вот так это выглядит:

pyplot3

Plotly

График, рисуемый через matplotlib, сильно тормозит при вращении или приближении. Я стал искать другие библиотеки для отображения графики, и нашел Plotly. Посмотрим, как это выглядит там:

Отлично! Вращение и приближение отрабатывают очень плавно. К сожалению, опыта не хватило, чтобы в данном варианте также позволить пользователю смотреть на изображение по клику на график.

Анимация изменения Latent space в течение эпох

Основная работа выполнена, но хочется чего-то более интересного. Так как я сохраняю выход из энкодера на каждой эпохе, то у меня есть история изменений Latent space во время обучения. Будет классно как-то это визуализировать.

В Plotly есть слайдеры - отлично, именно это и будет полезно. Хочется слайдером двигать эпоху, а на графике будут видны изменения. Пришлось повозиться с тем, чтобы преобразовать данные в нужный для Plotly формат, но в итоге получилось: на рисунках 1 и 10 эпоха:

На вид, будто цветок раскрылся =)

Создание GUI

Стало понятно, что необходим графический интерфейс, из которого можно будет:

  • Открывать разные режимы просмотра Latent space
  • Пересоздавать модель

Для этого использовал PyQt5, в котором нарисовал простенький интерфейс.

Приятного использования!