머신러닝/Vision

The Lottery Ticket Hypothesis

망나 2019. 6. 24. 11:39

ICLR 2019에서 Best Paper에 선정된 "The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural" 논문에 대한 정리입니다.

 

이 논문은 Network Pruning에 관련된 논문으로 pruning은 쉽게 말해 뉴럴 네트워크의 사이즈를 줄이는 방법을 말합니다. 네트워크의 pruning을 위해서는 pre-trained 된 네트워크에서 불필요한 연결을 제거하여 네트워크의 사이즈를 줄임으로써 기존의 네트워크와 동일한 정확도를 유지하면서도 더 빠른 처리 속도를 얻기 위해 사용되는 기법입니다. Pruning에 대해서는 1990년도의 Yann Lecun 교수님의 논문 "Optimal Brain Damage"에서 처음 다루어졌습니다.

 

오늘 소개할 논문으로 들어가기 전에 pruning에 대한 깊은 이해를 위해서 pruning에 대해 자세히 다룬 논문에 "Learning both Weights and Connections for Efficient Neural Networks"에 대해서 알아보도록 하겠습니다.

 

일단, 왜 pruning이 필요한가에 대한 이유를 이렇게 설명합니다. 큰 뉴럴 네트워크(많은 파라미터를 갖고 있는)는 아주 강력하지만 그만큼 많은 메모리를 차지하고 계산량도 어마어마합니다. 따라서 이를 수용하기 위한 좋은 성능의 하드웨어가 필요하고 이러한 제한 때문에 모바일에 적용하기가 힘들 수 있습니다.

 

그렇기 때문에 보통 over-parameterized 된 뉴럴 네트워크를 pruning을 통해서 중복되거나 불필요한 부분을 제거하여 위와 같은 문제를 해결하려 하는 것입니다.

 

논문에서 아래 그림과 같은 pruning 과정을 소개합니다.

 

Neural Network Pruning

 

  • Train Connectivity: Neural Network를 학습시킵니다. (어떤 connection이 중요한지를 판단하기 위해)
  • Prune Connections: 불필요한 connection을 제거합니다. (특정 threshold보다 낮은 weight를 불필요하다고 판단)
  • Train Weights: 남아있는 sparse Network를 재 학습시키며 pruning 이전과 비슷한 성능을 유지하도록 합니다.

하지만, pruning 뒤에 얻은 sparse 네트워크의 재 학습 과정에서 weight들을 random initialization 시켜 처음부터 다시 학습하는 방식을 이용하면 pruning 하기 전의 원래 네트워크와 비슷한 수준의 정확도를 얻기 힘들다는 문제가 발생했습니다. 이 문제에 대해서는 pruning 된 네트워크의 parameter 수가 적기 때문에 capacity가 그만큼 작아져 기존의 네트워크에 비해 학습이 잘 되지 않는다는 주장이 있습니다.

 

이 문제에 대해서 본 글에서 정리하고자 했던 논문 "The Lottery Ticket Hypothesis"의 저자는 pruning 된 네트워크를 잘 학습시키기 위한 방법을 제안했습니다.

 

논문 제목이 Lottery Ticket Hypothesis인 이유가 재밌는데 실제로 우리가 복권에 당첨되기 위해서 수많은 티켓을 사는

것을 parameter가 많은 큰 네트워크에 비유했습니다. 큰 네트워크에 포함된 subnetwork에는 당첨되지 않는 티켓과 당첨되는 티켓(저자는 논문에서 winning ticket이라고 표현했습니다 )을 포함하는데 여기서 우리는 winning ticket을 구분할 줄 알아야 한다는 거죠. 그러니까 쓸데없이 티켓을 많이 사지 말고 불필요한 weight들을 제거함으로써 당첨이 될 winning 티켓인 subnetwork만으로 네트워크를 구성해야 한다는 것이 저자들의 설명입니다.

 

그렇다면 본격적으로 논문에서 소개하는 방법에 대해서 알아보도록 하겠습니다.

 

The Lottery Ticket Hypothesis.

논문은 다음과 같은 가설을 전제로 연구를 진행합니다. 논문의 원문을 그대로 가지고 왔습니다.

A randomly-initialized, dense neural network contains a subnetwork that is initialized such that --when trained in isolation-- it can match the test accuracy of the original network after training for at most the same number of iterations.

