nikkie-ftnextの日記

イベントレポートや読書メモを発信

multilingual-e5のUsageのaverage poolの実装を理解する 〜attention maskが0のトークンを除いて平均しているんだ!〜

はじめに

プラチナランカー!!🙌 nikkieです🟧🟧🟧🟧🟧

transformers・PyTorchの組合せで、文のembeddingsを得るコードで理解したいことがありました。
理解を深める目的でこの記事でアウトプットします。

目次

E5

文のembeddings(埋め込み表現)を得る方法はいくつもあります(特に具体的なモデルがいくつも!)

  • 外部のAPIを使う
    • 例:OpenAIのada
  • テキストを入力するとembeddingsを出力するモデルをローカル環境で動かす
    • sentence-transformersで試しました

後者の方法で注目したモデルがE5。
embeddingsの性能評価でadaに匹敵するということで注目しました。
なんでも「EmbEddings from bidirEctional Encoder rEpresentations」でE5なんだとか(元論文より)。

Hugging Face Hubで公開されており、動かすためのサンプルコード(Usage)もあります。
今回は多言語に対応したE5(multilingual-e5)を見ていきます

参考文献

コネヒトさんのブログで印象に残りました。

ベンチマークにおいて、multilingual-e5 の性能がtext-embedding-ada-002 と大差ないことが報告されている。

Hironsanのエントリで完全理解しました。
[修正 (2023/12/11)] リンクを誤っていたので修正(当初のリンク

元論文

リポジトリhttps://github.com/microsoft/unilm/tree/master/e5
マイクロソフトなんですね)

average_pool関数、何をやってるんだ...?

# https://huggingface.co/intfloat/multilingual-e5-small の Usage より
def average_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
  • attention_mask[..., None]??
  • ~(たしか否定)
  • 1つのtorch.Tensorを返しているっぽいけど、これでaverage poolになってる?

呼び出し方から見ていきましょう

average_pool関数の呼び出し

embeddings = average_pool(
    outputs.last_hidden_state, batch_dict["attention_mask"]
)

outputsはモデル(multilingual-e5-small)の出力したTensorです。

>>> outputs.last_hidden_state.size()
torch.Size([2, 12, 384])

2文あり、どちらも12トークンに分割され、トークンが384次元のベクトルで表されているという理解です。
このトークンを平均して、文を表すベクトルを得たいわけですね1

もう1つのbatch_dictはトークナイザの出力(BatchEncoding)です。

>>> batch_dict["attention_mask"].size()
torch.Size([2, 12])

Usageで与えている文章から意図的に短くして、2文目をpaddingされるようにしています2

input_texts = [
    "query: how much protein should a female eat",
    "query: 南瓜的家",  # 元は 南瓜的家常做法
]
>>> batch_dict["input_ids"]
tensor([[     0,     41,   1294,     12,   3642,   5045,  21308,   5608,     10,
         117776,  73203,      2],
        [     0,     41,   1294,     12,      6,   4617,  39613,     43,   1433,
              2,      1,      1]])
>>> tokenizer.decode([0])
'<s>'
>>> tokenizer.decode([2])
'</s>'
>>> tokenizer.decode([1])
'<pad>'
>>> batch_dict["attention_mask"]  # paddingされたところが0です
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])

脱線:Attention mask

今回Hugging Faceの用語集を知りました。
https://huggingface.co/docs/transformers/glossary#attention-mask

The attention mask is a binary tensor indicating the position of the padded indices so that the model does not attend to them.

意訳:attention maskは、パディングされたインデックスの位置を示す2値(0/1)のtensor。それによりモデルが注意を払わずにすむ

用語集にあった動画です

~attention_mask[..., None].bool()って何?

評価順はこうでした

  1. attention_mask[..., None]
  2. attention_mask[..., None].bool()
  3. ~attention_mask[..., None].bool()

1ではunsqueezeしています3

>>> attention_mask = batch_dict["attention_mask"]
>>> attention_mask[..., None].size()
torch.Size([2, 12, 1])
>>> attention_mask.unsqueeze(-1).size()
torch.Size([2, 12, 1])

