도입
카라츠바의 빠른 곱셈 알고리즘에 앞서 O(n^2) 곱셈 알고리즘을 알아야 한다.
(물론 이 곱셈은 단순 32비트 두 정수의 곱이 아닌, 수만 자리의 숫자들을 다룰 때 사용된다.)
※ normalize() 함수에 음수를 처리하는 부분이 있는데, 이는 카라츠바 알고리즘에서 사용하게 된다.
지금 소개하는 알고리즘에선 사용되는 경우가 없다. (multiply()함수는 덧셈밖에 하지 않는다.)
//num[]의 자릿수 올림을 처리한다.
void normalize(vector<int>& num)
{
num.push_back(0);
//자릿수 올림을 처리한다.
for (int i = 0; i + 1 < num.size(); i++)
{
//자릿수가 음수인 경우 -> 카라츠바 알고리즘에 필요함.
if (num[i] < 0)
{
int borrow = (abs(num[i]) + 9) / 10;
num[i + 1] -= borrow;
num[i] += borrow * 10;
}
else
{
num[i + 1] += num[i] / 10;
num[i] %= 10;
}
}
while (num.size() > 1 && num.back() == 0) num.pop_back();
}
//긴 두 자연수의 곱을 반환한다.
//각 배열에는 각 수의 자릿수가 1의 자리에서부터 시작해 저장되어 있다.
//ex : multiply([3,2,1],[6,5,4]} = 123*456 = 56088 =[8,8,0,6,5]
vector<int> multiply(const vector<int>& a,const vector<int>& b)
{
vector<int> c(a.size() + b.size() + 1, 0);
for (int i = 0; i < a.size(); i++)
{
for (int j = 0; j < b.size(); j++)
{
c[i + j] += a[i] * b[j];
}
}
normalize(c);
return c;
}
int main()
{
vector<int> a, b;
for (int i = 0; i < 10; i++)
{
a.push_back(i);
}
for (int j = 9; j > 0; j--)
{
b.push_back(j);
}
cout << "9876543210 * 123456789 = " << 9876543210 * 123456789 << '\n';
vector<int> c = multiply(a, b);
cout << "a * b = ";
for (int i = c.size() - 1; i >= 0; i--)
{
cout << c[i];
}
return 0;
}
이중 for문이 사용되기 때문에 시간복잡도가 O(n^2)이다.
1960년도에 이보다 빠른 알고리즘을 카라츠바라는 사람이 고안했다.
카라츠바(Karatsuba)의 빠른 곱셈 알고리즘
256자리 수인 a와 b가 있다고 하자. 이 두 수를 각각 반으로 쪼개면 다음과 같이 표현할 수 있다.
a = a1x10^128 + a0
b = b1x10^128 + b0
(a1, b1은 첫 128자리, a0, b0는 다음 128자리)
카라츠바는 axb 를 네 개의 조각을 이용해서 나타낼 수 있다.
a x b = (a1x10^128 + a0) x (b1x10^128 + b0)
= a1 x b1 x 10^256 + (a1xb0 + a0xb1)x10^128 + a0xb0
여기까지만 하면 사실 시간복잡도는 O(n^2)이 나오기 때문에 애써 분할한 이유가 없어진다.
카라츠바는 여기서 더 나아가 네 번의 곱셈할 것을 세 번 만에 하는 방법을 발견했다.
z0 = a0 x b0
z2 = a1 x b1
z1 = a1 x b0 + a0 x b1 = (a0 + a1) x (b0 + b1) - z0 - z2
=> a x b = z2 x 10^256 + z1 x 10^128 + z0
카라츠바 알고리즘 구현
using namespace std;
//num[]의 자릿수 올림을 처리한다.
void normalize(vector<int>& num)
{
num.push_back(0);
//자릿수 올림을 처리한다.
for (int i = 0; i + 1 < num.size(); i++)
{
//자릿수가 음수인 경우 -> 카라츠바 알고리즘에 필요함.
if (num[i] < 0)
{
int borrow = (abs(num[i]) + 9) / 10;
num[i + 1] -= borrow;
num[i] += borrow * 10;
}
else
{
num[i + 1] += num[i] / 10;
num[i] %= 10;
}
}
while (num.size() > 1 && num.back() == 0) num.pop_back();
}
//긴 두 자연수의 곱을 반환한다.
//각 배열에는 각 수의 자릿수가 1의 자리에서부터 시작해 저장되어 있다.
//ex : multiply([3,2,1],[6,5,4]} = 123*456 = 56088 =[8,8,0,6,5]
vector<int> multiply(const vector<int>& a,const vector<int>& b)
{
vector<int> c(a.size() + b.size() + 1, 0);
for (int i = 0; i < a.size(); i++)
{
for (int j = 0; j < b.size(); j++)
{
c[i + j] += a[i] * b[j];
}
}
normalize(c);
return c;
}
//a += b*(10^k);
void addTo(vector<int>& a, const vector<int>& b, int k);
//a -=b; (a>b)
void subFrom(vector<int>& a, const vector<int>& b);
//karatsuba
vector<int> karatsuba(const vector<int>& a, const vector<int>& b)
{
int an = a.size(), bn = b.size();
//a가 b보다 짧은경우 바꾼다.
if (an < bn) return karatsuba(b, a);
//base case : a나 b가 비어있는 경우
if (an == 0 || bn == 0) return vector<int>();
//base case : a가 비교적 짧은 경우 O(n^2) 곱셈으로 변경한다.
if (an <= 5) return multiply(a, b);
int half = an / 2;
//a와 b를 밑에서 half자리와 나머지로 분리한다
vector<int> a0(a.begin(), a.begin() + half);
vector<int> a1(a.begin() + half, a.end());
vector<int> b0(b.begin(), b.begin() + min<int>(b.size(), half));
vector<int> b1(b.begin() + min<int>(b.size(), half), b.end());
//z2=a1*b1
vector<int> z2 = karatsuba(a1, b1);
//z0=a0*b0
vector<int> z0 = karatsuba(a0, b0);
//a0=a0+a1;
//b0=b0+b1
addTo(a0, a1, 0);
addTo(b0, b1, 0);
//z1=(a0+a1)*(b0+b1)-z0-z2
vector<int> z1 = karatsuba(a0, b0);
subFrom(z1, z0);
subFrom(z1, z2);
//result=z0+z1*10^half+z2*10^(half*2)
vector<int> result;
addTo(result, z0, 0);
addTo(result, z1, half);
addTo(result, z2, half + half);
return result;
}
int main()
{
using namespace chrono;
vector<int> a, b, c;
high_resolution_clock::time_point t1, t2;
duration<double> time_span;
string number;
cout << "첫번째 정수 입력: ";
cin >> number;
for (int i = number.size() - 1; i >= 0; i--)
{
a.push_back(number[i] - '0');
}
cout << "두번째 정수 입력: ";
cin >> number;
for (int i = number.size() - 1; i >= 0; i--)
{
b.push_back(number[i] - '0');
}
cout << "O(n^2) : ";
t1 = high_resolution_clock::now();
c = multiply(a, b);
t2 = high_resolution_clock::now();
time_span = duration_cast<duration<double>>(t2 - t1);
for (int i = c.size() - 1; i >= 0; i--)
{
cout << c[i];
}
cout << ", 소요시간 : " << time_span.count() << "초" << '\n';
t1 = high_resolution_clock::now();
c = karatsuba(a, b);
t2 = high_resolution_clock::now();
time_span = duration_cast<duration<double>>(t2 - t1);
cout << "karatsuba(a, b) : ";
for (int i = c.size() - 1; i >= 0; i--)
{
cout << c[i];
}
cout << ", 소요시간 : " << time_span.count() << "초" << '\n';
return 0;
}
void addTo(vector<int>& a, const vector<int>& b, int k)
{
a.resize(max((a.size() + 1), (b.size() + k)));
for (int i = 0; i < b.size(); i++)
{
a[i + k] += b[i];
}
normalize(a);
}
void subFrom(vector<int>& a, const vector<int>& b)
{
a.resize(max(a.size() + 1, b.size() + 1));
for (int i = 0; i < b.size(); i++)
{
a[i] -= b[i];
}
normalize(a);
}
실행 결과
실제로 시간 복잡도를 계산해보면 O(n^lg3)이 나온다.
lg3은 약 1.585이기에 n^2보다 빠른 알고리즘이라 볼 수 있다.
다만, 입력의 크기가 작은 경우 O(n^2) 알고리즘보다 느린 경우가 많다.
karatsuba() 함수에서 일정 크기 이하인 경우 O(n^2) 알고리즘을 쓰는 이유이다.
'알고리즘 > 알고리즘 문제 해결 전략' 카테고리의 다른 글
분할 정복 - 2. 울타리 잘라내기 (0) | 2020.11.22 |
---|---|
분할 정복 - 1. 쿼드 트리 뒤집기 (0) | 2020.11.16 |
분할 정복 (Divide & Conquer) (0) | 2020.11.13 |
무식하게 풀기(Brute force) - 4. 시계 맞추기 (0) | 2020.11.13 |
무식하게 풀기(Brute force) - 3. 게임판 덮기 (0) | 2020.11.09 |