|
1 | 1 | // @brief Many Factorials |
2 | 2 | #define PROBLEM "https://judge.yosupo.jp/problem/many_factorials" |
3 | 3 | #pragma GCC optimize("Ofast,unroll-loops") |
| 4 | +#define CP_ALGO_CHECKPOINT |
4 | 5 | #include <bits/stdc++.h> |
5 | | -//#define CP_ALGO_CHECKPOINT |
6 | 6 | #include "blazingio/blazingio.min.hpp" |
7 | | -#include "cp-algo/util/checkpoint.hpp" |
8 | | -#include "cp-algo/util/simd.hpp" |
9 | | -#include "cp-algo/util/bump_alloc.hpp" |
10 | | -#include "cp-algo/math/common.hpp" |
| 7 | +#include "cp-algo/math/factorials.hpp" |
11 | 8 |
|
12 | 9 | using namespace std; |
13 | | -using namespace cp_algo; |
14 | | - |
15 | | -constexpr int mod = 998244353; |
16 | | -constexpr int imod = -math::inv2(mod); |
17 | | - |
18 | | -template<bool use_bump_alloc = false, int maxn = 100'000> |
19 | | -vector<int> facts(vector<int> const& args) { |
20 | | - constexpr int accum = 4; |
21 | | - constexpr int simd_size = 8; |
22 | | - constexpr int block = 1 << 18; |
23 | | - constexpr int subblock = block / simd_size; |
24 | | - using T = array<int, 2>; |
25 | | - using alloc = conditional_t<use_bump_alloc, |
26 | | - bump_alloc<T, 30 * maxn>, |
27 | | - allocator<T>>; |
28 | | - basic_string<T, char_traits<T>, alloc> odd_args_per_block[mod / subblock]; |
29 | | - basic_string<T, char_traits<T>, alloc> reg_args_per_block[mod / subblock]; |
30 | | - constexpr int limit_reg = mod / 64; |
31 | | - int limit_odd = 0; |
32 | | - |
33 | | - vector<int> res(size(args), 1); |
34 | | - auto prod_mod = [&](uint64_t a, uint64_t b) { |
35 | | - return (a * b) % mod; |
36 | | - }; |
37 | | - for(auto [i, xy]: views::zip(args, res) | views::enumerate) { |
38 | | - auto [x, y] = xy; |
39 | | - auto t = x; |
40 | | - if(t >= mod / 2) { |
41 | | - t = mod - t - 1; |
42 | | - y = t % 2 ? 1 : mod - 1; |
43 | | - } |
44 | | - int pw = 0; |
45 | | - while(t > limit_reg) { |
46 | | - limit_odd = max(limit_odd, (t - 1) / 2); |
47 | | - odd_args_per_block[(t - 1) / 2 / subblock].push_back({int(i), (t - 1) / 2}); |
48 | | - t /= 2; |
49 | | - pw += t; |
50 | | - } |
51 | | - reg_args_per_block[t / subblock].push_back({int(i), t}); |
52 | | - y = int(y * math::bpow(2, pw, 1ULL, prod_mod) % mod); |
53 | | - } |
54 | | - cp_algo::checkpoint("init"); |
55 | | - uint32_t b2x32 = (1ULL << 32) % mod; |
56 | | - auto process = [&](int limit, auto &args_per_block, auto step, auto &&proj) { |
57 | | - uint64_t fact = 1; |
58 | | - for(int b = 0; b <= limit; b += accum * block) { |
59 | | - u32x8 cur[accum]; |
60 | | - static array<u32x8, subblock> prods[accum]; |
61 | | - for(int z = 0; z < accum; z++) { |
62 | | - for(int j = 0; j < simd_size; j++) { |
63 | | - cur[z][j] = uint32_t(b + z * block + j * subblock); |
64 | | - cur[z][j] = proj(cur[z][j]); |
65 | | - prods[z][0][j] = cur[z][j] + !cur[z][j]; |
66 | | - #pragma GCC diagnostic push |
67 | | - #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" |
68 | | - cur[z][j] = uint32_t(uint64_t(cur[z][j]) * b2x32 % mod); |
69 | | - #pragma GCC diagnostic pop |
70 | | - } |
71 | | - } |
72 | | - for(int i = 1; i < block / simd_size; i++) { |
73 | | - for(int z = 0; z < accum; z++) { |
74 | | - cur[z] += step; |
75 | | - cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z]; |
76 | | - prods[z][i] = montgomery_mul(prods[z][i - 1], cur[z], mod, imod); |
77 | | - } |
78 | | - } |
79 | | - cp_algo::checkpoint("inner loop"); |
80 | | - for(int z = 0; z < accum; z++) { |
81 | | - for(int j = 0; j < simd_size; j++) { |
82 | | - int bl = b + z * block + j * subblock; |
83 | | - for(auto [i, x]: args_per_block[bl / subblock]) { |
84 | | - auto ans = fact * prods[z][x - bl][j] % mod; |
85 | | - res[i] = int(res[i] * ans % mod); |
86 | | - } |
87 | | - fact = fact * prods[z].back()[j] % mod; |
88 | | - } |
89 | | - } |
90 | | - cp_algo::checkpoint("mul ans"); |
91 | | - } |
92 | | - }; |
93 | | - uint32_t b2x33 = (1ULL << 33) % mod; |
94 | | - process(limit_reg, reg_args_per_block, b2x32, identity{}); |
95 | | - process(limit_odd, odd_args_per_block, b2x33, [](uint32_t x) {return 2 * x + 1;}); |
96 | | - for(auto [i, x]: res | views::enumerate) { |
97 | | - if (args[i] >= mod / 2) { |
98 | | - x = int(math::bpow(x, mod - 2, 1ULL, prod_mod)); |
99 | | - } |
100 | | - } |
101 | | - cp_algo::checkpoint("inv ans"); |
102 | | - return res; |
103 | | -} |
| 10 | +using base = cp_algo::math::modint<998244353>; |
104 | 11 |
|
105 | 12 | void solve() { |
106 | 13 | int n; |
107 | 14 | cin >> n; |
108 | | - vector<int> args(n); |
| 15 | + vector<base> args(n); |
109 | 16 | for(auto &x : args) {cin >> x;} |
110 | 17 | cp_algo::checkpoint("read"); |
111 | 18 | auto res = facts(args); |
|
0 commit comments