FFT

快速计算多项式的乘法
个人感觉算法导论上面的知识的体系比较的详细,当然本文也会推荐几个学习的连接,感觉还是会板子就好了,不要太注重细节。

acdreamer
多项式乘法的终极版
FFT详解

引入

多项式的乘法其实就是卷积的计算:
example:
\((x^3+3x^2+2x+4)(x^2+3x+3) = \).
我们可以手动的计算。
但是我们还可以卷积进行计算。
将系数表示出来:
\((1 3 2 4)*(1 3 3)\), 按照普通的卷积的计算的方式,先翻转,然后就行移动求和就行了。
然而可以参考知乎大佬的文章,就会发现一个特别简单的计算离散卷积的方法。
(此处有公式)
最后算出来的结果,和模拟手算算出来的结果一致。(1 6 14 19 18 12).

使用Python进行验证

1
2
3
4
5
6
7
8
import numpy as np
x = np.array([1, 3, 2, 4])
y = np.array([1, 3, 3])
import scipy.signal
scipy.signal.convolve(x, y)
#输出
Out[7]: array([ 1, 6, 14, 19, 18, 12])

多项式的乘法

比如上面的例子:
输入:
4 3
1 3 2 4
1 3 3
输出:
1 6 14 19 18 12

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<complex>
using namespace std;
typedef complex<double> cd;//复数类的定义
const int maxl=2094153;//nlogn的最大长度(来自leo学长的博客)
const double PI=3.14159265358979;//圆周率,不解释
cd a[maxl],b[maxl];//用于储存变换的中间结果
int rev[maxl];//用于储存二进制反转的结果
void getrev(int bit){
for(int i=0;i<(1<<bit);i++){//高位决定二进制数的大小
rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
}//能保证(x>>1)<x,满足递推性质
}
void fft(cd* a,int n,int dft){//变换主要过程
for(int i=0;i<n;i++){//按照二进制反转
if(i<rev[i])//保证只把前面的数和后面的数交换,(否则数组会被翻回来)
swap(a[i],a[rev[i]]);
}
for(int step=1;step<n;step<<=1){//枚举步长的一半
cd wn=exp(cd(0,dft*PI/step));//计算单位复根
for(int j=0;j<n;j+=step<<1){//对于每一块
cd wnk(1,0);//!!每一块都是一个独立序列,都是以零次方位为起始的
for(int k=j;k<j+step;k++){//蝴蝶操作处理这一块
cd x=a[k];
cd y=wnk*a[k+step];
a[k]=x+y;
a[k+step]=x-y;
wnk*=wn;//计算下一次的复根
}
}
}
if(dft==-1){//如果是反变换,则要将序列除以n
for(int i=0;i<n;i++)
a[i]/=n;
}
}
int output[maxl];
char s1[maxl],s2[maxl];
int main(){
int n1, n2;
scanf("%d %d", &n1, &n2);
for(int i=n1-1; i>=0; i--){
scanf("%lf", &a[i]);
}
for(int i=n2-1; i>=0; i--){
scanf("%lf", &b[i]);
}
int s = 2;//2的幂次
int bit = 1;
for(bit=1;(1<<bit)<n1+n2-1;bit++){
s<<=1;//找到第一个二的整数次幂使得其可以容纳这两个数的乘积
}
getrev(bit);fft(a,s,1);fft(b,s,1);//dft
for(int i=0;i<s;i++)a[i]*=b[i];//对应相乘
fft(a,s,-1);//idft
for(int i=0;i<s;i++){//还原成十进制数
output[i]+=(int)(a[i].real()+0.5);//注意精度误差
}
int i;
for(i=n1+n2;!output[i]&&i>=0;i--);//去掉前导零
if(i==-1)printf("0");//特判长度为0的情况
for(;i>=0;i--){//输出这个十进制数
printf("%d ",output[i]);
}
putchar('\n');
return 0;
}

大数的乘法

输入两个大数,然后输出乘法的结果

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include<cstdio>
#include<cstdlib>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<complex>
using namespace std;
typedef complex<double> cd;//复数类的定义
const int maxl=2094153;//nlogn的最大长度(来自leo学长的博客)
const double PI=3.14159265358979;//圆周率,不解释
cd a[maxl],b[maxl];//用于储存变换的中间结果
int rev[maxl];//用于储存二进制反转的结果
void getrev(int bit){
for(int i=0;i<(1<<bit);i++){//高位决定二进制数的大小
rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
}//能保证(x>>1)<x,满足递推性质
}
void fft(cd* a,int n,int dft){//变换主要过程
for(int i=0;i<n;i++){//按照二进制反转
if(i<rev[i])//保证只把前面的数和后面的数交换,(否则数组会被翻回来)
swap(a[i],a[rev[i]]);
}
for(int step=1;step<n;step<<=1){//枚举步长的一半
cd wn=exp(cd(0,dft*PI/step));//计算单位复根
for(int j=0;j<n;j+=step<<1){//对于每一块
cd wnk(1,0);//!!每一块都是一个独立序列,都是以零次方位为起始的
for(int k=j;k<j+step;k++){//蝴蝶操作处理这一块
cd x=a[k];
cd y=wnk*a[k+step];
a[k]=x+y;
a[k+step]=x-y;
wnk*=wn;//计算下一次的复根
}
}
}
if(dft==-1){//如果是反变换,则要将序列除以n
for(int i=0;i<n;i++)
a[i]/=n;
}
}
int output[maxl];
char s1[maxl],s2[maxl];
int main(){
scanf("%s%s",s1,s2);//读入两个数
int l1=strlen(s1),l2=strlen(s2);//就算"次数界"
int bit=1,s=2;//s表示分割之前整块的长度
for(bit=1;(1<<bit)<l1+l2-1;bit++){
s<<=1;//找到第一个二的整数次幂使得其可以容纳这两个数的乘积
}
for(int i=0;i<l1;i++){//第一个数装入a
a[i]=(double)(s1[l1-i-1]-'0');
}
for(int i=0;i<l2;i++){//第二个数装入b
b[i]=(double)(s2[l2-i-1]-'0');
}
getrev(bit);fft(a,s,1);fft(b,s,1);//dft
for(int i=0;i<s;i++)a[i]*=b[i];//对应相乘
fft(a,s,-1);//idft
for(int i=0;i<s;i++){//还原成十进制数
output[i]+=(int)(a[i].real()+0.5);//注意精度误差
output[i+1]+=output[i]/10;
output[i]%=10;
}
int i;
for(i=l1+l2;!output[i]&&i>=0;i--);//去掉前导零
if(i==-1)printf("0");//特判长度为0的情况
for(;i>=0;i--){//输出这个十进制数
printf("%d",output[i]);
}
putchar('\n');
return 0;
}

未解决的问题

文章目录
  1. 1. 引入
    1. 1.1. 使用Python进行验证
  2. 2. 多项式的乘法
  3. 3. 大数的乘法
  4. 4. 未解决的问题
{{ live2d() }}