pytorch の Embedding の挙動について
pytorch の Embedding の挙動について

pytorch の Embedding の挙動について

はじめに

この記事では pytorch の embedding の挙動について記載します

Embedding とは何か

公式の仕様書はこちらになります

公式の説明は以下となっており、非常に的を得ていると思います

A simple lookup table that stores embeddings of a fixed dictionary and size.

意訳すると、 固定長の辞書埋め込みを保存するシンプルなルックアップテーブル になるんじゃないかなと思います。Embedding は、何だか難しそうにも思えてしまうのですが、ここに記載されている通り非常にシンプルなテーブルでしかないという事です

モジュールの解説としては以下のように記載があります

This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings.

こちらも意訳すると、 このモジュールはワードエンベディングを保存するために使われる事が多く、インデックスによって取得されます。このモジュールへの入力はインデックスのリスト、出力は対応するワードエンベディングになります というような感じかなと思います。

こちらは割と理解しずらいと思うので、以下細かく挙動を見ていきます

Embedding の挙動の確認

挙動を確認するために実際に動かして内部の動作を確認していきます

まず、Embedding を初期化します

>>> import torch
>>> from torch import nn

>>> torch.manual_seed(42) // 再現性のために seed を固定します
>>> emb = nn.Embedding(2, 5) // 2 x 5 次元の embedding を作ります

この時点で emb.weight の中身を確認すると以下のようになります

>>> print(emb.weight)
Parameter containing:
tensor([[ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229],
        [-0.1863,  2.2082, -0.6380,  0.4617,  0.2674]], requires_grad=True)

2 x 5 次元のベクトルがランダムに初期化されている事がわかります

ここで [ 0.3367, 0.1288, 0.2345, 0.2303, -1.1229] が index 0 に、 [-0.1863, 2.2082, -0.6380, 0.4617, 0.2674] が index 1 に対応する事が想像されます

実際に embedding を通してアクセスしてみると、入力した index に対応する embedding が取得できるのが確認できます

>>> print(emb(torch.tensor([0]))) // index 0 を入力としたいが、tensor にする必要がある
tensor([[ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229]],
       grad_fn=<EmbeddingBackward>)

>>> print(emb(torch.tensor([1])))
tensor([[-0.1863,  2.2082, -0.6380,  0.4617,  0.2674]],
       grad_fn=<EmbeddingBackward>)

>>> print(emb(torch.tensor([0, 1])))
tensor([[ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229],
        [-0.1863,  2.2082, -0.6380,  0.4617,  0.2674]],
       grad_fn=<EmbeddingBackward>)

公式の説明にある 固定長の辞書埋め込みを保存するシンプルなルックアップテーブル という事が理解できます

Embedding の学習

Embedding がシンプルなルックアップテーブルだという事は理解できましたが、ランダムに作られたベクトルというだけでは何の役にも立ちません。入力に対して学習をしてこのベクトルに意味を持たせる事が大事です。ここでは Embedding がどのように学習していくのかを見ていきます

まず、Embedding の状態を確認したいので、weight の grad という値を確認します

>>> print(emb.weight.grad)
None

学習を何もしていない状態では grad は特に何も無い事がわかります

では学習を進めるために Optimizer を作ります

>>> optimizer = torch.optim.SGD(emb.parameters(), lr=0.1, momentum=0.9)

torch.optim.SGD は embedding の parameter を受け取ります。これによって optimizer object を通して embedding の parameter を確認したり更新したりする事ができるようになります。また lr は learning rate になります

optimizer を作った上で embedding の loss を計算します ここでは適当に embedding の index 0 と index 1 のユークリッド距離が最小になるような学習をしたいという事にします

>>> loss = torch.linalg.norm(emb(torch.tensor([0])) - emb(torch.tensor([1])))
>>> print(loss)
tensor(2.7101, grad_fn=<CopyBackwards>)

これは embedding[0] と embedding[1] のユークリッド距離を単純に計算しているだけです

確認のため numpy で計算しても同じ結果になります

>>> import numpy as np
>>> a = np.array((0.3367,  0.1288,  0.2345,  0.2303, -1.1229))
>>> b = np.array((-0.1863,  2.2082, -0.6380,  0.4617,  0.2674))
>>> dist = np.linalg.norm(a-b)
>>> print(dist)
2.7101973470579592

この時点では単純に loss を計算しているだけなので、grad はまだ特に更新されていません

>>> print(emb.weight.grad)
None

loss.backward 実行するとはじめて emb.weight.grad に値がはいります

>>> loss.backward()
>>> print(emb.weight.grad)
tensor([[ 0.1930, -0.7673,  0.3219, -0.0854, -0.5130],
        [-0.1930,  0.7673, -0.3219,  0.0854,  0.5130]])

ただし、ここでもまだ weight 自体は更新されていません

>>> print(emb.weight)
Parameter containing:
tensor([[ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229],
        [-0.1863,  2.2082, -0.6380,  0.4617,  0.2674]], requires_grad=True)

最後に optimizer.step() を実行する事で emb.weight.grad から計算された値を使って weight が更新されます

>>> optimizer.step()
>>> print(emb.weight)
Parameter containing:
tensor([[ 0.3174,  0.2055,  0.2023,  0.2389, -1.0716],
        [-0.1670,  2.1315, -0.6058,  0.4531,  0.2161]], requires_grad=True)

>>> print(emb.weight.grad)
tensor([[ 0.1930, -0.7673,  0.3219, -0.0854, -0.5130],
        [-0.1930,  0.7673, -0.3219,  0.0854,  0.5130]])

この計算には optimizer の初期化時に指定した learning rate が使われています

検算するとわかりますが、新しい weight は weight - (grad x learning_rate) の式で計算されます

例えば embedding[0][0] の値はもともと 0.3367 でしたが 0.3367 - (0.1930 * 0.1) されて 0.3174 に更新されています

これを 1 step として学習を繰り返していく事で、loss が最小になるように embedding が更新されていきます

for i in range(10):
  optimizer.zero_grad()
  loss = torch.linalg.norm(emb(torch.tensor([0])) - emb(torch.tensor([1])))
  loss.backward()
  optimizer.step()

まとめ

pytorch の embedding の挙動について記載しました

挙動を追ってみる事で公式の説明の通り、非常にシンプルなルックアップテーブルである事がわかりました