nikkie-ftnextの日記

イベントレポートや読書メモを発信

trlのSFTTrainerの実装を覗き、初期化時にtokenizerやdata collatorがどのように設定されるかを理解する

はじめに

フウカチャン😭1 nikkieです。

trlというライブラリを使ったLLMのファインチューンのチュートリアルに過去に取り組みました。
その中で、自然言語のドキュメントではどうも細かい点が明確に分かりづらく、実装を見るのが手っ取り早そうと感じ始めました。

そこで今回は、データセットのテキストがどのようにトークンID列に変換されるか(=エンコーディングされるか)に絞って見ていきます。

目次

Gemmaのファインチューン

Hugging Faceの記事の改良版を書いています。

何もしなくてもGemmaは偉人の言葉の続きを生成できるのですが、以下のように誰が言ったかを明確にした所定のフォーマットで続きを生成するようにファインチューンします。

Quote: Imagination is more important than knowledge. Knowledge is limited. Imagination encircles the world.

Author: Albert Einstein

trl.SFTTrainerで訓練する箇所はこんなコードです。

trainer = SFTTrainer(
    model=model,  # AutoModelForCausalLM.from_pretrained("google/gemma-2b", ...)
    train_dataset=dataset,  # load_dataset("Abirate/english_quotes", split="train")
    args=training_args,
    peft_config=lora_config,
    formatting_func=formatting_prompts_func,
)
trainer.train()

ここを見たときに詳細が把握できていないなと思った事項を挙げます。

  • modelしか渡していないが、tokenizerはどのように初期化され適用される?(改良した記事で宿題に挙げています)
  • data collator2を渡していないが、どうなっている?

SFTTrainerの実装を読んで、上記の事項への回答を得ました(※私は納得したという話で、もっと深堀る余地はあると思います)

SFTTrainerはどんなtokenizerを持つのか?

modelに対応したtokenizerを持ちます。
https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/sft_trainer.py#L275

tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)

今回の場合

>>> model.config._name_or_path
'google/gemma-2b'

SFTTrainerの初期化時にtokenizerを渡すこともできます。
https://huggingface.co/docs/trl/sft_trainer#trl.SFTTrainer

The tokenizer to use for training. If not specified, the tokenizer associated to the model will be used.

tokenizerを渡さないとき(デフォルト値Noneなのですが)、tokenizerを設定するコードもありました。
https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/sft_trainer.py#L274-L277

if tokenizer is None:
    tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path)
    if getattr(tokenizer, "pad_token", None) is None:
        tokenizer.pad_token = tokenizer.eos_token

なお、Gemmaのtokenizerは該当しません。

>>> getattr(tokenizer, "pad_token", None)
'<pad>'
>>> tokenizer.pad_token
'<pad>'
>>> tokenizer.eos_token
'<eos>'

pad_tokenにeos_tokenを使う理由は、data collatorの設定に関すると理解しました。
ref: https://huggingface.co/docs/transformers/tasks/language_modeling#preprocess

Use the end-of-sequence token as the padding token (略)

(もしかしてこの設定をしたtokenizerを外から渡したほうがいいのかな?)

SFTTrainerはどんなdata collatorを持つのか?

data collator向けのtokenizerの設定まで分かったので、次はdata collatorを見ていきます。

https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/sft_trainer.py#L348-L349

if data_collator is None:
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

data_collatorもSFTTrainerの初期化時に渡せます。
https://huggingface.co/docs/trl/sft_trainer#trl.SFTTrainer

The data collator to use for training.

渡していない場合は、transformers.DataCollatorForLanguageModelingが使われるということですね。
https://huggingface.co/docs/transformers/main/en/main_classes/data_collator#transformers.DataCollatorForLanguageModeling

mlm引数はデフォルト値がTrueです。

  • mlm=Trueの場合、ランダムマスクされます(masked language model)
  • mlm=Falseの場合、ランダムマスクはありません(causal language model)

今回はテキスト生成をしたい3ので、mlm=Falseになっているという理解です4

data_collatorを外から渡すシーンとして、生成した部分だけについてモデルを訓練したい状況がドキュメントに書かれています。
「Train on completions only」https://huggingface.co/docs/trl/sft_trainer#train-on-completions-only
trl.DataCollatorForCompletionOnlyLMを使います5

ドキュメントを見ると「### Question: \n ### Answer: 」という形式のテキストで、「 ### Answer:」より後の部分についてだけロスの計算に使いたいようです。
今回扱うGemmaのファインチューンはこのケースに当てはまらないので、SFTTrainerがデフォルトで初期化するDataCollatorForLanguageModelingを使っているわけですね。

SFTTrainerが持つtokenizerはどのようにデータセットに適用されるのか?

