nikkie-ftnextの日記

イベントレポートや読書メモを発信

transformersのdata collator、何するものぞ?

はじめに

ピーーーーーーー。1nikkieです。

transformersを使ったLLMのファインチューニング(SFT)のコードを最近眺めているのですが、data collatorという概念がよく分かっていません。
1日1エントリを使って調べてみます。

目次

transformersのTrainerに渡すdata_collator

かつて写経したコードに以下が出てきます

trainer = Trainer(
    # 省略
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)

data_collator引数に渡っているものが何をするものなのか、つかめていませんでした。

資料をあたっていく

Hugging Face NLP Course

https://huggingface.co/docs/transformers/v4.41.0/en/main_classes/data_collator より

data collators may apply some processing (like padding).

最近読み始めたNLP Courseに解説がありました。
https://huggingface.co/learn/nlp-course/chapter3/2#dynamic-padding

いくつかのテキストをbatchにまとめて扱う中でcollate functionという概念が出てきます。

手元のいくつかのテキストは長さが揃っていません。
ただbatchとしては長さを揃えたいです(例:8件で長さ100)

長さが足りないテキストはどうするかというと、パディングします。

we have to define a collate function that will apply the correct amount of padding to the items of the dataset we want to batch together.

このcollate function用のクラスの1つがtransformers.DataCollatorWithPaddingです。

✍️data collatorは、いくつかのテキストからパディングしたbatchを作れる

https://huggingface.co/learn/nlp-course/chapter3/4?fw=pt では、PyTorchのDataLoaderに渡されています

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

train_dataloader = DataLoader(
    tokenized_datasets["train"], shuffle=True, batch_size=8, collate_fn=data_collator
)

『大規模言語モデル入門』

5.2.6より

本書では、このようなミニバッチ構築の処理を行う関数のことをcollate関数(collate function; data collator)と呼びます。
collateには、バラバラの情報を収集して適切な形にまとめるという意味があります。(Kindle版 p.188)

図5.8が分かりやすいです。

簡単な例を動かす

pip install 'transformers[torch]'

  • transformers 4.41.0
  • torch 2.3.0

テストケースにならいます。
https://github.com/huggingface/transformers/blob/v4.41.0/tests/trainer/test_data_collator.py#L111

% cat vocab.txt
[UNK]
[CLS]
[SEP]
[PAD]
[MASK]
>>> from transformers import BertTokenizer, DataCollatorWithPadding
>>> tokenizer = BertTokenizer("vocab.txt")
>>> data_collator = DataCollatorWithPadding(tokenizer)
>>> batch = data_collator([{"input_ids": [0, 1, 2]}, {"input_ids": [0, 1, 2, 3, 4, 5]}])
>>> batch
{'input_ids': tensor([[0, 1, 2, 3, 3, 3],
        [0, 1, 2, 3, 4, 5]]), 'attention_mask': tensor([[1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1]])}
>>> tokenizer.pad_token_id
3

3つのトークンからなる系列と6つのトークンからなる系列の2つをDataCollatorWithPaddingに入力したところ、2×6のTensor(=batch)が得られました。
このbatchは系列の長さが6に揃っています。
パティングを埋める値には、tokenizerの持つパティング用トークンのIDが使われています。

data collatorには種類がある

🤗 NLP Courseの例ではモデルとdata collatorは以下の組み合わせです。

  • AutoModelForSequenceClassification
  • DataCollatorWithPadding

写経した日経Linux記事では

  • AutoModelForCausalLM
    • 『大規模言語モデル入門』1.2よりCausalLMはテキスト生成に対応
  • DataCollatorForLanguageModeling

また、過去にBERTを事前訓練した記事では

  • RobertaForMaskedLM
  • DataCollatorForLanguageModeling

DataCollatorForLanguageModelingですが、https://huggingface.co/docs/transformers/v4.41.0/en/main_classes/data_collator の引用には続きがあり、

Some of them (like DataCollatorForLanguageModeling) also apply some random data augmentation (like random masking) on the formed batch.

✍️data collatorには、パディングしたbatchを作るだけでなく、パティングしてランダムマスキングもしたbatchを作れるものもある

訓練したいモデル(もっというと、解きたい課題)に応じて、data collatorを選ぶことになりそうです。

終わりに

data collatorを調べました。

  • data collatorは、長さがバラバラのテキストを適切な形のbatchにまとめる
  • 適切な形とは、パディングや、課題によってはランダムマスキング

data collatorはbatchと関わっていたのですね!

予告:次に知りたいのは、CausalLM(テキスト生成)の訓練におけるロスってどうやって計算しているんだろうということです。
data collatorはbatchをランダムマスキングもしているようなので、マスクしたところのロス、なんでしょうか?