226
Часть II • Введение в TensorFlow.js
Листинг 5.9.
Описание модели для обнаружения простых объектов на основе усеченной MobileNet
1
1
Ради большей ясности мы удалили часть кода обработки ошибок.
Глава 5. Перенос обучения: переиспользование предобученных нейронных сетей
227
Ключевая часть модели двойного предсказания в листинге 5.9 создана с помо
щью метода
buildNewHead()
. Схема работы модели приведена на рис. 5.14,
слева
.
Новая верхушка состоит из трех слоев. Слой схлопывания меняет форму вы
ходного сигнала у последнего сверточного слоя усеченной MobileNet так, чтобы
можно было добавить к ней плотные слои. Первый плотный слой — скрытый,
с нелинейностью типа ReLU. Второй плотный слой — итоговый выходной слой
верхушки, а потому и итоговый выходной слой всей модели обнаружения объ
ектов. Функция активации в этом слое — линейная по умолчанию. Данный слой
является ключевым для понимания работы модели, поэтому рассмотрим его
подробнее.
Как видно из кода, количество выходных нейронов итогового плотного слоя
равно пяти. Что отражают эти пять чисел? Они объединяют предсказания фор
мы целевого объекта и ограничивающего прямоугольника. Что интересно, их
смысл определяет не сама модель, а используемая функция потерь. Ранее вы уже
встречали различные виды функций потерь с понятными названиями наподобие
meanSquaredError
, подходящими для соответствующих задач машинного обучения
(например, см. табл. 3.6). Впрочем, это лишь один из двух способов задания функ
ций потерь в TensorFlow.js. Второй способ, который мы и будем здесь использовать,
включает описание пользовательской JavaScriptфункции, соответствующей опре
деленной сигнатуре. Эта сигнатура имеет такой вид.
z
z
Два входных аргумента: 1) истинные метки входных примеров данных и 2) соот
ветствующие предсказания модели. Каждый из них представляет собой двумер
ный тензор, причем форма этих тензоров должна совпадать, а первое измерение
отражать размеры батчей.
z
z
Возвращаемое значение представляет собой скалярный тензор (тензор фор
мы
[]
), значение которого представляет собой средние потери примеров данных
батча.
Наша пользовательская функция потерь, написанная в соответствии с этой
сигнатурой, приведена в листинге 5.10 и графически изображена в правой части
рис. 5.14. Первый входной аргумент
customLossFunction
(
yTrue
) — тензор с истин
ными метками формы
[batchSize,
5]
. Первый входной аргумент (
yPred
) — выход
ное предсказание модели, имеет такую же форму, как и
yTrue
. Из пяти измерений
по второй оси
yTrue
(пяти столбцов, если рассматривать его как матрицу) первое
представляет собой индикатор 01 формы целевого объекта (0 означает треугольник,
а 1 — прямоугольник), что определяется способом синтеза данных (см.
simple-
object-detection/synthetic_images.js
). Оставшиеся четыре столбца описывают
прямоугольник, ограничивающий целевой объект, точнее, значения левой, правой,
верхней и нижней его координат, каждое из которых находится в диапазоне от 0 до
CANVAS_SIZE
(224). Число 224 соответствует высоте и ширине входных изображений,
его источник — размер входного изображения модели MobileNet, лежащей в основе
нашей модели.
Do'stlaringiz bilan baham: |