poj-1001-Exponentiation-高精度乘法之karatsuba算法

缘起

之前a掉的题目涉及高精度乘法的时候采用的都是朴素的O(n^2)高精度乘法, 那是题目没卡时间,后来了解到优化的两种算法, 一种是 karatsuba O(n^(log_{2}3) 即约为 O(n^1.5)) 和 FFT(O(nlogn)), 于是先来学习karatsuba.

分析

Karatsuba乘法是一种快速乘法。此算法在1960年由Anatolii Alexeevitch Karatsuba 提出,并于1962年得以发表。karatsuba算法其实玩了一个tricky.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
Karatsuba算法主要应用于两个大数的相乘,原理是将大数分成两段后变成较小的数位,然后做3次乘法,并附带少量的加法操作和移位操作。
现有两个大数,x,y。
首先将x,y分别拆开成为两部分,可得x1,x0,y1,y0。他们的关系如下:
x = x1 * 10^m + x0;
y = y1 * 10^m + y0。其中m为正整数,且x0,y0 小于 10^m。
那么 x*y = (x1 * 10^m + x0)(y1 * 10^m + y0)
=z2 * 10^(2m) + z1 * 10^m + z0,其中:
z2 = x1 * y1;
z1 = x1 * y0 + x0 * y1;
z0 = x0 * y0。
此步骤共需4次乘法(其中z1的计算需要2次乘法),但是由Karatsuba改进以后仅需要3次乘法。因为:
z1 = x1 * y0+ x0 * y1 = (x1 + x0) * (y1 + y0) - x1 * y1 - x0 * y0,
只需要一次乘法(及若干次加减法)得到。
实例展示
设x = 12345,y=6789,令m=3。那么有:
12345 = 12 * 1000 + 345
6789 = 6 * 1000 + 789
下面计算:
z2 = 12 * 6 = 72
z0 = 345 * 789 = 272205
z1 = (12 + 345) * (6 + 789) - z2 - z0 = 11538
然后我们按照移位公式(xy = z2 * 10^(2m) + z1 * 10^(m) + z0)可得:
xy = 72 * 10002 + 11538 * 1000 + 272205 = 83810205

