(23.01.25)
medical ai 논문 리뷰 4탄
이번에 정리할 논문은 3D medical image segmentation task에 transformer를 적용한 논문이다.
- 논문 제목: UNETR: Transformers for 3D Medical Image Segmentation (2021)
- https://arxiv.org/pdf/2103.10504v3.pdf
Abstract
저자는 UNEt TRansformers (UNETR)를 제안한다.
- sequence representiation을 학습하고, global multi scale information을 포착하기 위해 transformer를 encoder로 사용한다.
- encoder와 decoder 아키텍쳐로는 U-shapped network 를 사용한다.
- transformer encoder는 different resolutions와 skip connections를 통해 decoder와 연결된다.
dataset
- BTCV (Multi Atlas Labeling Beyond The Cranial Vault) dataset for multi-organ segmentation
- MSD (Medical Segmentation Decathlon) dataset for brain tumor and spleen segmentation tasks
Introduction
transformer encoder로부터 추출된 representations는 skip connections를 통해 CNN-based decoder와 합쳐진다.
-> decoder에서 transformer가 아니라 CNN-based decoder를 사용하는 이유는 transformer가 localized information을 포착하지 못하기 때문이다.
> 저자가 말하는 main contribution
- propose a novel transformer-based model for volumetric medical imgae segmentation
- transformer encoder directly utilizes the embedded 3D volumes to capture long-range dependencies
- a skip-connected decoder combines the extracted representations at different resolutions and predict segmentation output
- BTCV, MSD dataset 활용, BTCV에서 sota 달성
Method
간단하게 요약하면, 3D image를 1D patches로 바꾸어 transformer의 encoder로 넣고, CNN-based decoder에 넣어 segmentation ouptput을 도출하는 것이다!!
transformer encoder는 기존의 방법과 유사하다.
먼저, 3D input으로 부터 1D sequence를 생성한다.
(P, P, P)는 각 patch의 resolution이다.
그리고, 1D learnable positional embedding 추가하고, patch embedding 해준다.
transformer backbone은 semantic segmenation을 위한 것이므로, class token은 추가하지 않는다.
transformer block은 총 12번 반복된다. transformer block은 위의 그림에 나와있는대로 구성되어 있다.
3, 6, 9, 12번째 수행한 결과는 decoder에서 skip connection을 통해 연결된다.
loss function
- soft dice loss + cross-entropy loss
evaluation metrics
- Dice score
- 95% Hausdorff Distance (HD)
implementation details
- batch size 6
- AdamW optimizer with initial learning rate of 0.0001 for 20,000 iterations
- transformer-based encoder follows ViT-B16 with L=12 layer, embedding size of K=768
- patch resolution 16x16x16
- five-fold cross-validaions (95:5)
Experiments