Mark Chang's Blog

Machine Learning, Deep Learning and Python

Viterbi Algorithm

1.Introduction

本文接續先前提到的 Hidden Markov Model

Natural Language Processing – Hidden Markov Model

繼續探討 part of speech tagging 的演算法

先前提到, 如果要在 Hidden Markov Model 找出一個機率最大的 tagging sequence

則必須把每一個序列都列出來, 看哪一個是機率最大的

但如果 Tag 種, 那麼長度為 的序列, 就有 種可能的 tagging sequence

由此可知, 暴力列舉的演算法非常沒有效率

2.Viterbi Algorithm

那麼, 來看看暴力列舉法, 到底出了什麼問題

在計算這些序列的機率時, 假設目前已經計算到了第 個字 , 這個字的 tag 時, 上一個字的 tag , 則此種情況發生的機率為

其中, 有三種可能, 分別是 , ,

假設在 之前的序列為 ,且在 這三個 state 的機率分別為

若是使用暴力列舉法, 要分別計算 , , 這三個序列的機率, 如下圖

viterbi1

由上圖可知 , 在 這個 state ,就多增加了 3 個序列傳遞下去, 之後每個 state 都會因為前面有 N 種不同的 state , 而增加 N 種不同的序列, 這樣一直增加, 序列的數量呈指數函數成長, 最後再一起比誰的機率比較大

這就是造成暴力列舉法沒效率的原因

其實, 可以用 Dynamic Programming 的概念, 在計算每個 state 的機率時, 就直接比較序列的機率大小,

只保留機率最大的一條序列, 傳遞下去, 如下圖

viterbi2

這樣每一個 state 就只會傳遞一個序列下去, 不會使序列數量呈指數成長

這就是所謂的 Viterbi Algorithm

3.Example

舉個例子,

有個研究者, 想根據某地人們生活日記中, 記載每天吃冰淇淋的數量, 來推斷當時的天氣變化如何

在某個地點有兩種天氣, 分別是 HotCold , 而當地的人們會記錄他們每天吃冰淇淋的數量, 數量分別為 1 , 23 ,

則可以把天氣變化的機率, 以及天氣吃冰淇淋數量的關係, 用 Hidden Markov Model 表示,

由於天氣是未知的, 為 hidden state , 天氣的集合為

而冰淇淋數量是已知的, 為 observable , 冰淇淋數量的集合為

天氣的 Transition Matrix , , 以及天氣變化對於冰淇淋數量的 Output Matrix , 如下

如果冰淇淋的記錄 , 用 Viterbi Algorithm 計算看看可能的天氣序列是什麼

首先, 從 initial state 開始, 計算第一個 state 以及 的機率

計算結果如下圖

p1

再來, 我們來計算第二個 state 的機率, 我們分別要算序列 的序列

結果如下圖:

p2

Viterbi Algorithm , 在同一個 state 只需要保留機率最大的序列即可

因此我們在 state 上, 只需要保留機率為 的序列 , 傳遞下去, 如下圖

p3

就用這樣的概念, 之後的每一個 state 都這樣計算, 就會得到每個 state 的機率值, 如下圖

p4

全部算完後, 再比較序列的最後一個 state, 哪一個機率比較大, 較大者可以傳遞到 end state (也就是最後的答案)

然後, 從 end state 回溯之前保留下來的序列, 如下圖

p6

得出最有可能的序列是 HCC

動畫版:

p7

4.Implementation

接著我們來實作一下, 先前提到的暴力列舉法, 以及 Viterbi 演算法

並比較這兩者的performance差異

新增一個python script檔案, 檔名為 viterbi.py ,並加入以下程式碼

1.前例中的Model, 和必要模組

viterbi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import timeit
_STATE=['H','C']
_PI={'H':.8, 'C':.2}
_A={ 'H':{'H':.7, 'C':.3 }, 'C':{'H':.4,'C':.6} }
_B={'H':{1:.2,2:.4,3:.4}, 'C':{1:.5,2:.4,3:.1} }

def p_aij(i, j):
    return _A[i][j]

def p_bik(i, k):
    return _B[i][k]

def p_pi(i):
    return _PI[i]

其中, timeit 是用來計時的模組, 用於比較演算法所花的時間,

_STATE , _PIstate 的種類和 initial state 的機率,

_A_BTransition Matrix 和 *Output Matrix’,

p_aij(i, j) , p_bik(i, k) , p_pi(i) 分別是 ,

2.暴力列舉法的function: brute_force_algo(obs_init, print_seq=False)

viterbi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def brute_force_algo(obs_init, print_seq=False):
    start = timeit.default_timer()
    seq_val=[];
    def rec(obs, val_pre, qseq_pre):
        if len(obs) >0:
            for q in _STATE:
                if len(qseq_pre) == 0 :
                    val = val_pre * p_pi(q) * p_bik(q, obs[0])
                else:
                    q_pre = qseq_pre[-1]
                    val = val_pre * p_aij(q_pre,q) * p_bik(q, obs[0])
                qseq = qseq_pre + [q]
                rec(obs[1:], val, qseq)
        else:
            seq_val.append((qseq_pre, val_pre))
    rec(obs_init, 1, [])
    if print_seq:
        for (seq,val) in seq_val:
            print 'seq : %s , value : %s'%(seq, val)
    print 'result of brute_force_algo:'
    stop = timeit.default_timer()
    print 'max_seq : %s  max_val : %s'%(
          reduce(lambda x1,x2: x2 if x2[1] > x1[1] else x1, seq_val))
    print 'runtime : %s'%(stop - start )

