Today(※最近) I Learned です。
といっても全然うまくいっていません
目次
- 目次
- AutoModelForCausalLM.from_pretrainedの引数
- attn_implementation引数のデフォルト値
- attn_implementation="eager"(訓練できる)
- attn_implementation="sdpa"かつ、Flash Attention 1(No available kernel)
- attn_implementation="flash_attention_2"(宿題)
- Flash Attentionの実装が重複してる感
- 積ん読
AutoModelForCausalLM.from_pretrainedの引数
手元にはGemmaをファインチューンするnotebookがあります。
名言の続きを、指定したフォーマットで生成できるようにファインチューンしています。
Google ColabのT4 GPU(無料枠)を使っています。
- Python 3.10
- torch 2.4.0+cu121
- transformers 4.44.2
- trl 0.10.1
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
quantization_config=bnb_config,
device_map={"": 0},
)
from_pretrainedには**kwargsがあります。
https://huggingface.co/docs/transformers/ja/model_doc/auto#transformers.AutoModelForCausalLM.from_pretrained
ここにはモデルの__init__の引数も渡せます。
attn_implementation引数のデフォルト値
from_pretrainedの**kwargsにはattn_implementationも渡せます。
ドキュメントはfrom_configメソッドのところに見つかりました。
https://huggingface.co/docs/transformers/ja/model_doc/auto#transformers.AutoModelForCausalLM.from_config
Can be any of "eager" (manual implementation of the attention), "sdpa" (using F.scaled_dot_product_attention), or "flash_attention_2" (using Dao-AILab/flash-attention).
By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual "eager" implementation.
torch 2.1.1以上の環境なので、"sdpa"となっています。
>>> model.config._attn_implementation
'sdpa'
attn_implementation="eager"(訓練できる)
eagerと指定したときは引き続き訓練できました
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
quantization_config=bnb_config,
device_map={"": 0},
attn_implementation="eager",
)
>>> model.config._attn_implementation
'eager'
以下の実装を使っているという理解です。
https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L259
attn_implementationによってAttentionの実装を切り替えているのはおそらくこちらの箇所
https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/gemma/modeling_gemma.py#L557-L561
GEMMA_ATTENTION_CLASSES = {
"eager": GemmaAttention,
"flash_attention_2": GemmaFlashAttention2,
"sdpa": GemmaSdpaAttention,
}
attn_implementation="sdpa"かつ、Flash Attention 1(No available kernel)
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
quantization_config=bnb_config,
device_map={"": 0},
attn_implementation="sdpa",
)
>>> model.config._attn_implementation
'sdpa'
trlのドキュメントより
https://huggingface.co/docs/trl/en/sft_trainer#using-flash-attention-1
For Flash Attention 1 you can use the BetterTransformer API and force-dispatch the API to use Flash Attention kernel.
ドキュメントに沿って、まずoptimumをインストール。
これは成功します
trainer.train()をwith文に加えるのですが、実行時にここでエラー。
なおPyTorchのドキュメントを参照し、より新しい書き方としています
https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
trainer.train()
RuntimeError: No available kernel. Aborting execution.
解決できるヒントがあるかもしれないのでwarningを控えておきます。
UserWarning: Memory efficient kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:718.) UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at ../aten/src/ATen/native/transformers/sdp_utils_cpp.h:495.) UserWarning: Flash attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:720.) UserWarning: Flash attention only supports gpu architectures in the range [sm80, sm90]. Attempting to run on a sm 7.5 gpu. (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:201.) UserWarning: CuDNN attention kernel not used because: (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:722.) UserWarning: The CuDNN backend needs to be enabled by setting the enviornment variable`TORCH_CUDNN_SDPA_ENABLED=1` (Triggered internally at ../aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:496.)
warningのうち Flash Attention 1 に関係ありそうなものを有効にしたら、「No available kernel」は解消するかも(disableなので、no availableなのかも。宿題事項)
optimumはインストールするだけでいいのかはよくわかっていません
attn_implementation="flash_attention_2"(宿題)
trlのドキュメントはこちら
https://huggingface.co/docs/trl/en/sft_trainer#using-flash-attention-2
flash-attnをインストールします(Installation and features参照)。
pip install ninja # 1.11.1.1 pip install flash-attn --no-build-isolation # 2.6.3
インストールは成功。
モデルも読み込めます。
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
quantization_config=bnb_config,
device_map={"": 0},
attn_implementation="flash_attention_2",
)
ですが、model.generate()を呼ぶと
RuntimeError: FlashAttention only supports Ampere GPUs or newer.
ドキュメントにも
FlashAttention-2 currently supports:
Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100).
とあるので、Colab課金などでA100を用意してリトライですね
Flash Attentionの実装が重複してる感
ここを見ればOKという資料は見つけられておらず、ずっとよく分からないなあと感じながら手を動かしてました。
flash-attnは3つのバージョンのFlash Attentionをサポート
trlのドキュメントで、Flash Attention 1はoptimumを案内していた理由はよく分かりません。
flash-attnで Flash Attention 1 は使えないのでしょうか?
そして、PyTorch。
scaled dot product attentionをPyTorchでも実装しているそうです。
分からないのは、attn_implementationを指定しないとき、デフォルトの"sdpa"ですが、これをwith文に入れたときと入れないときとで何が違うんでしょう?
それぞれの場合で、Attentionの実装はPyTorchのどれが使われている?
積ん読
自動的にFlash Attentionを使うような構造をしているが、どんな場合でも使用しているわけではない