5 minute read

LLaMA3 技术报告的简要分析。

对于LLMer,最突出的可能是LLaMA3的长上下文支持。在 8,192 个token的序列上训练模型,且通过掩码操作以确保自注意力不会跨越文档边界;同时Llama 3 使用具有 128K tokens的tokenizer;上下文长度也拓展到128K,这对于需要分析大量信息或处理复杂任务的应用程序非常有用。当然还有其处理多模态信息的能力。

但是对于RLer,我们更关心为什么要用DPO而不是PPO。尽管这两种方法在后来的DeepSeek论文中都被证明不如GRPO更简单高效,当然这仅是对于RL finetune LLM而言。至少现在在robotics上PPO仍然是online RL的主流方法。

References

本文将大幅度借鉴(写得太好了,以至于我只想改改格式)一文速览Llama 3:从Llama 3的模型架构到如何把长度扩展到100万——基于NTK-aware插值,感谢作者的辛勤付出。

1. Introduction

LLaMA3 Architecture

Figure: Illustration of the overall architecture and training of Llama 3. Llama 3 is a Transformer language model trained to predict the next token of a textual sequence.

LLaMA3的 model architecture 仍然是基于Transformer的语言模型,使用了自注意力机制和前馈神经网络。它的训练目标是预测文本序列的下一个标记。

Comparison between LLaMA3 and LLaMA2:

LLaMA 3 = LLaMA 2 +
     + Massive Scaling (405B, 15.6T tokens)
     + Improved Pretraining Corpus (15T, multilingual, code, reasoning)
     + Long-Context Support (8K → 128K tokens)
     + New Tokenizer (128K vocab, better compression)
     + Improved RoPE Base (θ = 500{,}000)
     + Post-training with DPO (instead of PPO)
     + Tool Use, Safety Enhancements, Chat Format
     + Multimodal Extensions (image, video, speech) [in progress]

Comparison between LLaMA3 and Transformer:

Transformer VS LLaMA3

Figure: Transformer VS LLaMA3

2. General Overview

2.1 Language model pre-training

In the language model pre-training stage, the model learns the structure of language and obtains large amounts of knowledge about the world from the text it is “reading”. To do this effectively, pre-training is performed at massive scale: we pre-train a model with 405B parameters on 15.6T tokens using a context window of 8K tokens. This standard pre-training stage is followed by a continued pre-training stage that increases the supported context window to 128K tokens.

预训练的模型是405B参数,使用了15.6T tokens;增大了处理的tokens,从8K增加到128K tokens。

2.2 Language model post-training

We align the model with human feedback in several rounds, each of which involves supervised finetuning (SFT) on instruction tuning data and Direct Preference Optimization (DPO; Rafailov et al., 2024). At this post-training2 stage, we also integrate new capabilities, such as tool-use, and observe strong improvements in other areas, such as coding and reasoning. See Section 4 for details. Finally, safety mitigations are also incorporated into the model at the post-training stage.

使用更稳定的DPO,在SFT和DPO的多轮迭代中,引入人类反馈。

2.3 Multi-modal encoder pre-training

主要介绍image encoder和speech encoder。用 image-text pairs 训练 image encoder。

2.4 Vision adapter training

将预训练的 image encoder 集成到预训练的 language model 中。适配器由一系列交叉注意力层(cross-attention)组成,这些层将 image encoder 表示输入到语言模型中。

2.5 Speech adapter training

Finally, we integrate the speech encoder into the model via an adapter that converts speech encodings into token representations that can be fed directly into the finetuned language model.

同样使用一个 adapter 将 speech encoder 的表示转换为 token 表示,之后直接输入到微调的语言模型中。

3. Pre-training Improvements in LLaMA 3

LLaMA3 Pre-training Improvements

Figure: LLaMA3 Decoder-only Transformer

3.1 Massive Scaling (405B parameters, 15.6T tokens)

  • Largest model: 405B parameters
  • Pre-trained on 15.6 trillion tokens
  • Trained with 3.8×10²⁵ FLOPs
  • Uses compute-optimal scaling laws to balance size and data

3.2 Long-context Training (8K → 128K tokens) 长上下文训练

  • Gradual expansion during pretraining (6 stages)
  • Used RoPE with base $\theta=500000$ for extended attention range,这一点是有文献做过研究证明的,主要是对于长度
  • Final phase used 800B tokens just for long-context adaptation

