AugMix

10 minutes to read

AugMix

배경

머신러닝 모델은 학습 데이터의 능력에 크게 의존한다. 따라서 학습 데이터와 테스트 데이터의 분포가 mismatched인 data shift 상황에서 accuracy가 감소되는 문제점이 있다.

이를 해결하려는 기존의 방식들은 크게 세 가지 문제점을 가지고 있다.

  1. Data Memorizing : 테스트 데이터와 유사한 입력 데이터로 모델을 학습시키는 것에 불과해서, 완전히 새로운 데이터를 접하는 경우에는 역시 robust하지 못하다. 즉, corruptions을 학습하는 건 해당 corruptions에서만 robust할 뿐이지 완전히 새로운 corruption에 대해서는 일반화되지 못한다.
  2. Trade-off : clean accuracy*와 robustness는 trade-off 관계이다. 따라서 robust해졌지만 clean accuracy는 감소하는 경우가 많다.
    • (참고) clean accuracy* : corrupted 되지 않은 원본 입력으로 테스트하여 얻은 정확도
  3. Augmented Data with Degraded Performance : 다양성을 증가시키기위해 augmentation primitives를 chain에 직접적으로 구성한다. 즉, augmented data는 original data의 manifold에서 크게 벗어날 수 있다.


개념

AugMix는 image classifier에서 robustness와 uncertainty estimates를 향상시키는 기술로, 구현이 쉽고 학습 단계에서 본 적 없는 corruptions에도 모델이 강인해질 수 있도록 도와준다.

특징

AugMix는 이전의 data augmentation 기법과 차별되는 다음의 특징을 갖는다.

  • augmentation operations가 확률적으로 샘플링됨
    ※ 이후 “알고리즘” 섹션에서 설명할 “1. Augmentation”과 관련 ※
  • 다양한 augmented images를 생산하기 위해 augmentation chains이 layered됨
    ※ 이후 “알고리즘” 섹션에서 설명할 “2. Mixing”과 관련 ※
  • Jensen-Shannon divergence를 consistency loss로 사용함으로써 다양한 augmentations에서도 일관성을 유지할 수 있음
    ※ 이후 “알고리즘” 섹션에서 설명할 “3. Jensen-Shannon Divergence Consistency Loss”와 관련 ※

알고리즘

AugMix는 몇 가지 augmentation chains을 convex combinations하여 사용함으로써, 다양성과 일관성을 보장했다.

AugMix에서 “다양성”과 “일관성”이라는 말은 굉장히 중요하다. 기존이 augmentation 기법은 “다양성”과 “일관성”이 일종의 trade-off 관계였기 때문이다. 가령, 다양성을 주려고 다양한 연산을 적용하다보면 증강된 데이터가 원본 데이터와 크게 멀어지는 경우가 이에 해당한다. AugMix는 “다양성”과 “일관성”을 모두 지켰다는 점에서 의의가 크다.

image-20220302184953251

image-20220302184841118

AugMix 알고리즘은 크게 ①Augmentation②Mixing, 그리고 ③Jensen-Shannon Divergence Consistency Loss로 나눌 수 있다.

① Augmentation은 “다양성”과 관련이 있고, ②Mixing은 “일관성”과 관련이 있다.


1. Augmentation

Augmentation은 data augmentation의 ‘다양성‘을 보장하기 위한 과정이다.

image-20220303101035678

3:	Fill x_aug with zeros
		...
5:	for i = 1, ..., k do
6:		Sample operations op1, op2, op3 ~ Ο
7:		Compose operations with varying depth op12 = op2·op1 and op123 = op3·op2·op1
8:		Sample uniformly from one of these operations chain ~ {op1, op12, op123}