不难知道karatsuba算法的时间复杂度满足
$$
T(n) = 3T(n/2)+O(n)
$$
所以karatsuba算法的复杂度是O(n^(log_{2}3)的.

我们依旧使用 poj 1001 Exponentiation 学习这种算法. 题目翻译参见 【1】

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
#include "stdafx.h"
#include <iostream>
#include <fstream>
#include <sstream>
#include <string>
#include <algorithm>
#include <iomanip>
#include <vector>

#define LOCAL
typedef long long LL;

using namespace std;

class bigint // 大整数类
{
public:
bigint(LL a) // 用long long新建bigint, 注意, long long 的范围是 10^19, 而且科学计数法的最高位还是9, 这就是为什么我们要选择long long 作为大整数的一节的原因. 因为做大整数乘法的时候 10^9*10^9=10^18 也不会超出long long 的范围
{
while(a) num.push_back(a%BASE),a/=BASE;
}

bigint(){};

bigint(string s) // 用 字符串 构造大整数, 例如输入的字符串是 "1034004543543", 我们要将其存储为 004543543-->1034, 其中num[0]=004543543, num[1] = 1034
{
int n = s.length();
for (int i = n-1; i>=0; i-=BASE_LEN) // 从高位开始构建
{
int j = max(0, i-BASE_LEN+1); // 则参与本次vector节点构建的字符是[j,...,i]
LL t = 0;
while(j<=i)
{
t = t*10+s[j]-'0'; // 洪特规则, 求出[j,...,i]表示的long long
j++;
}
num.push_back(t);
}
}

string toString() // 将大整数转为字符串输出(方便结果输出)
{
if (!num.size())
{
return "0";
}
stringstream ss;
bool flag = true; // 是否在输出vector的栈顶元素? (它是大整数的最高位部分), 之所以要区分是因为栈顶元素是不能补0的
while(!num.empty())
{
if (flag)
{
ss << num.back();
flag = false;
}
else
{
ss << setw(BASE_LEN) << setfill('0') <<num.back(); // 熟练使用STL 能极大减少编码的复杂度(而且看上去会很优雅)
}
num.pop_back();
}
return ss.str();
}

static bigint add(const bigint &a,const bigint &b) // 大整数加法 a+b(a,b皆非负), 其实就是小学生加法, 复杂度是O(n)的, 比字符串版本的大整数加法要快, 因为这里使用BASE_LEN作为一节, 所以其实一次做9位数加法(而且long long相加很快), 即常数优化的.效率更高
{
bigint ans;
for (LL i =0, carry=0; ; i++) // 注意, 大整数的结构是 num 的栈顶元素存储高位, 所以要从低位相加, carry 是上一步相加结果的进位(即i-1)
{
if (!carry && i>=a.num.size() && i >= b.num.size()) break; // 如果a和b都没有弹药了, 而且上一步并没有发生进位
if (i<a.num.size()) // 如果a还有弹药
{
carry += a.num[i];
}
if (i<b.num.size()) // 如果b还有弹药
{
carry += b.num[i];
}
ans.num.push_back(carry%BASE);
carry /= BASE; // 作为a.num[i+1]和b.num[i+1]的进位
}
return ans;
}

static bigint substract(const bigint &a,const bigint &b) // 高精度减法 a-b, 这里假设 a>=b, 复杂度是 O(N)
{
bigint ans;
unsigned int i; // 注意, 因为size() 返回的是无符号整数, 所以这里最好使用unsigned int
if (!b.num.size())
{
return a;
}
for (i = 0; i<b.num.size();i++)
{
ans.num.push_back(a.num[i]-b.num[i]); // 可能为负值
}
for (; i<a.num.size();i++)
{
ans.num.push_back(a.num[i]); // 把a剩下的塞进去
}
for (i = 0; i<ans.num.size()-1;i++) // 从ans的低位开始扫描负值, 即处理小学生减法之借位
{
if (ans.num[i]<0)
{
ans.num[i]+=BASE;
ans.num[i+1]--; // 借位
}
}
while(!ans.num.empty()&&!ans.num.back()) // 删除无意义的前导0
{
ans.num.pop_back();
}
return ans;
}

static int cmp(const bigint &a, const bigint &b) // 大于b返回1, 等于b返回0,小于b返回-1
{
if (a.num.size() != b.num.size())
{
return a.num.size()< b.num.size()?-1:1;
}
for (int i = a.num.size()-1; ~i; i--)
{
if (a.num[i]!=b.num[i])
{
return a.num[i]<b.num[i]?-1:1;
}
}
return 0;
}

static bigint multiply(const bigint &a, const bigint &b) // 非负大整数乘法 a*b 复杂度 O(n^2)
{
bigint ans(vector<LL>(a.num.size()+b.num.size(),0)); // m位*n位最多m+n位,最少m+n-1位
for (unsigned int i = 0; i< a.num.size(); i++)
{
for (unsigned int j =0; j<b.num.size();j++)
{
ans.num[i+j]+=a.num[i]*b.num[j]; // ans[i+j]的来源不止一个, 例如 ans[3]可以是 a[1]*b[2]也可以是a[2]*b[1]
ans.num[i+j+1]+=ans.num[i+j]/BASE; // 进位
ans.num[i+j]%=BASE;
}
}
while(!ans.num.empty() && !ans.num.back()) // 去掉无意义的前导0
{
ans.num.pop_back();
}
return ans;
}

static bigint pow(bigint &a, int &b) // 求a^b, b是整型(已经够大了)
{
bigint ans(vector<LL>(1,1)); // ans=1
while(b) // 快速幂
{
if (b&1)
{
ans = multiply(ans, a);
}
if (b>1)
{
a = multiply(a,a);
}
b>>=1;
}
return ans;
}

static string pow(string &a, int &b) // 使用karatsuba算法作为底层高精度乘法来封装高精度求幂
{
string ans="1"; // ans=1
while(b) // 快速幂
{
if (b&1)
{
ans = karatsuba(ans, a);
}
if (b>1)
{
a = karatsuba(a,a);
}
b>>=1;
}
return ans;
}

static bigint div(bigint &a,bigint &b)//非负大数相除a/b
{
int tmp;
if ((tmp=cmp(a,b))!=1)
{
if (!tmp) // 如果a==b, 则 a 减去b变成0
{
a = bigint(vector<LL>(1,0));
}
bigint ans(vector<LL>(1,tmp?0:1));
return ans;
}
int size_a=a.num.size(),size_b=b.num.size();//注意,此时a>b
bigint ans;//m位除以n位(m>=n)最多m-n+1位,最少m-n位,但是这里不能预设多少位,因为除法与加减乘都不一样,它是从高位开始算的而其它三种是从低位开始算的
for (int i = size_a-size_b; ~i; i--)
{
LL q=0;//每次除得的商
ok(a,b,i,q);
ans.num.push_back(q);
}
reverse(ans.num.begin(),ans.num.end());//必须倒序存储
while(!ans.num.empty() && !ans.num.back()) // 去掉无意义的前导0
{
ans.num.pop_back();
}
return ans;
}

static bigint res(bigint &a, bigint &b) // 高精度计算a%b
{
div(a,b);
return a;
}

bool iszero()
{
bigint zero("0");
return !num.size() || !cmp(*this, zero);
}

static bigint fac(int n) // 高精度计算 n!
{
bigint ans(vector<LL>(1,1));//ans=1
for (int i = 2; i <= n; i++) ans=multiply(ans,bigint(vector<LL>(1,i)));
return ans;
}

static string karatsuba(const string &x,const string &y) // 高精度乘法之karatsuba算法 x*y,
{
int length_x = x.length();
int length_y = y.length();
int m = max(x.length(), y.length())>>1; // x表示为 x1*10^m+x2, y表示为y1*10^m+y2, 则 x*y=10^(2m)*x1*y1+10^m*((x1+x2)*(y1+y2)-x1*y1-x2*y2)+x2*y2
if (m<=1) // 递归出口
{
return multiply(bigint(x),bigint(y)).toString();
}
string x1 = m>length_x?"0":x.substr(0, length_x-m);
string x2 = m>length_x?x:x.substr(length_x-m, m); // 取后m位
int i = 0;
while(i<x2.length() && x2[i]=='0')++i;
if (i==x2.length()) x2 = "0";
string y1 = m>length_y?"0":y.substr(0, length_y-m);
string y2 = m>length_y?y:y.substr(length_y-m, m); // 取后m位
i = 0;
while(i<y2.length() && y2[i]=='0')++i;
if (i==y2.length()) y2 = "0";
string x1y1 = karatsuba(x1, y1); // 分治
string x2y2 = karatsuba(x2, y2);
string cross = karatsuba(add(bigint(x1), bigint(x2)).toString(), add(bigint(y1), bigint(y2)).toString()); // 交叉项, 因此karatsuba只需要做3次乘法和若干次加减法
cross = substract(bigint(cross), bigint(x1y1)).toString(); // 注意, 这里能保证是非负的
cross = substract(bigint(cross), bigint(x2y2)).toString();
if (cross!="0") for (int i = 0; i<m;i++) cross.push_back('0'); // *10^m, 不为0的话, 才能增加0
if (x1y1!="0") for (int i = 0; i<(m<<1); i++) x1y1.push_back('0'); // * 10^{2m}
x1y1 = add(bigint(x1y1), bigint(cross)).toString();
x1y1 = add(bigint(x1y1), bigint(x2y2)).toString();
return x1y1;
}

private:
const static int BASE_LEN=9,BASE=1000000000;// 10亿为基数
vector<LL> num; // 构建大整数是这样的, 比如1034004543543, 在我们的bigint类中存储为 004543543-->1034, num[0]=004543543, num[1] = 1034
bigint(vector<LL> num):num(num){} //使用一个vector构建大整数
static void ok(bigint &a,bigint b,int i,LL &q)
{
for (int j = 0; j < i; j++) b.num.insert(b.num.begin(),0);//首先要给b补i个段的0
int t = cmp(a,b);
if (!~t) return; // 如果 a<b 则返回(q就是0)
if (!t) // 如果 a==b
{
a = bigint(vector<LL>(1,0));
q = 1;
return;
}
LL k;
bigint tmp_b;
while(~t) // 只要 a >=b
{
for (k = 1, tmp_b = b; ~cmp(a, tmp_b); k*=10, tmp_b=multiply(tmp_b, bigint(vector<LL>(1,10)))) // 这里可以考虑 a=999减去b=10, 减去tmp_b=10(初始化为b),tmp_b=100,得到a=889,而再减tmp_b=1000的时候减不动了, 但是889依旧大于b, 所以要再来一轮——减去10,减去100, 得779. 这样比999每次减去10快多了,这可以看做是一个优化
{
q+=k;
a = substract(a, tmp_b);
}
t = cmp(a,b);
} // 最后 a<b
}
};


int main()
{
#ifdef LOCAL
freopen("d:\\data.in", "r", stdin);
#endif
string a;
int b;
while(cin >> a >> b)
{
int pos = 0; // 最后结果小数点要左移动pos位
bool flag = false; // 有没有越过前导0
string aa;
for (int i = 0; i<a.length();i++)
{
if (a[i]=='.')
{
pos = (5-i)*b;
}
else
{
if (a[i]!='0')
{
flag = true; // 已经越过前导零
aa.push_back(a[i]);
}
else if (flag) // 是0 但是已经越过前导0了
{
aa.push_back(a[i]);
}
else // 是0 但是尚未越过前导0
{
continue;
}
}
}
bigint A(aa);
string _a = A.toString();
string ans = bigint::pow(_a, b); // 使用karatsuba算法求出乘幂
a.clear(); // a清空了用来装结果
while(pos&&!ans.empty()&&ans.back()=='0') // 抵消pos来去掉ans尾部的0
{
pos--;
ans.pop_back();
}
flag = pos; // 要不要加 .
while(pos--) // 开始装入真正的结果了
{
if (!ans.empty())
{
a.push_back(ans.back());
ans.pop_back();
}
else
{
a.push_back('0');
}
}
if (flag)
{
a.push_back('.');
}
while (!ans.empty())
{
a.push_back(ans.back());
ans.pop_back();
}
string ret(a.rbegin(), a.rend());
cout << ret << endl;
}
return 0;
}

ac情况

影法师 Accepted 128kB 9ms 10115 B G++ 刚刚

参考

【1】https://yfsyfs.github.io/2019/08/08/poj-1001-Exponentiation-%E9%AB%98%E7%B2%BE%E5%BA%A6%E4%B9%98%E6%B3%95-%E5%BF%AB%E9%80%9F%E5%B9%82/