KMP 알고리즘이란 문자열 검색 알고리즘으로 Knuth-Morris-Pratt의 이름을 줄여서 KMP라고 부른다.
어떤 긴 문자열 A에서 부분 문자열 B를 갖는 모든 위치를 구한다.
예를들어 A = "abcdabcef"이고, B = "abc"일때 "abc"를 포함하는 모든 시작 위치는 0과 4이므로 0과 4를 구하는 알고리즘이다.
kmp 알고리즘의 소스 코드를 보자
vector<int> kmp(const string& a, const string& b)
{
int n = a.size(), m = b.size(), j = 0;
vector<int> pi = getPi(b);
vector<int> res;
for (int i = 0; i < n; i++) {
while (j > 0 && a[i] != b[j])
j = pi[j - 1];
if (a[i] == b[j]) {
if (j == m - 1) res.push_back(i - m + 1);
else j++;
}
}
return res;
}
아직 이해하기엔 getPi가 뭔지 모른다.
pi에 대해서 알아보자.
pi[i] 란 구하는 부분 문자열 B의 0부터 i번째까지의 문자열 중 접두사도 되고 접미사도 되는 문자열의 최대 길이이다.
말로하면 이해가 안된다. 표를 보며 생각해보자.
B가 "aabaabac"일때 pi값을 각각 구한 것이다.
i | B[0..i] | 접두사이면서 접미사인 문자열 | pi[i] |
0 | a | 없음 | 0 |
1 | aa | a | 1 |
2 | aab | 없음 | 0 |
3 | aaba | a | 1 |
4 | aabaa | aa | 2 |
5 | aabaab | aab | 3 |
6 | aabaaba | aaba | 4 |
7 | aabaabac | 없음 | 0 |
표를 보니 한결 이해가 잘 된다.
그럼 이 pi는 왜 구한 것일까
문자열 A 가 "aab aab aac"이고 B가 "aab aac"라고 해보자.
처음 aabaa까지는 같으므로 그 다음을 조사할 것이다. A는 aabaa다음이 b이고 B는 aabaa다음이 c이므로 시작점 0은 답이 될 수 없다. 이때 pi가 없이 일반적인 for문을 돌린다면 다음 조사할 A의 위치는 1이 될 것이다. 그러나 우린 이미 aabaa까지는 같은 값을 가짐을 알고있고 이를 이용하기 위해 pi를 구하는 것이다. aabaa까지의 pi[4] = 2이다. 즉 aa까지는 일치하므로 그만큼을 더한 A[3]과 A[4]는 일치하다고 판단한 후 A[5]부터 다시 조사를 시작하는 것이다.
getPi를 구하는 소스 또한 kmp와 매우 비슷하다.
vector<int> getPi(const string& str)
{
int m = str.size(), j =0;
vector<int> pi(m, 0);
//시작점이 0이면 자기 자신을 찾기 때문에 안된다.
for (int i = 1; i < m; i++) {
while (j > 0 && str[i] != str[j])
j = pi[j - 1];
if (str[j] == str[i])
pi[i] = ++j;
}
return pi;
}
시작점이 0이 아닌 1부터 구하는 것을 유의하자
코드 원본 : https://github.com/chosh95/STUDY/blob/master/Algorithm/KMP.cpp
chosh95/STUDY
알고리즘 문제풀이. Contribute to chosh95/STUDY development by creating an account on GitHub.
github.com
C++ 코드
#include <iostream>
#include <vector>
#include <cstring>
using namespace std;
//pi[i] = str[0..i]까지의 부분 문자열에서
//접두사도 되고 접미사도 되는 최대길이.
//ex. aabaa는 aa가 답이므로 pi[4] = 2;
vector<int> getPi(const string& str)
{
int m = str.size(), j =0;
vector<int> pi(m, 0);
//시작점이 0이면 자기 자신을 찾기 때문에 안된다.
for (int i = 1; i < m; i++) {
while (j > 0 && str[i] != str[j])
j = pi[j - 1];
if (str[j] == str[i])
pi[i] = ++j;
}
return pi;
}
vector<int> kmp(const string& a, const string& b)
{
int n = a.size(), m = b.size(), j = 0;
vector<int> pi = getPi(b);
vector<int> res;
for (int i = 0; i < n; i++) {
while (j > 0 && a[i] != b[j])
j = pi[j - 1];
if (a[i] == b[j]) {
if (j == m - 1) res.push_back(i - m + 1);
else j++;
}
}
return res;
}
int main()
{
string A, B;
cin >> A >> B;
vector<int> v = kmp(A, B);
for (int i = 0; i < v.size(); i++)
cout << v[i] << " ";
cout << endl;
}