3 minute read

Deep_Learning study Lec09-4

Boostcourse의 ‘파이토치로 시작하는 딥러닝 기초’를 통한 공부와 정리 Post

Goal of Study

  • Batch Normalization 에 대해 알아본다.

Keyword

  • Batch Normalization
  • 경사 소실(Gradient Vanishing) / 폭발(Exploding)

1. 강의 요약

Gradient Exploding

Gradient Vanishing과 반대로 gradient를 앞 단으로 전파할 때, 이 값이 오히려 너무 큰 값이 되어버리는 경우를 Gradient Exploding이라고 한다.

Solution

  • Change activation function
  • Careful initialization
  • Small learning rate

위는 간접적인 해결 방식이고,

우리는 직접적인 해결 방식인 Batch Normalization이라는 방식을 사용할 것이다.

Internal Covariate Shift(ICS)

ICS가 Gradient Exploding이나 Vanishing 같은 문제를 만든다고 Batch Normalization 저자들은 주장한다.

먼저, Covariate Shift란 Train-set과 Test-set의 분포가 실제로 차이가 있고,
이 분포의 차이가 어떠한 문제점을 발생시킨다는 것이 Covariate Shift의 개념이다.

즉, 데이터 분포에 변화가 생긴다는 것이다.

그렇다면 ICS와 NN(Neural Network)에서는 ICS의 문제점이 어떻게 발생하는지 알아보자.

우리가 데이터를 학습할 때, Neural Network 모델에서 Loss를 계산하고 이를 Back Propagation을 통해 이 Network를 업데이트 하게 된다.

이런 상황에서 이제 ICS라는 문제가 발생하게 된다.

Neural Network의 모델에서 학습시 이 과정에서 layer를 지날때마다, 입력의 분포가 계속 변하게 된다.

결국에는 한 layer마다 Covariate Shift가 발생한다는 것이다. 이러한 현상을 Internal Covariate Shift라고 한다.

Batch Normalization

ReLU 계열의 함수와 He 초기화를 사용하는 것만으로도 어느 정도 기울기 소실과 폭주를 완화시킬 수 있지만,
이 두 방법을 사용하더라도 학습 중에 언제든 문제가 다시 발생할 수 있다.

Batch Normalization은 인공 신경망의 각 층에 들어가는 입력을 평균과 분산으로 정규화하여 학습을 효율적으로 만든다.

즉, 각 layer마다 Normalization을 하는 layer를 두어서 변형된 분포가 나오지 않도록 하는 것이다.

이를, 우리가 Neural Network를 학습할 때, mini-batch로 쪼개서 학습을 하는 일반적인 경우에
이 mini-batch들마다 Normalization을 해주겠다는 의미에서 Batch Normalization라고 한다.

위의 알고리즘이 Batch Normalization이라고 하는 방식을 계산하는 과정이다.

Input : 미니 배치 $B = {x^{(1)}, x^{(2)},…,x^{(m)}}$

배치 정규화는 학습 시 배치 단위의 평균과 분산들을 차례대로 받아 이동 평균($\mu$)과 이동 분산(variance, $\sigma$)을 따로 저장해놓는다.

이를 Sample mean, Sample Variance이라고도 부른다.

이들의 평균을 이용해 Learning mean과 Learning Variance을 계산을 하고, 이 둘은 Batch Normalization의 학습이 끝나게 되면 입력 batch 데이터에 상관이 없이 고정 값이 된다.

테스트 할 때는 해당 배치의 평균과 분산을 구하지 않고 구해놓았던 평균(Learning mean)과 분산(Learning Variance)으로 정규화를 한다.
(이전에 다룬 Train & eval mode에 관한 내용)

mnist_batchnorm

계속해서 MNIST학습 과정에서 개선하는 방향으로 코드에 있어 우리가 다뤄본 학습 방식 혹은 개선 방식에 대해
어느 부분이 개념에 따라 실제로 코드가 달라지는지를 다뤄보고 있다.

mnist_batchnorm을 사용하기 위해 달라지는 사항은 아래와 같다.

Batch Normalization을 적용한 것과 적용하지 않은 것의 성능을 비교하기 위한 코드이다.

# nn layers
# Batchnorm 사용
linear1 = torch.nn.Linear(784, 32, bias=True)
linear2 = torch.nn.Linear(32, 32, bias=True)
linear3 = torch.nn.Linear(32, 10, bias=True)
relu = torch.nn.ReLU()
bn1 = torch.nn.BatchNorm1d(32)
bn2 = torch.nn.BatchNorm1d(32)

# Batchnorm 사용x
nn_linear1 = torch.nn.Linear(784, 32, bias=True)
nn_linear2 = torch.nn.Linear(32, 32, bias=True)
nn_linear3 = torch.nn.Linear(32, 10, bias=True)

# model
bn_model = torch.nn.Sequential(linear1, bn1, relu,
                            linear2, bn2, relu,
                            linear3).to(device)
nn_model = torch.nn.Sequential(nn_linear1, relu,
                            nn_linear2, relu,
                            nn_linear3).to(device)

