nikkie-ftnextの日記

イベントレポートや読書メモを発信

transformers.AutoModelForCausalLM の from_pretrained() の attn_implementation 引数に泣かされています(Gemmaを例に)

Today(※最近) I Learned です。
といっても全然うまくいっていません

目次

AutoModelForCausalLM.from_pretrainedの引数

手元にはGemmaをファインチューンするnotebookがあります。
名言の続きを、指定したフォーマットで生成できるようにファインチューンしています。

Google ColabのT4 GPU(無料枠)を使っています。

  • Python 3.10
  • torch 2.4.0+cu121
  • transformers 4.44.2
  • trl 0.10.1
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b",
    quantization_config=bnb_config,
    device_map={"": 0},
)

from_pretrainedには**kwargsがあります。
https://huggingface.co/docs/transformers/ja/model_doc/auto#transformers.AutoModelForCausalLM.from_pretrained
ここにはモデルの__init__の引数も渡せます。

attn_implementation引数のデフォルト値

from_pretrained**kwargsにはattn_implementationも渡せます。
ドキュメントはfrom_configメソッドのところに見つかりました。 https://huggingface.co/docs/transformers/ja/model_doc/auto#transformers.AutoModelForCausalLM.from_config

Can be any of "eager" (manual implementation of the attention), "sdpa" (using F.scaled_dot_product_attention), or "flash_attention_2" (using Dao-AILab/flash-attention).

By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual "eager" implementation.

torch 2.1.1以上の環境なので、"sdpa"となっています。

>>> model.config._attn_implementation 
'sdpa'

attn_implementation="eager"(訓練できる)

eagerと指定したときは引き続き訓練できました

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b",
    quantization_config=bnb_config,
    device_map={"": 0},
    attn_implementation="eager",
)
>>> model.config._attn_implementation
'eager'

以下の実装を使っているという理解です。
https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L259

attn_implementationによってAttentionの実装を切り替えているのはおそらくこちらの箇所
https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L557-L561

GEMMA_ATTENTION_CLASSES = {
    "eager": GemmaAttention,
    "flash_attention_2": GemmaFlashAttention2,
    "sdpa": GemmaSdpaAttention,
}

attn_implementation="sdpa"かつ、Flash Attention 1(No available kernel)

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b",
    quantization_config=bnb_config,
    device_map={"": 0},
    attn_implementation="sdpa",
)
>>> model.config._attn_implementation
'sdpa'

trlのドキュメントより
https://huggingface.co/docs/trl/en/sft_trainer#using-flash-attention-1

For Flash Attention 1 you can use the BetterTransformer API and force-dispatch the API to use Flash Attention kernel.

ドキュメントに沿って、まずoptimumをインストール。
これは成功します

trainer.train()with文に加えるのですが、実行時にここでエラー。
なおPyTorchのドキュメントを参照し、より新しい書き方としています
https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html

with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
    trainer.train()

RuntimeError: No available kernel. Aborting execution.

https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L540 で送出されています。

解決できるヒントがあるかもしれないのでwarningを控えておきます。

UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:718.)
UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:495.)
UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:720.)    
UserWarning: Flash attention only supports gpu architectures in the range [sm80, sm90]. Attempting to run on a sm 7.5 gpu. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:201.)
UserWarning: CuDNN attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:722.)
UserWarning: The CuDNN backend needs to be enabled by setting the enviornment variable`TORCH_CUDNN_SDPA_ENABLED=1` (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:496.)

warningのうち Flash Attention 1 に関係ありそうなものを有効にしたら、「No available kernel」は解消するかも(disableなので、no availableなのかも。宿題事項)

optimumはインストールするだけでいいのかはよくわかっていません

attn_implementation="flash_attention_2"(宿題)

trlのドキュメントはこちら
https://huggingface.co/docs/trl/en/sft_trainer#using-flash-attention-2

flash-attnをインストールします(Installation and features参照)。

pip install ninja  # 1.11.1.1
pip install flash-attn --no-build-isolation  # 2.6.3

インストールは成功。

モデルも読み込めます。

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b",
    quantization_config=bnb_config,
    device_map={"": 0},
    attn_implementation="flash_attention_2",
)

ですが、model.generate()を呼ぶと

RuntimeError: FlashAttention only supports Ampere GPUs or newer.

ドキュメントにも

FlashAttention-2 currently supports:
Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100).

とあるので、Colab課金などでA100を用意してリトライですね

Flash Attentionの実装が重複してる感

ここを見ればOKという資料は見つけられておらず、ずっとよく分からないなあと感じながら手を動かしてました。

flash-attnは3つのバージョンのFlash Attentionをサポート

trlのドキュメントで、Flash Attention 1はoptimumを案内していた理由はよく分かりません。
flash-attnで Flash Attention 1 は使えないのでしょうか?

そして、PyTorch。
scaled dot product attentionをPyTorchでも実装しているそうです。

  • Flash Attention(バージョン 1と理解
  • Memory-Efficient Attention(xformers)
  • PyTorchでC++実装

分からないのは、attn_implementationを指定しないとき、デフォルトの"sdpa"ですが、これをwith文に入れたときと入れないときとで何が違うんでしょう?
それぞれの場合で、Attentionの実装はPyTorchのどれが使われている?

積ん読

自動的にFlash Attentionを使うような構造をしているが、どんな場合でも使用しているわけではない