카라츠바의 빠른 곱셈
카라츠바의 빠른 곱셈 알고리즘은 수백, 수만자리나 되는 큰 두개의 정수를 곱하는 알고리즘이다.
필요성
카라츠바 알고리즘을 소개하기에 앞서, 두자릿수 이상의 두 수를 곱하는 과정은 다음과 같다.
(자릿수 올림 적용)
(자릿수 올림 미적용)
이를 코드로 구현하면 다음과 같이 구현 할 수 있다.
//num[]의 자릿수를 올림을 처리한다. void normalize(vector<int>& num) { num.push_back(0); //자릿수 올림을 처리한다. int size = num.size(); for (int i = 0; i < size - 1; i++) { if (num[i] < 0) { int borrow = (abs(num[i]) + 9) / 10; num[i + 1] -= borrow; num[i] += borrow * 10; } else { num[i + 1] += num[i] / 10; num[i] %= 10; } } //앞에 남은 0을 제거한다. while (num.size() > 1 && num.back() == 0) num.pop_back(); } //초등학교식 정수 곱셈 vector<int> multiply(const vector<int>& a, const vector<int>& b) { vector<int> c(a.size() + b.size() + 1, 0); int aSize = a.size(); int bSize = b.size(); for (int i = 0; i < aSize; i++) for (int j = 0; j < bSize; j++) c[i + j] += a[i] * b[j]; normalize(c); return c; }
이 알고리즘의 시간복잡도는 두 정수의 길이가 모두 n이라고 할 때 O(n^2)이다. 2중 for문을 이용하고 있으니 이 점을 이해하기는 어렵지 않을 것이다.
카라츠바 알고리즘은 이 시간복잡도를 O(n^log(3)) 까지 낮춰주기 위해 사용된다.
log(3) = 1.585...이므로 O(n^2) 보다 훨씬 적은 곱셈을 필요로 한다.
만약 n이 10만이라고 하면 곱셈 횟수는 대략 100배 정도 차이가 난다.
아이디어
카라츠바 알고리즘이 어떻게 진행되는지 설명하기에 앞서 카라츠바 알고리즘이 시간복잡도를 O(log(3))으로 낮추기 위해 사용한 아이디어에 대해 설명하고자 한다.
자릿수가 n인 두개의 수 a,b를 단순히 곱하기 위해서는 O(n^2)이 소요되지만, 덧셈과 뺄셈을 하는데에는 O(n)시간만에 해결 할 수 있다.
카라츠바 알고리즘은 곱셈의 횟수를 줄이고, 덧셈과 뺄셈 횟수를 늘리는 방식으로 구현된다.
과정
이제 카라츠바 알고리즘이 어떻게 진행되는지 보도록 하자.
카라츠바 알고리즘은 곱하는 256자리의 두 정수 a,b를 다음과 같이 나눈다.
a1,b1은 각각 a,b의 첫 128자리, a0,b0는 각각 a,b의 뒷 128자리를 나타낸다.
이제 a * b의 계산 과정은 다음과 같이 나눌 수 있다.
이 상태에서는 n/2 크기의 두 정수의 곱셈이 총 4번 사용된다. 이 곱셈 횟수를 줄이기 위해 다음 수식을 이용한다.
이 수식을 수정하면 다음의 결과를 얻을 수 있다.
z2 = a1 * b1;
z0 = a0 * b0;
z1 = (a0 + a1) * (b0 + b1) - z0 - z2;
이렇게 수정하고 나면 a*b는 n/2 크기의 두 정수의 곱셈 3번, 덧셈 2번, 뺄셈 2번으로 수행 할 수 있다.
이를 재귀적으로 처리하여 a1*b1, a0*b0에 대해서도 적용하면 곱셈 결과를 얻을 수 있다.
구현
다음은 카라츠바 알고리즘을 구현한 코드다.
#include <iostream> #include <vector> #include <algorithm> using namespace std; //num[]의 자릿수를 올림을 처리한다. void normalize(vector<int>& num) { num.push_back(0); //자릿수 올림을 처리한다. int size = num.size(); for (int i = 0; i < size - 1; i++) { if (num[i] < 0) { int borrow = (abs(num[i]) + 9) / 10; num[i + 1] -= borrow; num[i] += borrow * 10; } else { num[i + 1] += num[i] / 10; num[i] %= 10; } } //앞에 남은 0을 제거한다. while (num.size() > 1 && num.back() == 0) num.pop_back(); } //초등학교식 정수 곱셈 vector<int> multiply(const vector<int>& a, const vector<int>& b) { vector<int> c(a.size() + b.size() + 1, 0); int aSize = a.size(); int bSize = b.size(); for (int i = 0; i < aSize; i++) for (int j = 0; j < bSize; j++) c[i + j] += a[i] * b[j]; normalize(c); return c; } //a += b * (10^k) void addTo(vector<int>& a, const vector<int>& b, int k) { int originalASize = a.size(); if (a.size() < b.size() + k) a.resize(b.size() + k); a.push_back(0); int aSize = a.size(); for (int i = originalASize; i < aSize; i++) a[i] = 0; int bSize = b.size(); for (int i = 0; i < bSize; i++) a[i + k] += b[i]; normalize(a); } // a -= b // a>= b인 경우에만 사용 가능하다. void subFrom(vector<int>& a, const vector<int>& b) { int bSize = b.size(); for (int i = 0; i < bSize; i++) a[i] -= b[i]; normalize(a); } vector<int> karatsuba(const vector<int>& a, const vector<int>& b) { int an = a.size(), bn = b.size(); //a가 b보다 짧을 경우 둘을 바꾼다. if (an < bn) return karatsuba(b, a); //기저 사례 : a나 b가 비어있는 경우 if (an == 0 || bn == 0) return vector<int>(); //기저 사례 : a가 비교적 짧은 경우, O(n^2) 곱셈으로 변경한다.(성능을 위함) if (an <= 50) return multiply(a, b); int half = an / 2; vector<int> a0(a.begin(), a.begin() + half); vector<int> a1(a.begin() + half, a.end()); vector<int> b0(b.begin(), b.begin() + min<int>(b.size(), half)); vector<int> b1(b.begin() + min<int>(b.size(), half), b.end()); //z2 = a1 * b1 vector<int> z2 = karatsuba(a1, b1); //z0 = a0 * b0 vector<int> z0 = karatsuba(a0, b0); //z1 = ((a0 + a1) * (b0 + b1)) - z0 - z2 addTo(a0, a1, 0); addTo(b0, b1, 0); vector<int> z1 = karatsuba(a0, b0); subFrom(z1, z0); subFrom(z1, z2); //병합 과정 //ret = z0+z1*10^half + z2 * 10(half*2) vector<int> ret(z2.size() + half * 2, 0); addTo(ret, z0, 0); addTo(ret, z1, half); addTo(ret, z2, half * 2); return ret; } int main() { vector<int> a; vector<int> b; for (int i = 0; i < 100; i++) a.push_back(i % 10); for (int i = 0; i < 73; i++) b.push_back(i % 10); vector<int> c = karatsuba(b, a); int cSize = c.size(); for (int i = 0; i < cSize; i++) cout << c[i]; return 0; }
시간복잡도
이제 카라츠바 알고리즘의 시간복잡도를 따져볼 차례다.
위의 구현에서는 a의 길이가 50 이하이면 O(n^2) 곱셈 알고리즘을 이용하도록 했지만, 계산 편의를 위해서 시간복잡도 분석에서는 한 자리 숫자에 도달해야 O(n^2) 곱셈 알고리즘을 이용한다고 친다.
a,b의 자릿수 n이 2^k 이라고 할 때 재귀호출의 깊이는 k가 된다.
한번 쪼갤 때 마다 수행해야 할 곱셈이 3배(a1*b1, a0*b0, (a0+a1) * (b0+b1))로 늘어나기 때문에 마지막 단계에서는 3^k 개의 부분 문제가 있고, 마지막 단계에서는 두 수 모두 한자리 숫자니까 곱셈 한번이면 충분하다.
따라서 곱셈 횟수는 총 O(3^k)다.
여기서 n=2^k 라고 가정했으니 k=log(n)이고,
O(3^k) = O(3^log(n)) = O(n^log(3))
이다.
병합과정도 따져보자면, 병합과정은 더하기, 빼기만으로 구현된다. 더하기와 빼기는 O(n) 시간 안에 해결되고, 재귀 호출로 단계가 내려갈 때마다 숫자의 길이는 절반이 되고 문제의 개수는 3배가 되기 때문에, 깊이 i에서 병합에 필요한 연산 횟수는 ((3/2)^i) * n 이다.
따라서 병합을 위해 필요한 총 연산 수는
이 함수는 n^log(3)과 같은 속도로 증가한다.
따라서, 최종 시간복잡도는 O(n^log3)이 된다.
출처
알고리즘 문제 해결전략 - 인사이트