nikkie-ftnextの日記

イベントレポートや読書メモを発信

sentence-transformersで、SentenceTransformerクラスをインスタンス化してencodeメソッドを呼び出すだけでembeddingsが得られるのは、どんな仕組みによるんだろう? ソースコードリーディングメモ

はじめに

むん!(シャニアニ1幕見ました) nikkieです

以前sentence-transformersを使ってテキストをembeddingsに変換しました。

再度触ってみたところ、「わずか数行のコードで何をやっているんだろう?」と気になり、実装を追いかけました。

目次

sentence-transformersのコードは何をやっている?

今回気になったのはこちらの2行

model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = model.encode(sentences)

使い方はたったこれだけ! 非常に簡単に使えます。
ですが、この裏側が何をやっているのか、ブラックボックスを開いてみたいと思ったのです。
例えば、Hugging Faceで公開されているモデル all-MiniLM-L6-v2 をダウンロードするので、transformersをラップしていそうですよね。

本記事執筆時点で最新のv2.2.2のソースコードを読んでいきます。

SentenceTransformerの初期化

https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/SentenceTransformer.py#L33

  • PyTorchのnn.Sequentialを継承したクラス
  • model_name_or_path引数を指定して呼び出した

今回のようにHugging Faceの公開モデルを指定した場合の処理内容は、大まかには以下のようになっていました。

  1. Hugging Faceからモデルをダウンロード
  2. ダウンロードされているmodules.jsonを読み込む(_load_sbert_modelメソッド。後述)
  3. 2で読み込んだmodules(nn.ModuleからなるOrderedDict)を渡して、親クラス(nn.Sequential)のイニシャライザを呼び出す

_load_sbert_modelメソッド

https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/SentenceTransformer.py#L810

このメソッドで読み込むmodules.jsonの中身は以下のようになっています。

% jq '.' ~/.cache/torch/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2/modules.json
[
  {
    "idx": 0,
    "name": "0",
    "path": "",
    "type": "sentence_transformers.models.Transformer"
  },
  {
    "idx": 1,
    "name": "1",
    "path": "1_Pooling",
    "type": "sentence_transformers.models.Pooling"
  },
  {
    "idx": 2,
    "name": "2",
    "path": "2_Normalize",
    "type": "sentence_transformers.models.Normalize"
  }
]

3つのモジュールですね(後述しますが、いずれもnn.Moduleです)。

typeの文字列を元に、sentence_transformers.util.import_from_string2を使ってクラスを取得。
クラスが持っているloadメソッドを使ってインスタンスを得ます。

# https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/SentenceTransformer.py#L839-L840
module_class = import_from_string(module_config['type'])
module = module_class.load(os.path.join(model_path, module_config['path']))

3つのモジュールのloadメソッドを見てみましょう。

Transformerの読み込み

sentence-transformersの独自クラスです。
nn.Moduleを継承しており、Hugging Faceのtransformersとは別物です。
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Transformer.py#L8

loadメソッドでは読み込むconfigファイルを見つけ、それをもとにTransformerをインスタンス化します。
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Transformer.py#L127-L137

このファイルが使われます。

% jq '.' ~/.cache/torch/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2/sentence_bert_config.json
{
  "max_seq_length": 256,
  "do_lower_case": false
}

2つの引数max_seq_lengthとdo_lower_caseが指定されて、Transformerがインスタンス化されます。
インスタンス化の中でtransformersのAutoModelやAutoTokenizerのfrom_pretrainedメソッドが呼び出され、読み込んだモデルやトークナイザが属性に設定されています。

Poolingの読み込み

nn.Moduleを継承したクラスです。
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Pooling.py#L9

Performs pooling (max or mean) on the token embeddings.

loadでは、Pooling用のconfigファイルを元にインスタンス化します。
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Pooling.py#L115-L120

% jq '.' ~/.cache/torch/sentence_transformers/sentence-transformers_all-MiniLM-L6-v2/1_Pooling/config.json
{
  "word_embedding_dimension": 384,
  "pooling_mode_cls_token": false,
  "pooling_mode_mean_tokens": true,
  "pooling_mode_max_tokens": false,
  "pooling_mode_mean_sqrt_len_tokens": false
}

Normalizeの読み込み

nn.Moduleを継承したクラスです。
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Normalize.py#L6

This layer normalizes embeddings to unit length

loadでは、(configを読み込まず常に)Normalizeをインスタンス化します。

class Normalize(nn.Module):
    # https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Normalize.py#L20-L22
    @staticmethod
    def load(input_path):
        return Normalize()

encodeメソッド

SentenceTransformersをインスタンス化した後に呼び出す、embeddingを得るメソッドの実装を見ていきましょう。

https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/SentenceTransformer.py#L111

フラグ引数の扱いのコードが最初と最後にあります。
肝は以下かなと思います。

  1. tokenizeメソッド呼び出し
  2. forwardメソッド呼び出し

tokenizeメソッド

# https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/SentenceTransformer.py#L315-L319
def tokenize(self, texts: Union[List[str], List[Dict], List[Tuple[str, str]]]):
    return self._first_module().tokenize(texts)

_first_moduleメソッドを呼び出しています。

# https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/SentenceTransformer.py#L331-L333
def _first_module(self):
    return self._modules[next(iter(self._modules))]

これはmodules.jsonに定義された最初のレイヤーという理解です。
今回の場合だと、Transformerクラスのインスタンスですね。

Transformerクラスのtokenizeメソッド
https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Transformer.py#L84-L114

loadしたときに、AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")が実行されています。
このトークナイザの__call__メソッドを呼び出しています。

# https://github.com/UKPLab/sentence-transformers/blob/v2.2.2/sentence_transformers/models/Transformer.py#L113
output.update(self.tokenizer(*to_tokenize, ...))

forwardメソッド

親クラスのnn.Sequentialに定義されています。

The forward() method of Sequential accepts any input and forwards it to the first module it contains. It then “chains” outputs to inputs sequentially for each subsequent module, finally returning the output of the last module.

SentenceTransformerのforwardメソッドは、モジュール3つのforwardメソッドを順に呼び出します。

  1. Transformerのforwardメソッド
  2. Poolingのforwardメソッド
  3. Normalizeのforwardメソッド

これらのforwardが順に呼び出され、呼び出したencode側で後処理をした結果、NumPyのndarrayとしてembeddingsが得られます。

終わりに

ライブラリsentence-transformersで、SentenceTransformerクラスをインスタンス化し、encodeメソッドにテキストを渡すだけでembeddingsが得られるのを可能にしている実装について見てきました。

  • SentenceTransformerクラスはPyTorchのnn.Sequentialを継承
    • forwardは、各モジュール(nn.Module)のforwardを順に呼び出す
  • Hugging Faceに設定ファイルを独自の命名規則で配置して、sentence-transformersが使える環境では再現性を持たせている

PyTorchのnn.Sequentialをうまく使っていると思いました!
JSONファイル一式とそれを読む実装を追加することでHugging Faceを通して、独自ライブラリでの読み込みもサポートした配布ができるのですね。