はじめに
むん!(シャニアニ1幕見ました) nikkieです
以前sentence-transformersを使ってテキストをembeddingsに変換しました。
再度触ってみたところ、「わずか数行のコードで何をやっているんだろう?」と気になり、実装を追いかけました。
目次
sentence-transformersのコードは何をやっている?
今回気になったのはこちらの2行
model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = model.encode(sentences)
使い方はたったこれだけ! 非常に簡単に使えます。
ですが、この裏側が何をやっているのか、ブラックボックスを開いてみたいと思ったのです。
例えば、Hugging Faceで公開されているモデル all-MiniLM-L6-v2 をダウンロードするので、transformersをラップしていそうですよね。
本記事執筆時点で最新のv2.2.2のソースコードを読んでいきます。
SentenceTransformerの初期化
- PyTorchのnn.Sequentialを継承したクラス
model_name_or_path
引数を指定して呼び出した
今回のようにHugging Faceの公開モデルを指定した場合の処理内容は、大まかには以下のようになっていました。
- Hugging Faceからモデルをダウンロード
- ダウンロードされているmodules.jsonを読み込む(
_load_sbert_model
メソッド。後述) - 2で読み込んだmodules(nn.ModuleからなるOrderedDict)を渡して、親クラス(nn.Sequential)のイニシャライザを呼び出す
_load_sbert_modelメソッド
このメソッドで読み込むmodules.jsonの中身は以下のようになっています。
% jq '.' ~/.cache/torch/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2/modules.json [ { "idx": 0, "name": "0", "path": "", "type": "sentence_transformers.models.Transformer" }, { "idx": 1, "name": "1", "path": "1_Pooling", "type": "sentence_transformers.models.Pooling" }, { "idx": 2, "name": "2", "path": "2_Normalize", "type": "sentence_transformers.models.Normalize" } ]
3つのモジュールですね(後述しますが、いずれもnn.Moduleです)。
typeの文字列を元に、sentence_transformers.util.import_from_string
2を使ってクラスを取得。
クラスが持っているloadメソッドを使ってインスタンスを得ます。
# https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/SentenceTransformer.py#L839-L840 module_class = import_from_string(module_config['type']) module = module_class.load(os.path.join(model_path, module_config['path']))
3つのモジュールのloadメソッドを見てみましょう。
Transformerの読み込み
sentence-transformersの独自クラスです。
nn.Moduleを継承しており、Hugging Faceのtransformersとは別物です。
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Transformer.py#L8
loadメソッドでは読み込むconfigファイルを見つけ、それをもとにTransformerをインスタンス化します。
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Transformer.py#L127-L137
このファイルが使われます。
% jq '.' ~/.cache/torch/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2/sentence_bert_config.json { "max_seq_length": 256, "do_lower_case": false }
2つの引数max_seq_lengthとdo_lower_caseが指定されて、Transformerがインスタンス化されます。
インスタンス化の中でtransformersのAutoModelやAutoTokenizerのfrom_pretrained
メソッドが呼び出され、読み込んだモデルやトークナイザが属性に設定されています。
Poolingの読み込み
nn.Moduleを継承したクラスです。
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Pooling.py#L9
Performs pooling (max or mean) on the token embeddings.
loadでは、Pooling用のconfigファイルを元にインスタンス化します。
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Pooling.py#L115-L120
% jq '.' ~/.cache/torch/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2/1_Pooling/config.json { "word_embedding_dimension": 384, "pooling_mode_cls_token": false, "pooling_mode_mean_tokens": true, "pooling_mode_max_tokens": false, "pooling_mode_mean_sqrt_len_tokens": false }
Normalizeの読み込み
nn.Moduleを継承したクラスです。
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Normalize.py#L6
This layer normalizes embeddings to unit length
loadでは、(configを読み込まず常に)Normalizeをインスタンス化します。
class Normalize(nn.Module): # https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Normalize.py#L20-L22 @staticmethod def load(input_path): return Normalize()
encodeメソッド
SentenceTransformersをインスタンス化した後に呼び出す、embeddingを得るメソッドの実装を見ていきましょう。
フラグ引数の扱いのコードが最初と最後にあります。
肝は以下かなと思います。
- tokenizeメソッド呼び出し
- forwardメソッド呼び出し
tokenizeメソッド
# https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/SentenceTransformer.py#L315-L319 def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]): return self._first_module().tokenize(texts)
_first_module
メソッドを呼び出しています。
# https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/SentenceTransformer.py#L331-L333 def _first_module(self): return self._modules[next(iter(self._modules))]
これはmodules.jsonに定義された最初のレイヤーという理解です。
今回の場合だと、Transformerクラスのインスタンスですね。
Transformerクラスのtokenizeメソッド
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Transformer.py#L84-L114
loadしたときに、AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
が実行されています。
このトークナイザの__call__
メソッドを呼び出しています。
# https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Transformer.py#L113
output.update(self.tokenizer(*to_tokenize, ...))
forwardメソッド
親クラスのnn.Sequentialに定義されています。
The forward() method of Sequential accepts any input and forwards it to the first module it contains. It then “chains” outputs to inputs sequentially for each subsequent module, finally returning the output of the last module.
SentenceTransformerのforwardメソッドは、モジュール3つのforwardメソッドを順に呼び出します。
- Transformerのforwardメソッド
- Poolingのforwardメソッド
- Normalizeのforwardメソッド
これらのforwardが順に呼び出され、呼び出したencode側で後処理をした結果、NumPyのndarrayとしてembeddingsが得られます。
終わりに
ライブラリsentence-transformersで、SentenceTransformerクラスをインスタンス化し、encodeメソッドにテキストを渡すだけでembeddingsが得られるのを可能にしている実装について見てきました。
- SentenceTransformerクラスはPyTorchのnn.Sequentialを継承
- forwardは、各モジュール(nn.Module)のforwardを順に呼び出す
- Hugging Faceに設定ファイルを独自の命名規則で配置して、sentence-transformersが使える環境では再現性を持たせている
PyTorchのnn.Sequentialをうまく使っていると思いました!
JSONファイル一式とそれを読む実装を追加することでHugging Faceを通して、独自ライブラリでの読み込みもサポートした配布ができるのですね。