对于预训练用到的数据:

  • Increased from ~1.8T (LLaMA 2) → 15.6T tokens
  • Multilingual, code, reasoning data
  • Careful quality filtering, deduplication, and upsampling for math/code
  • 在数据规模上,llama3.1同llama3一样:在大约15T多语言语料库上预训练,相比之下Llama 2的语料库为1.8T,数据时限则是到23年年底
  • 在数据组成上,llama3.1包含50%通用的知识、25%是数学和推理知识、17%是代码知识、8%是多语言,且其包含了文本、图像、视频、语音等方面的数据
  • 可以看到,高质量的数据其实不到一半,常规的SFT也不一定好搞,那怎么办呢,可以反复体会、反复学习;比如使用退火相关的手段加入高质量的数据,且把学习率搞低一些,以便更细致的学习高质量数据的特征(说白了,退火阶段拉低学习率,以尽可能overfit高质量数据)

3.3 New Tokenizer (128K Vocab)

  • 128K vocabulary via BPE, better compression than LLaMA 2
  • +28K tokens focused on non-English languages
  • Compression improved from 3.17 → 3.94 characters/token (English)

相比于 LLaMA2 的 32K 词表,LLaMA3 使用了具有 128K tokens 的 tokenizer。

相当于,一方面,分词器由 SentencePiece 换为了 Tiktoken,与 GPT4 保持一致,可以更有效地对语言进行编码;二方面,Token词表从LLAMA 2的32K拓展到了128K;基准测试显示,Tiktoken提高了token效率,与 Llama 2 相比,生成的token最多减少了 15%「正由于llama3具有更大的词表,比llama2的tokenizer具有更大的文本压缩率」;

增大了 token 的压缩率,从原来的每个 token 3.17 个字符增加到 3.94 个字符。

3.4 GQA (Grouped Query Attention)

LLaMA2 中只有 34B 和 70B 的模型才使用到 GQA,因为本质上 QGA 是一个减少计算量的trick。 而在 LLaMA3 中,所有模型都使用了 GQA。

GQA

Figure: MHA, GQA

为了提高推理效率,Llama 3在 8B 和 70B 都采用了分组查询注意力(GQA),根据相关实验可以观察到,尽管与 Llama 2 7B 相比,模型的参数多了 1B,但改进的分词器效率和 GQA 有助于保持与 Llama 2 7B 相同的推理效率。

4. Post-training and Instruction Alignment

指令和chat微调:先奖励建模,然后SFT,最后DPO。

在后期训练中,通过在预训练模型的基础上进行几轮对齐来生成最终的聊天模型。每轮都涉及监督微调 (SFT)、拒绝抽样 (RS) 和直接偏好优化DPO。

  • 在预训练基础上使用人类标注数据训练 reward model
  • 通过supervised fine-tuning (SFT) 对pre-trained model checkpoint 进行微调;并进一步通过 DPO 对 模型进行对齐。

LLaMA3 Post-training

Figure: Illustration of the overall post-training approach for Llama 3. Our post-training strategy involves rejection sampling, supervised finetuning, and direct preference optimization.

4.1 Reward Model (RM) Training

以下使用转载内容:

在预训练 checkpoint 的基础上,使用人类标注数据训练奖励模型。在 loss 中减去了边际项,因为观察到在数据扩展后的改进效果逐渐减弱。

  1. 与Llama 2一样,在过滤掉具有相似响应的样本后,使用所有的偏好数据进行奖励建模,比如为每个prompt从两个不同的模型中抽样两个response(比如一个能力强点的、一个能力相对弱点的,分别回答同一个prompt);且要求注释者通过将其偏好强度分类为4个级别之一来进行评分,基于他们对所选response与被拒绝response的偏好程度:显著更好、更好、稍微更好或勉强更好
  2. 此外,除了标准的偏好对(选择的,拒绝的)response外,注释还为某些prompt创建了第三个“edited response”,其中对来自对的选择response进行了进一步编辑以进行改进——即直接编辑所选response或通过反馈提示模型进一步改进其自身response;

因此,每个偏好排名样本都有两个或三个response,具有明确的排名(edited > chosen > rejected)。 在训练过程中,我们将prompt和多个response连接成一行,并随机打乱response。 这是一种近似于将responses放在单独行中并计算分数的标准场景,但在我们的消融实验中,这种方法提高了训练效率而没有损失准确性。

4.2 SFT (Supervised Fine-Tuning)

  • SFT used human-labeled and synthetic data
  • High-quality curated dialog data with format alignment (role formatting, system prompts)
  • Domain-specific mixes: general, multilingual, coding, long context, reasoning

SFT使用的是通过RM做拒绝采样后得到的数据 + 合成数据。

