Главная » Хабрахабр » Введение в состязательные сети

Введение в состязательные сети

Этой статьей я начинаю серию рассказов о состязательный сетях. Всем привет. Я не буду копировать весь код из примера сюда, только основные его части, поэтому, для удобства советую иметь его рядом для более простого понимания. Как и в предыдущей статье я подготовил соответствующий докер-контейнер в котором уже все готово для того чтобы воспроизвести то что написано здесь ниже. Докер контейнер доступен здесь, а ноутбук, utils.py и докерфайл здесь.

Несмотря на то, что фреймворк состязательных сетей был предложен Йеном Гудфеллоу в его уже знаменитой работе Generative Adversarial Networks ключевая идея пришла к нему из работ по доменной адаптации(Domain adaptation), поэтому и начнем мы обсуждение состязательных сетей именно с этой темы.

Например это могут быть медицинские записи разных социально-демографических групп(мужчины/женщины, взрослые/дети, азиаты/европейцы...). Представьте, что у вас есть два источниках данных о похожих наборах объектов. Типичные анализы крови представителей разных групп будут отличаться, поэтому модель, предсказывающая, скажем, риск сердечно-сосудистых заболеваний(ССЗ), обученная на представителях одной выборки не может применяться к представителям другой выборки.

Типичным решением такой проблемы будет добавление на вход модели признака идентифицирующего выборку, но, к сожалению, такой подход имеет множество недостатков:

  1. Несбалансированная выборка — азиатов больше, чем европейцев
  2. Разные статистики — дети сильно реже страдают от ССЗ, чем взрослые
  3. Недостаточная разметка одной из выборок — мужчины 60-х годов рождения гибли в Афганистане поэтому меньше данных о ССЗ в зависимости от региона рождения чем для женщин.
  4. Данные имеют различный набор признаков — анализы крови людей и мышей и т.д.

Но может и не стоит заморачиваться? Все эти причины могут сильно затруднить процесс обучения модели. Оказывается — стоит. Обучим по одной модели для каждой выборки и успокоимся. А если у вас не два источника данных, а сильно больше? Если вы сможете нивелировать разницу в статистиках из разных обучающих выборок, то, по сути вы сможете сделать одну выборку бОльшего размера чем каждая из исходных.

Я не буду рассказывать о том, как решали задачу адаптации доменов в “донейросетевую” эру, а сразу покажу базовую архитектуру.

Архитектура сети из статьи

В этой статье продемонстрировано как перенести модель классификации с одного источника данных на другой не используя метки для второго источника. В 2014 году, наш соотечественник Ярослав Ганин в соавторстве с Виктором Лемпицким опубликовал очень важную статью "Unsupervised Domain Adaptation by Backpropagation" (доменная адаптация без учителя с помощью обратного распространения ошибки). Представленная модель состояла из 3 подсетей: feature extractor(E), label predictor(P) и domain classifier© связанных между собой как на рисунке.

Слой, где он разрезан назван слоем признаков(features). Пара сетей E+P представляет из себя по сути обыкновенный классификатор, разрезанный где-то посередине. Задача сети E — извлечь такие признаки из данных, чтобы, с одной стороны P смог правильно угадать метку примера, а с другой стороны C не смог определить его источник. Сеть C получает на вход данные с этого слоя и пытается угадать из какого источника пришел пример.

Можно сказать, что каждый пример содержит информацию о своей метке и какую-то еще информацию. Для того чтобы лучше понять зачем это надо и почему это должно работать давайте поговорим об информации. Если вы сможете обучить идеальный автокодировщик на MNIST то вы сможете записать ту же самую информацию в другом виде. В случае с MNIST'ом вся эта информация может быть записана, например, в виде ч/б изображения размером 28х28 пикселей. Например, по изображению не всегда можно понять какая именно цифра была написана, однако какая-то доля информации о метке все же содержится в изображении. Понятно, что в некоторых случаях информация о метке в самом примере может быть неполной. Когда мы обучаем классификатор, мы стараемся максимально извлекать информацию о метке, но сделать это можно огромным количеством способов. Но, помимо метки, изображение имеет еще ряд явных и огромное количество неявных свойств: свойства почерка(толщина, наклон, “завитушки”), расположение(в центре или со сдвигом), шум и т.д. На одном и том же MNIST'е мы можем обучить 100 одинаково эффективных классификаторов, каждый из которых будет иметь свое собственное скрытое представление изображений, что уж говорить о том случае когда источники данных разные.

