[논문리뷰] On Distillation of Guided Diffusion Models
📑

[논문리뷰] On Distillation of Guided Diffusion Models

Category
딥러닝
Tags
Deep Learning
Paper Review
Diffusion Model
Published
March 3, 2023
Author
Jay

On Distillation of Guided Diffusion Models

 
 
 

0. Abstract

Classifier-free guided diffusion model은 최근 high-resolution image generation에 좋은 효율성을 보였으나, class-conditional model과 unconditional model 두가지 모두를 evaluating 해야해서, inference time에 있어 오래걸린다는 단점이 있었다.
따라서 본 논문에서는, classifier free guided diffusion model을 distillation하는 방법을 제시한다. 먼저 conditional과 unconditional 모델이 결합된 것과 같은 output을 내는 single model을 학습시키고, 이를 distill한다. 이를 통해 FID/IS score를 유지 하면서도 256배 정도 빠르게 샘플링이 가능하다.
풀어서 설명하자면 식에서 conditional model인 와 unconditional model인 모두를 inference, evaluating해야해서 느리다는 것. 그래서 cfg diffusion model을 single model로 distillation(증류)해서 속도 개선하겠다는 것! (두개였던걸 하나로 줄이는 과정을 반복하면서 1024 step → 4 step으로 256배 빠르게!)
 

1. Introduction

본 논문에서는 cfg diffusion 모델의 sampling 효율을 높이기 위한 two-step distillation 접근방식을 제안한다. 첫번째 step에서는 teacher model의 two diffsuion model을 결합한 것과 같은 output을 내는 single student 모델을 도입하여 학습한다. 두번째 step에서는 첫번째 step에서 학습시킨 student 모델을 fewer-step 모델로 distill(증류)한다. 본 논문의 접근방식을 사용하여 single distilled model은 넓은 범위의 guidance 강도를 활용할 수 있다.
실험을 통해 제안된 distilled model이 4 steps만에 teacher와 비슷한 visual quality로 이미지를 생성할 수 있고, 8~16steps만을 통해 넓은 범위의 guidance 강도에서 teacher와 비슷한 FID/IS score를 보임을 확인했다. (Figure 1. 참조)
notion image
 

2. Background

딥러닝 모델의 Distillation
우리가 딥러닝을 활용해 인공지능 예측 서비스를 만든다고 가정해보자. 연구 및 개발을 통해 만들어진 딥러닝 모델은 다량의 데이터와 복잡한 모델을 구성하여 최고의 정확도를 내도록 설계되었을 것이다. 하지만 모델을 실제 서비스로 배포한다고 생각했을 때, 이 복잡한 모델은 사용자들에게 적합하지 않을 수 있다.
그렇다면 다음의 두 모델이 있다면 어떤 모델을 사용하는 게 적합할까
  • 복잡한 모델 T : 예측 정확도 99% + 예측 소요 시간 3시간
  • 단순한 모델 S : 예측 정확도 90% + 예측 소요 시간 3분
어떤 서비스냐에 따라 다를 수 있겠지만, 배포 관점에서는 단순한 모델 S가 조금 더 적합한 것으로 보인다.
그렇다면, 복잡한 모델 T와 단순한 모델 S를 잘 활용하는 방법도 있지 않을까? 바로 여기서 탄생한 개념이 지식 증류(Knowledge Distillation)입니다. 특히, 복잡한 모델이 학습한 generalization 능력을 단순한 모델 S에 전달(transfer)해주는 것을 말한다.
이 개념을 처음으로 제시한 논문에서는 복잡한 모델을 cumbersome model로, 단순한 모델을 simple model로 나누어 설명하고 있다. 하지만 이후 등장하는 논문에서는 일반적으로 Teacher model과 Student model로 표현하고 있다. 먼저 배워서 나중에 지식을 전파해주는 과정이 선생님과 학생의 관계와 비슷하여 이렇게 표현한다.
 
Classifier-Free-Guidance(CFG)란?
Classifier Guidance
Diffusion Model이 대두되기 전, 기존에 사용되던 GAN 생성 모델에서는 Truncation이나, Low temperature Sampling으로 Diversity를 줄여 샘플의 퀄리티를 개선해왔다. 그러나 이런 기법들은 Diffusion model에서는 효과적이지 못했고 이로 인하여 Diffusion 모델이 GAN보다 성능이 떨어지는 것으로 생각되었으나, Diffusion Models Beat GANs on Image Synthesis이라는 논문에서 classifier guidance라는 방법을 제시하면서 판이 뒤집혔다.
Classifier guidance는 conditional diffusion 모델의 학습 후 과정에서 샘플의 divesity와 fidelity를 trade-off 하는 방법으로 샘플링 과정의 score(negative log likelihood 함수의 1차 미분) 추정 함수 뒤에 gradient of the log likelihood of an auxiliary classifier model를 더하였다.
위 함수에서 는 노이즈가 추가된 데이터를 의미한다.
기존 함수에 특정 class에 대한 classifier의 negative log likelihood 를 더함(w는 조절가능한 가중치) 이 방법으로 classifier를 이용하여 guide를 제시하여 더 높은 퀄리티의 이미지를 얻을 수 있다. 하지만 classfier guidance를 쓰기 위해서는class에 대한 label이 필요하며 classifier를 따로 학습해 주어야한다는 큰 단점이 있다.
 
