[논문 요약] COCONUT: Training Large Language Models to Reason in a Continuous Latent Space

저자

  • Shibo Hao et al, FAIR at Meta 

요약(NotebookLM)

  • 본 논문은 LLM의 추론 능력 향상을 위한 새로운 방법인 "Coconut(Chain of Continuous Thought)"을 제시합니다. 기존의 Chain-of-Thought(CoT) 방식이 언어 공간에서 추론 과정을 단계별로 생성하는 것과 달리, Coconut은 LLM의 Hidden States를 연속적인 사고(continuous thought)로 활용하여, 언어 제약 없이 추론을 수행합니다. 여러 추론 과제에 대한 실험 결과, Coconut은 특히 계획과 탐색이 필요한 복잡한 추론 과제에서 CoT보다 우수한 성능을 보이며, 추론 토큰 수도 감소시키는 것으로 나타났습니다. 이 연구는 잠재 추론(latent reasoning)의 가능성을 보여주고, 향후 연구 방향을 제시한다.

풀고자 하는 문제

  • 기존의 토큰, 즉 language space에서의 추론은 많은 한계점을 갖는다.
    • 추론 단계에서 생성되는 많은 토큰들이 Fluency를 위한 토큰들이 대부분이다.
    • 중요한 핵심적인 토큰들을 생성하려면 매우 복잡한 planning 과정을 거쳐야 한다.
    • 인간 역시 추론 단계에서는 뇌의 언어 영역을 담당하는 부분은 In-activaton된다.
  • LLM이 언어의 제약에서 벗어나서 추론을 할 수 있다면 ideal 할 것이다. 

제안하는 방식

  • 본 논문에서는 latent space 상에서 reasonining이 가능한 새로운 패러다임인, COCONUT을 제안한다.
  • COCONUT은 last hidden state(a continuous thought)를 다음 토큰의 input embedding으로 사용한다. 따라서 thought가 Fully Differentiable하다.
  • 기존 트랜스포머가 다음과 같다면(\(E_t=\text{token embedding}, P_t=\text{Positional Embedding}\)), 

제안하는 방식에서는 다음과 같이 변경된다.  

$$E_t =[e(x_1), e(x_2), ..., e(x_i), h_i, h_{i+1},..., h_{t-1}])$$

즉, 토큰 임베딩 대신, 이전 스텝의 hidden states가 입력 임베딩으로 사용된다. 

  • Training Details
    • CoT Data를 Supervised data로 사용하되, CoT를 단계별로 삭제해 나가는 Multi-stage learning (iCoT) 방식을 채택함

  • 위 그림처럼 Stage 0에서는 regular CoT와 같지만, k 번째 Stage에서는 k 개의 CoT 토큰이 (k x c)개의 continuous thought로 바뀜. c는 hyperparameter
  • 매 stage가 시작 될 때마다 optimizer는 초기화 시키고 <bot>, <cot> 토큰을 사용해서 continuous thought의 시작과 끝을 표시함.
  • 이 페이퍼의 가장 큰 한계점은 parellelism이다. 생각(continuous thought)은 미리 생성해둘 수 없기 때문에 t-1 시점까지 생각이 생성됐으면 그제서야 t 시점의 생각을 생성해 낼 수 있다. token이었다면 teacher forcing을 통해 미리 생성해둘 수 있음. 이 부분을 해결하는 것이 가장 큰 future work이 아닐까...
  • inference 시에 <eot> 토큰을 언제 모델에 입력시킬지는 a) binary classifier를 학습, b) constant length, 이 두가지 방법이 있으나 둘 다 잘 동작하기 때문에 단순함을 위해 b로 결정함

실험 결과

  • 3가지 데이터셋으로 실험
    • Math Reasoning: GSM8K, 3 stage, 2 latent thoughts per stage
    • Logical Reasoning: 5-hop ProntoQA, ProsQA(ProntoQA의 문제가 너무 쉬워서 새롭게 만든 데이터셋)
      • 6 stage, 1 latent thought per stage
  • 모든 데이터셋에 대해서 50 에폭 정도 학습 한 듯..
  • 매우 특이하게 모델로 GPT-2를 사용함. 이유는 안 나와있음..

  • 실험 결과를 보면, COCONUT이 No-CoT 대비해서는 항상 앞서지만, CoT에 대해서는 GSM8K에서 밀린다.
  • ProsQA는 다른 2개의 task보다 상대적으로 복잡한 planning이 필요한데, 이 task에서 특히 COCONUT의 성능 향상이 도드라진다(CoT는 이 task에서 No-CoT 대비해서 큰 성능 차이가 없다는 게 주목할만한 결과)
  • 중요한 점은, curriculum learning 없이 그냥 continuous thought를 사용하는 것은 전혀 성능 효과가 없었다는 것.
    • 토큰을 하나씩 제거하는 것처럼 여러가지 다른 task들을 생각해 볼 수 있을 듯

한계점 및 향후 과제

  • Scaling up 및 pretraining 단계에서 continuous thought를 실행해보는것(병렬처리가 안되서 아직은 불가능..)