알고리즘/알고리즘 문제 해결 전략

분할 정복 - 카라츠바 알고리즘

도입

 

카라츠바의 빠른 곱셈 알고리즘에 앞서 O(n^2) 곱셈 알고리즘을 알아야 한다.

(물론 이 곱셈은 단순 32비트 두 정수의 곱이 아닌, 수만 자리의 숫자들을 다룰 때 사용된다.)

 

※ normalize() 함수에 음수를 처리하는 부분이 있는데, 이는 카라츠바 알고리즘에서 사용하게 된다.

지금 소개하는 알고리즘에선 사용되는 경우가 없다. (multiply()함수는 덧셈밖에 하지 않는다.)

 

//num[]의 자릿수 올림을 처리한다.
void normalize(vector<int>& num)
{
	num.push_back(0);
	//자릿수 올림을 처리한다.
	for (int i = 0; i + 1 < num.size(); i++)
	{
		//자릿수가 음수인 경우 -> 카라츠바 알고리즘에 필요함.
		if (num[i] < 0)
		{
			int borrow = (abs(num[i]) + 9) / 10;
			num[i + 1] -= borrow;
			num[i] += borrow * 10;
		}
		else
		{
			num[i + 1] += num[i] / 10;
			num[i] %= 10;
		}
	}
	while (num.size() > 1 && num.back() == 0) num.pop_back();
}

//긴 두 자연수의 곱을 반환한다.
//각 배열에는 각 수의 자릿수가 1의 자리에서부터 시작해 저장되어 있다.
//ex : multiply([3,2,1],[6,5,4]} = 123*456 = 56088 =[8,8,0,6,5]
vector<int> multiply(const vector<int>& a,const vector<int>& b)
{
	vector<int> c(a.size() + b.size() + 1, 0);
	for (int i = 0; i < a.size(); i++)
	{
		for (int j = 0; j < b.size(); j++)
		{
			c[i + j] += a[i] * b[j];
		}
	}
	normalize(c);
	return c;
}

int main()
{
	vector<int> a, b;
	for (int i = 0; i < 10; i++)
	{
		a.push_back(i);
	}
	for (int j = 9; j > 0; j--)
	{
		b.push_back(j);
	}
	
	cout << "9876543210 * 123456789 = " << 9876543210 * 123456789 << '\n';

	vector<int> c = multiply(a, b);
	cout << "a * b = ";
	for (int i = c.size() - 1; i >= 0; i--)
	{
		cout << c[i];
	}
	return 0;
}

 

이중 for문이 사용되기 때문에 시간복잡도가 O(n^2)이다. 

1960년도에 이보다 빠른 알고리즘을 카라츠바라는 사람이 고안했다.

 

카라츠바(Karatsuba)의 빠른 곱셈 알고리즘

 

256자리 수인 a와 b가 있다고 하자. 이 두 수를 각각 반으로 쪼개면 다음과 같이 표현할 수 있다.

 

a = a1x10^128 + a0

b = b1x10^128 + b0

(a1, b1은 첫 128자리, a0, b0는 다음 128자리)

 

카라츠바는 axb 를 네 개의 조각을 이용해서 나타낼 수 있다.

 

a x b = (a1x10^128 + a0) x (b1x10^128 + b0) 

= a1 x b1 x 10^256 + (a1xb0 + a0xb1)x10^128 + a0xb0

 

여기까지만 하면 사실 시간복잡도는 O(n^2)이 나오기 때문에 애써 분할한 이유가 없어진다.

카라츠바는 여기서 더 나아가 네 번의 곱셈할 것을 세 번 만에 하는 방법을 발견했다.

 

z0 = a0 x b0

z2 = a1 x b1

z1 = a1 x b0 + a0 x b1 = (a0 + a1) x (b0 + b1) - z0 - z2

=>  a x b = z2 x 10^256 + z1 x 10^128 + z0

 

 카라츠바 알고리즘 구현

using namespace std;

//num[]의 자릿수 올림을 처리한다.
void normalize(vector<int>& num)
{
	num.push_back(0);
	//자릿수 올림을 처리한다.
	for (int i = 0; i + 1 < num.size(); i++)
	{
		//자릿수가 음수인 경우 -> 카라츠바 알고리즘에 필요함.
		if (num[i] < 0)
		{
			int borrow = (abs(num[i]) + 9) / 10;
			num[i + 1] -= borrow;
			num[i] += borrow * 10;
		}
		else
		{
			num[i + 1] += num[i] / 10;
			num[i] %= 10;
		}
	}
	while (num.size() > 1 && num.back() == 0) num.pop_back();
}