CFG: Classifier-Free Guidance
Classifier-Free Diffusion Guidance 논문에서는 Classifier Guidance가 noise level에 따른 classifier를 따로 학습시켜야 하고 classifier에 기반하여 모델을 평가하는 수치인 IS와 FID를 편법으로 올리는 방법(adversarial attack) 이라고 주장한다.
연구진들은 그래서 모델 자체만으로 작동하고 매우 간단하면서도, 효과가 확실한 방법인 Classifier-Free-Guidance를 제시했다.그 방법은 학습 코드 한 줄만 바꾸기만 해도 바로 적용가능하기 때문에 매우 간단하다. w는 조절가능한 가중치이다.
위 식에서 는 라는 텍스트 조건에 의해 conditioned(조건화) 된 것이고, 는 unconditioned 된 것이다. 이 식을 에 대하여 정리하면,
위 식을 보면 이 방법에서 guidance가 어떻게 작동 하는지 직관적으로 이해할 수 있다. sample의 conditional likelihood에서 unconditional likelihood를 뺀 값을 늘린다는 간단한 구조 덕분이다.
 
CFG의 역할
CFG =
stable-diffusion을 써보면 알겠지만 CFG값은 diversity를 늘려준다. 보통 CFG값을 11로 두고 사용하는데, 이보다 줄어들면 diversity가 낮아져 text에 더 combined된 정확한 이미지가 생성되고, 이보다 높이면 diversity가 높아져 더 다양하고 창의적인 장면을 생성할 수 있다.
즉, CFG는 sample quality와 diversity간의 tradeoff를 직접 다룰 수 있게 해준다.
 
 

3. Distilling a classifier-free guided diffusion model

주어진 학습된 guided model (teacher)에 대해, 본 접근 방식은 두 단계로 나뉜다.
Step One,
먼저 모든 time-step 에서 teacher의 output과 매치될 수 있는 continuous-time student model 를 학습가능한 파라미터 에 대해 정의한다. guidance strength의 범위인 에 대 student 모델에 대해 다음의 objective로 최적화한다.
notion image
이때, 이다. guidance strength 를 병합하기 위해, 가 student모델의 입력으로 주어지는 -conditioned model을 소개한다. 초기화(initialization)이 중요한 역할을 수행하므로, 우리는 teacher conditional 모델의 파라미터와 동일하게 student모델을 초기화했다. -conditioning에 관계된 새롭게 도입된 파라미터를 제외하고. 자세한건 아래와 같다.
The model architecture we use is a U-Net model similar to the ones used in Classifier-free diffusion guidance. We use the same number of channels and attention as used in Classifier-free diffusion guidance for both ImageNet 64x64 and CIFAR-10.
notion image
 
Step Two,
두번째 스텝에서는, 를 매 스탭마다 sampling step을 이등분해가며, fewer-step student인 로 distill한다. 주어진 에 대해 은 sampling step이라고 할 때, 우리는 student 모델을 한 step만에 teacher의 two-step DDIM sampling 결과와 매칭되도록 학습한다(i.e. from to and from to ).
step의 teacher 모델을 step의 student모델로 distilling 후에, 우리는 step student 모델을 다시 teacher로 설정하여 같은 절차를 반복할 수 있으며, teacher 모델을 step student 모델로 만들 수 있다. 각 스텝에서 student model을 teacher 모델과 같은 파라미터로 세팅한다. 디테일은 아래와 같다.
We first use the student model from Step-one as the teacher model. We start from 1024 DDIM sampling steps and progressively distill the student model from Step-one to a one step model. We train the student model for 50,000 parameter updates, except for sampling step equals to one or two where we train the model for 100,000 parameter updates, before the number of sampling step is halved and the student model becomes the new teacher model.
(첫번째 iteration에서는 teacher모델인 cfg에서 2번이었던 inference가 1번으로 줄어드는게 이해가 되는데, 그 이후에도 2배씩 계속 줄어든다는 부분은 이해가 안된다. Objective 구성에서 수학적으로 가능한 구조인건가..? teacher에서 student로 distill하는게 step을 절반만 가져갈 수 있으니, 그다음에도 동치로서 2배씩 줄어드는..?)
notion image
 
N-step deterministic and stochastic sampling
한번 가 학습되면, 에 대해 DDIM sampling이 가능하다. 주어진 distilled model 의 sampling procedure이 주어진 initialization 에 대해 deterministic하다는 것을 주의해야한다.
사실, N-step stochastic sampling도 가능하다. 다만, determinisitic sampler와 비교하여 stochastic sampling은 모델 약간 다른 time-steps에 evaluating하는게 필요하고, training algorithm을 약간 변형해야한다.
 

4. Experiments

본 논문의 접근방식으로 4-steps안에 경쟁력있는 FID/IS 스코어를 달성하는 것을 발견했다.
notion image
notion image
notion image
 

Reference