hdu 1402 A * B Problem Plus fft

缘起

既然学习了fft(【1】),自然要找个板题来测一下. 遂选择了 hdu 1402 A * B Problem Plus

分析

题意很裸

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
计算 A*B

【输入】
每一行包含2个正整数分别就是A和B,A,B可达 10^50000

【输出】
对每一行, 输出 A*B 的结果,


【样例输入】
1
2
1000
2

【样例输出】
2
2000

【限制】
Time limit 1000 ms
Memory limit 32768 kB

如果使用原先朴素的 O(N^2)的高精度乘法的话, 会吃T

Status Time Limit Exceeded
Length 7403
Lang C++
Submitted 2019-08-09 11:27:05
Shared
RemoteRunId 30184655

其实你想想也知道——题目名字叫”A * B Problem Plus”, 既然都叫做plus了, 你还用老方法,肯定吃T啊~

哪怕使用 karatsuba

30184817 2019-08-09 11:34:45 Time Limit Exceeded 1402 1000MS 1524K 8491 B G++ yfsyfsyfs

结果也是吃T的.

所以不得不祭出fft来搞它了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include "stdafx.h"
#include <stdio.h>
#include <string.h>
#include <complex>
#include <math.h>
#include <algorithm>

#define LOCAL
#define PI (atan(1.0)*4)

#define cp complex<double>

typedef long long LL;

using namespace std;

char a[140000], b[140000]; // 因为乘数不超过5w位(即令x=10的多项式次数<=5w), 乘积不会超过10w位(即多项式乘积结果的次数<=10w),而2^17=13w+, 而进行fft需要次数为2的幂次, 所以取14w
int c[140000]; // 装a*b乘积的结果
cp aa[140000], bb[140000];
int rev[140000];

// fft 模板
void fft(cp *a, int bit, int inv) // inv=1表示fft(即系数转点值), inv=-1表示ifft(即点值转系数),(1<<bit)是a数组的长度
{
int len = (1<<bit);
for (int i = 0; i<len; i++) rev[i] = i;
for (int i = 0; i<len; i++)
{
rev[i] = (rev[i>>1]>>1) | ((i&1)<<(bit-1));
if (i<rev[i]) swap(a[i], a[rev[i]]);
}
for (int mid = 1; mid < len; mid<<=1) // 合并区间的长度为2*mid
{
cp w = cp(cos(PI/mid), inv*sin(PI/mid));
for (int i = 0; i<len; i+=(mid<<1)) // i是合并区间的起点—— [i, i+2mid) 是合并区间
{
cp t = cp(1,0);
for (int j = 0; j<mid; j++, t*=w)
{
cp x = a[i+j], y = t*a[i+mid+j];
a[i+j] = x+y, a[i+mid+j] = x-y; // 蝴蝶变换
}
}
}
} // 短小精悍的20+行的fft模板

int main()
{
#ifdef LOCAL
freopen("d:\\data.in", "r", stdin);
#endif

while(~scanf("%s%s", a, b))
{
if (!strcmp(a, "0") || !strcmp(b, "0")) // 如果至少一个为0, 则答案就是0
{
puts("0");
continue;
}
int lena = strlen(a), lenb = strlen(b);
int n = lena+lenb, bit=0; // a*b的结果最多n位数
while ((1<<bit)<n) ++bit; // 2^bit>=n
int len = (1<<bit); //len>=n, len是2次幂
for (int i = 0; i<len; i++) aa[i] = i<lena?cp(a[lena-i-1]-'0', 0):0; // 小的索引存储低位
for (int i = 0; i<len; i++) bb[i] = i<lenb?cp(b[lenb-i-1]-'0', 0):0; // 初始化aa和bb, 因为fft算法必须要耕作于2幂次的长度的数组上
fft(aa, bit, 1); // 系数转点值
fft(bb, bit, 1); // 系数转点值
for (int i = 0; i< len; i++) aa[i]*=bb[i];
fft(aa, bit, -1); // 点值转系数
for (int i = 0; i<len; i++) c[i] = (int)(aa[i].real()/len+0.5);
for (int i = 0; i<n-1;i++) // 考虑进位
{
c[i+1]+=c[i]/10;
c[i]%=10;
}
bool flag = false;
for (int i = n; ~i; i--) // 去除前导0,打印结果
{
if (c[i]) flag = true;
if (flag) printf("%d", c[i]);
}
puts("");
}
return 0;
}

ac情况

30219737 2019-08-11 17:28:15 Accepted 1402 140MS 6740K 2319 B G++ yfsyfsyfs

注意,这里没有把它整合进之前的高精度模板中, 因为你想想, 一个段是10^9, 那么在卷积过程中(long long作为double进行运算)可能损失精度. 所以采用十进制.

参考

【1】https://yfsyfs.github.io/2019/08/09/FFT/