前言
關(guān)于位置編碼和RoPE
- 應用廣泛,,是很多大模型使用的一種位置編碼方式,包括且不限于LLaMA,、baichuan,、ChatGLM等等
- 我之前在本博客中的另外兩篇文章中有闡述過(一篇是關(guān)于LLaMA解讀的,一篇是關(guān)于transformer從零實現(xiàn)的),,但自覺寫的不是特別透徹好懂
再后來在我參與主講的類ChatGPT微調(diào)實戰(zhàn)課中也有講過,,但有些學員依然反饋RoPE不是特別好理解
考慮到只要花足夠多的時間 心思 投入,沒有寫不清楚的,,講課更是如此,,故為徹底解決這個位置編碼/RoPE的問題,我把另外兩篇文章中關(guān)于位置編碼的內(nèi)容抽取出來,,并不斷深入,、擴展、深入,,比如其中最關(guān)鍵的改進是兩輪改進,,一個12.16那天,一個12.21那天
- 12.16那天
小的改進是把“1.1 標準位置編碼的起源”中,,關(guān)于i,、2i、2i+1的一系列計算結(jié)果用表格規(guī)整了下 如此,,相比之前把一堆數(shù)字一堆,,表格更加清晰、一目了然 大的改進是把“3.1.1 第一種形式的推導(通俗易懂版)”的細節(jié)重新梳理了以下,,以更加一目了然,、一看即懂,,可能是全網(wǎng)關(guān)于RoPE最通俗細致的推導 - 12.21那天
把RoPE的本質(zhì)給強調(diào)出來
最終成為本文
第一部分 transformer原始論文中的標準位置編碼
如此篇文章《Transformer通俗筆記:從Word2Vec,、Seq2Seq逐步理解到GPT,、BERT》所述,,RNN的結(jié)構(gòu)包含了序列的時序信息,而Transformer卻完全把時序信息給丟掉了,,比如“他欠我100萬”,,和“我欠他100萬”,兩者的意思千差萬別,,故為了解決時序的問題,,Transformer的作者用了一個絕妙的辦法:位置編碼(Positional Encoding)
1.1 標準位置編碼的起源
即將每個位置編號,從而每個編號對應一個向量,,最終通過結(jié)合位置向量和詞向量,,作為輸入embedding,就給每個詞都引入了一定的位置信息,,這樣Attention就可以分辨出不同位置的詞了,,具體怎么做呢?
- 如果簡單粗暴的話,,直接給每個向量分配一個數(shù)字,,比如1到1000之間
- 也可以用one-hot編碼表示位置
- transformer論文中作者通過sin函數(shù)和cos函數(shù)交替來創(chuàng)建 positional encoding,其計算positional encoding的公式如下 其中,,pos相當于是每個token在整個序列中的位置,,相當于是0, 1, 2, 3...(看序列長度是多大,比如10,,比如100),,代表位置向量的維度(也是詞embedding的維度,transformer論文中設置的512維)
至于是embedding向量的位置下標對2求商并取整(可用雙斜杠表示整數(shù)除法,,即求商并取整),,它的取值范圍是,比如
位置向量的第多少維 (0 2 4等偶數(shù)維用sin函數(shù)計算) | | | | 0 | | | | 1 | | | | 2 | | | | 3 | | | | 4 | | | | 5 | | | | 6 | | | | .... | | | | 510 | | | | 511 | | | |
相當于 是指向量維度中的偶數(shù)維,,即第0維,、第2維、第4維...,,第510維,,用sin函數(shù)計算 是向量維度中的奇數(shù)維,即第1維,、第3維,、第5維..,,第511維,用cos函數(shù)計算
不要小看transformer的這個位置編碼,,不少做NLP多年的人也不一定對其中的細節(jié)有多深入,,而網(wǎng)上大部分文章談到這個位置編碼時基本都是千篇一律、泛泛而談,,很少有深入,,故本文還是細致探討下
1.2 標準位置編碼的示例:多圖多舉例
考慮到一圖勝千言 一例勝萬語,,舉個例子,,當我們要編碼「我 愛 你」的位置向量,,假定每個token都具備512維,,如果位置下標從0開始時,,則根據(jù)位置編碼的計算公式可得『且為讓每個讀者閱讀本文時一目了然,,我計算了每個單詞對應的位置編碼示例(在此之前,,這些示例在其他地方基本沒有)』
最終得到的可視化效果如下圖所示
1.3 標準位置編碼的coding實現(xiàn)
代碼實現(xiàn)如下
“”“位置編碼的實現(xiàn),,調(diào)用父類nn.Module的構(gòu)造函數(shù)”“” class PositionalEncoding(nn.Module): def __init__(self, d_model, dropout, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) # 初始化dropout層 pe = torch.zeros(max_len, d_model) # 創(chuàng)建一個max_len x d_model的全零張量 position = torch.arange(0, max_len).unsqueeze(1) # 生成0到max_len-1的整數(shù)序列,,并添加一個維度 # 計算div_term,用于縮放不同位置的正弦和余弦函數(shù) div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)) # 使用正弦和余弦函數(shù)生成位置編碼,,對于d_model的偶數(shù)索引,,使用正弦函數(shù);對于奇數(shù)索引,,使用余弦函數(shù),。 pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) # 在第一個維度添加一個維度,以便進行批處理 self.register_buffer('pe', pe) # 將位置編碼張量注冊為緩沖區(qū),,以便在不同設備之間傳輸模型時保持其狀態(tài) x = x + Variable(self.pe[:, :x.size(1)],
本文發(fā)布之后,,有同學留言問,上面中的第11行,、12行代碼
div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
為什么先轉(zhuǎn)換為了等價的指數(shù)+對數(shù)運算,,而不是直接冪運算?是效率,、精度方面有差異嗎,?
這里使用指數(shù)和對數(shù)運算的原因是為了確保數(shù)值穩(wěn)定性和計算效率
- 一方面,直接使用冪運算可能會導致數(shù)值上溢或下溢,。當d_model較大時,,10000.0 ** (-i / d_model)中的冪可能會變得非常小,以至于在數(shù)值計算中產(chǎn)生下溢,。通過將其轉(zhuǎn)換為指數(shù)和對數(shù)運算,,可以避免這種情況,,因為這樣可以在計算過程中保持更好的數(shù)值范圍
- 二方面,在許多計算設備和庫中,,指數(shù)和對數(shù)運算的實現(xiàn)通常比冪運算更快,。這主要是因為指數(shù)和對數(shù)運算在底層硬件和軟件中有特定的優(yōu)化實現(xiàn),而冪運算通常需要計算更多的中間值
所以,,使用指數(shù)和對數(shù)運算可以在保持數(shù)值穩(wěn)定性的同時提高計算效率,。
既然提到了這行代碼,我們干脆就再講更細致些,,上面那行代碼對應的公式為
其中的中括號對應的是一個從 0 到 的等差數(shù)列(步長為 2),,設為
且上述公式與這個公式是等價的
為何,,原因在于,,從而有
最終,再通過下面這兩行代碼完美實現(xiàn)位置編碼
# 使用正弦和余弦函數(shù)生成位置編碼,,對于d_model的偶數(shù)索引,,使用正弦函數(shù);對于奇數(shù)索引,,使用余弦函數(shù),。 pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term)
第二部分 從復數(shù)到歐拉公式
先復習下復數(shù)的一些關(guān)鍵概念
- 我們一般用表示復數(shù),實數(shù) 叫做復數(shù)的實部,,實數(shù) 叫做復數(shù)的虛部
- 復數(shù)的輻角是指復數(shù)在復平面上對應的向量和正向?qū)崝?shù)軸所成的有向角
- 的共軛復數(shù)定義為:,,也可記作,復數(shù)與其共軛的乘積等于它的模的平方,,即,,這是一個實數(shù)
2.1 如何通俗易懂的理解復數(shù)
在我們的日常生活中,經(jīng)常會遇到各種平移運動,,為了描述這些平移運動,,數(shù)學上定義了加減乘除,然還有一類運動是旋轉(zhuǎn)運動,,而加減乘除無法去描述旋轉(zhuǎn)運動,,而有了復數(shù)之后,便不一樣了,,此話怎講,?
根據(jù)復數(shù)的定義:,可以看出來:,,而這個展開過程就揭示了虛數(shù) 背后的本質(zhì),,因為這個展開過程中的兩次乘法可以看成連續(xù)的操作
- 即把 1 經(jīng)過2次完全一樣的操作:,變成了 ?1 ,,那什么樣的操作能得到這個效果呢,?
- 你兩眼一亮,,直呼:旋轉(zhuǎn)啊,先旋轉(zhuǎn) 90度,,再旋轉(zhuǎn) 90 度就可以了啊,,如下圖所示
so, 就代表了旋轉(zhuǎn)(至此,,可能你已經(jīng)隱隱約約意識到,,為何我們在解釋旋轉(zhuǎn)位置編碼時,為何要扯上復數(shù)了),,為形象說明,,再舉兩個例子
2.2 如何快速理解歐拉公式
2.2.1 什么是歐拉公式
當 表示任意實數(shù),, 是自然對數(shù)的底數(shù),, 是復數(shù)中的虛數(shù)單位,則根據(jù)歐拉公式有
表達的含義在于該指數(shù)函數(shù)可以表示為實部為,,虛部為的一個復數(shù)
該歐拉公式相當于建立了指數(shù)函數(shù),、三角函數(shù)和復數(shù)之間的橋梁,但怎么推導出來的呢,,其實很簡單
- 由于有
- 所以,,如果 ,則有
2.2.2 歐拉公式與三角函數(shù)
如何直觀的理解這個歐拉公式呢,?
其實,,可以把看作通過單位圓的圓周運動來描述單位圓上的點,通過復平面的坐標來描述單位圓上的點,,是同一個點不同的描述方式,,所以有,如下圖所示
根據(jù)歐拉公式,,可以輕易推出:
我們把復數(shù)當作向量來看待,,復數(shù)的實部是方向,虛部是方向,,很容易觀察出其幾何意義,,如下圖所示
還在思考怎么得來的?很簡單哦,,還記得向量的加減法么,?
第三部分 旋轉(zhuǎn)位置編碼(RoPE)的推導與實現(xiàn)
3.1 旋轉(zhuǎn)位置編碼的原理與推導
所謂旋轉(zhuǎn)位置編碼,其在位置編碼上刪除了絕對位置嵌入,,而在網(wǎng)絡的每一層增加了蘇劍林等人(2021)提出的旋轉(zhuǎn)位置嵌入(RoPE),,其思想是采用絕對位置編碼的形式 實現(xiàn)相對位置編碼,,且RoPE主要借助了復數(shù)的思想
具體來說,當咱們給self-attention中的向量都加入了位置信息后,,便可以表示為
其中
- 表示「第 個 token 對應的詞向量 」集成「位置信息 」之后的 query 向量
- 而 ,、 則分別表示第 個 token 對應的詞向量 集成位置信息 之后的 key 向量、 value 向量
3.1.1 第一種形式的推導(可能是全網(wǎng)最通俗易懂版)
接著論文中提出為了能利用上 token 之間的相對位置信息,,假定 query 向量 和 key 向量 之間的內(nèi)積操作可以被一個函數(shù) 表示,,該函數(shù) 的輸入是詞嵌入向量 、 ,,和它們之間的相對位置 :
這里面其實有很大的一個關(guān)鍵,,但大部分資料甚至RoPE原始論文都不會給你特別強調(diào)出來,即為何要構(gòu)造這么一個等式呢,?
- 原因在于左邊算是q和k向量的內(nèi)積,,而這恰好是transformer計算自注意力機制的核心一步,右邊等式則意味著m與n的相對位置
如此一來,,該等式便把“q和k的內(nèi)積”與“它們的相對位置”給串起來了 - 也如阿荀所說,,左邊是含有各自絕對位置信息的q向量和k向量,,而這個等式就是RoPE追求的目標,,物理含義就是通過顯式傳入絕對位置信息實現(xiàn)與傳入相對位置信息對等的情況
假定現(xiàn)在詞嵌入向量的維度是兩維 ,然后RoPE利用2維度平面上的向量的幾何性質(zhì),,再結(jié)合復數(shù)的性質(zhì),,神奇般的找到了滿足上述等式的 和 ,其形式如下:
這里面的 Re 表示復數(shù)的實部
- 進一步地,, 可以表示成下面的式子(如果此刻你覺得你有點懵,,沒事,下文馬上會一步一步的詳細推導):
看到這里會發(fā)現(xiàn),,這不就是 query 向量乘以了一個旋轉(zhuǎn)矩陣嗎,?這就是為什么叫做旋轉(zhuǎn)位置編碼的原因 可能有的同學還沒反應過來,怎么就叫「query 向量乘以了一個旋轉(zhuǎn)矩陣」了呢,?我再舉一個來自這里的例子,,以一目了然 如下圖所示,考慮一個矩陣,,它把向量在固定坐標系中逆時針旋轉(zhuǎn)一個角度,,得到 而這個旋轉(zhuǎn)矩陣就是
從而有, - 同理,, 可以表示成下面的式子:
- 最終可以表示如下:
然上述分別關(guān)于,、、的三個式子,,咋一步一步推導來的,?為做細致說明,,特參考此文一步一步解釋下
首先看第一個式子,對于,,這個式子的右邊項有兩部分,,一部分是、一部分是
- 對于前者,,可知其中的是個二維矩陣,,是個二維向量,自然相乘的結(jié)果也必然是一個二維向量,,用表示
- 對于后者,,根據(jù)歐拉公式,可得
- 基于上面第1點結(jié)論,,可知
然后將表示成復數(shù)形式,,可得
從而有
基于上面第2點結(jié)論,可知即是兩個復數(shù)相乘
- 考慮到以下兩個關(guān)于復數(shù)的背景知識
可得
將這個結(jié)果表達成實數(shù)向量形式,,即是
至此,,你也就不難發(fā)現(xiàn),這不就是query向量乘以了一個旋轉(zhuǎn)矩陣么
至于第二個式子,,根據(jù)上述過程同理,,可得key向量
最后第三個式子,函數(shù)g,,則可得
其中,,表示一個復數(shù)的實數(shù)部分,而則表示復數(shù)的共軛
- 考慮到
再結(jié)合上面第一個式子中的推導,,可得
繼續(xù)結(jié)合上面第一個式子中的推導(比如,,及),繼續(xù)可知,,我們現(xiàn)在要證明的是存在
- 總之,,接下來我們就要證明上述函數(shù) g 的計算公式是成立的
首先,回顧一下attention操作,,位置m的query和位置n的key會做一個內(nèi)積操作 即由
可得
「相當于[A,B]與[C,D]做內(nèi)積,,則相當于A B橫著,C D豎著,,最終結(jié)果為AC BD,,最后再把括號里的項全部對應相乘、展開」 - 首先,,把上面第二點的式子整理一下,,總計8項,為了把相關(guān)的項提取出來,第1項 8項合并處理,、第2項 7項合并處理,、第3項 6項合并處理、第4項 5項合并處理
其次,,考慮到
最后,,再把相關(guān)項的特點,兩次調(diào)整下順序即可
依據(jù)以上三點,,從而有
完美! 如此,,也就證明了,位置 m 的 query 和位置 n 的 key 的內(nèi)積就是函數(shù) g
最后,,把上面的式子一,、式子二的最終結(jié)果都分別用矩陣向量乘的形式來表達就是: 接下來,我們要計算兩個旋轉(zhuǎn)矩陣的乘積,,即中間部分的這個式子 展開之后,,可得 從而有
上面都還只是針對詞嵌入維度為2的情況,那對于的通用情況呢,,將2維推廣到任意維度,,可以表示如下:
內(nèi)積滿足線性疊加性,因此任意偶數(shù)維的RoPE,,我們都可以表示為二維情形的拼接,,即將詞嵌入向量元素按照兩兩一組分組
每組應用同樣的旋轉(zhuǎn)操作且每組的旋轉(zhuǎn)角度計算方式如下:
所以簡單來說 RoPE 的 self-attention 操作的流程是
- 對于 token 序列中的每個詞嵌入向量,首先計算其對應的 query 和 key 向量
- 然后對每個 token 位置都計算對應的旋轉(zhuǎn)位置編碼
- 接著對每個 token 位置的 query 和 key 向量的元素按照 兩兩一組 應用旋轉(zhuǎn)變換
- 最后再計算 query 和 key 之間的內(nèi)積得到 self-attention 的計算結(jié)果
3.1.2 第二種形式的推導(蘇劍林版)
與上面第一種形式的推導類似,,為了引入復數(shù),,首先假設了在加入位置信息之前,原有的編碼向量是二維行向量和,,其中和是絕對位置,現(xiàn)在需要構(gòu)造一個變換,,將和引入到和中,,即尋找變換:
也就是說,我們分別為,、設計操作,、,使得經(jīng)過該操作后,,,、就帶有了位置、的絕對位置信息 考慮到Attention的核心計算是內(nèi)積:
故我們希望的內(nèi)積的結(jié)果帶有相對位置信息,,即尋求的這個變換,,應該具有特性:
「怎么理解?很簡單,當m和n表示了絕對位置之后,,m與n在句子中的距離即位置差m-n,,就可以表示為相對位置了,且對于復數(shù),,內(nèi)積通常定義為一個復數(shù)與另一個復數(shù)的共軛的乘積」
- 為合理的求出該恒等式的一個盡可能簡單的解,,可以設定一些初始條件,比如,、,,然后可以先考慮二維情形,然后借助復數(shù)來求解
在復數(shù)中有,,表示取實部的操作(復數(shù) 和“ 復數(shù) 的共軛即 ”之積仍是一個復數(shù)) 因論文100課的群里有學員對該點存在疑問,,故借用七月黃老師的回復補充下:這個等式和復數(shù)乘法和向量乘積的聯(lián)系有關(guān) 考慮兩個復數(shù)
,的共軛是 一方面,,對于等式的右邊項而言 q和k*的乘積是 這個結(jié)果的實部是 二方面,,對于等式的左邊項而言 其對應于對應的實數(shù)向量和對應的實數(shù)向量的乘積
綜合以上兩點,可知右邊項所表示的“復數(shù)q和復數(shù)k的共軛k*的乘積”,,和左邊項做表示的“q,、k所對應向量的乘積”是一樣的
總之,我們需要尋找一種變換,,使得
- 簡單起見,,我們假設存在復數(shù),使得,,然后我們用復數(shù)的指數(shù)形式,,設
- 那么代入方程后就得到兩個方程
方程1: 方程2:Θf(q,m)?Θf(k,n) = Θg(q,k,m?n)
對于方程1,代入得到(接著,,再把和都設為0)
最后一個等號源于初始條件和,,所以現(xiàn)在我們可以很簡單地設,,,即它不依賴于
至于方程2,,同樣代入得到 Θf(q,m)?Θf(k,m) = Θg(q,k,0) = Θf(q,0)?Θf(k,0) = Θ(q)?Θ(k)
這里的、是,、本身的幅角,,而最后一個等號同樣源于初始條件 根據(jù)上式Θf(q,m)?Θf(k,m) = Θ(q)?Θ(k),可得Θf(q,m)?Θ(q)=Θf(k,m)?Θ(k),,所以Θf(q,m)?Θ(q)的結(jié)果是一個只與m相關(guān),、跟q無關(guān)的函數(shù),記為φ(m),,即Θf(q,m)=Θ(q)+φ(m) - 接著令n=m?1代入Θf(q,m)?Θf(k,n) = Θg(q,k,m?n),,可以得到 Θf(q,m)?Θf(k,m-1) = Θg(q,k,1)
然后將 Θf(q,m) 和 Θf(k,m-1) 的等式代入Θf(q,m)=Θ(q)+φ(m),,我們可以得到 Θ(q) + φ(m) - (Θ(k) + φ(m-1)) = Θg(q,k,1),整理一下就得到
即{φ(m)}是等差數(shù)列,,設右端為θ,,那么就解得φ(m)=mθ
綜上,我們得到二維情況下用復數(shù)表示的RoPE:
- 所以說,,尋求的變換就是,,也就是給乘以,相應地,,乘以
做了這樣一個變換之后,,根據(jù)復數(shù)的特性,有: 也就是,,如果把二維向量看做復數(shù),,那么它們的內(nèi)積,等于一個復數(shù)乘以另一個復數(shù)的共軛,,得到的結(jié)果再取實部,,代入上面的變換,也就有: 這樣一來,,內(nèi)積的結(jié)果就只依賴于,,也就是相對位置了 換言之,經(jīng)過這樣一番操作,,通過給Embedding添加絕對位置信息,,可以使得兩個token的編碼,經(jīng)過內(nèi)積變換(self-attn)之后,,得到結(jié)果是受它們位置的差值,,即相對位置影響的
于是,對于任意的位置為的二維向量,,把它看做復數(shù),,乘以,而根據(jù)歐拉公式,,有:
從而上述的相乘變換也就變成了(過程中注意:):
把上述式子寫成矩陣形式:
而這個變換的幾何意義,,就是在二維坐標系下,對向量進行了旋轉(zhuǎn),,因而這種位置編碼方法,被稱為旋轉(zhuǎn)位置編碼
根據(jù)剛才的結(jié)論,,結(jié)合內(nèi)積的線性疊加性,,可以將結(jié)論推廣到高維的情形??梢岳斫鉃?,每兩個維度一組,進行了上述的“旋轉(zhuǎn)”操作,然后再拼接在一起:
由于矩陣的稀疏性,,會造成計算上的浪費,,所以在計算時采用逐位相乘再相加的方式進行:
其中為矩陣逐位相乘操作
3.2 旋轉(zhuǎn)位置編碼的coding實現(xiàn)(分非LLaMA版和LLaMA版兩種)
原理理解了,接下來可以代碼實現(xiàn)旋轉(zhuǎn)位置編碼,,考慮到LLaMA本身的實現(xiàn)不是特別好理解,,所以我們先通過一份非LLaMA實現(xiàn)的版本,最后再看下LLaMA實現(xiàn)的版本
對于,,非LLaMA版的實現(xiàn),,其核心就是實現(xiàn)下面這三個函數(shù) (再次強調(diào),本份關(guān)于RoPE的非LLaMA版的實現(xiàn) 與上面和之后的代碼并非一體的,,僅為方便理解RoPE的實現(xiàn))
3.2.1 非LLaMA版的實現(xiàn)
3.2.1.1 sinusoidal_position_embedding的編碼實現(xiàn)
sinusoidal_position_embedding:這個函數(shù)用來生成正弦形狀的位置編碼,。這種編碼用來在序列中的令牌中添加關(guān)于相對或絕對位置的信息
def sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, device): position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(-1) ids = torch.arange(0, output_dim // 2, dtype=torch.float) theta = torch.pow(10000, -2 * ids / output_dim) # (max_len, output_dim//2) # 即公式里的:pos / (10000^(2i/d)) embeddings = position * theta # (max_len, output_dim//2, 2) embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) # (bs, head, max_len, output_dim//2, 2) embeddings = embeddings.repeat((batch_size, nums_head, *([1] * len(embeddings.shape)))) # (bs, head, max_len, output_dim) # reshape后就是:偶數(shù)sin, 奇數(shù)cos了 embeddings = torch.reshape(embeddings, (batch_size, nums_head, max_len, output_dim)) embeddings = embeddings.to(device)
一般的文章可能解釋道這個程度基本就over了,,但為了讓初學者一目了然計,,我還是再通過一個完整的示例,來一步步說明上述各個步驟都是怎么逐一結(jié)算的,,整個過程和之前此文里介紹過的transformer的位置編碼本質(zhì)上是一回事..
為方便和transformer的位置編碼做對比,,故這里也假定output_dim = 512
- 首先,我們有 ids 張量,,當 output_dim 為 512 時,,則
, ,, ,, , ,, ,, ... , ,, ids = [0,0, 1,1, 2,2, ..., 254,254, 255,255] 然后我們有一個基數(shù)為10000的指數(shù)運算,,使用了公式 torch.pow(10000, -2 * ids / output_dim) - 執(zhí)行 embeddings = position * theta 這行代碼,它會將 position 的每個元素與 theta 的相應元素相乘,,前三個元素為
- 接下來我們將對 embeddings 的每個元素應用 torch.sin 和 torch.cos 函數(shù)
對于 torch.sin(embeddings),,我們將取 embeddings 中的每個元素的正弦值:
對于 torch.cos(embeddings),我們將取 embeddings 中的每個元素的余弦值:
最后,,torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) 將這兩個新的張量沿著一個新的維度堆疊起來,,得到的 embeddings如下 - 最終,得到如下結(jié)果
[sin(\frac{0}{10000^{\frac{0}{512}}}), cos(\frac{0}{10000^{\frac{0}{512}}}), sin(\frac{0}{10000^{\frac{2}{512}}}), cos(\frac{0}{10000^{\frac{2}{512}}}), ..., cos(\frac{0}{10000^{\frac{510}{512}}})], [sin(\frac{1}{10000^{\frac{0}{512}}}), cos(\frac{1}{10000^{\frac{0}{512}}}), sin(\frac{1}{10000^{\frac{2}{512}}}), cos(\frac{1}{10000^{\frac{2}{512}}}), ..., cos(\frac{1}{10000^{\frac{510}{512}}})], [sin(\frac{2}{10000^{\frac{0}{512}}}), cos(\frac{2}{10000^{\frac{0}{512}}}), sin(\frac{2}{10000^{\frac{2}{512}}}), cos(\frac{2}{10000^{\frac{2}{512}}}), ..., cos(\frac{2}{10000^{\frac{510}{512}}})]
3.2.1.2 RoPE的編碼實現(xiàn)
RoPE:這個函數(shù)將相對位置編碼(RoPE)應用到注意力機制中的查詢和鍵上,。這樣,,模型就可以根據(jù)相對位置關(guān)注不同的位置
import torch.nn.functional as F # q,k: (bs, head, max_len, output_dim) # (bs, head, max_len, output_dim) pos_emb = sinusoidal_position_embedding(batch_size, nums_head, max_len, output_dim, q.device) # cos_pos,sin_pos: (bs, head, max_len, output_dim) # 看rope公式可知,,相鄰cos,sin之間是相同的,,所以復制一遍,。如(1,2,3)變成(1,1,2,2,3,3) cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 將奇數(shù)列信息抽取出來也就是cos 拿出來并復制 sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 將偶數(shù)列信息抽取出來也就是sin 拿出來并復制 # q,k: (bs, head, max_len, output_dim) q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1) q2 = q2.reshape(q.shape) # reshape后就是正負交替了 q = q * cos_pos + q2 * sin_pos k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1) k = k * cos_pos + k2 * sin_pos
老規(guī)矩,為一目了然起見,,還是一步一步通過一個示例來加深理解
- sinusoidal_position_embedding函數(shù)生成位置嵌入,。在output_dim=512的情況下,每個位置的嵌入會有512個維度,,但為了簡單起見,,我們只考慮前8個維度,前4個維度為sin編碼,,后4個維度為cos編碼,。所以,我們可能得到類似以下的位置嵌入
# 注意,,這只是一個簡化的例子,,真實的位置嵌入的值會有所不同。 pos_emb = torch.tensor([[[[0.0000, 0.8415, 0.9093, 0.1411, 1.0000, 0.5403, -0.4161, -0.9900], [0.8415, 0.5403, 0.1411, -0.7568, 0.5403, -0.8415, -0.9900, -0.6536], [0.9093, -0.4161, -0.8415, -0.9589, -0.4161, -0.9093, -0.6536, 0.2836]]]])
- 然后,,我們提取出所有的sin位置編碼和cos位置編碼,,并在最后一個維度上每個位置編碼進行復制
sin_pos = pos_emb[..., ::2].repeat_interleave(2, dim=-1) # 提取出所有sin編碼,并在最后一個維度上復制 cos_pos = pos_emb[..., 1::2].repeat_interleave(2, dim=-1) # 提取出所有cos編碼,,并在最后一個維度上復制
- 更新query向量
我們首先構(gòu)建一個新的q2向量,,這個向量是由原來向量的負的cos部分和sin部分交替拼接而成的 我們用cos_pos對q進行元素級乘法,用sin_pos對q2進行元素級乘法,,并將兩者相加得到新的query向量 q2 = torch.stack([-q[..., 1::2], q[..., ::2]], dim=-1).flatten(start_dim=-2) # q2: tensor([[[[-0.2, 0.1, -0.4, 0.3, -0.6, 0.5, -0.8, 0.7], # [-1.0, 0.9, -1.2, 1.1, -1.4, 1.3, -1.6, 1.5], # [-1.8, 1.7, -2.0, 1.9, -2.2, 2.1, -2.4, 2.3]]]]) q = q * cos_pos + q2 * sin_pos
公式表示如下 - ?????更新key向量
對于key向量,,我們的處理方法與query向量類似 k2 = torch.stack([-k[..., 1::2], k[..., ::2]], dim=-1).flatten(start_dim=-2) # k2: tensor([[[[-0.15, 0.05, -0.35, 0.25, -0.55, 0.45, -0.75, 0.65
3.2.1.3 attention的編碼實現(xiàn)
attention:這是注意力機制的主要功能
- 首先,如果use_RoPE被設置為True,,它會應用RoPE,,通過取查詢和鍵的點積(并進行縮放)
- 然后,進行softmax操作來計算注意力分數(shù),,以得到概率,,輸出是值的加權(quán)和,權(quán)重是計算出的概率
- 最后,,旋轉(zhuǎn)后的q和k計算點積注意力后,,自然就具備了相對位置信息
def attention(q, k, v, mask=None, dropout=None, use_RoPE=True): # q.shape: (bs, head, seq_len, dk) # k.shape: (bs, head, seq_len, dk) # v.shape: (bs, head, seq_len, dk) # (bs, head, seq_len, seq_len) att_logits = torch.matmul(q, k.transpose(-2, -1)) att_logits /= math.sqrt(d_k) # 對權(quán)重進行mask,將為0的部分設為負無窮大 att_scores = att_logits.masked_fill(mask == 0, -1e-9) # (bs, head, seq_len, seq_len) att_scores = F.softmax(att_logits, dim=-1) att_scores = dropout(att_scores) # 注意力權(quán)重與值的加權(quán)求和 # (bs, head, seq_len, seq_len) * (bs, head, seq_len, dk) = (bs, head, seq_len, dk) return torch.matmul(att_scores, v), att_scores if __name__ == '__main__': # (bs, head, seq_len, dk) q = torch.randn((8, 12, 10, 32)) k = torch.randn((8, 12, 10, 32)) v = torch.randn((8, 12, 10, 32)) res, att_scores = attention(q, k, v, mask=None, dropout=None, use_RoPE=True) # (bs, head, seq_len, dk), (bs, head, seq_len, seq_len) print(res.shape, att_scores.shape)
3.2.2 LLaMA版的實現(xiàn)
接下來,,我們再來看下LLaMA里是怎么實現(xiàn)這個旋轉(zhuǎn)位置編碼的,,具體而言,LLaMA 的model.py文件里面實現(xiàn)了旋轉(zhuǎn)位置編碼(為方便大家理解,,我給相關(guān)代碼 加了下注釋) 首先,,逐一實現(xiàn)這三個函數(shù) precompute_freqs_cis reshape_for_broadcast apply_rotary_emb
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # 計算頻率 t = torch.arange(end, device=freqs.device) # 根據(jù)結(jié)束位置生成序列 freqs = torch.outer(t, freqs).float() # 計算外積得到新的頻率 freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # 計算復數(shù) return freqs_cis # 返回復數(shù)
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim # 獲取輸入張量的維度 assert 0 <= 1 < ndim # 檢查維度的合理性 assert freqs_cis.shape == (x.shape[1], x.shape[-1]) # 檢查復數(shù)的形狀 shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # 計算新的形狀 return freqs_cis.view(*shape) # 重塑復數(shù)的形狀并返回
) -> Tuple[torch.Tensor, torch.Tensor]: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # 將xq視為復數(shù) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # 將xk視為復數(shù) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) # 重塑復數(shù)的形狀 xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) # 計算xq的輸出 xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) # 計算xk的輸出 return xq_out.type_as(xq), xk_out.type_as(xk) # 返回xq和xk的輸出
之后,在注意力機制的前向傳播函數(shù)中調(diào)用上面實現(xiàn)的第三個函數(shù) apply_rotary_emb,,賦上位置信息
# 對Query和Key應用旋轉(zhuǎn)嵌入 xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
第四部分 線性偏差注意力ALiBi
模型名稱 | 隱藏層維度 | 層數(shù) | 注意力頭數(shù) | 詞表大小 | 訓練數(shù)據(jù)(tokens) | 位置編碼 | 最大長度 | Baichuan-7B | 4,096 | 32 | 32 | 64,000 | 1.2 萬億 | RoPE | 4,096 | Baichuan-13B | 5,120 | 40 | 40 | 64,000 | 1.4 萬億 | ALiBi | 4,096 | Baichuan 2-7B | 4096 | 32 | 32 | 125,696 | 2.6萬億 | RoPE | 4096 | Baichuan 2-13B | 5120 | 40 | 40 | 125,696 | 2.6萬億 | ALiBi | 4096 |
注意看上表的位置編碼那一列,,baichuan 7B無論第一代還是第二代,位置編碼均用的RoPE,,而baichuan 13B則無論是第一代還是第二代,,均用的ALiBi
下面便詳細介紹下該ALiBi
4.1 什么是ALiBi
ALiBi全稱是Attention with Linear Biases,通過論文《Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation》提出,,其不像標準transformer那樣,,在embedding層添加位置編碼,而是在softmax的結(jié)果后添加一個靜態(tài)的不可學習的偏置項(說白了,,就是數(shù)值固定)
具體而言,,如下圖所示
- 當計算每個頭的注意力分數(shù)時,線性偏差注意力方法ALiBi會向每個注意力分數(shù)(,,左)添加一個常數(shù)偏差(右)
When computing attention scores for each head, our linearly biased attention method, ALiBi, adds a constant bias (right) to each attention score (qi· kj , left). 左邊是自注意力得分,,關(guān)于q和k的內(nèi)積 右邊是一個相對距離的矩陣, q1 q2 q3 q4 q5 k1 k2 k3 k4 k5 所以才有 q1和k1之間的距離是0,,所以對應位置就是0 q2和k1之間的距離是「相對位置偏移為“k的索引”1」 - 「q的索引2」,,得到1-2 = -1,就對應到了中間矩陣的取值為-1了 以此類推,,相對距離矩陣的中間對角線上都是0,,然后左下角的取值都是對應的「k的索引」-「q的索引」了 - 那m具體怎么取值呢,按論文中的說法是
當8個heads的時候,,m的取值為:
如果是16個heads,,則m的取值為:
相當于追加了一半的1/sqrt(2)到原來的8個head的每個m的取值
擴展到一般情況就是:對于n個head的話,m的取值就是,,即如下 ,, 這樣的m個坡度了
最終整體的公式便是
對于第i個query來說,他們之間的相對距離就是:k的索引 - q的索引 具體而言,,k的索引 遍歷,,而q的索引 取值為
// 待更
第五部分 LLaMA 2 Long中位置編碼的修改
5.1 LLaMA 2 Long相比LLaMA 2的變化:修改位置編碼 長度達到32K
23年9月底[Submitted on 27 Sep 2023 (v1), last revised 14 Nov 2023 (this version, v3)],GenAI, Meta正式發(fā)布LLaMA 2 Long(這是其論文《Effective Long-Context Scaling of Foundation Models》),,與LLaMA 2相比,,LLaMA 2 Long的變化主要體現(xiàn)在以下兩點
- 一是訓練參數(shù)上,采用了高達4000億token的數(shù)據(jù)源(We build our models by continually pretraining from LLAMA 2 checkpoints with additional 400 billion tokens formed as long training sequences)
——相反,,原始LLaMA 2包含多個變體,,但最多的版本也只有700億 - 二是架構(gòu)上,與LLaMA 2保持不變,,但對位置編碼進行了一個非常小的必要修改,,以此完成高達3.2萬token的上下文窗口支持
5.1.1 LLaMA 2 Long中的位置編碼做了怎樣的修改
在LLaMA 2中,,它的位置編碼采用的是旋轉(zhuǎn)編碼RoPE方法,其通過旋轉(zhuǎn)矩陣來實現(xiàn)位置編碼的外推
- 本質(zhì)上來說,,RoPE就是將表示單詞,、數(shù)字等信息的token embeddings映射到3D圖表上,給出它們相對于其他token的位置——即使在旋轉(zhuǎn)時也如此
- 這就能夠使模型產(chǎn)生準確且有效的響應,,并且比其他方法需要的信息更少,,因此占用的計算存儲也更小
然,Meta的研究人員通過對70億規(guī)模的LLaMA 2進行實驗,,確定了LLaMA 2中的RoPE方法的一個局限性,,即,阻止注意力模塊聚集遠處token的信息
為此,,Meta想出了一個非常簡單的破解辦法,,即
減少每個維度的旋轉(zhuǎn)角度(which essentially reduces the rotation angles of each dimension)
具體而言就是將超參數(shù)“基頻(base frequency)b”從10000增加到500000(increasing the “base frequency b” of ROPE from 10, 000 to 500, 000)
這個問題不要小瞧,值得好好細究下,,llama long原論文提到,,對于這個基頻參數(shù)的修改,除了llama long之外,,社區(qū)Reddit和code llama也同時運用了,,即:“We propose a simple modification to the default RoPE encoding to reduce the decaying effect – increasing the “base frequency b” of ROPE from 10, 000 to 500, 000, which essentially reduces the rotation angles of each dimension. The idea is also concurrently suggested in the Reddit r/LocalLLaMa community and Rozière et al. (2023).”
- 前者所謂社區(qū)Reddit的發(fā)現(xiàn)來源于這里:NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation,其對應的鏈接為(據(jù)ChatGPT的推斷,,該篇帖子發(fā)表于23年7月2日):https://www./r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/?rdt=46058
有意思的是,,后來大家把Reddit這位網(wǎng)友bloc97的發(fā)現(xiàn)命名為了NTK-aware方法,即“為了解決RoPE嵌入插值時丟失高頻信息(losing high frequency information when interpolating the RoPE embeddings)的問題,,Reddit一網(wǎng)友通過[NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation]開發(fā)了“NTK-aware”插值,,核心思想是:高頻外推,低頻內(nèi)插” 至于NTK-aware的詳細介紹參見此文《大模型長度擴展綜述:從直接外推ALiBi,、插值PI,、NTK-aware插值(對此介紹最詳)、YaRN到S2-Attention》的第三部分 - 后者Rozière et al. (2023)對應的是Code llama: Open foundation models for code, 2023「其Submitted on 24 Aug 2023 (v1), last revised 31 Jan 2024 (this version, v3),,原論文中提到:we increase the base period θ from 10,000 to 1,000,000 for fine-tuning」
至于對code llama的介紹詳見此文《代碼生成的原理解析:從Codex,、GitHub Copliot到CodeLlama,、CodeGeex》的第三部分
此外,在附錄中,Meta還通過可視化為螺旋圖這一非常有趣的方式,,將RoPE ABF與RoPE PI的差異進行了理論分析
- 上圖b旨在說明位置插值對映射向量相對位置的影響,,與上圖a相比,連續(xù)點之間的距離被大幅縮小
Figure 8b aims to illustrate the impact of Position Interpolation on the relative position of the mapped vectors. The distance between the consecutive points got reduced considerably compered to Figure8a. - 上圖c說明了調(diào)整基頻對結(jié)果的影響(The impact of Adjusted Base Frequency is illustrated on Figure 8c)
雖然螺旋頻率增加導致點之間最小距離縮小(although the minimal distance between points got considerably reduced due to the increased frequency of the helix) 但連續(xù)點之間的距離幾乎與上圖a相同(The distance between the consecutive points remained almost the same as on Figure 8a) 即螺旋頻率增加所帶來的影響將在高頻段中逐漸減少(This effect of increasedfrequency of the helix would be reduced in the high)
總之,,與RoPE PI相比,,RoPE ABF的優(yōu)勢主要體現(xiàn)在它能以更大的粒度分配嵌入向量(the embedded vectors),從而使模型更容易區(qū)分位置
此外,,他們還觀察到,嵌入向量之間的相對距離既對RoPE PI的關(guān)鍵參數(shù)有線性依賴性,,也對RoPE ABF的關(guān)鍵參數(shù)也有對數(shù)依賴性,。
這也就是為什么可以很容易地對基頻這一超參數(shù)“下手”
5.1.2 改動之后的效果
這一改動立刻奏效,縮小了RoPE對遠端token的衰減效應,,并且在擴展LLAMA的上下文長度上優(yōu)于一項類似的名為“位置插值”的方法RoPE PI(如下圖所示,,RoPE表示基線方法,RoPE ABF為Meta此次發(fā)明的新方法,,xPos是另一種應用了該方法的旋轉(zhuǎn)編碼變體)
然,,一個問題是,通過上面這個可視化結(jié)果,,Meta觀察到RoPE在長程區(qū)域出現(xiàn)了較大的“振蕩”,,這對于語言建模來說可能不是個好消息
不過,通過報告幾種方法在長序列困惑度和FIRST-SENTENCE-RETRIEVAL兩個任務上的表現(xiàn)來看,,問題不大
而且,,尤其在后者任務上,他們提出的RoPE ABF是唯一一個可以始終保持性能的變體
最終,,LLaMA 2 Long憑借著這一改動,,達成了3.2萬的上下文token,并通過長下文連續(xù)預訓練的共同作用,,獲得了開頭所示的好成績:
除了全面超越LLaMA 2,、在特定任務上超越Claude 2和ChatGPT,Meta也給出了它和一些開源長下文模型的對比,。結(jié)果也相當不賴,,如下圖所示
//待更
后記
最后,說明下為何像開頭說的是「23年12.16日這天對本文做了大修」呢,,原因在于
- 我司《論文審稿GPT第2版》即將進入模型訓練階段,,其涉及到三個候選模型:mistral-yarn、mistral,、llama-longlora
故準備解析下YaRN,,順帶把外推、內(nèi)插都全面介紹下,,而過程中不可避免會提到RoPE,,故也總算把RoPE徹底寫清楚了 - 這些東西,,哪怕是近期最新的技術(shù)、模型等理解了后 會發(fā)現(xiàn)都不難,,但我總想把理解的門檻無限降低,,所以想真正寫清楚或講清楚一個東西,必須得反復琢磨,、反復修改,,以讓更多人因此看懂,更何況當我和我的團隊每天看paper,、做項目,,更可以幫到大家不斷進階、深入
如今博客的訪問PV2000萬,,希望明年達到2000萬UV以上,,以上視為后記
參考文獻與推薦閱讀
- 馬同學關(guān)于向量和歐拉公式的幾篇科普文章
向量的加法 歐拉公式,,復數(shù)域的成人禮 - 關(guān)于歐拉公式的幾篇文章
被眾人膜拜的歐拉恒等式是個什么東東? 怎么向小學生解釋歐拉公式 e^(πi)+1=0,? - 讀懂旋轉(zhuǎn)編碼(RoPE)
- LLM學習記錄(五)--超簡單的RoPE理解方式,這篇文章很不錯
- 蘇劍林:Transformer升級之路:2,、博采眾長的旋轉(zhuǎn)式位置編碼
- LLaMA的解讀與其微調(diào):Alpaca-LoRA/Vicuna/BELLE/中文LLaMA/姜子牙/LLaMA 2
- 關(guān)于ALiBi的兩篇文章
[速讀經(jīng)典]ALiBi - 給注意力加上線性偏置 關(guān)于Transformer中的位置編碼-ALiBi - 最強LLaMA突然來襲,!只改一個超參數(shù),,實現(xiàn)上下文3.2萬token,多個任務打敗ChatGPT,、Claude 2
|