Skip to content

Commit 31f9138

Browse files
committed
optimize fft
1 parent 7198637 commit 31f9138

File tree

2 files changed

+147
-68
lines changed

2 files changed

+147
-68
lines changed

atcoder/convolution.hpp

Lines changed: 139 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -14,95 +14,165 @@ namespace atcoder {
1414

1515
namespace internal {
1616

17+
template <class mint,
18+
int g = internal::primitive_root<mint::mod()>,
19+
internal::is_static_modint_t<mint>* = nullptr>
20+
struct fft_info {
21+
static constexpr int rank2 = bsf_constexpr(mint::mod() - 1);
22+
std::array<mint, rank2 + 1> root; // root[i]^(2^i) == 1
23+
std::array<mint, rank2 + 1> iroot; // root[i] * iroot[i] == 1
24+
25+
std::array<mint, std::max(0, rank2 - 2 + 1)> rate2;
26+
std::array<mint, std::max(0, rank2 - 2 + 1)> irate2;
27+
28+
std::array<mint, std::max(0, rank2 - 3 + 1)> rate3;
29+
std::array<mint, std::max(0, rank2 - 3 + 1)> irate3;
30+
31+
fft_info() {
32+
root[rank2] = mint(g).pow((mint::mod() - 1) >> rank2);
33+
iroot[rank2] = root[rank2].inv();
34+
for (int i = rank2 - 1; i >= 0; i--) {
35+
root[i] = root[i + 1] * root[i + 1];
36+
iroot[i] = iroot[i + 1] * iroot[i + 1];
37+
}
38+
39+
{
40+
mint prod = 1, iprod = 1;
41+
for (int i = 0; i <= rank2 - 2; i++) {
42+
rate2[i] = root[i + 2] * prod;
43+
irate2[i] = iroot[i + 2] * iprod;
44+
prod *= iroot[i + 2];
45+
iprod *= root[i + 2];
46+
}
47+
}
48+
{
49+
mint prod = 1, iprod = 1;
50+
for (int i = 0; i <= rank2 - 3; i++) {
51+
rate3[i] = root[i + 3] * prod;
52+
irate3[i] = iroot[i + 3] * iprod;
53+
prod *= iroot[i + 3];
54+
iprod *= root[i + 3];
55+
}
56+
}
57+
}
58+
};
59+
1760
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
1861
void butterfly(std::vector<mint>& a) {
19-
static constexpr int g = internal::primitive_root<mint::mod()>;
2062
int n = int(a.size());
2163
int h = internal::ceil_pow2(n);
2264

23-
static bool first = true;
24-
static mint sum_e[30]; // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i]
25-
if (first) {
26-
first = false;
27-
mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
28-
int cnt2 = bsf(mint::mod() - 1);
29-
mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
30-
for (int i = cnt2; i >= 2; i--) {
31-
// e^(2^i) == 1
32-
es[i - 2] = e;
33-
ies[i - 2] = ie;
34-
e *= e;
35-
ie *= ie;
36-
}
37-
mint now = 1;
38-
for (int i = 0; i <= cnt2 - 2; i++) {
39-
sum_e[i] = es[i] * now;
40-
now *= ies[i];
41-
}
42-
}
43-
for (int ph = 1; ph <= h; ph++) {
44-
int w = 1 << (ph - 1), p = 1 << (h - ph);
45-
mint now = 1;
46-
for (int s = 0; s < w; s++) {
47-
int offset = s << (h - ph + 1);
48-
for (int i = 0; i < p; i++) {
49-
auto l = a[i + offset];
50-
auto r = a[i + offset + p] * now;
51-
a[i + offset] = l + r;
52-
a[i + offset + p] = l - r;
65+
static const fft_info<mint> info;
66+
67+
int len = 0; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
68+
while (len < h) {
69+
if (h - len == 1) {
70+
int p = 1 << (h - len - 1);
71+
mint rot = 1;
72+
for (int s = 0; s < (1 << len); s++) {
73+
int offset = s << (h - len);
74+
for (int i = 0; i < p; i++) {
75+
auto l = a[i + offset];
76+
auto r = a[i + offset + p] * rot;
77+
a[i + offset] = l + r;
78+
a[i + offset + p] = l - r;
79+
}
80+
rot *= info.rate2[bsf(~(unsigned int)(s))];
81+
}
82+
len++;
83+
} else {
84+
// 4-base
85+
int p = 1 << (h - len - 2);
86+
mint rot = 1, imag = info.root[2];
87+
for (int s = 0; s < (1 << len); s++) {
88+
mint rot2 = rot * rot;
89+
mint rot3 = rot2 * rot;
90+
int offset = s << (h - len);
91+
for (int i = 0; i < p; i++) {
92+
auto mod2 = 1ULL * mint::mod() * mint::mod();
93+
auto a0 = 1ULL * a[i + offset].val();
94+
auto a1 = 1ULL * a[i + offset + p].val() * rot.val();
95+
auto a2 = 1ULL * a[i + offset + 2 * p].val() * rot2.val();
96+
auto a3 = 1ULL * a[i + offset + 3 * p].val() * rot3.val();
97+
auto a1na3imag =
98+
1ULL * mint(a1 + mod2 - a3).val() * imag.val();
99+
auto na2 = mod2 - a2;
100+
a[i + offset] = a0 + a2 + a1 + a3;
101+
a[i + offset + 1 * p] = a0 + a2 + (2 * mod2 - (a1 + a3));
102+
a[i + offset + 2 * p] = a0 + na2 + a1na3imag;
103+
a[i + offset + 3 * p] = a0 + na2 + (mod2 - a1na3imag);
104+
}
105+
rot *= info.rate3[bsf(~(unsigned int)(s))];
53106
}
54-
now *= sum_e[bsf(~(unsigned int)(s))];
107+
len += 2;
55108
}
56109
}
57110
}
58111

59112
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
60113
void butterfly_inv(std::vector<mint>& a) {
61-
static constexpr int g = internal::primitive_root<mint::mod()>;
62114
int n = int(a.size());
63115
int h = internal::ceil_pow2(n);
64116

65-
static bool first = true;
66-
static mint sum_ie[30]; // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i]
67-
if (first) {
68-
first = false;
69-
mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
70-
int cnt2 = bsf(mint::mod() - 1);
71-
mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
72-
for (int i = cnt2; i >= 2; i--) {
73-
// e^(2^i) == 1
74-
es[i - 2] = e;
75-
ies[i - 2] = ie;
76-
e *= e;
77-
ie *= ie;
78-
}
79-
mint now = 1;
80-
for (int i = 0; i <= cnt2 - 2; i++) {
81-
sum_ie[i] = ies[i] * now;
82-
now *= es[i];
83-
}
84-
}
117+
static const fft_info<mint> info;
118+
119+
int len = h; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
120+
while (len) {
121+
if (len == 1) {
122+
int p = 1 << (h - len);
123+
mint irot = 1;
124+
for (int s = 0; s < (1 << (len - 1)); s++) {
125+
int offset = s << (h - len + 1);
126+
for (int i = 0; i < p; i++) {
127+
auto l = a[i + offset];
128+
auto r = a[i + offset + p];
129+
a[i + offset] = l + r;
130+
a[i + offset + p] =
131+
(unsigned long long)(mint::mod() + l.val() - r.val()) *
132+
irot.val();
133+
;
134+
}
135+
irot *= info.irate2[bsf(~(unsigned int)(s))];
136+
}
137+
len--;
138+
} else {
139+
// 4-base
140+
int p = 1 << (h - len);
141+
mint irot = 1, iimag = info.iroot[2];
142+
for (int s = 0; s < (1 << (len - 2)); s++) {
143+
mint irot2 = irot * irot;
144+
mint irot3 = irot2 * irot;
145+
int offset = s << (h - len + 2);
146+
for (int i = 0; i < p; i++) {
147+
auto a0 = 1ULL * a[i + offset + 0 * p].val();
148+
auto a1 = 1ULL * a[i + offset + 1 * p].val();
149+
auto a2 = 1ULL * a[i + offset + 2 * p].val();
150+
auto a3 = 1ULL * a[i + offset + 3 * p].val();
151+
152+
auto a2na3iimag =
153+
1ULL *
154+
mint((mint::mod() + a2 - a3) * iimag.val()).val();
85155

86-
for (int ph = h; ph >= 1; ph--) {
87-
int w = 1 << (ph - 1), p = 1 << (h - ph);
88-
mint inow = 1;
89-
for (int s = 0; s < w; s++) {
90-
int offset = s << (h - ph + 1);
91-
for (int i = 0; i < p; i++) {
92-
auto l = a[i + offset];
93-
auto r = a[i + offset + p];
94-
a[i + offset] = l + r;
95-
a[i + offset + p] =
96-
(unsigned long long)(mint::mod() + l.val() - r.val()) *
97-
inow.val();
156+
a[i + offset] = a0 + a1 + a2 + a3;
157+
a[i + offset + 1 * p] =
158+
(a0 + (mint::mod() - a1) + a2na3iimag) * irot.val();
159+
a[i + offset + 2 * p] =
160+
(a0 + a1 + (mint::mod() - a2) + (mint::mod() - a3)) *
161+
irot2.val();
162+
a[i + offset + 3 * p] =
163+
(a0 + (mint::mod() - a1) + (mint::mod() - a2na3iimag)) *
164+
irot3.val();
165+
}
166+
irot *= info.irate3[bsf(~(unsigned int)(s))];
98167
}
99-
inow *= sum_ie[bsf(~(unsigned int)(s))];
168+
len -= 2;
100169
}
101170
}
102171
}
103172

104173
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
105-
std::vector<mint> convolution_naive(const std::vector<mint>& a, const std::vector<mint>& b) {
174+
std::vector<mint> convolution_naive(const std::vector<mint>& a,
175+
const std::vector<mint>& b) {
106176
int n = int(a.size()), m = int(b.size());
107177
std::vector<mint> ans(n + m - 1);
108178
if (n < m) {
@@ -150,7 +220,8 @@ std::vector<mint> convolution(std::vector<mint>&& a, std::vector<mint>&& b) {
150220
}
151221

152222
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
153-
std::vector<mint> convolution(const std::vector<mint>& a, const std::vector<mint>& b) {
223+
std::vector<mint> convolution(const std::vector<mint>& a,
224+
const std::vector<mint>& b) {
154225
int n = int(a.size()), m = int(b.size());
155226
if (!n || !m) return {};
156227
if (std::min(n, m) <= 60) return convolution_naive(a, b);

atcoder/internal_bit.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ int ceil_pow2(int n) {
1717
return x;
1818
}
1919

20+
// @param n `1 <= n`
21+
// @return minimum non-negative `x` s.t. `(n & (1 << x)) != 0`
22+
constexpr int bsf_constexpr(unsigned int n) {
23+
int x = 0;
24+
while (!(n & (1 << x))) x++;
25+
return x;
26+
}
27+
2028
// @param n `1 <= n`
2129
// @return minimum non-negative `x` s.t. `(n & (1 << x)) != 0`
2230
int bsf(unsigned int n) {

0 commit comments

Comments
 (0)