← Back to feed
Papers·5일 전

HKUST MaskAlign — diffusion transformer 학습 가속을 위한 토큰 부분 정렬

HKUST MaskAlign — diffusion transformer 학습 가속을 위한 토큰 부분 정렬

HKUST 팀이 diffusion transformer 학습 시 clean-image representation 정렬의 비효율을 해결하는 MaskAlign을 제안했습니다. 기존 full-token 정렬은 토큰별 gradient 불균형을 유발해 모델이 모든 clean-image 토큰에 의존하게 만드는 문제가 있었는데, 무작위 토큰 부분집합에 대해서만 정렬을 적용해 이 의존성을 완화했습니다. 추가로 pre-mask token mixing block으로 정보 손실을 보충했으며, ImageNet 256×256에서 FID 1.81을 달성해 기대되는 성능을 보여줍니다.

HKUST 팀이 diffusion transformer 학습을 가속하는 MaskAlign을 공개했습니다. 기존 representation alignment의 토큰 불균형 문제를 지적하고, 부분 정렬로 해결합니다.

핵심 결론

  • 벤치ImageNet 256×256 class-conditional generation에서 FID 1.81 (DiT-XL/2 기반).
  • 학습기존 full-token alignment 대비 20% 적은 step으로 동등 성능 도달.

방법

  • 문제 진단Full-token 정렬 시 특정 토큰의 gradient norm이 크고 공간적으로 안정된 선호도를 보임. 이는 모델이 clean-image 모든 토큰에 과도하게 의존하게 만듦.
  • MaskAlign매 iteration마다 무작위 토큰 subset(예: 50%)만 정렬 대상으로 삼아, clean-image token set에 대한 의존성을 줄임.
  • Pre-mask mixingMasking 전에 lightweight token mixing block을 추가해 정보 손실을 보상. 단순한 MLP로 구현.

한계·조건

  • 실험 규모ImageNet 256×256만 보고. 512×512나 text-to-image 확장은 미검증.
  • 코드GitHub 공개 예정 — 현재 abstract만 공개.

편집자 한 줄

Representation alignment의 토큰 레벨 분석이 흥미롭습니다. DiT 계열 학습 효율을 높이는 실용적인 접근으로 보입니다.

  • #diffusion
  • #transformer
  • #representation-alignment
  • #hkust
HKUST
원문 보기 →

Comments

— 첫 댓글을 남겨보세요 —