오늘은 Hinton님의 The Forward-Forward Algorithm: Some Preliminary Investigations 논문을 리뷰입니다!
해당 논문의 목적은 기존 deep learning model의 학습방법인 backpropagation에 대한 단점을 지적하고 새로운 학습방법인 Forward-Forward 알고리즘을 제안하였습니다.
1. What is wrong with Backpropagation
Deep learning model의 backpropagation은 인간의 뇌가 학습하는 방법과 유사하게 설계되어있다고 알고 계신분들이 많은데요. 실제로 그렇지 않다고 하고 근거는 아래와 같습니다.
- Backward pass를 하기위해 neural activity를 저장하거나 error derivate를 전파하는 과정이 인간의 뇌에서 일어나지 않았고 그런 증거가 발견되지 않았다고 함
- 인간의 뇌는 중단되는 시간이 따로 없이 다른 sensory processing stage을 통해서 sensory data을 전달할 필요가 있고 그때 그때 봐가며 learning이 될 수 있어야 함
- Error derivatives를 propagate하기위해 중단되는 시간이 생김
- Backpropagation으로는 real time으로 inference와 learning이 불가함
- Backprogation은 모델의 forward계산의 정확한 knowledge가 필요. 즉, differentiable할 수 없는 black-box에 대해 forward-pass한다면 backpropatgation못하는 문제점이 있음
- 이에 대한 방안으로 강화학습이 있지만 강화학습은 high variance를 가진다는 문제점을 가짐
그래서, 해당 논문의 주요 목적은 unknown non-linearities가 포함된 neural network에 강화학습을 사용할 필요가 없고 Foward-Foward algorithm(FF)을 사용하면 된다는 것입니다. 그리고 FF는 neural acitivities를 저장하지 않거나 error derivatives를 propagate하지 않고도 sequential data를 pipelining하면서 학습가능하다는 장점을 가집니다.
그럼 FF가 장점만 가지냐? 그건 아닙니다. 단점은 아래와 같습니다.
- FF는 어떤 때에는 backpropagation보다 느림
- Generalized되지않았기 때문에 다양한 task, application에 사용하기 아직 힘듬
- 큰 dataset으로 학습된 큰 model의 학습능력을 내기위해서는 backpropagation을 사용해야함
FF가 backpropagation보다 우수할 수 있는 두 가지 영역은 cortex안에서 model의 학습과 강화학습에 의존하지 않고 매우 낮은 전력의 analog hardward를 사용하는 방법입니다.
2. The Forward-Forward Algorithm
Forward-Forward Algorithm (FF)는 Boltzmann machine과 Noise Contrastive Estimation의 영감을 받은 greedy multi-layer learning방법입니다. 아이디어는 backpropagation의 forward, backward passes를 2개의 forward passes로 대체하는 것입니다.
FF의 첫 번째 forward는 기존과 같이 top-down 형식으로 forward를 진행하고 두 번째 forward는 각 layer에서 weight update를 진행합니다.
기존과 다른 또다른 점은 contrastive learning을 하는것입니다. 즉, positive pass와 negative pass를 나누어 weight를 update합니다. Positive pass는 real(positive) data에서 작동하며 매 hidden layer에서 goodness(잘햇어!)을 향상시키기 위해 weight를 조절합니다. 반대로 negative pass에서는 negative data에서 작동하며 매 hidden layer에서 goodness를 낮추기 위해 weight를 조절합니다. 그래서, FF를 사용하기 위한 충분 조건은 different data와 opposite objectives를 가져야하는 것입니다.
그럼 각 layer마다 goodness function을 어떻게 정의할까요? 해당 논문에서는 해당 layer안에서 rectified linear neurons의 activities의 제곱(square)의 합으로 정의한다고 합니다. FF learning은 real data에 대해서는 특정 threshold보다 높게 goodness가 출력되도록 하고 negative data에 대해서는 threshold보다 낮게 측정되도록 하는게 목적입니다. 이를 수식화하면 다음과 같습니다.
\[
p(positive) = \sigma ( \sum_j y^2_j - \theta ) , \quad \cdots Eq. (1)
\]
\( \sigma \)는 logistic function(i.e. sigmoid)이며 \( \theta \)는 threshold이며 \( y_j \)는 layer normalization전의 \( j \)번째 hidden unit의 acitivity값입니다. 위의 objective function이 loss function이 되고 Pytorch기준으로 loss.backward, optimizer.step을 통해 weight update합니다.
그리고 negative data는 외부에서 제공되거나 neural net의 top-down connection을 이용해서 predict되어 생성가능하다 합니다.
2.1 Learning multiple layers of representation with a simple layer-wise goodness function
만약 first hidden layer의 acitivities를 second hidden layer의 input으로 사용하고 싶다면 어떻게 해야할까?
FF는 first hidden layer의 hidden vector의 length을 normalize하여 second layer의 input으로 보낸다고 합니다. 이렇게 하면 first hidden layer에서 goodness를 계산하기 위해 사용했던 정보를 제거할 수 있고 next hidden layer에 first hidden layer의 relative acitivities의 정보를 사용할 수 있도록 합니다. (즉, relative acitivies는 layer normalization에 영향을 받아 없어지거나 하지 않는 것) 달리 표현하면 first hidden layer의 activity vector는 length와 orientation을 가지고 length(before layer normalization)는 그 layer의 goodness를 define하는데 사용되고 orientation(after layer normalization)은 next layer에 전달하기위해 사용됩니다.
3. Some experiments with FF
FF가 small neural network에서 어떻게 작동하지 설명하도록 합니다.
3.1 The backpropagation baseline
새로운 learning 알고리즘을 설명하고 성능을 확인하기에 가장 적합한 MNIST에 대해 실험하려고 합니다. 그 전에 backpropagation을 사용했을 때 성능에 대해 이야기합니다. CNN을 사용하면 0.6% test error을 가진다고 합니다. 그리고 permutation-invarient task에서는 FC layer와 ReLU를 사용하면 1.4% test error를 가진다고 합니다. 즉, complicated regularizer사용 없이 backpropagation을 사용하면 1.4% test error 성능을 가지게 됩니다.
permutation-invarient task란?
입력 벡터 요소의 순서와 상관없이 같은 출력을 생성하는 모델을 뜻하며 대표적인 모델로는 MLP이다. permuation-invarient task가 아닌 모델로는 입력 이미지의 픽셀의 순서를 고려하는 CNN이 있다.
3.2 A simple supervised example of FF
※ 해당 글에서는 unsuperivsed, nlp, reccurent net에 대한 내용은 skip하고 supervised에 대한 내용만 다루겠습니다.
Supervised leraning은 single task, small model을 사용하고 싶을 때 유용한 방법입니다. 이를 FF에 적용하기 위해서는 input에 label을 포함시키는 방법을 사용합니다. Positive data에는 옳바른 label을 포함하고 있는 input image들로 구성되고 negative data는 틀린 label을 포함하고 있는 input image들로 구성됩니다. Positive, negative data의 다른 점은 오직 label이기 때문에 FF알고리즘은 label과 연관되어 있지 않은 image의 feature 정보는 모두 무시할 것입니다.
그럼 label을 어떻게 data에 포함할까요?
MNIST기준으로 설명드리면 class가 10개이므로 image의 첫 10 pixels에 label정보를 기입하는 것입니다. 이렇게 하여 논문에서는 60 epochs, 4 hidden fc layers, 2000 ReLUs로 구성된 network로 1.36% test errors를 얻었다고 합니다. 해당 결과는 backpropagation을 사용했을 때는 20 epoch으로 낼 수 있는 성능이라고 합니다.
FF로 training하고 나면 inference는 어떻게 진행할까요?
Inference시에는 test image의 첫 10 pixel에 neutral label(모두 0.1값)을 포함시켜서 single forward pass를 통해 classify한다고 합니다. 첫 번째 hidden layer의 activities(features)을 제외하고 다른 모든 hidden layer의 acitivities값에 대해 softmax을 적용하고 다 더해줍니다. 더했을 때 가장 큰 값을 가진 class index가 model의 최종 output class가 됩니다.
(이 부분이 제가 읽은 책중에 천개의 뇌 이론과 비슷하더라고요. 책에서는 수많은 뇌세포가 투표를 통해 객체가 무엇인지 판단한다고 하는데 각 뇌세포를 layer의 node라고 생각한다면 여러 layer의 여러 nodes의 acitivity값의 합(투표)으로 최종 class를 결정하니 비슷하네요)
이 방법은 neutral label을 사용하기 때문에 빠르지만 sub-optimal한 방법입니다. 그래서 논문에서는 input image에 특정한 하나의 label을 가진 input을 사용하는 것이 좋다고 합니다. 0 label을 가진 image, 1 label을 가진 image, ...., 9 label을 가진 image를 개별적으로 넣어보고 더했을 때 가장 높은 gooodness를 가진 label을 최종 output class 선택합니다.
추가적으로 해당 논문에서는 FF를 위한 data augmentation방법인 image jitttering을 제안하였습니다. 각 image마다 모든 방향으로 최대 2 pixels까지 shifting하여 총 25개의 다른 image를 생성하게 됩니다.
Image jittering을 통하여 pixel간의 spatial layout knowledge를 학습하도록 하게 하였고 결론적으로 permutation invariant을 없앴다고 합니다. 해당 augmentation과 함께 500 epochs을 학습하였을 때 CNN과 비슷한 test error인 0.64%을 도출하였다고 합니다.
그리고 흥미로운 결과로는 first hidden layer의 recpetive field를 보았을 때 아래와 같이 image의 첫 10 pixels에서 class label이 학습되는 것을 볼수 있습니다.
'AI paper review' 카테고리의 다른 글
Rate-Perception Optimized Preprocessing for Video Coding 논문 리뷰 (1) | 2023.12.31 |
---|---|
LoRA: Low-Rank Adaptation of Large Language Models 논문 리뷰 (0) | 2023.05.16 |
Segment Anything 논문 리뷰 (0) | 2023.04.07 |
GPT-1: Improving Language Understanding by Generative Pre-Training 논문 리뷰 (0) | 2023.02.13 |
EfficientDet Scalable and Efficient Object Detection (0) | 2022.03.11 |