AutoAugment[개념]를 통해 주어진 데이터에 대해 최적의 augmentations 기법을 선택한다. 이후, 테스트 과정에서 사용할 ImageNet-C[개념] 데이터와 겹치는 augmentation 기법은 학습 과정에서 배제한다. (e.g., image noising과 image blurring 연산은 학습 과정에서 배제) 또한 augmented data가 기존 data manifold에서 벗어나지 않도록 각 augmentation의 강도를 적절히 설정한다. (e.g., rotation operations의 회전 강도, like 2º or -15º)

augmentation 과정은 augmentation chain을 랜덤으로 $k$(default:3)개 생성하는데, 이때 각 augmentation chain은 랜덤으로 선택 된 1~3가지 augmentation operations로 구성된다.


2. Mixing

Mixing은 augmented data가 original data의 manifold에서 벗어나지 않도록 ‘일관성‘을 보장하기 위한 과정이다.

image-20220303102414162

4:	Sample mixing weights (w1, w2, ..., wk) ~ Dirichlet(α, α, ..., α)
5:	for i = 1, ..., k do
		...
9:		x_aug += w_i · chain(x_orig)
10:	end for
11:	Sample weight m ~ Beta(α, α)
12:	Interpolate with rule x_augmix = m*x_orig + (1-m)*x_aug
13: return x_augmix
14: end function

Augmentations 과정을 통해 생성된 $k$개의 augmentation chains은 mixing 과정에서 결합된다. 알파 합성(Alpha Compositing)에 의한 mixing을 단순하기 구현하기 위해 여기서는 elementwise convex combinations를 사용한다. convex coefficients의 $k$차원 벡터는 $\text{Dirichlet}{(\alpha, \alpha, ..., \alpha)}$ 분포[개념]로부터 랜덤하게 샘플링되어 각 augmentations chains의 결과에 곱해진 후 하나의 결과로 합쳐진다. 이후, 하나로 합쳐진 결과는 $\text{Beta}{(\alpha, \alpha)}$ 분포[개념]로부터 샘플링된 두 번째 random convex combination을 통해 original image와 합쳐진다.

즉, 최종적인 이미지는 연산 선택, 이러한 연산의 강도, augmentation chains의 길이, 그리고 mixing weights의 선택에 의해 여러 개의 랜덤한 sources가 통합된 결과이다.


3. Jensen-Shannon Divergence Consistency Loss

위의 과정을 통해 구한 augmentation scheme을 loss에 결합하여 신경망의 응답을 부드럽게 만든다. (mixing 과정을 통해) 이미지의 semantic content를 거의 보존했기 때문에 loss를 계산할 때 $x_{\text{orig}}, x_{\text{augmix1}}, x_{\text{augmix2}}$는 유사한 정도로 모델에 내장된다.

🗣️ “유사한 정도로 모델에 내장된다”는 것은, $x_{\text{augmix}}$을 $x_{\text{orig}}$에 비해 특별히 적은 가중치로 반영하지 않겠다는 의미이다.

이를 위해, original sample $x_{\text{orig}}$와 이것의 augmented 변형들의 posterior distributions에 대해 Jensen-Shannon divergence[개념]를 최소화한다. 즉, $p_{\text{orig}}=\hat{p}{(y \mid x_{\text{orig}})}$, $p_{\text{augmix1}}=\hat{p}{(y \mid x_{\text{augmix1}})}$, $p_{\text{augmix2}}=\hat{p}{(y \mid x_{\text{augmix2}})}$, 그리고 original loss $\mathcal{L}$에 대해 최종적인 loss는 다음과 같다. \(\mathcal{L}(p_{\text{orig}}, y) + λ\textbf{JS}(p_{\text{orig}}; p_{\text{augmix1}}; p_{\text{augmix2}}).\)

\[\textbf{JS}(p_{\text{orig}}, p_{\text{augmix1}}, p_{\text{augmix2}}) = \frac{1}{3}(\textbf{KL}[p_{\text{orig}} \mid\mid M] + \textbf{KL}[p_{\text{augmix1}} \mid\mid M] + \textbf{KL}[p_{\text{augmix2}} \mid\mid M]).\]

