Часть III • Продвинутые возможности глубокого обучения с TensorFlow.js
В процессе обучения по умолчанию используется tfjsnode. Впрочем, как и в пре
дыдущих примерах сверточных сетей, скорость обучения можно значительно по
высить за счет использования tfjsnodegpu. Для этого при наличии настроенного
должным образом GPU с поддержкой CUDA достаточно добавить в команду
yarn
train
флаг
--gpu
. Обучение ACGAN займет не менее нескольких часов. Для мони
торинга хода выполнения этого «долгоиграющего» задания можно воспользоваться
TensorBoard, добавив флаг
--logDir
:
yarn train --logDir /tmp/mnist-acgan-logs
После открытия TensorBoard в отдельном терминале с помощью следующей
команды:
tensorboard --logdir /tmp/mnist-acgan-logs
можно перейти по URL TensorBoard (выводится в консоль серверным процессом
TensorBoard) в браузере и посмотреть на кривые потерь. На рис. 10.15 приведены
примеры кривых потерь процесса обучения. Одна из характерных особенностей кри
вых потерь обучения GAN — они не обязательно стремятся вниз, подобно кривым
потерь большинства других типов нейронных сетей. Как потери для дискриминатора
(dLoss на рисунке), так и потери для генератора (gLoss на рисунке) меняются не
монотонно и замысловато отражают друг друга.
Ни одна из функций потерь не приближается к нулю ближе к концу обучения,
они просто устанавливаются на определенном уровне (сходятся). На этом этапе про
цесс обучения завершается, относящаяся к генератору часть модели сохраняется на
диск для использования на шаге генерации в браузере:
await generator.save(saveURL);
Запустить демонстрацию генерации в браузере можно с помощью команды
yarn
watch
. Эта команда компилирует файл
mnist-acgan/index.js
и соответствующие
HTML и CSSресурсы, после чего открывает вкладку в браузере и отображает
страницу демонстрации
1
.
Страница демонстрации загружает обученный генератор ACGAN, сохраненный
на предыдущем этапе. Поскольку для текущего этапа дискриминатор особо не ну
жен, он не сохраняется и не загружается. После загрузки генератора можно сформи
ровать батч латентных векторов вместе с батчем желаемых индексов классов цифр
и вызвать с ними в качестве аргументов метод
predict()
генератора. Выполняющий
эти действия код из файла
mnist-acgan/index.js
:
const latentVectors = getLatentVectors(10);
const sampledLabels = tf.tensor2d(
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10, 1]);
const generatedImages =
generator.predict([latentVectors, sampledLabels]).add(1).div(2);
1
Можете полностью пропустить шаги обучения и сборки и перейти сразу к странице де
монстрации, размещенной по адресу http://mng.bz/4eGw.
Do'stlaringiz bilan baham: |