KMP算法

KMP算法

KMP算法(全称Knuth-Morris-Pratt字符串查找算法,由三位发明者的姓氏命名)是可以在文本串s中快速查找模式串p的一种算法。

暴力查找

普通的字符串暴力查找

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
// 暴力匹配
int i = 0, j = 0;
while (i < s.length())
{
if (s[i] == p[j])
++i, ++j;
else
i = i - j + 1, j = 0;
if (j == p.length()) // 匹配成功
{
// 对s[i - j .. i - 1]进行一些操作
cout << i - j << endl;
i = i - j + 1;
j = 0;
}
}

对于文本串的每一位s[i],都需对模式串p[i]进行遍历,若文本串长m,模式串长p,最坏的情况下O(nm)。因此引入KMP算法以快速匹配模式串

KMP算法

基本概念

最长前缀串,后缀串,PMT

对于一个字符串s

我们将s[0, i](i < len)定义为他的前缀, s[j, len-1](j >= 0)定义为他的后缀。

i != len-1j != 0,这成为真前(后)缀。

在此基础上,引出PMT,即(Partial Match Table,部分匹配表

对于一个字符串s的s[i], 其子串s[0,i],其前后缀相等的最长长度,即为s[i]的border(真前缀真后缀两个集合的交集中,最长元素的长度)

例:对于串aaabaaab,其对应的border为0,1,2,0,1,2,3,4

具体原理

kmp函数

定义模式串s指针j, 匹配串t指针i, pmt数组pmt[MAXN]

1
2
a b a b c (i=0)
a b a (j=0)
  1. i遍历整个匹配串

  2. i++的过程中,对s[i]t[j],进行比较来控制j的位置。

    1. 若成功,则j++

    2. 若失败,则进行回跳。

      例:第n位匹配不成功,因为第n-1位匹配成功,根据前后缀相等,第n-1 位在前缀必有对应相等的值,

      且第n-1位为后缀末位,因此其对应前缀末尾,其下标正好等于前缀后缀的长度-1。然而需要进行比较的并不是前缀的最后一位,在回跳的前提下这位亦是已匹配好的,因此需比较的是下一位。所以j = pmt[j-1]-1+1 = pmt[j-1]

    3. j=strlen(j), 即匹配成功,输出结果并回跳。(为什么不回跳至开头:尽可能保留以匹配成功的串,防止遗漏)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
const int MAXN = 1e6 + 10;
int pmt[MAXN];
char s[MAXN], t[MAXN];

void kmp()
{
int slen = strlen(s);
int tlen = strlen(t);
int j = 0, i = 0; //j 是模式串上的指针,i为匹配串上的指针
for(j = 0, i = 0; i < slen; i++)
{
while(j && s[i]!=t[j]) j = pmt[j-1];//连续回跳 若j=0,即回跳已到开头,无法在向前退位
if(s[i]==t[j]) j++;
if(j == tlen) //匹配成功,跳回开头。
{
//添加输出等等
j = pmt[j-1];
}
}
}

getpmt函数

kmp的核心在于pmt数组,要获取pmt数组,即将自己的前缀和自己的后缀进行比较

1
2
a b a b a      (后缀)
a b a b a (前缀)

定义前缀串s指针j, 后缀串t指针i

i遍历整个串,j的位置即为前缀的长度,即pmt数组的值,遍历原理与模式串和匹配串的匹配逻辑相同。

1
2
3
4
5
6
7
8
9
10
11
12
void getpmt()//错开一位自己匹配自己
{
int tlen = strlen(t);
int i = 1, j = 0;
pmt[0] = 0;
for(i = 1, j = 0; i < tlen; i++)
{
while(j && t[i]!=t[j]) j = pmt[j-1]; // 匹配不成功,回跳
if(t[i] == t[j]) j++;
pmt[i] = j;
}
}

代码样例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
const int MAXN = 1e6 + 10;
int nxt[MAXN];
char s[MAXN], t[MAXN];

int main(int argc, char const *argv[])
{
scanf("%s%s", s, t);
int slen = strlen(s);
int tlen = strlen(t);
t[tlen] = '#';
for(int i = 1; i < tlen; i++)
{
int j = nxt[i - 1];
while(j && t[i] != t[j]) j = nxt[j - 1];
if(t[i] == t[j]) j++;
nxt[i] = j;
}
for(int i = 0, pre = 0, j = 0; i < slen; i++)
{
j = pre;
while(j && s[i] != t[j]) j = nxt[j - 1];
if(s[i] == t[j]) j++;

if(j == tlen)
printf("%d\n", 1 + (i - tlen + 1));
pre = j;
}
for(int i = 0; i < tlen; i++)
printf("%d ", nxt[i]);

return 0;