이러한 loss는 세 개의 $p_{\text{orig}}$, $p_{\text{augmix1}}$, $p_{\text{augmix2}}$ 분포 중 어떤 한 샘플이 나타내는 샘플 분포의 동일성에 대한 평균 정보로 이해할 수 있다.

  • 참고
    • <a name=Jensen-ShannonDivergence>Jensen-Shannon divergence</a>
      • KL Divergence에 비해 upper bounded.
      • 모델을 다양한 입력 범위에 대해 ①stable, ②consistent, ③intensive하게 만들어준다 (Bachman et al., 2014; Zheng et al., 2016; Kannan et al., 2018).
    • 두 개의 augmentation chains을 사용하는 이유
      • $\textbf{JS}(p_{\text{orig}}; p_{\text{augmix1}})$은 성능을 향상시키기에 부족하다.
      • $\textbf{JS}(p_{\text{orig}}; p_{\text{augmix1}}; p_{\text{augmix2}}; p_{\text{augmix3}})$은 성능을 너무 조금 향상시킨다. (‘굳이’인 느낌)


Experiments

Datasets

실험에 사용되는 데이터셋은 크게 CIFAR-10, CIFAR-100, 그리고 ImageNet으로 나뉜다.

  • CIFAR (Krizhevsky & Hinton, 2009)
    • 32×32×3 color natural images
    • 50,000 training & 10,000 testing images
    • CIFAR-$N$ has $N$ categories
  • ImageNet (Deng et al., 2009)
    • 1,000 classes
    • approximately 1.2 milion large-scale color images

실험에는 이러한 데이터셋의 변형된 버전도 사용되는데, 논문에 나와있는 용어는 아니지만, 여기서는 편의에 따라 “Dataset-C”와 “Dataset-P”로 나눈다.

Dataset-C

img

data shift 상황에서의 네트워크 성능을 평가하기 위해 원본 데이터셋에 corruption을 추가한 CIFAR-10-C, CIFAR-100-C, ImageNet-C (Hendrycks & Dietterich, 2019) 데이터셋을 사용한다. 각 데이터셋은 세 가지 타입(blur, weather, and digital corruption), 총 15개의 노이즈가 각각 5가지의 강도로 추가된다.

Dataset-P

imgimgimg

classifier의 예측 안정성(prediction stability)을 평가하기 위해 원본 데이터셋에 dataset-C에 비해 작은 약간의 변화를 추가한 CIFAR-10-P, CIFAR-100-P, ImageNet-P 데이터셋을 사용한다. 각 데이터셋은 영상 타입이며, 예를 들어 시간이 지남에 따라 밝기가 점점 증가한다는 식으로 변형되었다. 이를 통해 밝기가 증가함에 따라 영상 프레임간의 일관되지않거나 급격한 예측이 발생하는 지를 확인한다.


Metrics

네트워크 성능을 평가하기 위한 메트릭으로 본 논문에서는 세 가지를 제안한다.

mCE(Mean Corruption Error)

  • Clean Error : 변형되지 않은 clean data에 대한 일반적인 classification error
  • Corruption Error : 변형된 corrupted data에 대한 classification error
    • $E_{c,s}$ : corruption $c$가 $s(1≤s≤5)$의 강도로 주어졌을 때 error
    • $\text{u}CE_c = \sum_{s=1}^5{E_{c,s}}$ : unnormalized corruption error
    • $CE_c = \sum_{s=1}^5{E_{c,s}}/\sum_{s=1}^t{E_{c,s}^{\text{AlexNet}}}$ : normalized corruption error
    • $\text{m}CE$ : 15가지 corruptions에 대한 평균적인 error

mFP(mean Flip Probability), mFR(mean Flip Rate)

