Bog'liq Цэй Ш., Байлесчи С., и др. - JаvaScript для глубокого обучения (Библиотека программиста) - 2021
Глава 11. Основы глубокого обучения с подкреплением
467 Функция потерь для предсказания Q-значений и обратного
распространения ошибки
Воспользуемся уже знакомой вам функцией потерь
meanSquaredError
для вычисле
ния расхождения полученных предсказанных и целевых Qзначений (рис. 11.16).
Таким образом, нам удалось свести процесс обучения DQN к задаче регрессии, схо
жей с предыдущими примерами, например Bostonhousing и Jenaweather. Сигнал
рассогласования функции потерь
meanSquaredError
служит движущей силой об
ратного распространения ошибки, а обновление динамической DQN производится
на основе полученных в результате обновлений весовых коэффициентов.
Схема на рис. 11.16 включает части, уже показанные на рис. 11.14 и 11.15. На ней
не только эти части собраны воедино, но и добавлены новые прямоугольники и стрел
ки для функции потерь
meanSquaredError
и основанного на ней обратного распростра
нения ошибки (см. правую нижнюю часть схемы). Таким образом, на ней приведена
полная картина алгоритма глубокого Qобучения нашего агента для игры «Змейка».
Код в листинге 11.9, содержащий внутренности метода
trainOnReplayBatch()
класса
SnakeGameAgent
из файла
snake-dqn/agent.js
, играющего ключевую роль
в нашем алгоритме RL, тесно связан со схемой на рис. 11.16. В этом методе опи
сана функция потерь для вычисления
meanSquaredError
между предсказанными
и целевыми Qзначениями. Далее он вычисляет градиенты
meanSquaredError
от
носительно весовых коэффициентов динамической DQN с помощью функции
tf.variableGrads()
(в разделе Б.4 вы найдете подробный обзор функций вычисле
ния градиентов TensorFlow.js, в частности, и
tf.variableGrads()
). Далее, с помо
щью оптимизатора вычисленные градиенты применяются для обновления весовых
коэффициентов DQN, подталкивая тем самым динамическую DQN в сторону более
точных оценок Qзначений. После повторения этого миллионы раз получается DQN,
обеспечивающий неплохие игровые результаты в игре «Змейка». Отвечающая за
вычисление целевых Qзначений (
targetQs
) часть кода из следующего листинга
уже приводилась в листинге 11.8.
Вот и все, что мы хотели рассказать о внутреннем устройстве алгоритма глубоко
го Qобучения. Запустить обучение на основе этого алгоритма в среде Node.js можно
с помощью следующей команды:
yarn train --logDir /tmp/snake_logs
При наличии настроенного должным образом GPU с поддержкой CUDA можете
добавить в эту команду флаг
--gpu
для ускорения обучения. В результате указания
флага
--logDir
данная команда во время обучения заносит в журнал в каталоге
журналов TensorFlow.js следующие метрики: 1) скользящее среднее совокупных
вознаграждений за 100 последних эпизодов игры (
cumulativeReward100
); 2) скольз
ящее среднее количества «съеденных» за 100 последних эпизодов игры фруктов
(
eaten100
); 3) значение параметра исследования (
epsilon
) и 4) скорость обучения,
выражаемую в ходах в секунду (
framesPerSecond
). Просмотреть эти журналы можно,
запустив TensorFlow.js с помощью следующих команд и перейдя в браузере по HTTP
URL клиентской части TensorFlow.js (по умолчанию
http://localhost:6006
):
pip install tensorboard tensorboard --logdir /tmp/snake_logs