//긴 두 자연수의 곱을 반환한다.
//각 배열에는 각 수의 자릿수가 1의 자리에서부터 시작해 저장되어 있다.
//ex : multiply([3,2,1],[6,5,4]} = 123*456 = 56088 =[8,8,0,6,5]
vector<int> multiply(const vector<int>& a,const vector<int>& b)
{
	vector<int> c(a.size() + b.size() + 1, 0);
	for (int i = 0; i < a.size(); i++)
	{
		for (int j = 0; j < b.size(); j++)
		{
			c[i + j] += a[i] * b[j];
		}
	}
	normalize(c);
	return c;
}

//a += b*(10^k);
void addTo(vector<int>& a, const vector<int>& b, int k);
//a -=b; (a>b)
void subFrom(vector<int>& a, const vector<int>& b);

//karatsuba
vector<int> karatsuba(const vector<int>& a, const vector<int>& b)
{
	int an = a.size(), bn = b.size();
	//a가 b보다 짧은경우 바꾼다.
	if (an < bn) return karatsuba(b, a);
	//base case : a나 b가 비어있는 경우
	if (an == 0 || bn == 0) return vector<int>();
	//base case : a가 비교적 짧은 경우 O(n^2) 곱셈으로 변경한다.
	if (an <= 5) return multiply(a, b);
	int half = an / 2;
	//a와 b를 밑에서 half자리와 나머지로 분리한다

	vector<int> a0(a.begin(), a.begin() + half);
	vector<int> a1(a.begin() + half, a.end());
	vector<int> b0(b.begin(), b.begin() + min<int>(b.size(), half));
	vector<int> b1(b.begin() + min<int>(b.size(), half), b.end());

	//z2=a1*b1
	vector<int> z2 = karatsuba(a1, b1);

	//z0=a0*b0
	vector<int> z0 = karatsuba(a0, b0);

	//a0=a0+a1;
	//b0=b0+b1
	addTo(a0, a1, 0);
	addTo(b0, b1, 0);

	//z1=(a0+a1)*(b0+b1)-z0-z2
	vector<int> z1 = karatsuba(a0, b0);
	subFrom(z1, z0);
	subFrom(z1, z2);

	//result=z0+z1*10^half+z2*10^(half*2)
	vector<int> result;
	addTo(result, z0, 0);
	addTo(result, z1, half);
	addTo(result, z2, half + half);
	return result;
}


int main()
{
	using namespace chrono;
	vector<int> a, b, c;
	high_resolution_clock::time_point t1, t2;
	duration<double> time_span;
	string number;

	cout << "첫번째 정수 입력: ";

	cin >> number;

	for (int i = number.size() - 1; i >= 0; i--)
    {
		a.push_back(number[i] - '0');
	}

	cout << "두번째 정수 입력: ";
	cin >> number;

	for (int i = number.size() - 1; i >= 0; i--)
	{
    	b.push_back(number[i] - '0');
	}

	cout << "O(n^2) : ";
	t1 = high_resolution_clock::now();
	c = multiply(a, b);
	t2 = high_resolution_clock::now();
	time_span = duration_cast<duration<double>>(t2 - t1);
	for (int i = c.size() - 1; i >= 0; i--)
	{
		cout << c[i];
	}
	cout << ", 소요시간 : " << time_span.count() << "초" << '\n';
	
	t1 = high_resolution_clock::now();
	c = karatsuba(a, b);
	t2 = high_resolution_clock::now();
	time_span = duration_cast<duration<double>>(t2 - t1);
	cout << "karatsuba(a, b) : ";
	for (int i = c.size() - 1; i >= 0; i--)
	{
		cout << c[i];
	}
	cout << ", 소요시간 : " << time_span.count() << "초" << '\n';
	return 0;
}

void addTo(vector<int>& a, const vector<int>& b, int k)
{
	a.resize(max((a.size() + 1), (b.size() + k)));

	for (int i = 0; i < b.size(); i++)
	{
		a[i + k] += b[i];
	}

	normalize(a);
}

void subFrom(vector<int>& a, const vector<int>& b)
{
	a.resize(max(a.size() + 1, b.size() + 1));
	for (int i = 0; i < b.size(); i++)
	{
		a[i] -= b[i];
	}

	normalize(a);
}

 

 실행 결과

곱하는 두 정수의 크기가 큰 경우

 

곱하는 두 정수의 크기가 작은 경우

 

실제로 시간 복잡도를 계산해보면 O(n^lg3)이 나온다.

lg3은 약 1.585이기에 n^2보다 빠른 알고리즘이라 볼 수 있다.

다만, 입력의 크기가 작은 경우 O(n^2) 알고리즘보다 느린 경우가 많다. 

karatsuba() 함수에서 일정 크기 이하인 경우 O(n^2) 알고리즘을 쓰는 이유이다.