영상 프레임간의 예측 안정성과 관련된 Perturbation robustness를 측정하기 위해 flip probability를 계산한다. flip probability는 인접 프레임간에 미묘한 차이가 존재할 때, 그 예측이 달라지는(이를 “flipped”라고 표현한다) 확률을 의미한다. 10가지perturbation 종류에 대해 평균값을 계산한 것을 mean Flip Probability (mFP)라고 한다. ImageNet-C에서는, AlexNet의 flip probabilities를 normalization하여 mean Flip Rate(mFR)을 계산한다.

CE(Calibration Error)

모델의 uncertainty estimates를 평가하기 위해 모델의 miscalibration을 측정한다. 여기서 “calibrated”는 classifiers가 출력한 정확도를 신뢰할 수 있는 능력을 의미한다. 예를 들어, calibrated model이라면, 70%의 정확도로 예측했을 때 그 신뢰도 역시 70%여야 한다.

  • idealized RMS Calibration Error : $\sqrt{E_C[(P(Y=\hat{Y} \mid C=c) - c)^2]}$

    주어진 신뢰도 c에서의 accuracy와 실제 신뢰도 c의 squared difference를 의미한다.


CIFAR-10 and CIFAR-100에 대한 실험

Training Setup

  • Learning rate : initial learning rate = 0.1, which decays following a cosine learning rate
  • Input images : All input images are pre-processed with standard random left-right flipping and cropping prior to any augmentations.
  • AUGMIX parameters : We do not change AUGMIX parameters across CIFAR-10 and CIFAR-100 experiments for consistency.
  • Epochs : The All Convolutional Network and Wide ResNet train for 100 epochs, and the DenseNet and ResNeXt require 200 epochs for convergence.
  • Optimization : We optimize with stochastic gradient descent using Nesterov momentum.
  • Weight decay : We use a weight decay of 0.0001 for Mixup and 0.0005 otherwise.


Result①: Corruption Error

CIFAR-10-C와 CIFAR-100-C에 대해 여러 augmentation 기법 적용에 따른 corruption error 결과는 다음과 같다. Figure 5는 CIFAR-10-C에 대한 Corruption Error를 그래프화한 것이고, Table 1은 결과를 표로 나타낸 것이다.

image-20220303145123150

image-20220303145916172

AUGMIX는 비교한 여러 augmentation 중 가장 낮은 Corruption Error를 달성했다.

image-20220303151502989

  • Standard에서
    • 비교적 예측이 힘든 corruption 종류 : Gaussian Noise, Glass Blur, Impulse Noise, Shot Noise
    • 비교적 예측이 쉬운 corruption 종류 :Brightness, Fog
  • AugMix에서
    • 비교적 예측이 힘든 corruption 종류 : Gaussian Noise, Glass Blur, Pixelate, Shot Noise
    • 비교적 예측이 쉬운 corruption 종류 : Brightness, Defocus Blur, Zoom Blur, Motion Blur, Fog

image-20220303151553607

위 그림에서 노란색 막대는 Standard와 AugMix 간의 corruption error 차이를 시각화한 것이다. 이때 특히나 그 차이가 심한 부분을 검은색 윤곽선으로 표현했다.

  • corruption error에서 AugMix가 비교적 좋은 결과를 야기하는 corruption 종류 : Gaussian Noise, Shot Noise, Impulse Noise, Glass Blur
  • corruption error에서 AugMix가 성능에 큰 영향을 주지 못하는 corruption 종류 : Fog, Brightness
  • 전체적으로 AugMix는 corruption error를 대략 11~17.8%정도 감소시킨다.

  • (Standard와 AugMix 두 결과에 의하면) Noise corruptions와 Glass Blur, 그리고 Pixelate는 예측이 어렵고, Brightness와 Fog는 예측이 쉽다.

    하지만 Standard와 비교했을 때, AugMix에서는 corruption 종류에 따른 Corruption Error 차이가 덜하다. 즉, AugMix일 때 네트워크가 더 robust해졌다.

Result②: Flip Probability & Calibration Error

