はじめに
この記事では pytorch の embedding の挙動について記載します
Embedding とは何か
公式の仕様書はこちらになります
公式の説明は以下となっており、非常に的を得ていると思います
A simple lookup table that stores embeddings of a fixed dictionary and size.
意訳すると、 固定長の辞書埋め込みを保存するシンプルなルックアップテーブル
になるんじゃないかなと思います。Embedding は、何だか難しそうにも思えてしまうのですが、ここに記載されている通り非常にシンプルなテーブルでしかないという事です
モジュールの解説としては以下のように記載があります
This module is often used to store word embeddings and retrieve them using indices. The input to the module is a list of indices, and the output is the corresponding word embeddings.
こちらも意訳すると、 このモジュールはワードエンベディングを保存するために使われる事が多く、インデックスによって取得されます。このモジュールへの入力はインデックスのリスト、出力は対応するワードエンベディングになります
というような感じかなと思います。
こちらは割と理解しずらいと思うので、以下細かく挙動を見ていきます
Embedding の挙動の確認
挙動を確認するために実際に動かして内部の動作を確認していきます
まず、Embedding を初期化します
>>> import torch
>>> from torch import nn
>>> torch.manual_seed(42) // 再現性のために seed を固定します
>>> emb = nn.Embedding(2, 5) // 2 x 5 次元の embedding を作ります
この時点で emb.weight
の中身を確認すると以下のようになります
>>> print(emb.weight)
Parameter containing:
tensor([[ 0.3367, 0.1288, 0.2345, 0.2303, -1.1229],
[-0.1863, 2.2082, -0.6380, 0.4617, 0.2674]], requires_grad=True)
2 x 5 次元のベクトルがランダムに初期化されている事がわかります
ここで [ 0.3367, 0.1288, 0.2345, 0.2303, -1.1229]
が index 0 に、 [-0.1863, 2.2082, -0.6380, 0.4617, 0.2674]
が index 1 に対応する事が想像されます
実際に embedding を通してアクセスしてみると、入力した index に対応する embedding が取得できるのが確認できます
>>> print(emb(torch.tensor([0]))) // index 0 を入力としたいが、tensor にする必要がある
tensor([[ 0.3367, 0.1288, 0.2345, 0.2303, -1.1229]],
grad_fn=<EmbeddingBackward>)
>>> print(emb(torch.tensor([1])))
tensor([[-0.1863, 2.2082, -0.6380, 0.4617, 0.2674]],
grad_fn=<EmbeddingBackward>)
>>> print(emb(torch.tensor([0, 1])))
tensor([[ 0.3367, 0.1288, 0.2345, 0.2303, -1.1229],
[-0.1863, 2.2082, -0.6380, 0.4617, 0.2674]],
grad_fn=<EmbeddingBackward>)
公式の説明にある 固定長の辞書埋め込みを保存するシンプルなルックアップテーブル
という事が理解できます
Embedding の学習
Embedding がシンプルなルックアップテーブルだという事は理解できましたが、ランダムに作られたベクトルというだけでは何の役にも立ちません。入力に対して学習をしてこのベクトルに意味を持たせる事が大事です。ここでは Embedding がどのように学習していくのかを見ていきます
まず、Embedding の状態を確認したいので、weight の grad という値を確認します
>>> print(emb.weight.grad)
None
学習を何もしていない状態では grad は特に何も無い事がわかります
では学習を進めるために Optimizer を作ります
>>> optimizer = torch.optim.SGD(emb.parameters(), lr=0.1, momentum=0.9)
torch.optim.SGD
は embedding の parameter を受け取ります。これによって optimizer
object を通して embedding の parameter を確認したり更新したりする事ができるようになります。また lr
は learning rate になります
optimizer を作った上で embedding の loss を計算します ここでは適当に embedding の index 0 と index 1 のユークリッド距離が最小になるような学習をしたいという事にします
>>> loss = torch.linalg.norm(emb(torch.tensor([0])) - emb(torch.tensor([1])))
>>> print(loss)
tensor(2.7101, grad_fn=<CopyBackwards>)
これは embedding[0] と embedding[1] のユークリッド距離を単純に計算しているだけです
確認のため numpy で計算しても同じ結果になります
>>> import numpy as np
>>> a = np.array((0.3367, 0.1288, 0.2345, 0.2303, -1.1229))
>>> b = np.array((-0.1863, 2.2082, -0.6380, 0.4617, 0.2674))
>>> dist = np.linalg.norm(a-b)
>>> print(dist)
2.7101973470579592
この時点では単純に loss を計算しているだけなので、grad はまだ特に更新されていません
>>> print(emb.weight.grad)
None
loss.backward
実行するとはじめて emb.weight.grad
に値がはいります
>>> loss.backward()
>>> print(emb.weight.grad)
tensor([[ 0.1930, -0.7673, 0.3219, -0.0854, -0.5130],
[-0.1930, 0.7673, -0.3219, 0.0854, 0.5130]])
ただし、ここでもまだ weight 自体は更新されていません
>>> print(emb.weight)
Parameter containing:
tensor([[ 0.3367, 0.1288, 0.2345, 0.2303, -1.1229],
[-0.1863, 2.2082, -0.6380, 0.4617, 0.2674]], requires_grad=True)
最後に optimizer.step()
を実行する事で emb.weight.grad
から計算された値を使って weight が更新されます
>>> optimizer.step()
>>> print(emb.weight)
Parameter containing:
tensor([[ 0.3174, 0.2055, 0.2023, 0.2389, -1.0716],
[-0.1670, 2.1315, -0.6058, 0.4531, 0.2161]], requires_grad=True)
>>> print(emb.weight.grad)
tensor([[ 0.1930, -0.7673, 0.3219, -0.0854, -0.5130],
[-0.1930, 0.7673, -0.3219, 0.0854, 0.5130]])
この計算には optimizer の初期化時に指定した learning rate
が使われています
検算するとわかりますが、新しい weight は weight - (grad x learning_rate)
の式で計算されます
例えば embedding[0][0]
の値はもともと 0.3367
でしたが 0.3367 - (0.1930 * 0.1)
されて 0.3174
に更新されています
これを 1 step として学習を繰り返していく事で、loss が最小になるように embedding が更新されていきます
for i in range(10):
optimizer.zero_grad()
loss = torch.linalg.norm(emb(torch.tensor([0])) - emb(torch.tensor([1])))
loss.backward()
optimizer.step()
まとめ
pytorch の embedding の挙動について記載しました
挙動を追ってみる事で公式の説明の通り、非常にシンプルなルックアップテーブルである事がわかりました