前言
通過(guò)本博客內(nèi)之前的文章可知,,自回歸解碼的標(biāo)準(zhǔn)做法是緩存序列中先前標(biāo)記的鍵(K)和值(V) 對(duì),從而加快注意力計(jì)算速度,。然而,,隨著上下文窗口或批量大小的增加,多頭注意力 (MHA)模型中與 KV 緩存大小相關(guān)的內(nèi)存成本顯著增長(zhǎng)
對(duì)于較大的模型,,KV 緩存大小成為瓶頸,,鍵和值投影可以在多個(gè)頭之間共享,而不會(huì)大幅降低性能,,可以使用
- 具有單個(gè) KV 投影的原始多查詢格式(MQA),,ChatGLM2-6B即用的這個(gè)
不過(guò),多查詢注意(Multi-query attention,,簡(jiǎn)稱MQA)只使用一個(gè)鍵值頭,,雖大大加快了解碼器推斷的速度,但MQA可能導(dǎo)致質(zhì)量下降,,而且僅僅為了更快的推理而訓(xùn)練一個(gè)單獨(dú)的模型可能是不可取的 - 或具有多個(gè) KV 投影的分組查詢注意力(grouped-query attention,,簡(jiǎn)稱GQA),LLaMA2和Mistral均用的這個(gè)
這是一種多查詢注意的泛化,它通過(guò)折中(多于一個(gè)且少于查詢頭的數(shù)量,,比如4個(gè))鍵值頭的數(shù)量,,使得經(jīng)過(guò)強(qiáng)化訓(xùn)練的GQA以與MQA相當(dāng)?shù)乃俣冗_(dá)到接近多頭注意力的質(zhì)量,即速度快 質(zhì)量高
經(jīng)實(shí)驗(yàn)論證,,GQA 變體在大多數(shù)評(píng)估任務(wù)上的表現(xiàn)與 MHA 基線相當(dāng),,并且平均優(yōu)于 MQA 變體
多頭注意力MHA | 分組查詢注意力GQA | 多查詢注意力MQA |
| LLaMA2 | ChatGLM2 |
| Mistral | Google Gemini |
| Google gemma2 | |
// 待更
第二部分 ChatGLM2之多查詢注意力(Muti Query Attention)
2.1 MQA的核心特征:各自Query矩陣,但共享Key 和 Value 矩陣
多查詢注意力(Muti Query Attention)是 2019 年Google一研究者提出的一種新的 Attention 機(jī)制(對(duì)應(yīng)論文為:Fast Transformer Decoding: One Write-Head is All You Need,、這是其解讀之一),,其能夠在保證模型效果的同時(shí)加快 decoder 生成 token 的速度
除了ChatGLM2用的MQA之外,23年12月Google最新推出的「多模態(tài)大模型Gemini」的注意力機(jī)制也使用的Multi-Query Attention
那其與17年 Google提出的transformer中多頭注意力機(jī)制(簡(jiǎn)稱MHA)有啥本質(zhì)區(qū)別呢,?有意思的是,,區(qū)別在于:
- 我們知道MHA的每個(gè)頭都各自有一份不同的Key、Query,、Value矩陣
- 而MQA 讓所有的頭之間 共享 同一份 Key 和 Value 矩陣,,每個(gè)頭只單獨(dú)保留了一份 Query 參數(shù),從而大大減少 Key 和 Value 矩陣的參數(shù)量
總之,,MQA 實(shí)際上是將 head 中的 key 和 value 矩陣抽出來(lái)單獨(dú)存為一份共享參數(shù),,而 query 則是依舊保留在原來(lái)的 head 中,每個(gè) head 有一份自己獨(dú)有的 query 參數(shù)
如下圖圖右所示
總之,,MHA 和 MQA 之間的區(qū)別只在于建立 Wqkv Layer 上
self.Wqkv = nn.Linear( # 【關(guān)鍵】Multi-Head Attention 的創(chuàng)建方法
3 * self.d_model, # 有 query, key, value 3 個(gè)矩陣, 所以是 3 * d_model
query, key, value = qkv.chunk( # 【關(guān)鍵】每個(gè) tensor 都是 (1, 512, 768)
self.Wqkv = nn.Linear( # 【關(guān)鍵】Multi-Query Attention 的創(chuàng)建方法
d_model + 2 * self.head_dim, # 只創(chuàng)建 query 的 head 向量,,所以只有 1 個(gè) d_model
device=device, # 而 key 和 value 不再具備單獨(dú)的頭向量
query, key, value = qkv.split( # query -> (1, 512, 768)
[self.d_model, self.head_dim, self.head_dim], # key -> (1, 512, 96)
dim=2 # value -> (1, 512, 96)
對(duì)比上面的代碼,你可以發(fā)現(xiàn)
- 在 MHA 中,,query, key, value 每個(gè)向量均有 768 維度
- 而在 MQA 中,,只有 query 是 768 維,而 key 和 value 均只剩下 96 維了,,恰好是 1 個(gè) head_dim 的維度
因此,,可以確認(rèn):在 MQA 中,除了 query 向量還保存著 8 個(gè)頭,,key 和 value 向量都只剩 1 個(gè)「公共頭」了,,這也正好印證了論文中所說(shuō)的「所有 head 之間共享一份 key 和 value 的參數(shù)」
剩下的問(wèn)題就是如何將這 1 份參數(shù)同時(shí)讓 8 個(gè)頭都使用,代碼里使用矩陣乘法 matmul 來(lái)廣播,,使得每個(gè)頭都乘以這同一個(gè) tensor,,以此來(lái)實(shí)現(xiàn)參數(shù)共享:
def scaled_multihead_dot_product_attention(
q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads) # (1, 512, 768) -> (1, 8, 512, 96)
kv_n_heads = 1 if multiquery else n_heads
k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads) # (1, 512, 768) -> (1, 8, 96, 512) if not multiquery
# (1, 512, 96) -> (1, 1, 96, 512) if multiquery
v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads) # (1, 512, 768) -> (1, 8, 512, 96) if not multiquery
# (1, 512, 96) -> (1, 1, 512, 96) if multiquery
attn_weight = q.matmul(k) * softmax_scale # (1, 8, 512, 512)
attn_weight = torch.softmax(attn_weight, dim=-1) # (1, 8, 512, 512)
out = attn_weight.matmul(v) # (1, 8, 512, 512) * (1, 1, 512, 96) = (1, 8, 512, 96)
out = rearrange(out, 'b h s d -> b s (h d)') # (1, 512, 768)
return out, attn_weight, past_key_value
第三部分 LLaMA2之分組查詢注意力——Grouped-Query Attention
23年,Google的研究者們提出了一種新的方法,,即分組查詢注意(GQA,,論文地址為:GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints)
- 舉個(gè)例子,一般模型中會(huì)有這么兩個(gè)參數(shù):n_heads,、n_kv_heads,,其中,,n_heads的個(gè)數(shù)便是Q的個(gè)數(shù)(相當(dāng)于多少個(gè)頭 則多少個(gè)Q),n_kv_heads指的是K,、V的個(gè)數(shù)
- 因?yàn)槎鄠€(gè)頭會(huì)共享一個(gè)K或V,,則頭和Q的個(gè)數(shù)會(huì)大于K V的個(gè)數(shù),比如可能8個(gè)頭下:8個(gè)Q,、4個(gè)K,、4個(gè)V,即如下圖圖中所示
// 待更