정리하자면, 입력 \(x\)와 초기 파라미터 \(\theta = \theta_{0} \sim  D_{\theta}\)인 기존의 dense feed-forward neural network \(f(x;\theta)\)와 기존의 neural network를 mask \(m \in \left \{ 0,1 \right \}^{\left | \theta  \right |}\)으로 pruning 하여 얻은 subnetwork \(f(x;m\odot\theta)\)이 있을 때,

 

Network \(f(x;\theta)\)를 \(j\) iterations 만큼 학습시켰을 때 accuracy \(a\)를 얻고

Subnetwork \(f(x;m\odot\theta)\)를 \(j^{'}\) iterations 만큼 학습시켰을 때 accuracy \(a^{'}\)를 얻었다면

 

The lottery ticket 가설은 \(j^{ '}\leq j\)(학습에 필요한 iteration 수가 적고),  \(a^{'}\geq a\)(test accuracy가 높고),  \(\left \| m \right \|_{0}\ll \left | \theta  \right | \)(모델의 parameter 수도 적은3가지 조건을 만족하는 \(\exists m\)을 예측하는 것입니다.

 

즉, 쉽게 말하면 기존의 Network를 pruning 하여 기존보다 적은 수의 파라미터로 기존보다 덜 학습을 해도 test 정확도는 높은 Subnetwork를 얻는 것이 목적입니다.

 

하지만, 기존의 pruning 방법으로 subnetwork \(f(x;m\odot\theta)\)를 찾고 weights를 randomly reinitialized \(f(x;m\odot\theta^{'})\)하여 재 학습을 시킬 경우에는 기존의 neural network와 비슷한 수준의 정확도 얻을 수가 없는데 그 이유로는 network의 parameter 수가 줄어들기 때문에 이전처럼 효과적으로 학습이 이루어지지 않는다는 것입니다.

 


 

Method of Identifying winning tickets.

따라서 본 논문에서는 subnetwork(winning tickets)을 찾는 다음과 같은 방법을 소개합니다.

 

논문에서 제안하는 winning tickets을 식별하는 방법 (Identifying winning tickets)

 

방법은 정말 간단한데 1~3번까지는 기존의 neural network pruning과 동일하고 마지막 4번의 방법만 추가됐다고 볼 수 있습니다. 보시면 처음 초기화된 network weights \(\theta_{0}\) 값을 저장해 뒀다가 pruning 후 얻은 subnetwork의 weights를 초기값과 같은 \(\theta_{0}\)를 다시 넣어서 재 학습을 진행하는 간단한 방법을 소개하고 있습니다.

 

또한 기존의 pruning 방식이 network를 1번 학습시키고, \(p\)% 만큼 가지치기를 하고 나머지 weights를 초기화하는 one-shot 접근법이었다면, 본 논문에서는 \(n\) 라운드만큼 각 라운드마다 \(p^{1/n}\)%만큼 가지치기를 반복하는 iterative pruning 방법을 사용한다고 합니다.

 

제안 방법 도식화

 

정리하자면, 위 그럼 1~4를 총 \(n\) 번 반복시키는 것이 본 논문에서 제안하는 방법입니다.

 


 

Experimental Result.

 

제안하는 방법의 성능을 검증하기 위해서 저자들은 다음과 같은 Network를 사용하여 실험을 진행하였습니다.

 

논문에서 실험에 사용된 네트워크 구조들

 

 

- Winning tickets in fully-connected networks(Lenet) trained on MNIST - 

 

아래 그래프를 보면 제안하는 방법을 적용하게 되면 기존의 weights를 초기화하는 방법(reinit)뿐만 아니라 pruning을 하기 전 기존의 네트워크(100.0)에 비해서도 test accuracy가 높은 결과를 확인할 수 있습니다. 숫자의 의미는 prune 되고 남아있는 네트워크 parameter의 비율입니다. 즉 21.1은 기존의 네트워크에 비해 21.1%의 parameter만 남기고 78.9%의 parameter가 prune 된 subnetwork를 의미합니다.

 

Lenet 모델에 대한 Test accuracy (iterative pruning)

 

또한 pruning ~20%까지는 기존의 네트워크에 비해서 학습 속도가 향상되는 결과를 보였으며, ~13.5%까지는 test accuracy가 증가하는 결과를 보였습니다. 아래 그래프를 보면 해당 결과를 확인할 수 있습니다.

 

one-shot pruning과 iterative pruning에 대한 early-stopping iteration과 accuracy 비교 그래프

 

 

- Winning tickets in Convolutional Networks trained on CIFAR10 - 

 

Convolutional networks에서의 실험 결과를 보면, Conv-2/4는 최대 3.5x, Conv-6은 최대 2.5x 만큼 빨리 minimum validation loss에 도달했습니다.. 또한 test accuracy의 경우 Conv-2/4/6 각각 최대 3.4%, 3.5%, 3.3% 높은 결과를 얻을 수 있었다고 합니다.

 

Conv-2/4/6 네트워크에 대한 early-stopping iteration과 test accuracy에 대한 실험 결과

 

추가적으로 Dropout과 제안 방법을 함께 사용했을 때도 더 나은 성능을 보여주는 결과를 얻었습니다. 

 

Conv-2/4/6 iteratively pruned and trained with dropout

 

 

- Winning tickets in VGG and Resnet trained on CIFAR10 - 

 

마지막으로 VGG와 Resnet을 사용한 실험에서는 이전과는 다른 방식인 global pruning을 사용하였다고 합니다. 이전 실험에서 각 layer별로 pruning을 하였다면 global pruning은 모든 layer에 대해서 한번에 pruning을 하는 방식입니다. 이러한 방식을 사용한 이유는 네트워크의 layer에 따라 parameter 수의 차이가 크기 때문입니다. VGG-19의 경우에는 각 layer별 parameter 수가 1.7K, 36K, 2.35M개로 차이가 많이 납니다. 이러한 상황에서 각 layer별로 동일한 비율로 pruning을 진행하게 되면 첫번째 layer처럼 parameter가 작으면 bottleneck이 되어 winning tickets을 식별하는데 방해가 될 수 있습니다. 따라서 global pruning을 사용하였고 더 좋은 성능을 얻을 수 있었다고 합니다.

 

VGG-19에 대한 test accuracy(30k, 60k, 그리고 120k iterations)
Resnet-18에 대한 test accuracy(10k, 20k, 그리고 30k iterations)

 



Discussion.

winning ticket 초기화의 중요성. 네트워크 pruning 후 얻은 winning ticket의 weight를 random reinitialize하게되면 학습속도가 더 느려지고 테스트 정확도 또한 감소합니다. 하지만 논문에서 제안한 방법으로 초기 weight를 이용하면 더 빨리 학습하고 높은 테스트 정확도를 얻게 됩니다. 이러한 원인의 한가지 가능성은 네트워크의 초기 weight가 최종 학습 후에 얻는 값과 비슷하다는 가정을 할 수 있습니다. 하지만, 논문의 Appendix F의 실험에서 이 반대의 경우를 확인할 수 있었습니다 - winning ticket의 weight가 다른 weight들 보다 더 많이 변화하는 경우. 따라서 weight 초기화의 중요성은 그 자체가 아닌 optimization algorithm, dataset, 그리고 model과 연관이 있다고 볼 수 있습니다.

 

winning ticket 구조의 중요성. winning ticket을 발견하기 위해서 training data를 활용하기 때문에 winning ticket의 구조가 특정 문제에 대해 맞춤화된 inductive bias가 encode되었다고 볼 수 있습니다. 

 

winning ticket 일반화. 최근 연구에 따르면 neural network에서 더 압축될 수 있는 network의 일반화 경계가 더 엄격하다(tighter generalization bounds)는 것을 증명했습니다. Ther lottery ticket 가설은 이러한 관계를 보충 설명할 수 있습니다 - larger network는 명시적으로 간단한 표현을 포함하고 있습니다.

 

neural network 최적화의 의미. 최근 연구에서 충분히 overparameterize된 2 layer relu network를 SGD로 학습했을 경우 global optima에 도달한다는 것을 증명했습니다. 여기서 저자들은, SGD로 neural network를 최적화하는데 winning ticket의 존재는 필수적이고 SGD가 잘 초기화된 subnetwork를 찾아내고 학습시킨다고 추측합니다. 이러한 논리로 볼 때, overparametrize된 network는 winning ticket의 가능성을 가진 더 많은 subnetwork 조합을 포함하고 있기 때문에 더 쉽게 학습할 수 있다고 합니다.

 

 

 

ICLR 논문 발표 영상

 

'머신러닝 > Vision' 카테고리의 다른 글

Generative Adversarial Networks?  (0) 2020.08.12
About GAN  (0) 2020.08.12
ResNet 이해하기  (0) 2019.06.04
Generative Adversarial Networks (GANs)  (0) 2019.02.27
Batch Normalization?  (0) 2018.10.23