Если рассмотреть данные из двух разных источников(например MNIST и SVHN), то можно сказать что каждый из примеров содержит информацию о метке и об источнике. Идея Ганина заключается в том, что если с помощью нейросети мы можем максимизировать информацию, то ничто не мешает нам ее минимизировать. Если мы способны обучить нейросеть E извлекать признаки содержащие информацию о метке, и делать это одинаково, независимо от того откуда пришел пример, то сеть P обученная только на примерах из одного источника, должна быть способна предсказывать метки и для второго источника.

Таблица результатов

При этом, обе модели, конечно, никогда не видели ни одной метки из MNIST при обучении. И действительно, нейросеть, обученная на примерах из SVHN с применением доменной адаптации определяет класс изображений из MNIST точнее чем сеть обученная только на SVHN — 71% точности против 59%. Фактически это означает, что вы можете переносить обученный классификатор с одной выборки на другую, даже если для второй выборки вы не знаете меток.

Поэтому я разберу другой пример применения этой техники, и, надеюсь, он поможет еще лучше продемонстрировать идею "разделения" информации в слое признаков. Несмотря на то, что задача классификации цифр достаточно проста, обучение сетей использовавшихся в статье может потребовать существенных ресурсов, к тому же, код решающий эту задачу нетрудно найти в интернете.

Давайте научимся представлять примеры MNIST в сжатом виде, но будем это делать так, чтобы в сжатом представлении не содержалось информации о том какая именно цифра была изображена. Очень часто, когда речь заходит об извлечении информации или о ее представлении в каком-то особенном виде на сцену выходят автокодировщики. В то же время, если классификатор по извлеченным признакам не способен угадать метку, то вся информация о метке забыта. Тогда, если декодер нашей сети, получив извлеченные признаки и метку исходной цифры способен восстановить исходное изображение, то мы можем считать что кодировщик не теряет никакой информации помимо метки.

Для этого нам придется создать и обучить 3 сети — кодировщик(E), декодировщик(D) и классификатор(С).

В этот раз сделаем кодировщик сверточным, добавив пару сверточных слоев, для этого будем использовать класс Sequential.

conv1 = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1), nn.BatchNorm2d(num_features=16), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Dropout(0.2)
) conv2 = nn.Sequential( nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1), nn.BatchNorm2d(num_features=32), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), nn.Dropout(0.2)
) self.conv = nn.Sequential( conv1, conv2
)

Информация об этих слоях доступна в огромном количестве в интернете (или, например, в нашей книге), поэтому здесь я не буду подробно их разбирать. По сути, он позволяет нам задавать сразу подсети, в данном случае это последовательность из слоев свертки, нормализации по минибатчам, функции активации, субдискретизации и дропаута.

Слои(или подсети) Sequential в функции forward могут быть использованы точно так же как и любые другие слой

def forward(self, x): x = self.conv(x) x = x.view(-1, 7*7*32) x = self.fc(x) return x

Для того чтобы сделать декодировщик аналогичный кодировщику последние его слои будем задавать с помощью транспонированной свертки

conv1 = nn.Sequential( nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=3, stride=2), nn.BatchNorm2d(num_features=16), nn.ReLU(), nn.Dropout(0.2)
) conv2 = nn.Sequential( nn.ConvTranspose2d(in_channels=16, out_channels=1, kernel_size=2, padding=1, stride=2), nn.Tanh()
)

Существенным отличием декодировщика будет то, что на вход он получает не только признаки полученные из кодировщика, но и метку:

def forward(self, x, y): x = torch.cat([x, y], 1) x = self.fc(x) x = x.view(-1, 32, 7, 7) x = self.deconv(x) return x

torch.cat позволяет сконкатенировать признаки и метку в один вектор, а дальше мы просто восстанавливаем изображение из этого вектора.

И третьей сетью будет обыкновенный классификатор, предсказывающий по признакам извлеченным с помощью кодировщика исходную метку изображения.
Цикл обучения всей модели теперь выглядит следующим образом:

for x, y in mnist_train: y_onehot = utils.to_onehot(y, 10) # train classifier C C.zero_grad() z = E(x) C_loss = NLL_loss(C(z), y) C_loss.backward(retain_graph=True) C_optimizer.step() # train decoder D and encoder E E.zero_grad() D.zero_grad() AE_loss = MSE_loss(D(z, y_onehot), x) C_loss = NLL_loss(C(z), y) FADER_loss = AE_loss - beta*C_loss FADER_loss.backward() D_optimizer.step() E_optimizer.step()

Сначала мы используем кодировщик для извлечения признаков изображения и обновляем веса только классификатора так, чтобы он лучше предсказывал метку:

z = E(x)
C_loss = NLL_loss(C(z), y)
C_loss.backward(retain_graph=True)
C_optimizer.step()

Следующим шагом мы обучаем автокодировщик и накладываем дополнительное требование извлекать такие признаки, по которым классификатору будет труднее восстановить метку: При этом мы просим PyTorch сохранить граф вычислений для того чтобы использовать его повторно.

AE_loss = MSE_loss(D(z, y_onehot), x)
C_loss = NLL_loss(C(z), y) FADER_loss = AE_loss - C_loss
FADER_loss.backward()
D_optimizer.step()
E_optimizer.step()

Таким образом мы поочередно учим то классификатор, то автокодировщик преследующие противоположные цели. Обратите внимание на то, что веса классификатора при этом не обновляются, однако веса кодировщика обновляются в сторону увеличения ошибки классификатора. В этом и заключается идея состязательных сетей.

В то же время, мы обучаем декодировщик используя эту информацию в совокупности с меткой уметь восстанавливать исходный пример. В результате обучения такой модели мы хотим получить кодировщик извлекающий из примеров всю необходимую для восстановления примера информацию за исключением метки. На изображении ниже каждая строка получена восстановлением изображения из признаков одной из цифр в сочетании с 10-ю возможными метками. Но что если мы подадим на вход декодировщику другую метку? Цифры взятые за основу расположены на диагонали(точнее не сами исходные примеры, а восстановленные, но с использованием "правильной" метки).

Перенос стиля между цифрами

Кроме того, видно, что строка полученная из цифры 1 нестабильна, я объясняю это тем, что в написании единицы содержится не очень много информации о стиле, пожалуй, только толщина линии и наклон, но точно нет информации о "завитушках". На мой взгляд этот пример отлично демонстрирует идею об извлечении информации отличной от метки, так как видно, что в одной и той же строке все цифры "написаны" в одном стиле. Поэтому остальные цифры написанные в том же стиле могут оказаться довольно разнообразны, хотя в каждом отдельном случае стиль будет один на всю строку, но на разных этапах обучения он будет меняться.

Аналогичным образом из модель извлекает признаки из фотографий лиц и “забывает” метки типа наличия бороды или очков. Осталось только добавить, что подобный подход был опубликован на NIPS’17 в статье от команды Facebook. Вот пример того что у них получилось взятый из статьи:

Пример из статьи FADER Networks

В следующей статье я расскажу о том как генерировать изображения с нуля и почему эта конкретная модель не умеет так делать. Хотя в этом посте мы и рисовали “новые” цифры, но для этого нам приходилось использовать уже существующие цифры чтобы выбрать стиль.


Оставить комментарий

Ваш email нигде не будет показан
Обязательные для заполнения поля помечены *

*

x

Ещё Hi-Tech Интересное!

[Перевод] Что мне нравилось в Поле Аллене

Воспоминания Билла Гейтса о Поле Аллене, с которым они вместе, будучи ещё студентами, основали в 1975 году компанию «Microsoft» (название компании предложил именно Пол) Я хочу выразить свои сожаления его сестре Джоди, его семье и множеству его друзей и коллег ...

Где работать в ИТ #2: «СКБ Контур»

В конце октября ей исполняется 30 лет, количество всех сотрудников перевалило за 8 тысяч. «СКБ Контур» — одна из крупнейших и старейших ИТ компаний в России. По оценкам, собранным на сервисе оценки работодателей «Моего круга», в июле 2018 «Контур» разделил ...