訓練中に目にしたログからSFTTrainer初期化時に、train_datasetformatting_funcが適用されるようです。
ここの実装を見ていきます。

self._prepare_dataset(train_dataset, ...)という呼び出し箇所があります。
https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/sft_trainer.py#L372-L384

_prepare_dataset()メソッド

https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/sft_trainer.py#L477

packing引数は指定していないのでデフォルト値のFalse6
そのため、以下の分岐となります。
https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/sft_trainer.py#L518-L527

if not packing:
    return self._prepare_non_packed_dataloader(
        tokenizer,
        dataset,
        dataset_text_field,
        max_seq_length,
        formatting_func,
        add_special_tokens,
        remove_unused_columns,
    )

_prepare_non_packed_dataloader()メソッド

https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/sft_trainer.py#L542

tokenizerが適用されるのはここです。
https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/sft_trainer.py#L587-L593

tokenized_dataset = dataset.map(
    tokenize,
    batched=True,
    remove_columns=dataset.column_names if remove_unused_columns else None,
    num_proc=self.dataset_num_proc,
    batch_size=self.dataset_batch_size,
)

tokenize()_prepare_non_packed_dataloader()のスコープで定義された関数です。
https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/sft_trainer.py#L556-L575
一部抜粋します

def tokenize(element):
    outputs = tokenizer(
        element[dataset_text_field] if not use_formatting_func else formatting_func(element),
        add_special_tokens=add_special_tokens,
        truncation=True,
        padding=False,
        max_length=max_seq_length,
        return_overflowing_tokens=False,
        return_length=False,
    )

    # 省略

    return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}

tokenizerの__call__()が呼ばれています。

  • 入力されるテキストはformatting_funcが適用されたテキストになる
  • truncationがTruemax_lengthで切られる
  • paddingはFalse(埋めて長さを揃えていない)

SFTTrainer初期化時に、train_datasetformatting_func込みでtokenizerが適用され、トークンIDからなるデータセットに変換されているのですね。

>>> trainer.train_dataset
Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 2508
})
>>> trainer.train_dataset[0]
{'input_ids': [2,
  14782,
  235292,
  1080,
  2448,
  5804,
  235289,
  4784,
  1354,
  603,
  3303,
  3443,
  1816,
  108,
  6823,
  235292,
  29231,
  72661,
  1],
 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

これらの処理を行ったうえで、trl.SFTTrainerの親クラスtransformers.Trainerの初期化メソッドが呼ばれています。
https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/sft_trainer.py#L413-L425
tokenizerdata_collatorが渡されています。

終わりに

数行のコードで訓練できるtrl.SFTTrainerについて、tokenizerやdata collatorがどのように設定されているのか見てきました。
SFTTrainerを初期化したときに以下が行われています。

  • tokenizer
    • 外から渡さない場合、モデルに対応するtokenizerが初期化される
    • データセットにはtokenizerが適用される(トークンID列に変換される)
  • data collator
    • 外から渡さない場合、transformers.DataCollatorForLanguageModelingmlm引数はFalse
        • data collatorに渡されるtokenizerは、eos_tokenがpaddingに使われるように設定された保証がないような気がします...
    • 生成部分だけ訓練に使いたい場合、trl.DataCollatorForCompletionOnlyLMを外から渡す

ドキュメントでは自然言語の曖昧さからよく分からずにいたところは一通り見られました。
SFTTrainerはeasyではありますが、「CausalLMでtokenizerはeos_tokenをpaddingに使う」ということを知って、easyなインターフェースに任せず外から渡した方がいいかもなと思い始めています(easyだからといって使うとユーザが間違えてしまわないかな?)

今回のコードの最後の状態はこちらです。
SFTConfig7を使うように小さなアップデートをしました。


  1. アクアトープが見られます!
  2. パディングやランダムマスキングして、テキストをbatchにまとめてくれる存在です。
  3. Causal language modelingの冒頭より。「There are two types of language modeling, causal and masked. (略) Causal language models are frequently used for text generation.」「Causal language modeling predicts the next token in a sequence of tokens,
  4. tokenizerのpadding tokenについて書いてあった箇所には「set mlm=False」と続いています。ref: https://huggingface.co/docs/transformers/tasks/language_modeling#preprocess
  5. trl.DataCollatorForCompletionOnlyLMtransformers.DataCollatorForLanguageModelingを継承していました。ref: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/utils.py#L115
  6. packingSFTTrainer初期化時の引数にありますが、SFTConfigで指定するのが推奨のようでした
  7. transformers.TrainingArgumentsを継承したクラスとして導入されていました。ref: https://github.com/huggingface/trl/blob/v0.9.6/trl/trainer/sft_config.py#L21 パラメタオブジェクトってやつですね