Introduction
在數位電路裡面,如果一個電路沒有 latch 或 flip flop 這類的元件,它的輸出值只會取決於目前的輸入值,和上個時間點的輸入值是無關的,這種的電路叫作 combinational circuit 。
對於類神經網路而言,如果它的值只是從輸入端一層層地依序傳到輸出端,不會再把值從輸出端傳回輸入端,這種神經元就相當於 combinational circuit ,也就是說它的輸出值只取決於目前時刻的輸入值,這樣的類神經網路稱為 feedforward neural network 。
如果一個電路有 latch 或 flip flop 這類的元件,它的輸出值就跟上個時間點的輸入值有關,這種的電路它稱為 sequential circuit 。
所謂的 Recurrent Neural Network ,是一種把輸出端再接回輸入端的類神經網路,這樣可以把上個時間點的輸出值再傳回來,記錄在神經元中,達成和 latch 類似的效果,使得下個時間點的輸出值,跟上個時間點有關,也就是說,這樣的神經網路是有 記憶 的。
Recurrent Neural Network
由一個簡單神經元所構成的 Recurrent Neural Network ,構造如下:
這個神經元在 時間,訓練資料的輸入值為 ,訓練資料的答案為 ,神經元 的輸出值 ,可用以下公式表示:
其中, 為輸入神經元 的值, 是給目前的時間(current)時,輸入值 的權重, 是給上個時間點(previous)時,輸出值 的權重,而 為 bias 。從上圖可看出,紫色的線將神經網路的輸出端 連回輸入端 ,使得於時間 的輸出值跟上個時間點 的輸出值有關。
可以把這個神經元從時間點 到時間點 的運算,展開成下圖:
從上圖,最左邊開始,依序將 輸入神經元 ,而依序得出的值為 。神經元 在時間點 的輸出值 ,會接到時間點 時的輸入值 。
Training Recurrent Neural Network
訓練 recurrent neural network 的方法,和訓練 feedforward neural network 的方法一樣,都可以用 back propagation 。但是在 recurrent neural network 中,要依據時間順序,將值從最後一個時間點,回傳到第一個時間點。
在時間點 時的 cost function 為:
計算 recurrent neural network 的 back propagation 要分為兩部分來算,先算好時間點位於 的偏微分 值,再依序往前算出時間點 之前的偏微分值,如下:
其中, 為 到 中的其中一個時間點。用 Backward Propagation 詳細推導過程 所提到的推導方法,可推導出 、 與 的值,並令 代入以上公式,得出:
此公是可分為兩部分,當 時,與 時。計算 的方式不同。
在 時, 的傳遞過程就如同 feedforward neural network ,如下圖:
若 時, 要算 之前,要先從 時間點將 傳遞過來,傳遞過程如下圖:
因為需要把 從後面的時間點往前面傳,故這個過程又稱為 back propagation through time 。
於時間點 計算完 後,用以下公式將 時間點算出的偏微分值,更新到神經元的權重:
用 Backward Propagation 詳細推導過程 ,求出 、 和 的值,並換成用 代替 代入以上公式,得出:
此過程如下圖所示:
Implementation
再來是實作的部分,以下是個簡單的應用,用 Recurrent Neural Network 來預測一個字串序列中,下一個可能出現的字是什麼。例如,給定以下字串:
1
|
|
根據這個字串的特徵,如果連續出現了兩個 0 ,可以預測下個出現的為 1 ,若前面兩個字為 10 則可預測下個出現的自為 0 ,以此類推。
以下為實作部分:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
|
其中, x
為輸入的序列, n_out
為神經元預測的結果。進行這個演算法之前,首先,先給權重 w_p, w_c, w_b
的初始值用介於 -0.5~0.5 之間的隨機值。再來是進行訓練過程,用 for loop 進行了 10000 次的訓練,在每次的訓練過程中,先進行 forward propagation 依時間順序,算出每個時間點的 n_out
。再來是用 back propagation through time 來更新 w_p, w_c, w_b
的值。訓練完後,進行一次 forward propogation 用訓練過程得出的權重來預測序列的下一個字,並將預測結果印出。
到 interactive mode 執行以下程式,輸入序列 001001001 。
1 2 3 4 5 6 7 8 9 10 11 |
|
左側為輸入序列,右側為預測的結果,可以發現 recurrent neural network 可以預測出下個字可能會是 0 還是 1 。當左側為 1 時,右側的數字會接近於 1 。
Further Reading
關於 recurrent neural network 可參考 coursera 課程 Geoffrey Hinton. Neural Networks for Machine Learning
https://www.coursera.org/course/neuralnets