>>> attention_mask[..., None]
tensor([[[1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1]],

        [[1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [1],
         [0],
         [0]]])

続いて、同じサイズで要素がBoolからなるTensorを作ります。

最後に、TensorのBool値を入れ替えます。

>>> ~attention_mask[..., None].bool()
tensor([[[False],
         [False],
         [False],
         [False],
         [False],
         [False],
         [False],
         [False],
         [False],
         [False],
         [False],
         [False]],

        [[False],
         [False],
         [False],
         [False],
         [False],
         [False],
         [False],
         [False],
         [False],
         [False],
         [ True],
         [ True]]])

つまり、attention maskを元に、paddingされたトークンをTrueで示すTensorができたわけです!
これをmaskとして使っていきます

トークンのembeddingsから平均算出

masked_fillにより、maskがTrueの要素はvalue=0.0で埋められます。
paddingのembeddingsはすべて0ベクトルにするわけです(ブロードキャストですね)

>>> is_padding_mask = ~attention_mask[..., None].bool()
>>> last_hidden_states = outputs.last_hidden_state
>>> last_hidden_states.masked_fill(is_padding_mask, 0.0)
tensor([[[ 0.2526, -0.0612, -0.2972,  ...,  0.1744,  0.0097,  0.1362],
         [ 0.1764, -0.3090, -0.2646,  ...,  0.1445,  0.1458,  0.2989],
         [ 0.1566, -0.3100, -0.2594,  ...,  0.1459,  0.1691,  0.2915],
         ...,
         [ 0.3416, -0.0517, -0.3631,  ...,  0.2228,  0.2226,  0.4097],
         [ 0.2260, -0.1668, -0.2421,  ..., -0.0132,  0.0760,  0.2289],
         [ 0.2526, -0.0612, -0.2972,  ...,  0.1744,  0.0097,  0.1362]],

        [[ 0.2115,  0.0041, -0.2504,  ...,  0.3300,  0.1579, -0.0122],
         [ 0.2307, -0.0033, -0.4641,  ...,  0.4944,  0.3775,  0.1403],
         [ 0.2374,  0.0025, -0.4539,  ...,  0.4880,  0.3873,  0.1412],
         ...,
         [ 0.2115,  0.0041, -0.2504,  ...,  0.3300,  0.1579, -0.0123],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<MaskedFillBackward0>)

文のベクトル = トークンのembeddingsの合計 / トークンの合計

上記のTensorから文ごとに、トークンのembeddingsの(つまりdim=1で)合計を取ります4
このときpaddingのトークンは0.0になっているので、合計に寄与しません。

>>> last_hidden = last_hidden_states.masked_fill(is_padding_mask, 0.0)
>>> last_hidden.size()
torch.Size([2, 12, 384])
>>> last_hidden.sum(dim=1).size()
torch.Size([2, 384])

文ごとにトークンの数を出します。
ここでattention_maskが0となるpaddingトークンは寄与しません。

>>> attention_mask.sum(dim=1)
tensor([12, 10])
>>> attention_mask.sum(dim=1)[..., None]
tensor([[12],
        [10]])

つまり、(attention_maskが0である)paddingトークンを除いてトークンのembeddingの平均をとっているわけです。
まさに、average pool!

embeddingsの正規化

Usageでは処理が続きます。

embeddings = F.normalize(embeddings, p=2, dim=1)

https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html

p=2, dim=1は引数のデフォルト値であり、

it uses the Euclidean norm over vectors along dimension 1 for normalization.

つまり、ベクトルの要素を2乗して足すと1になるように正規化されています5

>>> import torch.nn.functional as F
>>> normalized = F.normalize(embeddings, p=2, dim=1)
>>> normalized.norm(p=2, dim=1)
tensor([1.0000, 1.0000])
>>> torch.sum(normalized * normalized, 1)
tensor([1.0000, 1.0000])

勾配を除く

masked_fillでTensorを見たときにgrad_fnがありました。
これは勾配計算が有効になっているということですね。

今回の利用シーンではテキストをembeddingsに変換できればよく、勾配は不要なので、average_pool関数をtorch.no_gradデコレートしました。
https://pytorch.org/docs/stable/generated/torch.no_grad.html
[追記 (2023/12/11)] torch.inference_modeなるものもあるようで、次に試す際は使い分けを調べたいと思います

現時点の理解を反映したコード(Usageの再実装)

  • Python 3.11.4
  • pip install 'transformers[torch]'

ライブラリのバージョン

accelerate==0.25.0
certifi==2023.11.17
charset-normalizer==3.3.2
filelock==3.13.1
fsspec==2023.12.1
huggingface-hub==0.19.4
idna==3.6
Jinja2==3.1.2
MarkupSafe==2.1.3
mpmath==1.3.0
networkx==3.2.1
numpy==1.26.2
packaging==23.2
psutil==5.9.6
PyYAML==6.0.1
regex==2023.10.3
requests==2.31.0
safetensors==0.4.1
sympy==1.12
tokenizers==0.15.0
torch==2.1.1
tqdm==4.66.1
transformers==4.35.2
typing_extensions==4.8.0
urllib3==2.1.0

終わりに

multilingual-e5-smallのUsageのaverage_pool関数の実装の理解を深めました。
このモデルは文中のすべてのトークンのベクトルを返しており、attention maskが0のトークンを除いて平均して、文を表すベクトルを得ているんだ!

P.S. sentence-transformersの実装と一致!

以前ソースコードを見ましたが、

PoolingやNormalizeと実装としては同じだと思います。
以下のコードでもできるんですね〜(attention maskを数値として扱って、内積をとってる!)

# https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Pooling.py#L85-L86
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)

sentence-transformersのPoolingはaverage pool以外にもCLSやmaxを使ったpoolingを備えていますね。
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Pooling.py#L16-L20

またsentence-transformersではdetachして勾配を除いていました。
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/SentenceTransformer.py#L182

変更履歴

  • 2023/12/11 修正1点・追記1点

  1. CLSトークンのベクトルを文を表すベクトルとするアプローチもあると理解しています
  2. tokenizerの__call__padding=Trueを渡しています。 https://huggingface.co/docs/transformers/main_classes/tokenizer#transformers.PreTrainedTokenizer.__call__.padding
  3. ドキュメントではまだ見つけられていないのですが、こちらがほしい説明でした
  4. Tensorのsumメソッド https://pytorch.org/docs/stable/generated/torch.sum.html
  5. Tensorのnormメソッド https://pytorch.org/docs/stable/generated/torch.norm.html