2022년 Snap Inc. 에서 게재한 논문인 EfficentFormer 논문을 리뷰합니다.
1. Introduction
해당 논문은 주요 내용은 다음과 같습니다.
- "Vision Transformer(ViT)가 high performance를 내면서 (mobile device에서) mobilenet만큼 빨라질 수 있을까"에 대한 의문점에서 시작
- 기본적으로 VIT는 accuracy 성능은 좋은데 lightweight CNN(e.g. MobileNet)보다 느리다는 단점을 가짐
- 그 의문점을 풀기 위해 기존 ViT의 inefficient한 구조에 대해 분석
- Efficient한 구조를 갖는 dimension-consistent한 ViT 모델(EfficientFormer) 을 제안
- 특히나, 해당 논문은 FLOPs나 parameter수가 아닌 inference speed에 초점을 맞춤
위의 EfficientFormer 모델은 아래 그림에서 보이듯이 inference speed도 빠르면서 좋은 accuracy성능을 보임을 알 수 있습니다.
2. On-device Latency Analysis of Vision Transformers
기존 ViT의 구조중 어떤 operation이나 architecture가 on-device inference speed에 악영향을 주는지 확인하기 위해 실험 및 분석을 하였고 아래와 같은 4가지 observations을 확인하였습니다. 아래 그림은 iPhone12 device에서 latency profiling을 진행한 결과를 나타냅니다.
Observation 1: 큰 kernel과 stride를 갖는 patch embedding이 mobile device의 inference speed에 악영향을 줌.
위의 그림에서 Patch Embedding은 일반적으로 non-overlapping convolution layer로 구현되는데 non-overlapping하게 하기 위해 large kernel과 stride를 사용하게 됩니다. 하지만 대부분의 compiler에서는 large kernel convolution을 지원하지 않고 large kernel convolution은 기존의 acceleration algorithm(e.g. Winograd)으로부터 가속화 되지 않습니다. 위의 latency profiling그림에서 실제로 일반적인 transformer의 patch embedding이 inference하는데 시간을 많이 잡아먹는 것을 확인 가능합니다.
저자들은 기존의 non-overlapping patch embedding을 여러 개의 3x3 convolutions(hardware-efficient한 구조)으로 대체하여 모델을 design하겠다고 합니다.
Observation 2. token mixer의 선택에 consistent feature dimension이 중요하고 Multi-Head Self Attention(MSHA)이 주된 speed bottleneck아님.
Token mixer란 말 그대로 token간의 information을 섞는 역할을 하며 token mixer의 대표적인 예시로는 MSHA가 있고 아래 그림에서 보이듯이 transformer variant 논문들에서는 token mixer을 spatial MLP나 Pooling layer을 사용함을 알 수 있다.
저자들은 token mixer의 선택으로 머가 좋은 지 분석하기 위해 pooling과 MSHA를 비교 실험하였습니다. (token mixer을 shifted window attention을 사용한 논문이 있는데 해당 모듈은 대부분의 mobile compiler에서 지원하지 않는다고 하여 비교 안 하였음)
- pooling을 사용하는 PoolFormer-s24, MSHA를 사용하는 LeViT-256를 비교
- LeViT는 4D tensor에 대해 Conv연산후, MSHA연산을 하는데 MSHA는 3D tensor에 대해 연산하므로 reshape이 빈번하게 필요함. 헌데 reshape연산이 inference speed의 bottleneck으로 작용(위의 latency profiling그림 참고!)
- PoolFormer는 4D tensor에 대해 Conv연산후 4D tensor대상으로 연산하는 Pooling을 사용하기 때문에 reshape이 필요 없고 결과적으로 PoolFormer가 LeViT보다 빠름
- 여기서, 4D tensor 연산에서 3D tensor연산으로 바뀌지 않고 그대로 4D에서 4D tensor연산하는 것을 "consistent feature dimension" 하다고 함
- DeiT-S와 LeViT-256를 비교했을 때 3D연산을 하는 MSHA자체는 inference speed에 대해 큰 overhead를 가져오지 않음을 확인
- 오직 빈번한 reshape 연산이 없을 경우에만!
위의 분석을 통해서 token mixer선택에 있어서 consistent feature dimension을 유지하기 위해서 reshape operation을 거의 사용하지 않도록 하는 dimension-consistent network를 제안하게 됩니다. 제안하는 EfficientFormer는 token mixer선택에 있어서 4D tensor연산을 하는 pooling과 3D tensor연산을 하는 MSHA를 모두 사용하게 됩니다. (자세한 EfficientFormer의 내용은 뒤에서...)
Observation 3: Conv-BN이 LN(Layer Normalization)-Linear보다 latency-favorable하며 Conv-BN을 사용했을 경우 Accuracy drop이 acceptable함
기본적으로 Layer Normalization(LN)-Linear구조는 3D linear projection을 하게 되므로 MSHA와 같이 사용됩니다. 하지만 LN은 전체 network inference time 중 10~20% 정도를 차지하는 것을 위의 latency profiling 그림에서 볼 수 있습니다. 이는 LN이 inference를 잴 경우에 running statistics를 collect 해야 하기 때문에 생기는 시간입니다.
이에 반해, Conv-BN구조는 4D tensor에 대해 연산을 하고 inference시에 BN이 Conv구조에 folding 될 수 있으므로 latency를 낮추는 데 더 용이합니다. 하지만 Conv-BN은 LN-Linear보다는 "조금" accuracy성능이 낮게 나온다고 하네요.
그래서 EfficientFormer가 token mixer가 pooling일 때는 Conv-BN구조를 사용할 것이고 MSHA일 경우는 LN-Linear를 사용하도록 할 것입니다.
Observation 4: nonlinearity(activation function)의 latency는 hardware와 compiler에 의존적
"Towards efficient vision transformer inference: a first study of transformers on mobile devices." 논문에서는 GeLU가 hardware에 inefficient 하다고 했지만 실제로 저자들이 iPhone12에서 실험해봤을 때 GeLU가 ReLU만큼 느리지 않다는 것을 확인하였습니다. 반대로, HardSwish는 iPhone12에서 느린 것을 확인하였습니다. (LeViT-256에서 HardSwish 사용 시: 44.5ms, GeLu사용 시: 11.9ms)
그래서, EfficientFormer에서는 GeLU을 사용합니다.
3. Design of EfficientFormer
위의 observations을 기반으로 저자들은 EfficientFormer을 제안하게 됩니다. EfficientFormer는 patch embedding(\(PatchEmbed\))와 여러 개로 쌓은 meta block(\(MB\))로 구성됩니다.
\[
Y = \prod^m_i MB_i (PatchEmbed(X^{B,3,H,W}_0 )). \quad Eq.(1)
\]
- \(X_0\): Input image
- \(B\), \(H\), \(W\): Batch size, Height, Width
- \(Y\): Output
- \(m\): transformer block 총 개수
그리고 \(MB\)는 token mixer(\(TokenMixer\))와 \(MLP\) block으로 표현됩니다.
\[
X_{i+1} = MB_i (X_i) = MLP(TokenMixer(X_i)). \quad Eq.(2)
\]
여기서 \(X_{i | i >0 } \)은 \(i^{th} MB\)으로부터 forward된 intermediate feature를 뜻합니다.
추가로 저자들은 Stage(\(S\))를 정의하였는데요. Stage는 여러개의 MetaBlock들로 구성되어 있고 각 Stage는 같은 spatial size를 가지며 각 Stage가 가진 MetaBlock의 수를 \(N_i \times\)로 표현한다고 하네요. 총 stage개수는 4개이며 각 stage 사이마다 embedding operation(\(Embedding\))이 있습니다. Embedding operation은 embedding dimension으로 project시키기 위함과 token 길이를 downsample하기위해 사용됩니다.
아래부터는 EfficientFormer의 상세한 구조 디자인 설명을 드리도록 하겠습니다.
3.1 Dimension-consistent Design
Section 2에서 말씀드린 observation을 통해 dimension consistent design을 제안하게 됩니다. 위의 그림에서 보이듯이 4D partition부분과 3D partition부분으로 나누는데 처음에는 4D partition으로 stage를 시작하고 마지막부분에 3D partition부분을 수행하여 reshape연산을 최소화하여 dimension consistent design을 구성하게됩니다.
- 4D partition: \(MB^{4D} \)로 표현되며 Conv-net style과 token mixer로 pooling layer 사용하여 구현
- 3D partition: \(MB^{3D} \)로 표현되며 linear projection과 token mixer로 MSHA 사용하여 구현
위의 EfficientFormer구조 그림은 예시일 뿐이며 실제 4D, 3D partition길이는 NAS를 통해 찾는다고 합니다!
먼저 , input image는 patch embedding에 의해 processing된다고 말씀드렸는데 observation 1에 근거하여 patch embedding은 2개의 3x3 convolution(stride 2)으로 구현됩니다.
\[
X_i^{B,C_{j|j=1}, \frac{H}{4}, \frac{W}{4} } = PatchEmbed(X^{B,3,H,W}_0). \quad Eq.(3)
\]
\(C_j\)는 j-th stage의 channel 수를 의미합니다. 그다음으로 \(MB^{4D} \)는 \(Pool\) mixer를 사용하여 다음과 같이 표현됩니다.
\[
\begin{array}{l} I_i = Pool(X_i^{ B,C \frac{H}{2^{j+1}},\frac{W}{2^{j+1}}}) + X_i^{ B,C \frac{H}{2^{j+1}},\frac{W}{2^{j+1}}} \cr X_i^{ B,C \frac{H}{2^{j+1}},\frac{W}{2^{j+1}}} = Conv_B(Conv_{B,G} ( I_i ) ) + I_i \end{array} \quad Eq.(4)
\]
\( Conv_{B,G} \)는 연속된 Conv-BN-GeLU을 의미합니다. \(MB^{4D} \) block 연산후에는 한번의(one-time) reshape 연산으로 4D에서 3D로 feature dimension을 변경합니다. 해당 feature를 입력으로 \(MB^{3D}\)는 다음과 같이 연산합니다.
\[
\begin{array}{l} I_i = Linear(MSHA(Linear(LN(X_i^{B, \frac{HW}{4^{j+1}},C_j})))) + X_i^{B, \frac{HW}{4^{j+1}},C_j}, \cr X_i^{B, \frac{HW}{4^{j+1}},C_j} = Linear(Linear_G(LN(I_i))) + I_i \end{array} \quad Eq.(5)
\]
\(Linear_G\)는 Linear-GeLU를 의미하고 MSHA연산은 다음과 같다.
\[
MSHA(Q,K,V)= Softmax(\frac{Q \cdot K^T}{\sqrt C_j } +b ) \cdot V. \quad Eq.(6)
\]
\(Q, K,V\)는 각각 query, key, value를 뜻하며 linear projection으로부터 학습되는 variable이다. 또한 \(b\)는 parameterized attention bias로 position encoding역할로 사용된다.
3.2 Latency Driven Slimming
3.2.1 Design of Supernet
dimension-consistent design을 하기위해 저자들은 supernet으로부터 architecture search를 하는 NAS방법을 사용한다. Supernet는 다음과 같은 MetaPath(\(MP\))을 정의하여 구성됩니다.
\[
\begin{array}{l} MP_{i,j=1,2} \in \{ MB^{4D}_i , II_i \}, \cr MP_{i,j=3,4} \in \{ MB^{4D}_i , MB^{3D}_i II_i \}. \end{array} \quad Eq.(7)
\]
여기서 \( II \)은 identity path을 의미하고 \(j\)는 \(j^{th}\) stage, \(i\)는 \(i^{th}\)block을 의미한다. 즉, supernet의 training시에 stage 1,2에는 \(MB^{4D}_i\) 또는 \(II_i\)이 선택될 수 있는 것이고 stage 3,4에는 \(MB^{4D}_i\), \(MB^{3D}_i\) 또는 \(II_i\)이 선택 가능하다는 것이다.
여기서 stage 3,4에서만 \(MB^{3D}_i\) 이 추가된 이유는 아래와 같다.
- MSHA는 token 길이에 따라 quadratic(4배)하게 computation cost가 커지므로 상대적으로 token 길이가 작은 뒤쪽의 stage를 사용
- 초기 stage에는 low-level feature를 학습하고 마지막 stage들은 long-term dependencies을 학습한다는 측면에서 뒤쪽의 stage에 MSHA를 적용하는 게 옳음
3.2.2 Search Space
- \(C_j \): 각 stage의 channel 수
- \( N_j\): 각 stage의 block의 수
- \(\mathcal{N} \): \(MB^{3D}\)에 적용할 마지막 block수
3.2.3 Search algorithm
NAS에서는 supernet을 학습을 완료하고 나면 어떤 path(subnet)가 best인지 찾는 search algorithm이 필요합니다. 저자들은 supernet의 학습이 완료되면 바로 어떤 path가 best인지 알 수 있는 efficient한 gradient-based search algorithm을 제안합니다.
해당 search algorithm은 3가지 step을 수행합니다.
(1) supernet training시에 Gumble Softmax sampling을 함께 사용하여 선택된 \(MP\)의 importance score을 측정합니다.
\[
X_{i+1} = \sum_n \frac{e^{ ( \alpha^n_i + \epsilon^n_i )} / \tau }{\sum_n e^{ ( \alpha^n_i + \epsilon^n_i ) / \tau }} \cdot MP_{i,j} \cdot (X_i). \quad Eq.(8)
\]
\(\alpha\)는 trainable parameter로 MP의 importance score를 나타내고 해당 block이 선택될 확률을 뜻한다. \( \epsilon \sim U(0,1)\)은 exploration역할을 하게 되고 \( \tau\)는 temperature, \(n\)은 선택 가능한 block은 type을 의미한다.
(2) 16배수로 나누어진 channel(width)들을 가지는 여럿 \(MB^{4D}\)와 \(MB^{3D}\)의 on-device latency lookup table을 구축한다.
(3) single-width를 가지는 supernet기준으로 채널 수를 조절하는 gradual slimming 을 진행한다.
supernet를 구성할 때 각 MP에 대해 여럿 다양한 channel(width) 수 path가 없었는데 그 이유는 저자들은 single-width supernet구조에서 channel수를 줄이는 작업을 진행하였다. (이는 여럿 다양한 channel수에 대한 search도 supernet training시에 할 경우 memory-consuming이 크기 때문)
Gradual slimming은 다음과 같이 수행됩니다.
- \(S_{1,2}, S_{3,4}\)의 각 \(MP_i\)에 대해 importance score을 \( \frac{\alpha^{4D}_i}{\alpha^I_i}, \frac{\alpha^{4D}_i+ \alpha^{3D}_i}{\alpha^I_i} \)로 정의
- 각 stage에 대한 importance score를 구하기 위해 각 stage안에 포함되는 \(MP_i\)의 importance score를 summation
- 다음 3가지 옵션에 대해 action(수행)해보면서 per-latency accuracy drop \( \frac{- \%}{ms}\)을 기준으로 3개 중 1개의 옵션을 취함
- Option 1: 가장 낮은 importance score를 가지는 \(MP\)에 대해 \(II\))(identity)를 선택
- Option 2: 첫 번째 \(MB^{3D}\)를 제거
- Option3: 가장 낮은 importance score를 가지는 \(MP\)에 대해 channel수를 16으로 나눔 (16으로 나눈 \(MP\)에 대해 latency lookup table이 존재하므로 해당 latency 사용)
위의 gradual slimming에 대한 algorithm은 아래를 참고하시면 됩니다.
위의 graudal slimming을 수행 완료하여 최종 선택된 EfficientFormer 구조는 아래와 같습니다.
4. Experiments and Discussion
저자들은 PyTorch 1.11과 Timm library를 통해 EfficientFormer를 구현하였고 mobile speed는 A14 bionic chip이 장착되고 NPU사용이 가능한 iPhone12에서 1000번 inference하고 평균을 내어 결과를 내었다고 합니다. CoreMLToolssms run-time model을 deploy하기 위해 사용하였습니다.
4.1 Image classifciation
ImageNet-1K dataset에 대해 실험하였고 300 epochs 학습 기준으로 결과를 비교하였습니다. EfficientFormer는 AdamW optimizer를 사용하였으며 5 epochs의 warm-up training과 consine annealing scheduler적용하였습니다. 또한 initial learning rate는 \(10^{-3} \times (batch sizze / 1024) \), minimum learning rate는 \(10^{-5}\)이며 distillation을 위한 teacher model을 RegNetY-16GF(82.9% top-1 accuracy on ImageNet-1k)으로 설정하였습니다.
4.2 EfficientFormer as Backbone
Detection이나 segmentation task에서도 performance가 뛰어난지 확인하기 위해 EfficientFormer를 backbone으로 사용하였습니다. Mask-RCNN에 EfficientFormer을 combine하였고 COCO-2017 dataset기준으로 결과를 측정하였습니다. EfficientFormer의 weight는 ImageNet-1K pretrained weight로 initialization하였고 AdamW optimizer와 initial learning \( 1 \times 10^{-4} \)을 사용하였고 12 epochs만 학습했다고 합니다.
'AI paper review > Mobile-friendly' 카테고리의 다른 글
YOLOv7: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors 논문 리뷰 (6) | 2022.07.22 |
---|---|
[MobileOne] An Improved One millisecond Mobile Backbone 논문 리뷰 (0) | 2022.06.25 |
Lite Pose 논문 리뷰 (0) | 2022.04.18 |
MobileViT 논문 리뷰 (0) | 2022.03.28 |
EfficientNetv2 논문 리뷰 (0) | 2022.03.24 |