Posted in

论文分析|高效长文本生成——让模型更高效、更智能!_AI阅读总结 — 包阅AI

包阅导读总结

思维导图:

文章地址:https://mp.weixin.qq.com/s/es9Z1Gh1dtvDGE1UhmI9PQ

文章来源:mp.weixin.qq.com

作者:LLM??SPACE

发布时间:2024/8/5 6:26

语言:中文

总字数:2804字

预计阅读时间:12分钟

评分:80分

标签:大模型,长文本生成,内存优化,Transformer,MST技术


以下为原文内容

本内容来源于用户推荐转载,旨在分享知识与观点,如有侵权请联系删除 联系邮箱 media@ilingban.com

https://arxiv.org/abs/2407.15892

贡献者:

Cheng Luo,Jiawei Zhao,Zhuoming Chen,
Beidi Chen,Anima Anandkumar

  1. 权重:模型的参数,包括所有层的权重矩阵,需要在训练前加载到显存中。
    1. 激活值(Activations):计算并存储每层的激活,这些值需要保存以便于在反向传播时计算梯度。

    2. 中间值(Intermediate Values)——计算时每一层时都需要储存:在模型的不同层,特别是多头自注意力(Multi-Head Attention)层和多层感知器(MLP)层中,计算过程中会产生中间值,如Q(Query)、K(Key)、V(Value)张量,以及MLP层的中间线性变换结果。
    • Transformer 中的Attention层,计算 QKV 和注意力矩阵;
    • Transformer 中的MLP层,将序列嵌入维度放大再缩小;
    • LM Head(Language Modeling Head)将嵌入映射为 logits,之后计算 loss / 下一个 token。
  2. 反向传播时需要计算并存储:梯度(Gradients)——在反向传播过程中计算得到的模型权重的梯度,用于更新模型参数,保持优化器的状态等等。

Innovation

文章提出了MST方法:多个小序列(Mini-sequences)迭代处理
核心思路:通过将输入序列划分为多个小序列(mini-sequences),并迭代处理这些小序列,从而减少了中间内存的使用。
MST方法关注的是计算过程中间状态因为Transformer 在计算过程中会产生非常巨大的中间状态:

但与此相对的是 GPU 的显存是有限的,而且 GPU 的显存分了不同层次,速度快的 HBM 非常小,搬运到低层次的 SRAM 很费时间。因此降低中间状态的大小就非常重要。所幸语言模型的计算在token 之间依赖比较小,因此可以并行操作,这就导向一个关键技巧,也就是分块处理。
假设有 16 个 token 要处理,每个 token 处理时需要X的内存,如果一次性处理,那就需要 16X 的内存;但如果分两块处理,那么整个过程最多只需要 8X 的内存;第一次计算完成后,相应内存就可以释放了。当然,不是分块越多越好,分块越多,各种搬运记录计算结果,汇总结果所需的时间也会越长。
FlashAttention 其实就在使用这个技巧,因为 Attention 的中间状态太大了,随着序列长度二次方增长谁也受不了。随着 Attention 的中间状态被 FlashAttention 和 Ulysses 打下来,我们自然就盯上其他中间状态了。
本文就在讨论分块解决 Llama 3 中 MLP 层和 LM Head 的中间状态。
分解过程都是类似的,都是分解、计算、汇总。设序列长度为N, 嵌入维度为d,序列嵌入X∈R^(N*d)那么就将X沿着分成 M 块,每块计算结果,然后汇总。

MLP层:Llama3 的 MLP 层有三个线性层 Wgate,Wup,Wdown。前向计算过程为:

• 首先将 N*d 大小的序列矩阵通过Wgate,Wup放大为两个N*l的矩阵

• 然后两个矩阵逐元素点乘,最后通过Wdown缩小为原来的 N*d 大小。

汇总的过程就很简单,直接在序列维度拼接就可以。

LMHead 计算:计算loss 的过程是将序列矩阵右乘一个矩阵映射到词表,也就是得到 NxV 的矩阵,然后算交叉熵。此时需要检查标签是否在块的范围内。汇总的过程是将 loss 加和再除以 M 取平均。

MST 没有改变 FLOP,但会增加 HBM 访问次数,标准 MLP 的访问次数为 Θ (Nd+NI+dI),而 MST 的访问次数为Θ (Nd+NI+dIM)。标准 LM Head 的次数是 Θ (Nd+NI+dV),而 MST 下这个数字Θ (Nd+NI+dVM)。如果序列长度很短,中间状态占主导,dI,dV占主导,MST会降低GPU 显存的访问量