大语言模型如何扩充上下文长度?
我们在应用大语言模型遇到的最典型的限制就是输入文本的上下文长度。开源的模型的上下文长度限制从2K到32K不等,商业模型最大上下文限制从2K到100K范围。上下文长度对应用大语言模型有着非常关键的影响,包括知识增强、记忆等Agents完成的工作,都是为了解决大语言模型上下文长度限制而设计的。大语言模型为什么会有上下文长度限制?是否有方法能扩充长度到几倍甚至十几倍?这几个问题困扰我很久。最近一段时间经过调研之后,我发现这些问题已经有了令人兴奋的进展,我也收获一些心得,记录于此。 先说结论: LLM的训练和计算都是没有上下文长度限制的,限制的只有计算资源和模型效果 头部公司和开源社区都有了阶段性的成果,最新的transformers,llama.cpp等开源项目已经内置了扩充上下文长度方法 如何在扩充上下文长度的同时,降低训练成本,保证模型效果,是一个还在不断探索的话题 LLM的上下文长度限制之谜 实际上,目前以Transformer为核心的LLM,理论上而言是没有上下文长度限制的。唯一限制上下文长度的,只有训练时的资源消耗,以及预测时的输出效果。如果不考虑这两点,LLM完全可以支持任意长度的上下文,这里本质原因是,Transformer的核心:Attention算法和上下文长度无关。 为说明这个本质原因,我们回顾下Attention的计算,参考「Attention is all your need」经典论文1 : 定义n为输入的token数量,vocab_size为token词典大小,d为文本embedding的维度大小,k为Query和Key向量的维度大小。那么整个Attention计算过程如下: Inputs是n个token序列,查找(vocab_size, d)大小的Embedding词典,转化为(n, d)的输入矩阵X 类似的,将输入的token位置i经过位置向量计算(查表或者实时计算),转化为(n, d)的词典,和上面的X词典相加,获得带上位置向量的X作为输入。注意位置向量的计算有两种方法,一种是通过查表的方式,即查找一个(pos_size, d)大小的Embedding词典,另外一种是实时计算,根据token的位置i,通过位置embedding算法,计算出对应的位置向量。这两种方法各有优缺点,这将是突破上下文长度限制的重点。 将X乘以\(W^Q,W^K和W^V\)三个Q,K,V权重矩阵,获得Q,K,V值矩阵。其中\(W^Q\)形状为(d, k), \(W^K\)形状为(d,k), \(W^V\)形状为(d, v),注意着三个权重矩阵都和输入长度无关,获得的Q和K矩阵大小是(n, k),V矩阵大小是(n, v) 如下计算attention: $$Attention(Q,K,V) =softmax(\frac {QK^T}{\sqrt{k}})V$$ 其中\(QK^T\)计算结果为(n, n)矩阵,再乘以V的,输出结果为(n,v)矩阵。注意这些计算都是实时计算,计算复杂度和输入长度有关。 5. 在Multi-Head Attention算法中,上述4个步骤所有矩阵变成了张量,增加了h个header,输入矩阵X变成(h, n, d)大小,\(W_q\)大小为(h, d, k), \(W_k\)大小为(h, d, k), \(W_v\)大小为(h, d, v)。Q, K, V矩阵分别大小为(h, n, k), (h, n, k), (h, n, v)。通过将多头concat,输出(n, hv)大小的结果\(Attention(Q,K,V)\),再经过一次线性变化,最终Attention结果为: $$MultiHead(Q, K, V) = Concat(Attention(Q, K, V))W^O$$ \(W^O\)大小为(hv, d),所以最终Attention结果\(MultiHead\)维度为(n, d)。 从上面的计算可见,整个模型的参数,只有位置向量需要查表的时候,是和上下文长度有关,而其他的所有权重矩阵,都和上下文长度无关。如果位置向量实时计算时,attention算法所依赖的所有参数都和上下文长度无关。 那么限制到底在哪里呢?上面已经提到,限制就在: 计算资源 模型效果 先说计算资源的限制,可以证明(过程略),上述第三步计算Q,K,V矩阵的计算复杂度是\(O(nd^2)\),第四步计算attention的计算复杂度是\(O(n^2d)\),所以计算attention总体的计算复杂度是\(O(nd^2+n^2d)\),如果d > n,则\(O(nd^2)\)占计算复杂度的大头,例如LLaMa1模型的n为2048,d为4096,所以可以估计训练复杂度和训练输入的上下文长度呈线性关系。以LLaMa1模型举例,训练一次的计算成本约为300万美元,如果将输入长度扩大到8倍(即16K),那么训练成本将增长到2400万美元。因此如果要在预训练阶段用更长的上下文长度训练,这个成本将变得难以接受。...