具体而言,对于一个人工prompt,让模型生成若干个回答,然后采样其中的K个response(通常在 10 到 30 之间),然后让RM针对这多个response逐一进行质量上的打分,最终把得分最高的response保留下来(作为之后SFT数据的一部分,此举也符合 Bai 等人2022的研究),其它的则丢弃。

(十分推荐细读一文速览Llama 3:从Llama 3的模型架构到如何把长度扩展到100万——基于NTK-aware插值

4.3 DPO Instead of PPO

  • Replaces PPO with Direct Preference Optimization (DPO)
  • Better stability, easier implementation
  • Reduces variance in preference learning compared to PPO

Meta进一步使用直接偏好优化DPO 对得到的SFT模型进行训练「DPO本质上是个二分类,就是从人工标注的<Prompt,Good Response,Bad Response>三元数据里学习,调整模型参数鼓励模型输出Good Response,不输出Bad Response」,以实现人类偏好的对齐。

我们关心与PPO的比较:

此外,它们还探索了诸如PPO(Schulman等,2017)等在线算法,但发现DPO在大规模模型上所需的计算量更少,并且表现更好,特别是在像IFEval(Zhou等,2023)这样的指令跟随基准上。

除此之外,Meta还对DPO做了一些针对LLaMA3的修改:

  • 在DPO损失中屏蔽格式化token
  • 使用NLL损失进行正则化

详细内容转载自一文速览Llama 3:从Llama 3的模型架构到如何把长度扩展到100万——基于NTK-aware插值

  • 在DPO损失中屏蔽格式化token

从选择的和拒绝的response中屏蔽特殊格式化token,包括标题和终止token,以稳定DPO训练 Masking out formatting tokens in DPO loss: We mask out special formatting tokens including header and termination tokens (described in Section 4.1.1) from both chosen and rejected responses in the loss to stabilize DPO training

因为他们观察到,这些token对损失的贡献可能导致模型行为不理想,例如尾部重复或突然生成终止token(We observe that having these tokens contribute to the loss may lead to undesired model behaviors such as tail repetition or abruptly generating termination tokens)

假设这是由于DPO损失的对比性质——在选择的和拒绝的response中存在共同token导致学习目标冲突,因为模型需要同时增加和减少这些token的可能性

  • 使用NLL损失进行正则化

在选择的序列上添加了一个额外的负对数似然(NLL)损失项,缩放系数为 0.2,类似于Pang等人(2024) 这有助于通过保持生成的期望格式并防止选择response的对数概率下降——来进一步稳定DPO训练(Pang等人,2024;Pal等人,2024)

4.4 Enhanced Capabilities

Code

  • Trained code experts with continued pretraining on 1T code tokens
  • LCFT (Long-context Finetuning) for 16K code windows
  • SFT + rejection sampling specifically targeted at code correctness and readability

Multilingual

  • Upsampled non-English tokens during pretraining
  • Included synthetic instruction tuning in other languages

Reasoning & Math

  • Filtered incorrect traces with reward models
  • MCTS used to generate correct multi-step chains
  • Self-correction from failed generations

Long-context (in SFT and DPO)

  • Hierarchical summarization, QA over large context
  • Used 0.1% synthetic long-context data for robust adaptation

Tool Use

  • Brave Search, Python interpreter, Wolfram Alpha API
  • Tool chaining via step-by-step planning in multi-turn dialogues

Factuality & Hallucination Mitigation

  • Added real-time feedback via tools to reduce hallucinations
  • Execution verification in code and math

Steerability

  • Used system prompts for style, tone, format
  • Fine-grained response control

5. Multi-Modal Integration (In Progress)

5.1 Multi-modal Encoder Overview

  • Separate image and speech encoders trained on paired datasets
  • Adapter-based integration for cross-modal reasoning

5.2 Vision Adapter

  • Cross-attention layers map image embeddings into LLM hidden states
  • Task-specific adapters used for retrieval, captioning, VQA

5.3 Speech Adapter

  • Speech encodings converted to token-like embeddings
  • Directly fed into the LLM decoder via adapter

6. Evaluation and Results

  • LLaMA 3.1 405B achieves GPT-4-level performance on reasoning, coding, tool use
  • Top-tier multilingual and long-context benchmarks
  • Evaluation includes ARC, GSM8K, BFCL, InfiniteBench

7. Conclusion and Outlook

  • LLaMA 3 is not only a scaled-up LLaMA 2
  • Major changes in tokenizer, scaling, alignment, safety, and modality
  • Next steps: full release of multimodal variants, extended tool chains, and continual tuning pipelines