はじめに
プラチナランカー!!🙌 nikkieです🟧🟧🟧🟧🟧
transformers・PyTorchの組合せで、文のembeddingsを得るコードで理解したいことがありました。
理解を深める目的でこの記事でアウトプットします。
目次
- はじめに
- 目次
- E5
- average_pool関数、何をやってるんだ...?
- 現時点の理解を反映したコード(Usageの再実装)
- 終わりに
- P.S. sentence-transformersの実装と一致!
- 変更履歴
E5
文のembeddings(埋め込み表現)を得る方法はいくつもあります(特に具体的なモデルがいくつも!)
- 外部のAPIを使う
- 例:OpenAIのada
- テキストを入力するとembeddingsを出力するモデルをローカル環境で動かす
- sentence-transformersで試しました
後者の方法で注目したモデルがE5。
embeddingsの性能評価でadaに匹敵するということで注目しました。
なんでも「EmbEddings from bidirEctional Encoder rEpresentations」でE5なんだとか(元論文より)。
Hugging Face Hubで公開されており、動かすためのサンプルコード(Usage)もあります。
今回は多言語に対応したE5(multilingual-e5)を見ていきます
参考文献
コネヒトさんのブログで印象に残りました。
ベンチマークにおいて、multilingual-e5 の性能がtext-embedding-ada-002 と大差ないことが報告されている。
Hironsanのエントリで完全理解しました。
[修正 (2023/12/11)] リンクを誤っていたので修正(当初のリンク)
元論文
リポジトリ:https://github.com/microsoft/unilm/tree/master/e5
(マイクロソフトなんですね)
average_pool関数、何をやってるんだ...?
# https://huggingface.co/intfloat/multilingual-e5-small の Usage より def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor: last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
attention_mask[..., None]
??~
(たしか否定)- 1つのtorch.Tensorを返しているっぽいけど、これでaverage poolになってる?
呼び出し方から見ていきましょう
average_pool関数の呼び出し
embeddings = average_pool(
outputs.last_hidden_state, batch_dict["attention_mask"]
)
outputsはモデル(multilingual-e5-small)の出力したTensorです。
>>> outputs.last_hidden_state.size() torch.Size([2, 12, 384])
2文あり、どちらも12トークンに分割され、各トークンが384次元のベクトルで表されているという理解です。
このトークンを平均して、文を表すベクトルを得たいわけですね1。
もう1つのbatch_dictはトークナイザの出力(BatchEncoding)です。
>>> batch_dict["attention_mask"].size() torch.Size([2, 12])
Usageで与えている文章から意図的に短くして、2文目をpaddingされるようにしています2
input_texts = [ "query: how much protein should a female eat", "query: 南瓜的家", # 元は 南瓜的家常做法 ]
>>> batch_dict["input_ids"] tensor([[ 0, 41, 1294, 12, 3642, 5045, 21308, 5608, 10, 117776, 73203, 2], [ 0, 41, 1294, 12, 6, 4617, 39613, 43, 1433, 2, 1, 1]]) >>> tokenizer.decode([0]) '<s>' >>> tokenizer.decode([2]) '</s>' >>> tokenizer.decode([1]) '<pad>' >>> batch_dict["attention_mask"] # paddingされたところが0です tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
脱線:Attention mask
今回Hugging Faceの用語集を知りました。
https://huggingface.co/docs/transformers/glossary#attention-mask
The attention mask is a binary tensor indicating the position of the padded indices so that the model does not attend to them.
意訳:attention maskは、パディングされたインデックスの位置を示す2値(0/1)のtensor。それによりモデルが注意を払わずにすむ
用語集にあった動画です
~attention_mask[..., None].bool()
って何?
評価順はこうでした
attention_mask[..., None]
attention_mask[..., None].bool()
~attention_mask[..., None].bool()
>>> attention_mask = batch_dict["attention_mask"] >>> attention_mask[..., None].size() torch.Size([2, 12, 1]) >>> attention_mask.unsqueeze(-1).size() torch.Size([2, 12, 1]) >>> attention_mask[..., None] tensor([[[1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [1]], [[1], [1], [1], [1], [1], [1], [1], [1], [1], [1], [0], [0]]])
続いて、同じサイズで要素がBoolからなるTensorを作ります。
最後に、TensorのBool値を入れ替えます。
>>> ~attention_mask[..., None].bool() tensor([[[False], [False], [False], [False], [False], [False], [False], [False], [False], [False], [False], [False]], [[False], [False], [False], [False], [False], [False], [False], [False], [False], [False], [ True], [ True]]])
つまり、attention maskを元に、paddingされたトークンをTrueで示すTensorができたわけです!
これをmaskとして使っていきます
トークンのembeddingsから平均算出
masked_fillにより、maskがTrueの要素はvalue=0.0で埋められます。
paddingのembeddingsはすべて0ベクトルにするわけです(ブロードキャストですね)
>>> is_padding_mask = ~attention_mask[..., None].bool() >>> last_hidden_states = outputs.last_hidden_state >>> last_hidden_states.masked_fill(is_padding_mask, 0.0) tensor([[[ 0.2526, -0.0612, -0.2972, ..., 0.1744, 0.0097, 0.1362], [ 0.1764, -0.3090, -0.2646, ..., 0.1445, 0.1458, 0.2989], [ 0.1566, -0.3100, -0.2594, ..., 0.1459, 0.1691, 0.2915], ..., [ 0.3416, -0.0517, -0.3631, ..., 0.2228, 0.2226, 0.4097], [ 0.2260, -0.1668, -0.2421, ..., -0.0132, 0.0760, 0.2289], [ 0.2526, -0.0612, -0.2972, ..., 0.1744, 0.0097, 0.1362]], [[ 0.2115, 0.0041, -0.2504, ..., 0.3300, 0.1579, -0.0122], [ 0.2307, -0.0033, -0.4641, ..., 0.4944, 0.3775, 0.1403], [ 0.2374, 0.0025, -0.4539, ..., 0.4880, 0.3873, 0.1412], ..., [ 0.2115, 0.0041, -0.2504, ..., 0.3300, 0.1579, -0.0123], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, 0.0000, ..., 0.0000, 0.0000, 0.0000]]], grad_fn=<MaskedFillBackward0>)
文のベクトル = トークンのembeddingsの合計 / トークンの合計
上記のTensorから文ごとに、トークンのembeddingsの(つまりdim=1で)合計を取ります4。
このときpaddingのトークンは0.0になっているので、合計に寄与しません。
>>> last_hidden = last_hidden_states.masked_fill(is_padding_mask, 0.0) >>> last_hidden.size() torch.Size([2, 12, 384]) >>> last_hidden.sum(dim=1).size() torch.Size([2, 384])
文ごとにトークンの数を出します。
ここでattention_maskが0となるpaddingトークンは寄与しません。
>>> attention_mask.sum(dim=1) tensor([12, 10]) >>> attention_mask.sum(dim=1)[..., None] tensor([[12], [10]])
つまり、(attention_maskが0である)paddingトークンを除いてトークンのembeddingの平均をとっているわけです。
まさに、average pool!
embeddingsの正規化
Usageでは処理が続きます。
embeddings = F.normalize(embeddings, p=2, dim=1)
https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
p=2, dim=1は引数のデフォルト値であり、
it uses the Euclidean norm over vectors along dimension 1 for normalization.
つまり、ベクトルの要素を2乗して足すと1になるように正規化されています5
>>> import torch.nn.functional as F >>> normalized = F.normalize(embeddings, p=2, dim=1) >>> normalized.norm(p=2, dim=1) tensor([1.0000, 1.0000]) >>> torch.sum(normalized * normalized, 1) tensor([1.0000, 1.0000])
勾配を除く
masked_fillでTensorを見たときにgrad_fn
がありました。
これは勾配計算が有効になっているということですね。
今回の利用シーンではテキストをembeddingsに変換できればよく、勾配は不要なので、average_pool
関数をtorch.no_grad
でデコレートしました。
https://pytorch.org/docs/stable/generated/torch.no_grad.html
[追記 (2023/12/11)] torch.inference_modeなるものもあるようで、次に試す際は使い分けを調べたいと思います
現時点の理解を反映したコード(Usageの再実装)
- Python 3.11.4
pip install 'transformers[torch]'
ライブラリのバージョン
accelerate==0.25.0 certifi==2023.11.17 charset-normalizer==3.3.2 filelock==3.13.1 fsspec==2023.12.1 huggingface-hub==0.19.4 idna==3.6 Jinja2==3.1.2 MarkupSafe==2.1.3 mpmath==1.3.0 networkx==3.2.1 numpy==1.26.2 packaging==23.2 psutil==5.9.6 PyYAML==6.0.1 regex==2023.10.3 requests==2.31.0 safetensors==0.4.1 sympy==1.12 tokenizers==0.15.0 torch==2.1.1 tqdm==4.66.1 transformers==4.35.2 typing_extensions==4.8.0 urllib3==2.1.0
終わりに
multilingual-e5-smallのUsageのaverage_pool
関数の実装の理解を深めました。
このモデルは文中のすべてのトークンのベクトルを返しており、attention maskが0のトークンを除いて平均して、文を表すベクトルを得ているんだ!
P.S. sentence-transformersの実装と一致!
以前ソースコードを見ましたが、
PoolingやNormalizeと実装としては同じだと思います。
以下のコードでもできるんですね〜(attention maskを数値として扱って、内積をとってる!)
# https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Pooling.py#L85-L86 input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
sentence-transformersのPoolingはaverage pool以外にもCLSやmaxを使ったpoolingを備えていますね。
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Pooling.py#L16-L20
またsentence-transformersではdetachして勾配を除いていました。
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/SentenceTransformer.py#L182
変更履歴
- 2023/12/11 修正1点・追記1点
- CLSトークンのベクトルを文を表すベクトルとするアプローチもあると理解しています↩
-
tokenizerの
__call__
にpadding=True
を渡しています。 https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.padding↩ - ドキュメントではまだ見つけられていないのですが、こちらがほしい説明でした ↩
- Tensorのsumメソッド https://pytorch.org/docs/stable/generated/torch.sum.html↩
- Tensorのnormメソッド https://pytorch.org/docs/stable/generated/torch.norm.html↩