Skip to content

Commit 03629c7

Browse files
committed
More [[gnu::target("avx2")]]
1 parent c75eefd commit 03629c7

File tree

4 files changed

+38
-38
lines changed

4 files changed

+38
-38
lines changed

cp-algo/math/cvector.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ namespace cp_algo::math::fft {
115115
checkpoint("dot");
116116
}
117117
template<bool partial = true>
118-
void ifft() {
118+
[[gnu::target("avx2")]] void ifft() {
119119
size_t n = size();
120120
if constexpr (!partial) {
121121
point pi(0, 1);
@@ -177,7 +177,7 @@ namespace cp_algo::math::fft {
177177
}
178178
}
179179
template<bool partial = true>
180-
void fft() {
180+
[[gnu::target("avx2")]] void fft() {
181181
size_t n = size();
182182
bool parity = std::countr_zero(n) % 2;
183183
for(size_t leaf = 0; leaf < n; leaf += 4 * flen) {

cp-algo/math/fft.hpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ namespace cp_algo::math::fft {
125125
checkpoint("dot");
126126
}
127127

128-
void dot(auto &&C, auto const& D) {
128+
[[gnu::target("avx2")]] void dot(auto &&C, auto const& D) {
129129
dot(C, D, A, B, C);
130130
}
131131

@@ -209,7 +209,7 @@ namespace cp_algo::math::fft {
209209
template<modint_type base> uint32_t dft<base>::mod = {};
210210
template<modint_type base> uint32_t dft<base>::imod = {};
211211

212-
void mul_slow(auto &a, auto const& b, size_t k) {
212+
[[gnu::target("avx2")]] void mul_slow(auto &a, auto const& b, size_t k) {
213213
if(std::empty(a) || std::empty(b)) {
214214
a.clear();
215215
} else {
@@ -230,7 +230,7 @@ namespace cp_algo::math::fft {
230230
}
231231
return std::max(flen, std::bit_ceil(as + bs - 1) / 2);
232232
}
233-
void mul_truncate(auto &a, auto const& b, size_t k) {
233+
[[gnu::target("avx2")]] void mul_truncate(auto &a, auto const& b, size_t k) {
234234
using base = std::decay_t<decltype(a[0])>;
235235
if(std::min({k, std::size(a), std::size(b)}) < magic) {
236236
mul_slow(a, b, k);
@@ -279,7 +279,7 @@ namespace cp_algo::math::fft {
279279
}
280280
cp_algo::checkpoint("mod split");
281281
}
282-
void cyclic_mul(auto &a, auto &&b, size_t k) {
282+
[[gnu::target("avx2")]] void cyclic_mul(auto &a, auto &&b, size_t k) {
283283
assert(std::popcount(k) == 1);
284284
assert(std::size(a) == std::size(b) && std::size(a) == k);
285285
using base = std::decay_t<decltype(a[0])>;
@@ -312,13 +312,13 @@ namespace cp_algo::math::fft {
312312
}
313313
cp_algo::checkpoint("mod join");
314314
}
315-
auto make_copy(auto &&x) {
315+
[[gnu::target("avx2")]] auto make_copy(auto &&x) {
316316
return x;
317317
}
318-
void cyclic_mul(auto &a, auto const& b, size_t k) {
318+
[[gnu::target("avx2")]] void cyclic_mul(auto &a, auto const& b, size_t k) {
319319
return cyclic_mul(a, make_copy(b), k);
320320
}
321-
void mul(auto &a, auto &&b) {
321+
[[gnu::target("avx2")]] void mul(auto &a, auto &&b) {
322322
size_t N = size(a) + size(b);
323323
if(N > (1 << 20)) {
324324
N--;
@@ -331,7 +331,7 @@ namespace cp_algo::math::fft {
331331
mul_truncate(a, b, N - 1);
332332
}
333333
}
334-
void mul(auto &a, auto const& b) {
334+
[[gnu::target("avx2")]] void mul(auto &a, auto const& b) {
335335
size_t N = size(a) + size(b);
336336
if(N > (1 << 20)) {
337337
mul(a, make_copy(b));

cp-algo/util/checkpoint.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
#include <string>
66
#include <map>
77
namespace cp_algo {
8-
std::map<std::string, double> checkpoints;
98
template<bool final = false>
10-
void checkpoint([[maybe_unused]] std::string const& msg = "") {
9+
void checkpoint([[maybe_unused]] auto const& msg = "") {
1110
#ifdef CP_ALGO_CHECKPOINT
11+
static std::map<std::string, double> checkpoints;
1212
static double last = 0;
1313
double now = (double)clock() / CLOCKS_PER_SEC;
1414
double delta = now - last;

cp-algo/util/complex.hpp

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,38 @@ namespace cp_algo {
1212
constexpr complex(): x(), y() {}
1313
constexpr complex(T const& x): x(x), y() {}
1414
constexpr complex(T const& x, T const& y): x(x), y(y) {}
15-
complex& operator *= (T const& t) {x *= t; y *= t; return *this;}
16-
complex& operator /= (T const& t) {x /= t; y /= t; return *this;}
17-
complex operator * (T const& t) const {return complex(*this) *= t;}
18-
complex operator / (T const& t) const {return complex(*this) /= t;}
19-
complex& operator += (complex const& t) {x += t.x; y += t.y; return *this;}
20-
complex& operator -= (complex const& t) {x -= t.x; y -= t.y; return *this;}
21-
complex operator * (complex const& t) const {return {x * t.x - y * t.y, x * t.y + y * t.x};}
22-
complex operator / (complex const& t) const {return *this * t.conj() / t.norm();}
23-
complex operator + (complex const& t) const {return complex(*this) += t;}
24-
complex operator - (complex const& t) const {return complex(*this) -= t;}
25-
complex& operator *= (complex const& t) {return *this = *this * t;}
26-
complex& operator /= (complex const& t) {return *this = *this / t;}
27-
complex operator - () const {return {-x, -y};}
28-
complex conj() const {return {x, -y};}
29-
T norm() const {return x * x + y * y;}
30-
T abs() const {return std::sqrt(norm());}
15+
[[gnu::target("avx2")]] complex& operator *= (T const& t) {x *= t; y *= t; return *this;}
16+
[[gnu::target("avx2")]] complex& operator /= (T const& t) {x /= t; y /= t; return *this;}
17+
[[gnu::target("avx2")]] complex operator * (T const& t) const {return complex(*this) *= t;}
18+
[[gnu::target("avx2")]] complex operator / (T const& t) const {return complex(*this) /= t;}
19+
[[gnu::target("avx2")]] complex& operator += (complex const& t) {x += t.x; y += t.y; return *this;}
20+
[[gnu::target("avx2")]] complex& operator -= (complex const& t) {x -= t.x; y -= t.y; return *this;}
21+
[[gnu::target("avx2")]] complex operator * (complex const& t) const {return {x * t.x - y * t.y, x * t.y + y * t.x};}
22+
[[gnu::target("avx2")]] complex operator / (complex const& t) const {return *this * t.conj() / t.norm();}
23+
[[gnu::target("avx2")]] complex operator + (complex const& t) const {return complex(*this) += t;}
24+
[[gnu::target("avx2")]] complex operator - (complex const& t) const {return complex(*this) -= t;}
25+
[[gnu::target("avx2")]] complex& operator *= (complex const& t) {return *this = *this * t;}
26+
[[gnu::target("avx2")]] complex& operator /= (complex const& t) {return *this = *this / t;}
27+
[[gnu::target("avx2")]] complex operator - () const {return {-x, -y};}
28+
[[gnu::target("avx2")]] complex conj() const {return {x, -y};}
29+
[[gnu::target("avx2")]] T norm() const {return x * x + y * y;}
30+
[[gnu::target("avx2")]] T abs() const {return std::sqrt(norm());}
3131
[[gnu::target("avx2")]] T const real() const {return x;}
3232
[[gnu::target("avx2")]] T const imag() const {return y;}
33-
T& real() {return x;}
34-
T& imag() {return y;}
35-
static constexpr complex polar(T r, T theta) {return {T(r * cos(theta)), T(r * sin(theta))};}
36-
auto operator <=> (complex const& t) const = default;
33+
[[gnu::target("avx2")]] T& real() {return x;}
34+
[[gnu::target("avx2")]] T& imag() {return y;}
35+
[[gnu::target("avx2")]] static constexpr complex polar(T r, T theta) {return {T(r * cos(theta)), T(r * sin(theta))};}
36+
[[gnu::target("avx2")]] auto operator <=> (complex const& t) const = default;
3737
};
38-
template<typename T> complex<T> conj(complex<T> const& x) {return x.conj();}
39-
template<typename T> T norm(complex<T> const& x) {return x.norm();}
40-
template<typename T> T abs(complex<T> const& x) {return x.abs();}
41-
template<typename T> T& real(complex<T> &x) {return x.real();}
42-
template<typename T> T& imag(complex<T> &x) {return x.imag();}
38+
template<typename T> [[gnu::target("avx2")]] complex<T> conj(complex<T> const& x) {return x.conj();}
39+
template<typename T> [[gnu::target("avx2")]] T norm(complex<T> const& x) {return x.norm();}
40+
template<typename T> [[gnu::target("avx2")]] T abs(complex<T> const& x) {return x.abs();}
41+
template<typename T> [[gnu::target("avx2")]] T& real(complex<T> &x) {return x.real();}
42+
template<typename T> [[gnu::target("avx2")]] T& imag(complex<T> &x) {return x.imag();}
4343
template<typename T> [[gnu::target("avx2")]] T const real(complex<T> const& x) {return x.real();}
4444
template<typename T> [[gnu::target("avx2")]] T const imag(complex<T> const& x) {return x.imag();}
4545
template<typename T>
46-
constexpr complex<T> polar(T r, T theta) {
46+
[[gnu::target("avx2")]] constexpr complex<T> polar(T r, T theta) {
4747
return complex<T>::polar(r, theta);
4848
}
4949
template<typename T>

0 commit comments

Comments
 (0)