Full code

전체 코드는 강의에서 제공하는 Code를 대부분 참고해서 학습해보고, 필요하다면 따라서 코딩해본 본인 코드를 포스팅한다.
Full Code를 참고하면 도움이 될 것 같다.

따로 코드는 포스팅하지 않고 학습 결과만 포스팅하면 다음과 같다.

[Epoch 1-TRAIN] Batchnorm Loss(Acc): bn_loss:0.12966(bn_acc:0.96) vs No Batchnorm Loss(Acc): nn_loss:0.18779(nn_acc:0.94)
[Epoch 1-VALID] Batchnorm Loss(Acc): bn_loss:0.14066(bn_acc:0.96) vs No Batchnorm Loss(Acc): nn_loss:0.20529(nn_acc:0.94)

[Epoch 2-TRAIN] Batchnorm Loss(Acc): bn_loss:0.09642(bn_acc:0.97) vs No Batchnorm Loss(Acc): nn_loss:0.16732(nn_acc:0.95)
[Epoch 2-VALID] Batchnorm Loss(Acc): bn_loss:0.11711(bn_acc:0.96) vs No Batchnorm Loss(Acc): nn_loss:0.18843(nn_acc:0.94)

[Epoch 3-TRAIN] Batchnorm Loss(Acc): bn_loss:0.08405(bn_acc:0.97) vs No Batchnorm Loss(Acc): nn_loss:0.15942(nn_acc:0.96)
[Epoch 3-VALID] Batchnorm Loss(Acc): bn_loss:0.11109(bn_acc:0.97) vs No Batchnorm Loss(Acc): nn_loss:0.20730(nn_acc:0.95)

[Epoch 4-TRAIN] Batchnorm Loss(Acc): bn_loss:0.08102(bn_acc:0.97) vs No Batchnorm Loss(Acc): nn_loss:0.14705(nn_acc:0.96)
[Epoch 4-VALID] Batchnorm Loss(Acc): bn_loss:0.12212(bn_acc:0.96) vs No Batchnorm Loss(Acc): nn_loss:0.18678(nn_acc:0.95)

[Epoch 5-TRAIN] Batchnorm Loss(Acc): bn_loss:0.06202(bn_acc:0.98) vs No Batchnorm Loss(Acc): nn_loss:0.12650(nn_acc:0.96)
[Epoch 5-VALID] Batchnorm Loss(Acc): bn_loss:0.09171(bn_acc:0.97) vs No Batchnorm Loss(Acc): nn_loss:0.17650(nn_acc:0.96)

[Epoch 6-TRAIN] Batchnorm Loss(Acc): bn_loss:0.06397(bn_acc:0.98) vs No Batchnorm Loss(Acc): nn_loss:0.12816(nn_acc:0.96)
[Epoch 6-VALID] Batchnorm Loss(Acc): bn_loss:0.10242(bn_acc:0.97) vs No Batchnorm Loss(Acc): nn_loss:0.19704(nn_acc:0.95)

[Epoch 7-TRAIN] Batchnorm Loss(Acc): bn_loss:0.05521(bn_acc:0.98) vs No Batchnorm Loss(Acc): nn_loss:0.12984(nn_acc:0.96)
[Epoch 7-VALID] Batchnorm Loss(Acc): bn_loss:0.09313(bn_acc:0.97) vs No Batchnorm Loss(Acc): nn_loss:0.19197(nn_acc:0.96)

[Epoch 8-TRAIN] Batchnorm Loss(Acc): bn_loss:0.05698(bn_acc:0.98) vs No Batchnorm Loss(Acc): nn_loss:0.12324(nn_acc:0.97)
[Epoch 8-VALID] Batchnorm Loss(Acc): bn_loss:0.10201(bn_acc:0.97) vs No Batchnorm Loss(Acc): nn_loss:0.18410(nn_acc:0.96)

[Epoch 9-TRAIN] Batchnorm Loss(Acc): bn_loss:0.05078(bn_acc:0.98) vs No Batchnorm Loss(Acc): nn_loss:0.11740(nn_acc:0.97)
[Epoch 9-VALID] Batchnorm Loss(Acc): bn_loss:0.09656(bn_acc:0.97) vs No Batchnorm Loss(Acc): nn_loss:0.19621(nn_acc:0.96)

[Epoch 10-TRAIN] Batchnorm Loss(Acc): bn_loss:0.04743(bn_acc:0.99) vs No Batchnorm Loss(Acc): nn_loss:0.12933(nn_acc:0.96)
[Epoch 10-VALID] Batchnorm Loss(Acc): bn_loss:0.09120(bn_acc:0.97) vs No Batchnorm Loss(Acc): nn_loss:0.21004(nn_acc:0.95)

Learning finished

📣
본 포스팅의 학습 환경 : Anaconda, CPU, Pytorch, Jupyter Notebook
포스팅에서 오류나 궁금한 점은 Comments를 작성해주시면, 많은 도움이 됩니다.💡

Leave a comment