image-20220303154039743

  • (Figure 6의 좌측 자료로부터) AugMix는 비교적 perturbation에 robust함을 확인할 수 있다.
  • (Figure 6의 우측 자료로부터) AugMix가 clean data와 corrupted data간의 RMS Calibration Error 차이를 효과적으로 감소시켰음을 확인할 수 있다.

image-20220303155419448

  • CIFAR-10 Clean Error
    • Adversarial training에서 현저히 높은 에러를 보여준다.
    • 그 외 기법간의 에러 차이는 근소하다.
  • CIFAR-10-P mean Flip Probability
    • Adversarial training과 AugMix에서 낮은 에러를 보여준다.

∴ AugMix는 clean error는 최대한 보존하면서도 mean Flip Probability는 낮춰주는 데에 효과적이다.

ImageNet에 대한 실험

Baselines

큰 규모의 classes를 갖는 ImageNet에서의 성능 비교를 위해, 다음의 기준으로 기존의 방법과 AugMix의 성능을 비교할 대상을 선정하였다.

  • Cutout은 ImageNet 규모에서 효과적이라는 것이 증명되지 않았으므로 ImageNet-C에서 성능이 검증된 Stylized ImageNet을 비교 대상으로 결정했다.
  • Patch Uniform은 랜덤으로 선택된 이미지 영역에 uniform noise를 주입한다는 것을 제외하고는 Cutout과 유사하므로 비교 대상으로 삼았다. 참고로 원논문에서는 학습시 Gaussian noise를 사용하나, 테스트 대상 데이터와 중복되므로 이를 uniform noise로 바꾸어 성능을 평가했다. 대략 30개가 넘는 hyperparameters를 수정했다.
  • AutoAugment는 높은 성능을 달성하는 augmentation policy를 찾아준다. 이때 출력 결과로 나온 최적의 augmentations 중 테스트 데이터와 겹치는 corruptions은 제거했으므로 논문에서는 AutoAugment*****로 표기한다.
  • Random AutoAugment는 AutoAugment*을 사용하여 랜덤으로 샘플링된 augmentation policy를 갖는다. AutoAugment와 비교했을 때 더 적은 계산으로 더 다양한 augmentation를 제공한다는 특징이 있다.
  • MaxBlurPooling은 최근에 제안된 architectural modification으로, pooling 결과를 smooth하게 만들어준다.
  • Stylized ImageNet (SIN)은 원본 ImageNet 데이터에 추가적으로 style transfer가 적용된 ImageNet을 사용하여 모델을 학습시키는 기술이다

여기에 추가적으로 SIN과 AugMix를 결합한 것 역시 비교 대상으로 삼았다.

Training Setup

학습에는 Goyal et al. (2017)의 표준 training scheme을 따르는 ResNet-50이 사용된다.

  • learning rate : we linearly scale the learning rate with the batch size, and use a learning rate warm-up for the first 5 epochs.
  • **epochs **: AutoAugment and AUGMIX train for 180 epochs.
  • Preprocessing : All input images are first pre-processed with standard random cropping horizontal mirroring.

Result①: Corruption Error

image-20220305184335714

  • AugMix는 standard와 비교했을 때 Clean Error는 23.9→22.4%로 1.5% 감소시킴과 동시에 mean Corruption Error는 80.6→68.4%로 12.2% 감소시켰다. 기존의 augmentation 기법들이 clean error와 corruption error가 trade-off 관계가 있었다는 점을 감안할 때 의미있는 결과이다.
  • AugMix는 다른 augmentation 기법과 결합하여 사용하기에도 좋게 만들어졌다. 여기서는 SIN(Stylized ImageNet) 기법과 결합하여 사용해 본 결과, AugMix만을 사용했을 때 보다 더 좋은 성능을 거두었음을 보여주었다.

Result②: Flip Rate

