Решение вступительной задачи для стажировки Generative Neural Networks at Computational Arts в JetBrains.
Ссылка на видео с примером работы:
- matplotlib
- plotly
- numpy
- tensorflow
- keras
- PyQt5
- Windows: файл setup.bat (должен быть установлен Python 3.8)
- Linux: файл setup.sh (должен быть установлен python3)
sudo chmod +x setup.sh
sudo ./setup.sh
- Кнопка Latent Space (Plotly) открывает в браузере визуализацию Latent Space с возможностью просмотра ее на каждой эпохе (с помощью слайдера)
- Кнопка Latent Space (Pyplot) открывает окно с визуализацией Latent Space и возможностью просмотра исходного изображения и декодированного (требуется навести мышкой на желаемую точку и нажать на любую клавишу на клавиатуре). Если мышка наведена не на точку, программа автоматически находит ближайшую.
- Кнопка Latent Space Freeroam открывает окно с визуализацией Latent Space и возможностью просмотреть декодированное изображение в любой точке пространства (действия аналогичны действиям в предыдущем окне).
- Кнопка Make model запускает заново обучение модели, в результате чего обновляются файлы моделей
.h5
и массив векторовencoder_res.npy
Был выбран датасет изображений для автоэнкодера: fashion_mnist. Причины, по которым был выбран именно этот датасет:
- Более интересный набор изображений, чем в MNIST
- Изображения маленького размера и монохромные - на моем компьютере GPU от AMD, поэтому скорость обучения нейронной сети невысока, так что это ключевой фактор
И энкодер, и декодер были созданы как обычные нейронные сети со слоями Input, Dense.
Энкодер:
Input(784)->Dense(256, relu)->Dense(128, relu)->Dense(64, relu)->Dense(3, linear)
Декодер:
Input(3)->Dense(64, relu)->Dense(128, relu)->Dense(256, relu)->Dense(784, sigmoid)
Оптимизатор был выбран - adam
, функция потерь - binary_crossentropy
, число эпох - 20, batch_size
- 64.
Изначально была идея кодировать в двумерное пространство, однако практика показала, что при таком кодировании нейронная сеть не может корректно различить все 10 классов. Поэтому изображение кодируется в трехмерное пространство, соответственно, графики Latent space будут также трехмерными.
Стало понятно, что необходимо сохранять модель, чтобы при каждом запуске не обучать ее заново. Сделал сохранение в h5. Также, чтобы в будущем реализовать анимацию изменения Latent space, после каждой эпохи прогоняю энкодер по данным для проверки и сохраняю то, что возвращает нейронка, в файл.
Сначала я попробовал визуализировать Latent Space после обучения. Также хотелось сделать так, чтобы можно было посмотреть на декодированное изображение, выбрав на графике любую точку с координатами закодированного вектора (не только те точки, что были во входных данных).
Было достаточно понятно, как это сделать с помощью Matplotlib. Вот скриншот того, как это выглядит:
Достаточно навести мышкой на желаемую чать графика и нажать на любую клавишу клавиатуры - применяется декодер на выбранный вектор и выводится результат в виде изображения.
Немного модифицируем предыдущий код таким образом, чтобы при наведении мышкой на график и нажатия клавиши происходил поиск ближайшей точки среди результатов кодирования тестовых данных. Таким образом появляется возможность выводить не только результат декодера, но и то, как выглядело исходное изображение. Вот так это выглядит:
График, рисуемый через matplotlib, сильно тормозит при вращении или приближении. Я стал искать другие библиотеки для отображения графики, и нашел Plotly. Посмотрим, как это выглядит там:
Отлично! Вращение и приближение отрабатывают очень плавно. К сожалению, опыта не хватило, чтобы в данном варианте также позволить пользователю смотреть на изображение по клику на график.
Основная работа выполнена, но хочется чего-то более интересного. Так как я сохраняю выход из энкодера на каждой эпохе, то у меня есть история изменений Latent space во время обучения. Будет классно как-то это визуализировать.
В Plotly есть слайдеры - отлично, именно это и будет полезно. Хочется слайдером двигать эпоху, а на графике будут видны изменения. Пришлось повозиться с тем, чтобы преобразовать данные в нужный для Plotly формат, но в итоге получилось: на рисунках 1 и 10 эпоха:
На вид, будто цветок раскрылся =)
Стало понятно, что необходим графический интерфейс, из которого можно будет:
- Открывать разные режимы просмотра Latent space
- Пересоздавать модель
Для этого использовал PyQt5, в котором нарисовал простенький интерфейс.
Приятного использования!