Microsoft에서 나온 논문인 LoRA를 오늘 리뷰해 봅니다. LoRA는 GPT와 같은 Large Language Models(LLM)을 특정 task에 fine-tuning(adaptation)하는 데 있어서 time, resource cost가 너무 크다는 단점을 해결하기 위한 방법입니다.
1. Introduction
LLM은 기본적으로 pre-trained model로부터 특정 task(e.g. summarization, question and answering, ...)에 adaptation하기 위해 fine-tuning을 해야 합니다. Fine-tuning을 하면서 LLM모델의 weight parameters를 모두 다시 학습하게 되는데 이게 엄청난 cost!!입니다. 예를 들어 GPT-2(or 3), RoBERTa large모델의 경우 fine-tuning만 몇 달이 걸리게 됩니다.
그래서 이를 해결하기 위해 해당 논문에서는 Low-Rank Adaptation(LoRA)를 제안하게 됩니다. 이름에서 유추가능하듯이 LoRA는 Low-Rank 방법을 이용하여 time, resource cost를 줄이게 됩니다.
Low-Rank 방법을 사용하게 된 motivation 및 basis는 "Measuring the Intrinsic Dimension of Objective Landscapes" 논문과 "Intrinsic Dimensionality Explains the Effectiveness of Language Model Fine-Tuning." 논문에서 말하길 "over-parameterized model은 low intrinsic dimension으로 존재하고 있다"라는 사실에서 기반하고 있습니다. 그래서 저자들은 model adaptation동안의 weight change에도 low intrinsic rank를 가질 거라고 가정하게 되고 Low-Rank 방법을 사용하게 됩니다.
LoRA는 기존 pre-trained weights는 frozen해두고 몇 개의 dense(fc) layers만 학습하는 것인데 이때 학습방법이 dense layer의 weight을 low rank로 decomposition한 matrices만을 optimization하는 것입니다.
그래서 위 Figure 1과 같이 fine-tuning시에 pre-trained weights \( W \)는 frozen해두고 low rank decomposition된 weights \(A\), \(B\)만 학습하고 \( W \)에 summation하게 됩니다. Low rank로 decomposition된 weights는 당연하게도 기존 \(W\)보다 훨씬 작은 크기의 weight이기 때문에 time, resource cost를 줄일 수 있게 됩니다. 또한 pre-trained model을 가지고 있는 상태에서 특정 task에 adaptation하기 위해서 \(A\)와 \(B\)만 storage에 저장하면 되고 다른 task에 adaptation하기 위해 또 다른 \(A'\), \(B'\)만 갈아 끼우면 되기 때문에 storage, task switching면에서 매우 효율적입니다. 추가적으로 inference시에도 fine-tuned model의 latency성능이 낮아지지도 않습니다.
1.1 Terminologies and Conventions
- \( d_{model} \): Transformer의 input, output dimension size
- \( W_q, W_k, W_v, W_o \): Self-attention moduel의 query/key/value/output projection matrices
- \(W \) or \(W_0 \): Pre-trained weight
- \( \Delta W \): Adaptation동안의 accumulated gradient update
- \( r \): LoRA module의 rank
- Model optimization방법으로 Adam을 사용
- Transformer MLP feedforward dimension: \(d_{ffn} = 4 \times d_{model} \)
2. Problem Statement
LoRA방법은 training objective에 상관없이 모두 사용가능(agnostic)하지만 해당 논문에서는 LLM에 focus맞추어 설명합니다.
\( \Phi \)로 parameterized되어 있는 pre-trained language model \( P_{\Phi} (y| x) \)가 주어졌다고 가정합니다. \( P_{\Phi} (y| x) \)는 GPT와 같은 generic한 multi-task learner입니다.
해당 pre-trained language model을 downstream text generation task에 adaptation 하는 상황을 생각해봅시다. Downstream task의 예시로는 summarization, natural language to SQL (NL2SQL) 등이 있습니다. Adaptation을 위해 각 downstream task은 context-target pair의 training dataset \( Z = \{( x_i , y_i) \}_{i=1, \ldots, N} \)을 가집니다. ( \(x_i \)와 \(y_i \)는 token sequences) NL2SQL task의 경우 \(x_i \)은 natural language query이고 \( y_i \)은 그에 대한 SQL command일 것이고 summarization task의 경우에는 \(x_i \)은 article 내용이고 \(y_i \)는 그에 대한 요약내용이겠죠.
기존의 full fine-tuning이라면 model은 pre-trained weights \( \Phi_0 \)으로 initialized될 것이고 아래와 같은 conditional language modeling objective를 minimize하기위해 \( \Phi_0 + \Delta \Phi \)을 update 합니다:
\[
max_{\Phi} \sum_{ (x,y) \in Z} \sum^{ |y|}_{t=1} log(P_{\Phi} (y_t | x, y_{ < t} ))\quad \cdots Eq. (1)
\]
위의 full fine-tuning을 사용할 경우에 "각" downstream task를 위해 \( | \Phi_0 | \) dimension과 같은 크기의 \( | \Delta \Phi |\)을 매번 재학습해야 한다는 문제점을 가집니다. GPT-3와 같이 1,750억개의 weights를 가진 pre-trained model을 사용하게 되면 엄청난 cost가 들것입니다.
이를 해결하기위해 LoRA는 update해야하는 parameter를 \( \Delta \Phi = \Delta \Phi ( \Theta ) \)와 같이 encode하여 훨씬 작은 size의 parameter \( \Theta \)로 대체 학습하는 것입니다( \( | \Theta | \ll | \Phi_0 | \) ). 그래서 최적의 \( \Delta \Phi \)를 찾는 task는 \( \Theta \)를 optimization하는 것으로 대체됩니다:
\[
max_{\Phi} \sum_{ (x,y) \in Z} \sum^{ |y|}_{t=1} log(P_{\Phi_0 + \Delta \Phi ( \Theta ) } (y_t | x, y_{ < t} ))\quad \cdots Eq. (2)
\]
위와 같은 LoRA방식으로 GPT-3을 fine-tuning할 경우 기존 full fine-tuning보다 학습해야 할 parameter수가 전체의 0.01%로 줄어듭니다. 아래 section에서는 LoRA방법에서 정확히 어떻게 \( \Theta \)가 표현되는지 얼마나 작은 size로 encode되는 지 알아보도록 하죠!
3. Our method
3.1 Low-Rank Parameterized Update Matrices
리마인드하면 LoRA는 adaptation 동안에 low intrinsic rank를 가진 weight로 update하는 방법입니다. 수학적으로 pre-trained weight matrix \( W_0 \in \mathbb{R}^{d \times k } \) 에 대해 \( W_0 + \Delta W = W_0 + BA \)로 update하는 것입니다. 즉, \( W_0 \)은 frozen되고 low rank로 decomposition된 \(B \in \mathbb{R}^{d \times r} \)와 \( A \in \mathbb{R}^{r \times k} \)만을 학습하는 것입니다(rank \(r \ll min(d,k) \)을 만족함 ).
그리고 \(W_0 \)와 \( \Delta W = BA \)는 같은 input에 곱해지고 그들의 output vector는 coordinate-wise하게 합(summation)해집니다. 이에 대해 forward pass를 표현하면 다음과 같습니다:
\[
h = W_0 x + \Delta W x = W_0 x + BAx \quad \cdots Eq. (3)
\]
\(A\)는 random Gaussian initialization되고 \(B\)는 0으로 initialization됩니다. 그래서 training 시작 시에 \( \Delta W = BA \)또한 0입니다. 그리고 \( \Delta W x\)는 \( \frac{ \alpha}{r} \)으로 scaling됩니다. Adam으로 optimization 할 때 \( \alpha \)를 tuning하는 것은 learning rate를 tuning하는 것과 같이 하였습니다. 그래서 \( \alpha \)을 처음 \( r\)값으로 정하였다고 합니다. Scaling은 \(r\)값을 변화시킬때 hyperparameter를 재조정할 필요를 줄이는 데 도움이 됩니다.
위는 실제 LoRA코드를 snippet한 것인데 위에 설명드린 수식과 내용과 일치하는 것을 알 수 있습니다. (코드에서 확인해보니 \(r, \alpha \)값은 보통 (8, 16) 또는 (16, 32)을 사용하였습니다.)
3.1.1 No additional Inference Latency
LoRA를 사용하여 inference하려고 할 때는 기존 pre-trained weight \(W_0\)에 학습한 \(BA\)를 더해주고 사용하면 되기 때문에 infernece latency성능 하락은 전혀 없습니다. 그리고 \(W_0\)을 기반으로 또 다른 task로 학습한 \(B'A'\)가 있을 경우 \(BA\)을 빼주고 \(B'A'\)을 더해주어 사용하면 되기 때문에 reusability이 좋습니다.
3.2 Applying LoRA to Transformer
논문에서는 trainable weight를 최소화하기위해 LoRA를 모든 layer 및 module에 적용하지않습니다. 오직 LoRA를 Transformer의 attention weights인 \(W_q\)또는 \(W_k \), \(W_v\)에만 적용하였고 나머지 MLP module에는 적용하지 않았습니다. (실제 성능 실험에서는 \(W_q \)와 \(W_v\)에만 LoRA적용하였습니다.) 이렇게 셋팅하고 진행함으로써 1,750억개의 parameter를 가진 GPT-3에 대해 fine-tuning시에 원래 VRAM를 1.2TB사용하던 것이 LoRA를 통해 350GB로 줄어들었습니다. 또한 training speed또한 25%가량 줄었다고 합니다.
4. Experiment Results
GPT-2기준 성능비교 시 기존 방법들보다 trainable weight도 적으며 다양한 데이터셋에 대해 성능도 잘 나온것을 확인가능합니다. (실험에 대한 더 많은 내용은 논문 참조부탁드립니다. ㅠ)
'AI paper review' 카테고리의 다른 글
Rate-Perception Optimized Preprocessing for Video Coding 논문 리뷰 (1) | 2023.12.31 |
---|---|
Segment Anything 논문 리뷰 (0) | 2023.04.07 |
GPT-1: Improving Language Understanding by Generative Pre-Training 논문 리뷰 (0) | 2023.02.13 |
The Forward-Forward Algorithm: Some Preliminary Investigations 논문 리뷰 (0) | 2023.01.28 |
EfficientDet Scalable and Efficient Object Detection (0) | 2022.03.11 |