image-20220305184649885

  • Flip Rate는 Perturbation robustness 측정 지표이다. ImageNet-P에 각 augmentation 기법을 테스트한 결과, 가장 낮은 mean Flip Rate를 달성하였음을 확인할 수 있다.

  • 하지만 테스트 데이터에 적용된 corruption의 종류에 따라 다른 Flip Rate를 보여주는데, noise corrruption에 대해서는 Patch Uniform 기법이 더 성능이 좋음을 알 수 있다.

    Patch Uniform은 랜덤으로 선택된 이미지 영역에 uniform noise를 주입”하는 augmentation 기법이다. 저자는 모든 종류의 augmentation 기법에서 테스트단에서 마주칠 수 있는 종류의 corruptions은 학습에서 배제했는데, 여기서는 Gaussian이나 Shot noise가 아니라는 이유에서 허용된 것으로 보인다. 사실상, 다른 분포를 갖는 노이즈를 학습한 것이므로 성능이 특히 더 좋아진 게 아닐까 싶다. 하지만 이는 역시 기존의 Corruption Memorizing한다는 문제점에 해당한다.



참고자료

AutoAugment

What is the difference between 6 and 9? - Quora

지나친 data augmentation은 manifold intrusion으로 이어질 수 있다. 따라서 이전까지는 데이터 도메인에 따라 어떤 augmentation을 적용할지 선택하는 과정이 수동으로 이루어졌다 (e.g., “새의 종류를 판별하는 문제에서는 histogram color swapping과 같은 augmentation을 사용하면 안되겠군”). AutoAugment주어진 데이터에서의 효과적인 augmentation을 찾는 과정을 '**자동으로**' 수행하는 것을 목표로 한다. 이를 위해 탐색공간(search space)을 정의하여 최적화된 augmentation 기법을 찾는다. 이때 탐색공간은 하나의 augmentation policy는 5개의 sub-policies로 구성하고, 각 sub-policy는 2개의 image operation으로 구성하여 정의된다. 각 sub-policy가 생성하는 augmentation 기법은 다음과 같다. \(\text{1개의 augmentation policy}\\ = \text{2개의 이미지 연산}\\ =\text{이미지 연산}×\text{이미지 연산}\\ =(\text{augmentation 기법} × \text{확률} × \text{강도})^2\) AutoAugment는 규모가 작은 데이터셋에서 특히나 더 뛰어난 성능을 보여주지만, 높은 계산 복잡도로 인해 일반적인 연구 환경에 적용이 어렵다는 단점이 있다.

AugMix에서는 실험 데이터에 적용할 적절한 augmentation 기법을 뽑아내기위해 AutoAugment를 사용했다. 이후, 뽑아낸 최적화된 augmentation 기법들 중 ImageNet-C와 겹치는 corruptions에 대해서는 배제하고 실험했다.


Dirichlet & Beta Distribution

Dirichlet distributionBeta distribution은 주로 image classification에서 사용되는 parametric distribution*이다.

  • (참고) parametric distribution* : 모수를 가정하여 추정하는 분포. 데이터가 적어도 잘 동작한다는 장점이 있지만, 모수에 영향을 받으므로 데이터에 따른 분포 업데이트 반영이 어렵다는 단점이 있다.

Beta distribution는 univariate 특성을 가지며 매개변수 $\alpha$, $\beta$에 대해 [0,1] 범위에서 정의되는 연속확률 분포이다.

img

\[f(x; α, β) = \frac{\Gamma{(\alpha+\beta)}}{\Gamma{(\alpha)}\Gamma{(\beta)}}x^{\alpha-1}(1-x)^{\beta-1}\\ \Gamma{(n)} = (n-1)!\]

Dirichlet distribution는 multivariate 특성을 가지며 다음의 연속확률 분포이다.

\[B(\alpha) = \frac{\Pi_{i=1}^k{\Gamma{(\alpha_i)}}}{\Gamma{(\sum_{i=1}^k{\alpha_i})}}\]



References