其中, input argument obs_initobservable ,

print_seq 是用來控制是否要印出計算過程中產生的序列, 或者只印出最後結果

演算法用 recursive function 的方式實現, 詳細內容在此不詳述,

3.viterbi演算法的function: viterbi_algo(obs_init, print_seq=False)

viterbi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
def viterbi_algo(obs_init, print_seq=False):
    start = timeit.default_timer()
    state_snapshot=[]
    def rec(obs, val_pre, qseq_pre):
        if len(obs) > 0:
            val = {}
            qseq = {}
            for q in _STATE:
                if len(val_pre) == 0:
                    val.update({ q:p_pi(q) * p_bik(q, obs[0]) })
                    qseq.update({q:[]})
                    state_snapshot.append(([q],val[q]))
                else:
                    val_temp = [( qseq_pre[q_pre]+[q_pre],
                                val_pre[q_pre] * p_aij(q_pre,q) * p_bik(q, obs[0] ))
                                for q_pre in _STATE ]
                    max_q_seq = reduce(lambda x1,x2: x2 if x2[1] > x1[1] else x1, val_temp)
                    state_snapshot.append((max_q_seq[0]+[q],max_q_seq[1]))
                    val.update({ q:max_q_seq[1]  })
                    qseq.update({ q:max_q_seq[0] })
            return rec(obs[1:],val,qseq)
        else:
            val_temp =[( qseq_pre[q]+[q] , val_pre[q] ) for q in _STATE ]
            max_q_seq = reduce(lambda x1,x2: x2 if x2[1] > x1[1] else x1, val_temp)
            return max_q_seq
    seq,val = rec(obs_init, {},[])
    if print_seq:
        for (seq,val) in state_snapshot:
            print 'seq : %s , value : %s'%(seq, val)
    print 'result of viterbi_algo:'
    print 'max_seq : %s , max_value : %s'%(seq, val)
    stop = timeit.default_timer()
    print 'runtime : %s'%(stop - start )

其中, 所使用的參數, 和暴力列舉法一樣

唯演算法部份改用 Viterbi algorithm , 但也是以 recursive function 的形式寫成, 在此不詳述

接著到interactive mode 載入 viterbi.py

1
>>> import viterbi

先來看一下暴力列舉法

1
2
3
4
5
6
7
8
9
10
11
12
>>> viterbi.brute_force_algo([3,1,1],True)
seq : ['H', 'H', 'H'] , value : 0.006272
seq : ['H', 'H', 'C'] , value : 0.00672
seq : ['H', 'C', 'H'] , value : 0.00384
seq : ['H', 'C', 'C'] , value : 0.0144
seq : ['C', 'H', 'H'] , value : 0.000224
seq : ['C', 'H', 'C'] , value : 0.00024
seq : ['C', 'C', 'H'] , value : 0.00048
seq : ['C', 'C', 'C'] , value : 0.0018
result of brute_force_algo:
max_seq : ['H', 'C', 'C']  max_val : 0.0144
runtime : 0.000172138214111

暴力列舉法會列出所有的序列, 並找出機率最大的序列

再來是 viterbi algorithm

1
2
3
4
5
6
7
8
9
10
>>> viterbi.viterbi_algo([3,1,1],True)
seq : ['H'] , value : 0.32
seq : ['C'] , value : 0.02
seq : ['H', 'H'] , value : 0.0448
seq : ['H', 'C'] , value : 0.048
seq : ['H', 'H', 'H'] , value : 0.006272
seq : ['H', 'C', 'C'] , value : 0.0144
result of viterbi_algo:
max_seq : ['H', 'C', 'C'] , max_value : 0.0144
runtime : 0.000169038772583

Viterbi algorithm 不會把每個序列都列出來, 而是用 Dynamic Programming 的方式, 保留下比較有可能的 subsequence , 最後也可以得出正確結果 , 且runtime比暴力列舉法快

再來, 增加一下input sequence的長度, 讓這兩種演算法的差異突顯出來,

為了避免印出過多序列, 第二個參數輸入 False , 如下

1
2
3
4
5
6
7
8
9
>>> viterbi.brute_force_algo([3,1,1,2]*4,False)
result of brute_force_algo:
max_seq : ['H', 'C', 'C', 'H', 'H', 'C', 'C', 'H', 'H', 'C', 'C', 'H', 'H', 'C', 'C', 'C']  max_val : 2.83168745718e-11
runtime : 0.245212078094

>>> viterbi.viterbi_algo([3,1,1,2]*4,False)
result of viterbi_algo:
max_seq : ['H', 'C', 'C', 'H', 'H', 'C', 'C', 'H', 'H', 'C', 'C', 'H', 'H', 'C', 'C', 'C'] , max_value : 2.83168745718e-11
runtime : 0.000442981719971

比較一下, 暴力列舉法的 runtime 呈指數函數成長, 而 viterbi algorithm 的是呈多項式函數成長

5. Reference

本文參考至兩本教科書

Foundations of Statistical Natural Language Processing

Speech and Language Processing

以及台大資工系 陳信希教授的 自然語言處理 課程講義

Comments