Часть II • Введение в TensorFlow.js
3.3.1. Унитарное кодирование категориальных данных
Прежде чем изучать модель, предназначенную для решения задачи классификации
ирисов, необходимо поговорить о способе представления целевых меток (видов
цветов) в этой задаче многоклассовой классификации. Во всех предыдущих при
мерах машинного обучения в книге представление целевых признаков было более
простым, например в виде одного числа в задаче предсказания времени скачивания
и задаче предсказания цен на бостонскую недвижимость, а также представления
01 бинарных целевых признаков в задаче обнаружения фишинга. В задаче же клас
сификации ирисов три вида цветков представлены несколько менее привычным
способом, с помощью так называемого
унитарного кодирования
(onehot encoding).
Откройте файл
data.js
и взгляните на строку:
const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES);
Здесь
shuffledTargets
— простой JavaScriptмассив, состоящий из целочислен
ных меток для примеров в перетасованном виде. Значения всех его элементов равны
0, 1 или 2: соответственно трем видам ирисов в наборе данных. С помощью вызова
tf.tensor1d(shuffledTargets).toInt()
набор преобразуется в одномерный тензор
с типом элементов
int32
. Затем он передается в функцию
tf.oneHot()
, возвраща
ющую двумерный тензор формы
[numExamples,
IRIS_NUM_CLASSES]
.
numExamples
—
число примеров данных, содержащихся в
targets
, а
IRIS_NUM_CLASSES
— просто
константа, равная 3. Чтобы взглянуть на фактические значения
targets
и
ys
, можете
добавить сразу после вышеупомянутой строки код для вывода в консоль наподобие
следующего:
const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES);
// Добавленные строки кода для вывода в консоль значений `targets` и `ys`
console.log('Value of targets:', targets);
ys.print();
1
После этих изменений процессупаковщик Parcel, запускаемый Yarnкомандой
watch
в терминале, автоматически пересобирает вебфайлы. Далее вы можете от
крыть инструменты разработчика (devtool) на соответствующей вкладке браузера
и обновить страницу. Сообщения, выводимые вызовами
console.log()
и
print()
,
попадают в консоль devtool. Они выглядят примерно так:
Value of targets: (50) [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0]
Tensor
[[1, 0, 0],
1
В отличие от targets, ys не просто JavaScriptмассив, а тензорный объект, использующий
память GPU. Следовательно, с помощью обычного вызова console.log его значение не по
смотреть. Метод print() специально предназначен для извлечения значений из GPU,
их форматирования с учетом формы тензора в удобном для человека виде и вывода
в консоль.
Do'stlaringiz bilan baham: |