오늘은 Meta AI의 Segment Anything논문을 리뷰합니다. 논문 이름이 목적과 내용을 뜻하는 논문이네요. 아래 사진과 같이 어떤 이미지든(zero-shot) segment할 수 있다는 것을 의미합니다. Zero-shot transfer이 가능하며 어떤 task에도 generalization될 수 있다는 점에서 ChatGPT와 같이 이러한 모델을 foundation model이라합니다.
1. Introduction
- ChatGPT와 같은 Large Language Models (LLM)은 (1) zero-shot generalization이 뛰어나고 (2) hand-crafted 질문 text을 입력으로 아주 적절한 대답(response)을 즉각적으로 출력할 수 있는 prompt engineering이 가능합니다.
- 이러한 모델을 foundation model이라고 지칭
- 전체적인 학습 flow는 web-scale의 dataset으로 pre-trained하고 특정 task(e.g. QA, translation)에 맞춰 fine-tuning
- 해당 논문의 목표: LLM과 비슷한 학습 방법을 사용해서 image segmentation용 foundation model을 만들어 보자!
- 하지만... 다음과 같은 문제가 있음
- 어떤 task가 zero-shot generalization을 가능하게 할까?
- 어떤 model architecture을 사용해야할까?
- 어떤 data가 해당 task와 model에 적합할까?
- 하지만... 다음과 같은 문제가 있음
위의 문제를 해결하기 위해 논문에서 해당 방법론 제안
1. Task
- LLM에서 사용하는 방식과 비슷하게 prompting technique을 기반으로 promptable segmentation task 제안
- 목표는 어떠한 형태의 segmentation prompt가 주어져도 valid한 segmentation mask(result)을 출력하도록 하는 것
- Segmentation prompt는 아래 사진과 같이 segment할 image object에 대한 spatial 또는 text information이기만 하면 됨
- 아래 사진에서 가능한 segmentation prompt 종류가 명시됨: points, boxes, segment mask, text
- 심지어 해당 prompt가 애매모호(ambiguity)하거나 여러 object를 지칭해도 됨
- 예를 들어 prompt 중 한종류인 point label이 shirt를 가리키고 있을 때 실제 의도는 shirt그 자체일 수도 있지만 shirt를 입고있는 사람이 될 수 도 있음
- 해당 prompt가 여러 object를 지칭해도 model의 output은 반드시 여러 object중 하나의 reasonable mask를 뽑아내도록 학습할 것임
2. Model
- 다음과 같은 3가지 constraints를 만족하는 model을 만들어야 함
- Flexible prompts를 지원
- interactive하게, real-time으로 segmentation mask를 compute 가능
- (prompt에 대한) ambiguity-aware한 특성
- 위 constraint을 만족시키는 모델 Segmen Anything Model (SAM) 제안
- SAM은 image encoder, prompt encoder, mask decoder로 구성 됨
- Image encoder: image을 입력으로 image embedding 출력
- Prompt encoder: prompt를 입력으로 prompt embedding 출력 (flexible!)
- Mask decoder: 위의 두 embedding값을 입력으로 segmentation mask 출력 (fast!)
- Input image는 같고 prompt가 다를 경우 image embedding reuse가능
- Prompt는 point, box, mask, text를 받을 수 있도록 flexible하게 구성
- 하나의 prompt입력에 대해 여러 segmentation mask를 출력할 수 있도록 하여 ambiguity-aware특성 만족
- SAM은 image encoder, prompt encoder, mask decoder로 구성 됨
3. Data engine
- SAM model이 strong generalization을 얻기 위해서는 방대한 dataset이 필요함
- 이를 위해 data engine을 구축
- Model-in-the-loop dataset annotation을 사용하는 방식
- Data engine은 총 3가지 strategy로 구성
- Assisted-manual: 기존의 annotation task와 비슷하게 SAM이 annotator를 assist하는 형식
- Semi-automatic: prompting하여 선택된 objects들 중 subset만 SAM이 automatic하게 mask를 생성해주고 나머지는 annotator가 진행
- Fully-automatic: foreground points들의 regular grid prompt를 입력으로 SAM이 이미지당 100개까지의 high-quality mask를 생성 (아래 왼쪽 그림 참조)
4. Dataset
- Data engine의 fully automatic strategy로 생성된 최종 dataset이 SA-1B
- 1B masks와 11M의 licensed, privacy-preserving images로 구성됨
2. Segment Anything Task
2.1 Task
Segment Anything Task인 promptable segmentation task에 대해 더 자세히 설명드리겠습니다.
본 논문에서는 promptable segmentation task을 주어진 어떠한 prompt라도 valid한 segmenation mask를 출력하는 것으로 정의합니다. 여기서 'valid'의 정의는 ambiguous(모호한)해도 되고 여러 물체를 가리키고 있어도 괜찮습니다. 다만 반드시 여러 물체 중 하나를 꼭 가리키는 segmentation mask이어야합니다. 이렇게 task를 정의한 이유는 이 방법이 natural pre-training algorithnm이며 prompt를 이용한 zero-shot transfer하기에 가장 general한 방법이기 때문입니다.
2.2 Pre-training
Pre-training은 image와 prompts(points, boxes, masks)을 입력으로 하여 나온 model output인 predicted segmentation mask와 ground truth의 차이를 최소화하도록 학습합니다. 여기서의 주된 목적은 prompt가 ambiguous해도 어떤 prompt에 대해서도 valid한 mask를 prediction하도록 하는 것이 목적입니다. 그래서 model의 prediction이 ambiguity를 포함하게 되는데 이게 user가 사용하기에 효과적이고 data engine의 automatic annotation에도 flexible하게 사용가능하게 합니다.
2.3 Zero-shot transfer
어떠한 prompt에도 적절하게 pre-training되기 때문에 특정 task에 zero-transfer하기 용이합니다. 예를 들어 '고양이'를 detect하는 bounding box detector를 한 유저가 가지고 있는 상태에서 '고양이'를 segmentation하고 싶다면 SAM모델에 bounding box output을 prompt로 주는 방식으로 해결할 수 있습니다. (SAM은 bounding box를 prompt로 입력받아 pre-training되어있기 떄문입니다.)
3. Segment Anything Model (SAM)
Promptable segmentation task을 위한 SAM모델은 (1) Image encoder, (2) flexible prompt encoder, (3) fast mask decoder 로 이루어져 있습니다.
3.1 Image Encoder
Pre-trained Vision Transfomer(ViT)의 하나인 Masked autoencoders (MAE)의 enocder 사용하여 image encoder를 구성하였습니다. MAE는 high-resolution Input을 process하기위해 적용되었습니다. Image에 대해 여러 prompt가 존재한다면 Image encoder는 image당 한번만 실행됩니다.
3.2 Prompt encoder
Prompt를 (1) sparse(points, boxes, text)와 dense(mask)으로 나누어 정의하였습니다.
Sparse set에 해당하는 points와 boxes는 learned embeddings으로 합산된 positional encodings을 사용하여 표현하고 text는 CLIP으로부터 상용화된 text encoder로 text를 표현합니다. Dense set은 convolutions과 image embedding(from image encoder)과 함께 summed element-wise으로 embedding됩니다.
3.3 Mask decoder
Mask decoder의 역할은 image embedding, prompt embedding과 output token을 효과적으로 segmentation mask에 mapping하는 것입니다. Mask decoder는 transformer decoder block과 dynamic mask prediction head를 사용하였습니다. 해당 decoder block은 전체 embedding을 update하기위해 prompt self-attention과 cross-attention을 2가지 방향으로 사용하였습니다. ([1] prompt-to-image, [2] image-to-prompt embeddings) 아래와 그림과 같이 2개의 blocks이후에는 image embedding을 upsampling하고 MLP는 output token을 dynamic linear classifier에 mapping하게 됩니다. 그리고 각 image location에 대해 mask foreground probability를 계산합니다.
3.4 Resolving ambiguity
SAM은 하나의 prompt에 대해 여러개의 output masks를 predict하도록 합니다. 아래 사진과 같이 3개의 mask outputs이면 대부분 cases을 처리할 수 있다는 것을 발견하여 한 prompt에 대해 총 3개의 output masks를 예측하도록 학습하였습니다. Training시에 masks에 대해 minimum loss만 backprop하였다고 합니다. 그리고 3개의 masks에 대해 rank를 매기기 위해 confidence score(estimated IOU)를 예측하도록 하였습니다.
3.5 Losses and training
Focal loss와 Dice loss의 linear comibation으로 mask prediction에 대해 supervise learning진행하였습니다. 그리고 여러 geometric prompts를 섞어서 training 진행하여 단일로 prompt를 사용했을 때보다 robust하게 학습되도록 하였습니다. (text prompt사용 시에만인듯 합니다?) Mask마다 총 11번의 random sampling prompts를 진행하여 data engine에 SAM 모델이 integrate되도록 하였습니다.
4. Segment Anything Data Engine
Data engine은 3가지 stage로 구성: (1) model-assisted manual annotation stage, (2) semi-automatic stage, (3) fully automatic stage 되어있습니다.
4.1 Assisted-mamanual stage
전문적인 annotator들이 SAM을 기반한 browser-based segmentation tool을 통해 foreground/ background object를 클릭해가며 labeled mask를 만드는 stage입니다. (노가다ㅠ..) Labeling시에 object의 semantic constraints를 따로 두지않았다고 하며 stuff, things에 대해 자유롭게 labeling하도록 하였다고 합니다.
이 stage에 사용되는 SAM은 public segmentation datasets으로 학습되었으며 어느정도 충분한 data가 더 쌓이면 SAM을 retrain하였다고 합니다. 그리고 더 많이 쌓였다면 image encoder를 ViT-B에서 더 큰 모델인 ViT-H로 변경하였다고 합니다. 총 6번의 retraining작업을 진행하였습니다. (retraining이 진행됨에 따라 annotation속도가 빨라진다고 하네요) 이 stage에서 총 120k images에 대해 4.3M masks를 얻었다고 합니다.
4.2 Semi-automatic stage
이 단계에서는 masks의 다양성을 높여 SAM의 성능을 향상시키는 데 목표합니다. Annotator들에게 SAM으로부터 어느정도 masks labeling이 되어있는 이미지를 주고 annotate되지않은 부분을 추가적으로 annotate하도록 합니다. 즉, 덜 중요한 object들에 대해 초점을 더 맞춘것입니다. Confident mask를 detect하기 위해 첫 번째 stage에서 얻은 masks에 대해 generic object category를 이용하여 detect하도록 bounding box detector(Faster R-CNN)를 학습하였습니다. 해당 stage에서는 추가적으로 180k images에 대해 5.9M masks를 만들었습니다.
4.3 Fully automatic stage
여기서부터는 annotator들이 없습니다. 이전 2개의 stage을 통해 충분히 다양한 masks을 모아서 모델의 성능을 향상시켰고 여기서는 ambiguity-aware model을 만듭니다. 즉, 애매모호한 prompt가 입력으로 들어와도 납득할만한 masks를 출력하도록 합니다. 32x32의 regular grid에 point prompt을 입력으로 넣어 각 point에 대해 valid object를 segment하도록 합니다. 예를 들어, 한 point가 물체의 part 또는 subpart에 놓여있다면 SAM은 subpart, part, 전체 object에 대한 masks를 출력하도록 하는 것입니다. 그리고 IOU prediction module을 사용하여 confident(stable) masks만 선택되도록 합니다. (0.5-threshold부터 0.5+threshold값안에 속하는 것만 stable하다고 정의) 마지막으로 non-maximal suppression (NMS)을 통해 duplicate된 mask결과를 filtering합니다. 작은 masks를 찾기 위해 여러개의 overlapping zoomed-in image crop방식도 사용했다고 합니다. 이렇게 하여 총 11M image에 대해 1.1B masks를 생성해내었다고 합니다. (이 결과가 그대로 SA-1B dataset이 된것은 아닙니다!)
5. Segment Anything Dataset
5.1 Images
1,100만개의 image는 license가 있다고 하며 high-resolution(평균 3300x4950사이즈)이라고 합니다. 해당 이미지들을 SA-1B dataset으로 release할때는 1500 pixels까지 downsampling하였다고 합니다.
5.2 Masks
11억개의 masks를 만들었으며 그중 99.1%가 automatic하게 생성된 것이라 합니다. Automatic하게 생성된 masks가 사람이 annotate한 결과와 별반 다르지 않았다고 하여 SA-1B dataset은 오직 automatic하게 생성된 masks로만 구성되어있다고 합니다.
5.3 Mask properties
아래 사진은 SA-1B dataset과 다른 segmentation dataset간의 object center의 spatial distribution을 나타낸것입니다. SA-1B가 다른 dataset에 비해 image corner의 coverage가 뛰어난 것을 알 수 있습니다. 특히 COCO나 Open Images dataset은 중앙에 편향된 masks만 가진것을 볼 수 있습니다.
아래 Fig 6의 legend에서는 다른 dataset들간의 크기를 비교하였습니다. SA-1B가 2번째로 큰 Open Images dataset보다 11배 많은 image, 400배 많은 masks를 가진다고 합니다. Fig 6 왼쪽에서는 image당 mask의 distribution을 비교하였습니다. SA-1B가 다른 dataset들에 비해 한 image에 더 많은 mask label이 있는 것을 알 수 있습니다. Fig 6 중앙에서는 SA-1B가 masks의 개수가 많기 때문에 다른 dataset들에 비해 small, midium size의 masks가 많습니다. 마지막으로 Fig 6 오른쪽에서는 masks shape의 complexity를 측정한 표입니다. Complexity를 측정하기 위해 mask concavity(오목함)를 측정하였고 SA-1B dataset에는 작은 masks가 많지만 다른 dataset들과 비슷하게 concavity을 가진다는 것을 알 수 있습니다.
'AI paper review' 카테고리의 다른 글
Rate-Perception Optimized Preprocessing for Video Coding 논문 리뷰 (1) | 2023.12.31 |
---|---|
LoRA: Low-Rank Adaptation of Large Language Models 논문 리뷰 (0) | 2023.05.16 |
GPT-1: Improving Language Understanding by Generative Pre-Training 논문 리뷰 (0) | 2023.02.13 |
The Forward-Forward Algorithm: Some Preliminary Investigations 논문 리뷰 (0) | 2023.01.28 |
EfficientDet Scalable and Efficient Object Detection (0) | 2022.03.11 |