FFT与NTT 日付: 10月 28, 2018 リンクを取得 Facebook × Pinterest メール 他のアプリ 一直只会套模板,这次认真总结一下FFT与NTT。 首先,这两种算法在程序竞赛中主要用于加速多项式乘法。用系数表示的多项式做乘法是一个复杂度$O(n^2)$的卷积运算。但是一个$n$次的多项式可以用$n+1$个不同的点唯一确定,而在点表示的多项式做乘法时,只需要把对应的值相乘即可,这个计算是$O(n)$的。如果有比$O(n^2)$更优的算法能够将多项式从系数表示和点表示互相转化,就能降低复杂度了。 ## FFT(Fast Fourier Transform) 首先需要引入几个概念和定理 #### Complex Roots of Unity(单位复根) $$\omega_n^k = e^{\frac{2\pi ik}{n}} = cos(\frac{2\pi k}{n}) + i\cdot sin(\frac{2\pi k}{n})$$ $\omega_n^ k$是满足$\omega^n=1$的复数根,满足性质 $$\omega_n^j \omega_n^k = \omega_n ^{(j+k)\mod n}$$ $$\omega_n^ 0 =\omega_n^ n= 1, \omega_n ^ {n/2} = -1$$ #### Cancelation Lemma(消去引理) $$\omega^{dk}_{dn} = \omega _n ^k, \quad \forall n,k,d \ge 0$$ #### Halving Lemma(折半引理) $$(\omega_n^ k)^2=\omega^ k_{n/2}, \quad \forall k\ge 0, 2 \mid n$$ 这说明如果$n$是偶数,$n$个$n$次单位复数根的平方的集合就是$n/2$个$n/2$次单位复数根。为FFT数据规模折半提供依据 #### Summation Lemma(求和引理) $$\sum_{j=0}^ {n-1}(\omega_n^ k)^j = 0, \quad \forall n \ge 1, n \nmid k$$ DFT(Discrete Fourier Transform)就是求多项式在$n$个单位复根处的值,逆DFT就是已知多项式在$n$个单位复根处的值,求系数。也就是多项式系数形式和点形式之间的变换。 对于次数界为$n$的多项式$A(x)=\sum_{j=0}^ {n-1} a_j x^j$,其DFT就是求$y_j = \sum_{k=0}^ {n-1} a_k \omega_n^{jk}$。逆DFT运算就是求$a_j=\frac{1}{n}\sum_{k=0}^ {n-1}y_k\omega_n^{-jk}$ Proof $$\begin{aligned} \frac{1}{n}\sum_{k=0}^ {n-1}y_k \omega_n^ {-jk} &= \frac{1}{n}\sum_{k=0}^ {n-1}(\sum_{l=0}^ {n-1}a_l \omega_n^ {kl})\omega_n^ {-jk} \\ &= \frac{1}{n}\sum_{l=0}^ {n-1}a_l \sum_{k=0}^ {n-1}\omega_n^ {k(l-j)} \\ &= \frac{1}{n}a_j \sum_{k=0}^ {n-1} w_n^ 0, \quad (Summation) \\ &= a_j \end{aligned}$$ 那么,怎么快速求DFT呢?我们可以首先令次数界$n$是2的幂次(不够可以补0)。如果要计算$A$在$\omega_n^ 0, \cdots, \omega_n^ {n-1}$处的值,可以把原多项式分解为两个次数界为$n/2$的多项式 $$A_0(x) = a_0 + a_2x+\cdots+a_{n-2}x^ {n/2-1}$$ $$A_1(x)=a_1 + a_3x + \cdots+a_{n-1}x^ {n/2-1}$$ 这样原多项式就可以写作 $$A(x) = A_0(x^2) + xA_1(x^2)$$ 由折半引理,只需要对$A_0(x)$和$A_1(x)$在$\omega_{n/2}^ 0, \cdots, \omega_{n/2}^ {n/2-1}$处求值即可,问题规模减半。如果用递归写就很容易了。但是写成递归的形式常数会比较大,一般都是用迭代写法。考虑递归的情况下,分裂的情况,如图。 background Layer 1 0 1 2 3 4 5 6 7 000 001 010 011 100 101 110 111 0 2 4 6 000 010 100 110 1 3 5 7 001 011 101 111 0 4 000 100 2 6 010 110 1 5 001 101 3 7 011 111 0 000 4 100 2 010 6 110 1 001 5 101 3 011 7 111 我们可以先把原数组排列成最底层的顺序,然后不断向上合并结果。观察最底一层下标的二进制表示,可以发现相当于末$log_2 n$位翻转后排序。对于一个下标$i$,如果将其二进制末$log_2 n$位倒转后的$j$比它大(否则就会交换两次复原)的话,交换原数组中的$i$和$j$位置即可。 ```c++ template void arrange(vector &A) { int n = A.size(); for (size_t i = 1, x = 0, y = 0; i <= n; i++) { if (x > y) swap(A[x], A[y]); // x和y保持末log2(n)位是相反的,x保证枚举了[0, n) x ^= i & -i, y ^= n / (i & -i) >> 1; } } ``` 考虑合并两个子问题(合并过程称为蝴蝶操作)。当前层需要处理的项数为$l$,$A_0(\omega_{n/2}^ k)$和$A_1(\omega_{n/2}^ k)$分别存放在$a[k]$和$a[k+l/2]$中,设$t=\omega_n^ k a[k+l/2]$,$a[k]$和$a[k+l/2]$分别需要被更新为$a[k] + t$和$a[k] - t$。 Code ```c++ namespace FFT { static const double PI = acos(-1.0); template void arrange(vector &A) { int n = A.size(); assert(n == (n & -n)); for (int i = 1, x = 0, y = 0; i <= n; i++) { if (x > y) swap(A[x], A[y]); x ^= i & -i, y ^= n / (i & -i) >> 1; } } template void fourier(vector> &A, int inv) { assert((inv == 1 || inv == -1) && is_floating_point::value); int n = 1 << (32 - __builtin_clz(A.size() - 1)); A.resize(n), arrange(A); for (int l = 1; l < n; l <<= 1) { complex wl(cos(inv * PI / l), sin(inv * PI / l)), w(1, 0), t; for (int i = 0; i < l; i++, w *= wl) for (int s = 0; s < n; s += l + l) t = w * A[s + i + l], A[s + i + l] = A[s + i] - t, A[s + i] += t; } if (inv == -1) for (int i = 0; i < n; i++) A[i] /= n; } template vector multiply(const vector &A, const vector &B) { double bias = is_integral::value ? 0.5 : 0; int s = A.size() + B.size() - 1; vector> U, V; // or long double for (auto x : A) U.emplace_back(x, 0); for (auto x : B) V.emplace_back(x, 0); U.resize(s), V.resize(s), fourier(U, 1), fourier(V, 1); for (size_t i = 0; i < U.size(); i++) U[i] *= V[i]; fourier(U, -1); vector R; for (int i = 0; i < s; i++) R.push_back(U[i].real() + bias); return R; } }; ``` ## NTT(Number Theoretic Transform) 和FFT中利用单位复根相似,NTT利用了原根(Primitive Root)的性质。 对素数$p$,从2开始枚举,满足$g^{p-1} \equiv 1 \pmod p$的$g$是模$p$意义下的原根,原根有$g^ i \neq g^ j \pmod p, \quad i \neq j$的性质 设$g_n=g^{(p-1)/n}, \quad n \mid (p-1)$。$g_n$满足性质$$g_n^ n=g^{p-1}=1 \pmod p$$ $$g_n^ {n/2} = \sqrt{1} \equiv -1 \pmod p$$ #### 消去引理 $$g_{dn}^ {dk} = g_n^ k$$ #### 折半引理 $$(g_n^ k)^ 2 = g_{n/2}^ k$$ #### 求和引理 $$\sum_{j=0}^ {n-1}(g_n^ k)^j = 0 \pmod p, \quad \forall n \ge 1, n \nmid k$$ Code ```c++ namespace NTT { // must exists n*k+1 = P, (1004535809, 3), (786433, 10) static const int P = 998244353, G = 3; template void arrange(vector &A) { int n = A.size(); assert(n == (n & -n)); for (int i = 1, x = 0, y = 0; i <= n; i++) { if (x > y) swap(A[x], A[y]); x ^= i & -i, y ^= n / (i & -i) >> 1; } } // int add(int x, int y) { return x + y < P ? x + y : x + y - P; } int add(int x, int y) { return (x + y) % P; } int mpow(long long a, int k) { int r = 1; for (a %= P; k; k >>= 1, a = a * a % P) if (k & 1)r = r * a % P; return r; } template void fourier(vector &A, int inv) { assert(inv == -1 || inv == 1); int n = 1 << (32 - __builtin_clz(A.size() - 1)); A.resize(n), arrange(A); for (int l = 1; l < n; l <<= 1) { int wl = mpow(G, ((P - 1) / l >> 1) * (l + l + inv)), w = 1, t; for (int i = 0; i < l; i++, w = 1LL * w * wl % P) for (int s = 0; s < n; s += l + l) t = 1LL * w * A[s + i + l] % P, A[s + i + l] = add(A[s + i], P - t), A[s + i] = add(A[s + i], t); } int t = mpow(n, P - 2); if (inv == -1) for (int i = 0; i < n; i++) A[i] = 1LL * A[i] * t % P; } template vector multiply(vector A, vector B) { int s = A.size() + B.size() - 1; A.resize(s), B.resize(s), fourier(A, 1), fourier(B, 1); for (size_t i = 0; i < A.size(); i++) A[i] = 1LL * A[i] * B[i] % P; fourier(A, -1); A.resize(s); return A; } }; ``` ## Arbitrary Moduler(任意模数) NTT的局限性在于对模数的限制,竞赛中常见的模数$10^9+7$就不能用作NTT的模数。对于任意模数,一种方法是用**三模数NTT**,使用三个NTT模数,保证$p_1p_2p_3 > na^2$,然后用CRT合并结果,但是这种方法在模数比较大(比如$10^9+7$)时会爆long long,会导致写起来非常麻烦并且常数爆炸。还有一种方法是**拆系数FFT**,下面总结一下这种方法。 将两个多项式$A,B$的系数拆为 $$A_i=x_i\sqrt{M}+y_i,\quad B_j=x_j\sqrt{M}+y_j$$ 这样系数相乘就变成了 $$A_iB_j=x_ix_jM+(x_iy_j+y_ix_j)\sqrt{M}+y_iy_j$$ 这里$M$相关的部分可以暂时不管,先把三部分的系数分别做卷积,最后再考虑合并。这样拆分后可以保证$x,y \le \sqrt{M}$,两个这样的数相乘,加上卷积中求和引入的数组长度,所有的数值不会超过$nM$,如果$nM$在long long范围内就可以处理了。$M$需要比所有的原系数大,一般取$M=2^{30}$。这种方法一共需要4次DFT与3次IDFT,但是常数还可以优化。 前面的拆分中,原本对多项式$A$的DFT相当于改成了对$X$和$Y$的两次DFT。考虑构造这样两个复系数数组$$P_j = X_j+iY_j, \quad Q_j=X_j-iY_j$$ 每个系数分别是共轭复数。可以发现$$X(\omega^ k)=[P(\omega^ k)+Q(\omega^ k)]/2$$ $$Y(\omega ^ k)=i[Q(\omega^ k)-P(\omega^ k)]/2$$ 如果我们求出了$P$和$Q$的DFT,就可以根据上式计算出$X$和$Y$的DFT。因为$P,Q$系数共轭,存在性质$$\overline{Q(\omega_n^ k)} = P(\omega_n^ {n-k})$$ Proof $$\begin{aligned} P(\omega_n^ {n-k}) &= \sum_{j=0}^ {n-1} (x_j+iy_j)\omega_n^ {(n-k)j}\\ &=\sum_{j=0}^ {n-1}(x_j+iy_j)[cos\frac{2\pi(n-k)j}{n}+i\cdot sin\frac{2\pi(n-k)j}{n}]\\ &=\sum_{j=0}^ {n-1}(x_j+iy_j)(cos\frac{2\pi kj}{n} - i\cdot sin\frac{2\pi kj}{n})\\ &=\sum_{j=0}^ {n-1}\overline{(x_j-iy_j)(cos\frac{2\pi jk}{n}+i\cdot sin\frac{2\pi jk}{n})}\\ &=\sum_{j=0}^ {n-1}\overline{(x_j-iy_j)\omega_n^ {kj}}\\ &= \overline{Q(\omega_n^ k)} \end{aligned}$$ 因此只需对$P$做一次DFT就可以得到$X$和$Y$的DFT。做IDFT时,构造$$R(\omega^ k) = X_1(\omega^ k)X_2(\omega^ k) + iY_1(\omega^ k)Y_2(\omega^ k)$$做IDFT后,得到$$R_i=X_1X_2 +iY_1Y_2$$因此两次IDFT就可以算出所有结果。一共是两次DFT和两次IDFT Arbitrary Moduler FTT ```c++ namespace MFFT { static const double PI = acos(-1.0); template void arrange(vector &A) { int n = A.size(); assert(n == (n & -n)); for (int i = 1, x = 0, y = 0; i <= n; i++) { if (x > y) swap(A[x], A[y]); x ^= i & -i, y ^= n / (i & -i) >> 1; } } template void fourier(vector> &A, int inv) { assert((inv == 1 || inv == -1) && is_floating_point::value); int n = 1 << (32 - __builtin_clz(A.size() - 1)); A.resize(n), arrange(A); vector> W(n >> 1, {1, 0}); for (int l = 1; l < n; l <<= 1) { complex wl(cos(inv * PI / l), sin(inv * PI / l)), t; for (int i = l - 2; i >= 0; i -= 2) W[i] = W[i >> 1]; for (int i = 1; i < l; i += 2) W[i] = W[i - 1] * wl; for (int i = 0; i < l; i++) for (int s = 0; s < n; s += l + l) t = W[i] * A[s + i + l], A[s + i + l] = A[s + i] - t, A[s + i] += t; } if (inv == -1) for (int i = 0; i < n; i++) A[i] /= n; } template vector multiply(const vector &A, const vector &B, int p = 1e9 + 7) { assert(is_integral::value); using CD = complex; // or long double; using LL = long long; int s = A.size() + B.size() - 1; vector U, V; for (auto x : A) U.emplace_back(x >> 15, x & 32767); for (auto x : B) V.emplace_back(x >> 15, x & 32767); U.resize(s), V.resize(s), fourier(U, 1), fourier(V, 1); for (size_t i = 0, n = U.size(); i + i <= n; i++) { size_t j = (n - i) & (n - 1); auto a = U[i], b = U[j], c = V[i], d = V[j]; U[i] = (a + conj(b)) * (c + conj(d)) * 0.25 - (a - conj(b)) * (c - conj(d)) * CD(0, 0.25); U[j] = (b + conj(a)) * (d + conj(c)) * 0.25 - (b - conj(a)) * (d - conj(c)) * CD(0, 0.25); V[i] = CD(0, 0.5) * (conj(b * d) - a * c), V[j] = CD(0, 0.5) * (conj(a * c) - b * d); } fourier(U, -1), fourier(V, -1); vector R(s); for (int i = 0; i < s; i++) { LL t1 = (LL)(U[i].real() + 0.5) % p; LL t2 = (LL)(V[i].real() + 0.5) % p; LL t3 = (LL)(U[i].imag() + 0.5) % p; R[i] = ((t1 << 30) % p + (t2 << 15) % p + t3) % p; } return R; } }; ``` コメント
コメント
コメントを投稿