最近流行りのSelf Attentionを使ったGAN(Generative Adversarial Network)の実装にチャレンジしてみました。
GAN(Generative Adversarial Network)は敵対的生成ネットワークと呼ばれるもので、画像を生成するGenerator、生成した画像を偽物か本物か識別するDiscriminatorの2種類のネットワークを作成し、GeneratorはDiscriminatorを騙せるような画像を作れるように、DiscriminatorはGeneratorが作る画像を偽物と識別できるようにお互いが学習していく事で、Generatorが本物と見分けがつかない画像を生成できるようになります。
Generatorは偽札を作る悪者、Discriminatorは偽札を見破る警察機関みたいなものですね。
Self Attentionは、入力値X(テンソル、行列)から、Quely, Key, Valueの3つの情報を作って、QuelyとKeyの内積をとって、Softmax関数で0~1の間で合計が1になるように規格化したAttention Mapを作り、Attention MapとValueの内積をとってSelf Attention Mapを作成し、X +Self Attention Map × γ(係数、初期値0から学習)を計算し出力するモジュールです。
Xの各要素の関係性を学習することで、画像処理においては画素のどこに注目するか、自然言語処理においては単語間の関係性を考慮する事が可能になります。
(何言ってんだこの人??って思いますよね?。私も実際に概要を理解するのに、各種資料を読み込んで3か月以上悩みました。)
両方を組み合わせたものがSelf Attention GAN、略してSAGAN!!
う~ん、かっちいい!!
実際の所、Attention機構は最近必須の技術と化しています。
我々のお仕事で例えると、処方内容・薬歴・患者情報等から患者さんにどういった所に注意すべきか伝えるかを考えるのがAttention Mapの作成に当たるのかな?
(的確なAttention Mapを作るために、我々は知識と経験を積んで学習していく訳ですね)
とりあえず猫画像を学習させて、新しい猫画像を生成する事を目指します。
まだ、実験段階なので画像は白黒64×64ピクセル×約1700枚で学習させてみました。
エポック数(学習に回す回数)を200で実行。GPU使って1時間程度かかりました。
結果がこちら!
上が、学習に使った画像。下が生成したインチキ猫画像。
何だこりゃ???
学習が全然足りないようですね。なんとなく、猫っぽい心霊画像といった感じの物がちょっと生成できたぐらいでした。
さらに、改良と学習を進めてみます。