KMP 算法+Python 实现

KMP 算法

字符串匹配有很多算法可以实现,Knuth-Morris-Pratt 算法(简称KMP)是最常用的之一,时间复杂度为 O(m+n),其中 m 为模式串的长度,n 为目标串的长度。
KMP 算法的核心为部分匹配表 PMT(Partial Match Table),在理解 PMT 之前,需要先了解字符串的前缀与后缀分别指什么。

  • 前缀:对于字符串 AB,若存在任意非空字符串 S,使得 A=BS,则称 BA 的前缀。如 hello 的前缀集合为 [hell, hel, he, h]
  • 后缀:对于字符串 AB,若存在任意非空字符串 S,使得 A=SB,则称 BA 的后缀。如 hello 的后缀集合为 [ello, llo, lo, o]

那么,PMT 中的值即为当前位置字符串的前缀集合与后缀集合中公共元素的最大长度。比如字符串 ababa 的 PMT 值为 2,因为前缀集合为 [abab, aba, ab, a],后缀集合为 [baba, aba, ba, a],公共元素为 abaa,最大长度为 2。
如何使用 PMT 来加速字符串的匹配呢?
对于字符串 ababababca 与 模式串 abababca,我们首先给出模式串的 PMT 数组:

开始匹配时,会在下图红色位置匹配失败,这时需要回溯,普通方法直接令 i=1, j=0 从头开始匹配,这样复杂度极高,KMP 算法利用模式串 PMT 数组的特性,可以回溯到指定位置而减少一部分不必要的匹配。

我们目前有以下两点信息:

  1. 由图可以知道j 位置之前的 pattern 部分 与 i 位置之前的 string 同长度部分是完全相同的(用灰色底色显示);
  2. 此时 j 位置之前的 pattern 部分字符串 PMT 值为 4,说明此子串前四个字母和后四个字母完全相同。

结合这两点可知,i 位置之前长度为 4 的子串与 pattern 开头的长度为 4 的子串完全相同。那这部分直接跳过,匹配下一个位置即可!!如下图:

综上,每次失配时只需要回溯到 pattern 当前位置的上一个位置的 PMT 值的位置开始匹配就可以,那么将 PMT 数组统一向右移一位得到 next 数组,表示当前位置失配时需要回溯的目标位置,即:

最开始的位置补 0 只是为了写代码方便,用来指示已经回溯到 pattern 串的开头位置了。

Python 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class KMP(object):
"""KMP algorithm."""
def kmp(self, s, pattern):
"""KMP string matching function.

Args:
s: str, target string.
pattern: str, pattern string.

Returns:
index: int, the starting matching index of s (-1 if cannot match).
"""
__next = self.get_next(pattern)
i = j = 0 # 分别为 s 和 pattern 的指针
while i < len(s) and j < len(pattern):
if j == -1 or s[i] == pattern[j]:
# j 为 -1 是由 j = __next[j] 回溯产生的,即回溯到了 pattern 开头。
# 说明此处没有公共前缀与后缀,两个指针同时后移,相当于从 pattern 的头部开始重新匹配。
i += 1
j += 1
else:
j = __next[j]
if j == len(pattern):
return i - j
return -1

def get_next(self, pattern):
"""Get 'next' array for a pattern string."""
# 注:得到 next 数组的方法相当于 pattern 自身与自身的匹配算法
# 不同的是,因为 next[0] = -1,所以相当于 pattern 与 pattern[1:] 匹配
__next = [0] * len(pattern)
__next[0] = -1
i = 1 # pattern[1:] 的指针
j = 0 # pattern 的指针
while i < len(pattern) - 1: # pattern[1:] 长度只有 len(pattern) - 1
if j == -1 or pattern[i] == pattern[j]:
i += 1
j += 1
__next[i] = j
else:
j = __next[j]
return __next


if __name__ == "__main__":
s = "ababababca"
p = "abababca"
# print("Next:", KMP().get_next(p))
print(KMP().kmp(s, p))

参考

https://www.zhihu.com/question/21923021/answer/281346746
http://www.ruanyifeng.com/blog/2013/05/Knuth%E2%80%93Morris%E2%80%93Pratt_algorithm.html