diff --git a/math/factorial/checker.cpp b/math/factorial/checker.cpp new file mode 100644 index 000000000..6a66d5330 --- /dev/null +++ b/math/factorial/checker.cpp @@ -0,0 +1,62 @@ +// https://github.com/MikeMirzayanov/testlib/blob/master/checkers/wcmp.cpp + +// The MIT License (MIT) + +// Copyright (c) 2015 Mike Mirzayanov + +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: + +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. + +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "testlib.h" + +using namespace std; + +int main(int argc, char * argv[]) +{ + setName("compare sequences of tokens"); + registerTestlibCmd(argc, argv); + + int n = 0; + string j, p; + + while (!ans.seekEof() && !ouf.seekEof()) + { + n++; + + ans.readWordTo(j); + ouf.readWordTo(p); + + if (j != p) + quitf(_wa, "%d%s words differ - expected: '%s', found: '%s'", n, englishEnding(n).c_str(), compress(j).c_str(), compress(p).c_str()); + } + + if (ans.seekEof() && ouf.seekEof()) + { + if (n == 1) + quitf(_ok, "\"%s\"", compress(j).c_str()); + else + quitf(_ok, "%d tokens", n); + } + else + { + if (ans.seekEof()) + quitf(_wa, "Participant output contains extra tokens"); + else + quitf(_wa, "Unexpected EOF in the participants output"); + } +} diff --git a/math/factorial/gen/example_00.in b/math/factorial/gen/example_00.in new file mode 100644 index 000000000..a0b2e9a25 --- /dev/null +++ b/math/factorial/gen/example_00.in @@ -0,0 +1,6 @@ +5 +0 +5 +100 +1234567 +998244352 diff --git a/math/factorial/gen/near_half.cpp b/math/factorial/gen/near_half.cpp new file mode 100644 index 000000000..b9318b5a2 --- /dev/null +++ b/math/factorial/gen/near_half.cpp @@ -0,0 +1,16 @@ +#include +#include "../params.h" +#include "random.h" + +int main(int, char* argv[]) { + long long seed = atoll(argv[1]); + auto gen = Random(seed); + + int T = T_MAX; + printf("%d\n", T); + for (int t = 0; t < T; ++t) { + int N = gen.uniform(0.45 * MOD, 0.55 * MOD); + printf("%d\n", N); + } + return 0; +} diff --git a/math/factorial/gen/near_max.cpp b/math/factorial/gen/near_max.cpp new file mode 100644 index 000000000..60232f8cb --- /dev/null +++ b/math/factorial/gen/near_max.cpp @@ -0,0 +1,16 @@ +#include +#include "../params.h" +#include "random.h" + +int main(int, char* argv[]) { + long long seed = atoll(argv[1]); + auto gen = Random(seed); + + int T = T_MAX; + printf("%d\n", T); + for (int t = 0; t < T; ++t) { + int N = gen.uniform(0.95 * MOD, MOD - 1); + printf("%d\n", N); + } + return 0; +} diff --git a/math/factorial/gen/random.cpp b/math/factorial/gen/random.cpp new file mode 100644 index 000000000..efd579d31 --- /dev/null +++ b/math/factorial/gen/random.cpp @@ -0,0 +1,16 @@ +#include +#include "../params.h" +#include "random.h" + +int main(int, char* argv[]) { + long long seed = atoll(argv[1]); + auto gen = Random(seed); + + int T = T_MAX; + printf("%d\n", T); + for (int t = 0; t < T; ++t) { + int N = gen.uniform(0, MOD - 1); + printf("%d\n", N); + } + return 0; +} diff --git a/math/factorial/hash.json b/math/factorial/hash.json new file mode 100644 index 000000000..cb82cf563 --- /dev/null +++ b/math/factorial/hash.json @@ -0,0 +1,16 @@ +{ + "example_00.in": "180d1bbdde5539dd3a77d325f8940297325665f9f1be249c051dad6aa177200a", + "example_00.out": "c73d2af1573bc3649136764834711acaf148eb3588cc639d8bf5a57da9bdd867", + "near_half_00.in": "b9b509e96007f2595f9a1845c56544b107a20794261c0a3913056ebc1c472d48", + "near_half_00.out": "09c498ab37ca56ec5d94ca104802a0a8fbc78539941f994880463e874cb5e266", + "near_half_01.in": "78c59f15afe18136e693628e08aa54d462218ac825d29667c75f5d0e6044babc", + "near_half_01.out": "0d58b0459853ec01a08a006ee4f2288c2304264e5551b3a0d65b93dfe111620b", + "near_max_00.in": "da530850d58cc2e3e79f6104dc0af91b76a0ec7d61fc5120b4487f2a94312d6f", + "near_max_00.out": "06f27a007c8c14ad2cb868844bd5a5e94346dc144ee51b6446a1b684a104902f", + "near_max_01.in": "2d4f31d5ddec6a701ed69d204aaa8c49439066f4d675fd99774795757358dc3c", + "near_max_01.out": "3655b1fb1215b40a9c762042096ef89af95e7aa94da5ca63b7daf72263c023a1", + "random_00.in": "ef2dc732889096480a438e542b5dc79018083e8771651ef8ffa37b0fd1f6c246", + "random_00.out": "65ac0924d4ca6ff146997bc802a42aec719d48cc8556127fc2f5d514efaef047", + "random_01.in": "0842027f68881d70d3376224a68b92f9e19f13840ef9d573acdba3c85d1fc1d0", + "random_01.out": "bd08bbc7d8dd2d279da20ce0ba42d99198728e7b3faccc03ab24c4d5576066f1" +} \ No newline at end of file diff --git a/math/factorial/info.toml b/math/factorial/info.toml new file mode 100644 index 000000000..656f20d2d --- /dev/null +++ b/math/factorial/info.toml @@ -0,0 +1,26 @@ +title = 'Factorial' +timelimit = 10.0 +forum = "https://github.com/yosupo06/library-checker-problems/issues/128" + +[[tests]] + name = "example.in" + number = 1 +[[tests]] + name = "random.cpp" + number = 2 +[[tests]] + name = "near_max.cpp" + number = 2 +[[tests]] + name = "near_half.cpp" + number = 2 + +[[solutions]] + name = "naive.cpp" + wrong = false + allow_tle = true + + +[params] + T_MAX = 5 + MOD = 998244353 diff --git a/math/factorial/sol/correct.cpp b/math/factorial/sol/correct.cpp new file mode 100644 index 000000000..b11b2df0f --- /dev/null +++ b/math/factorial/sol/correct.cpp @@ -0,0 +1,647 @@ +#include +#include +#include +#include + +using namespace std; + +using ll = long long; +using u32 = unsigned int; +using u64 = unsigned long long; +using u128 = unsigned __int128; + +using pi = pair; +using vi = vector; +template +using vc = vector; +template +using vvc = vector>; +template +using vvvc = vector>; +#define vv(type, name, h, ...) \ + vector> name(h, vector(__VA_ARGS__)) +#define vvv(type, name, h, w, ...) \ + vector>> name( \ + h, vector>(w, vector(__VA_ARGS__))) + +#define FOR1(a) for (ll _ = 0; _ < ll(a); ++_) +#define FOR2(i, a) for (ll i = 0; i < ll(a); ++i) +#define FOR3(i, a, b) for (ll i = a; i < ll(b); ++i) +#define FOR4(i, a, b, c) for (ll i = a; i < ll(b); i += (c)) +#define FOR1_R(a) for (ll i = (a)-1; i >= ll(0); --i) +#define FOR2_R(i, a) for (ll i = (a)-1; i >= ll(0); --i) +#define FOR3_R(i, a, b) for (ll i = (b)-1; i >= ll(a); --i) +#define overload4(a, b, c, d, e, ...) e +#define overload3(a, b, c, d, ...) d +#define FOR(...) overload4(__VA_ARGS__, FOR4, FOR3, FOR2, FOR1)(__VA_ARGS__) +#define FOR_R(...) overload3(__VA_ARGS__, FOR3_R, FOR2_R, FOR1_R)(__VA_ARGS__) + +#define all(x) x.begin(), x.end() +#define len(x) ll(x.size()) +#define elif else if + +#define eb emplace_back +#define mp make_pair +#define mt make_tuple +#define fi first +#define se second + +// (0, 1, 2, 3, 4) -> (-1, 0, 1, 1, 2) +int topbit(int x) { return (x == 0 ? -1 : 31 - __builtin_clz(x)); } +int topbit(u32 x) { return (x == 0 ? -1 : 31 - __builtin_clz(x)); } +int topbit(ll x) { return (x == 0 ? -1 : 63 - __builtin_clzll(x)); } +int topbit(u64 x) { return (x == 0 ? -1 : 63 - __builtin_clzll(x)); } + +template +mint inv(int n) { + static const int mod = mint::get_mod(); + static vector dat = {0, 1}; + assert(0 <= n); + if (n >= mod) n %= mod; + while (len(dat) <= n) { + int k = len(dat); + int q = (mod + k - 1) / k; + dat.eb(dat[k * q - mod] * mint::raw(q)); + } + return dat[n]; +} + +template +mint fact(int n) { + static const int mod = mint::get_mod(); + assert(0 <= n && n < mod); + static vector dat = {1, 1}; + while (len(dat) <= n) dat.eb(dat[len(dat) - 1] * mint::raw(len(dat))); + return dat[n]; +} + +template +mint fact_inv(int n) { + static vector dat = {1, 1}; + if (n < 0) return mint(0); + while (len(dat) <= n) dat.eb(dat[len(dat) - 1] * inv(len(dat))); + return dat[n]; +} + +template +mint fact_invs(Ts... xs) { + return (mint(1) * ... * fact_inv(xs)); +} + +template +struct modint { + static constexpr u32 umod = u32(mod); + static_assert(umod < u32(1) << 31); + u32 val; + + static modint raw(u32 v) { + modint x; + x.val = v; + return x; + } + constexpr modint() : val(0) {} + constexpr modint(u32 x) : val(x % umod) {} + constexpr modint(u64 x) : val(x % umod) {} + constexpr modint(int x) : val((x %= mod) < 0 ? x + mod : x){}; + constexpr modint(ll x) : val((x %= mod) < 0 ? x + mod : x){}; + bool operator<(const modint &other) const { return val < other.val; } + modint &operator+=(const modint &p) { + if ((val += p.val) >= umod) val -= umod; + return *this; + } + modint &operator-=(const modint &p) { + if ((val += umod - p.val) >= umod) val -= umod; + return *this; + } + modint &operator*=(const modint &p) { + val = u64(val) * p.val % umod; + return *this; + } + modint &operator/=(const modint &p) { + *this *= p.inverse(); + return *this; + } + modint operator-() const { return modint::raw(val ? mod - val : u32(0)); } + modint operator+(const modint &p) const { return modint(*this) += p; } + modint operator-(const modint &p) const { return modint(*this) -= p; } + modint operator*(const modint &p) const { return modint(*this) *= p; } + modint operator/(const modint &p) const { return modint(*this) /= p; } + bool operator==(const modint &p) const { return val == p.val; } + bool operator!=(const modint &p) const { return val != p.val; } + modint inverse() const { + int a = val, b = mod, u = 1, v = 0, t; + while (b > 0) { + t = a / b; + swap(a -= t * b, b), swap(u -= t * v, v); + } + return modint(u); + } + modint pow(ll n) const { + assert(n >= 0); + modint ret(1), mul(val); + while (n > 0) { + if (n & 1) ret *= mul; + mul *= mul; + n >>= 1; + } + return ret; + } + static constexpr int get_mod() { return mod; } + // (n, r), r は 1 の 2^n 乗根 + static constexpr pair ntt_info() { + if (mod == 167772161) return {25, 17}; + if (mod == 469762049) return {26, 30}; + if (mod == 754974721) return {24, 362}; + if (mod == 880803841) return {23, 211}; + if (mod == 943718401) return {22, 663003469}; + if (mod == 998244353) return {23, 31}; + if (mod == 1045430273) return {20, 363}; + if (mod == 1051721729) return {20, 330}; + if (mod == 1053818881) return {20, 2789}; + return {-1, -1}; + } + static constexpr bool can_ntt() { return ntt_info().fi != -1; } +}; + +using modint998 = modint<998244353>; + +template +vc> matrix_mul(const vc> &A, const vc> &B, int N1 = -1, + int N2 = -1, int N3 = -1) { + if (N1 == -1) { N1 = len(A), N2 = len(B), N3 = len(B[0]); } + vv(u32, b, N3, N2); + FOR(i, N2) FOR(j, N3) b[j][i] = B[i][j].val; + vv(T, C, N1, N3); + + if ((T::get_mod() < (1 << 30)) && N2 <= 16) { + FOR(i, N1) FOR(j, N3) { + u64 sm = 0; + FOR(m, N2) sm += u64(A[i][m].val) * b[j][m]; + C[i][j] = sm; + } + } else { + FOR(i, N1) FOR(j, N3) { + u128 sm = 0; + FOR(m, N2) sm += u64(A[i][m].val) * b[j][m]; + C[i][j] = T::raw(sm % (T::get_mod())); + } + } + return C; +} + +template +struct Monoid_Mul { + using value_type = T; + using X = T; + static constexpr X op(const X &x, const X &y) noexcept { return x * y; } + static constexpr X inverse(const X &x) noexcept { return X(1) / x; } + static constexpr X unit() { return X(1); } + static constexpr bool commute = true; +}; + +template +struct Sliding_Window_Aggregation { + using X = typename Monoid::value_type; + using value_type = X; + int sz = 0; + vc dat; + vc cum_l; + X cum_r; + + Sliding_Window_Aggregation() + : cum_l({Monoid::unit()}), cum_r(Monoid::unit()) {} + + int size() { return sz; } + + void push(X x) { + ++sz; + cum_r = Monoid::op(cum_r, x); + dat.eb(x); + } + + void pop() { + --sz; + cum_l.pop_back(); + if (len(cum_l) == 0) { + cum_l = {Monoid::unit()}; + cum_r = Monoid::unit(); + while (len(dat) > 1) { + cum_l.eb(Monoid::op(dat.back(), cum_l.back())); + dat.pop_back(); + } + dat.pop_back(); + } + } + + X lprod() { return cum_l.back(); } + X rprod() { return cum_r; } + + X prod() { return Monoid::op(cum_l.back(), cum_r); } +}; + +// long でも大丈夫 +// (val * x - 1) が mod の倍数になるようにする +// 特に mod=0 なら x=0 が満たす +ll mod_inv(ll val, ll mod) { + assert(mod >= 0); + if (mod == 0) return 0; + val %= mod; + if (val < 0) val += mod; + ll a = val, b = mod, u = 1, v = 0, t; + while (b > 0) { + t = a / b; + swap(a -= t * b, b), swap(u -= t * v, v); + } + if (u < 0) u += mod; + return u; +} +#line 1 "library/mod/crt3.hpp" + +constexpr u32 mod_pow_constexpr(u64 a, u64 n, u32 mod) { + a %= mod; + u64 res = 1; + FOR(32) { + if (n & 1) res = res * a % mod; + a = a * a % mod, n /= 2; + } + return res; +} + +template +T CRT3(u64 a0, u64 a1, u64 a2) { + static_assert(p0 < p1 && p1 < p2); + static constexpr u64 x0_1 = mod_pow_constexpr(p0, p1 - 2, p1); + static constexpr u64 x01_2 = mod_pow_constexpr(u64(p0) * p1 % p2, p2 - 2, p2); + u64 c = (a1 - a0 + p1) * x0_1 % p1; + u64 a = a0 + c * p0; + c = (a2 - a % p2 + p2) * x01_2 % p2; + return T(a) + T(c) * T(p0) * T(p1); +} +#line 2 "library/poly/convolution_naive.hpp" + +template +vc convolution_naive(const vc &a, const vc &b) { + int n = int(a.size()), m = int(b.size()); + if (n > m) return convolution_naive(b, a); + if (n == 0) return {}; + vc ans(n + m - 1); + if (n <= 16 && (T::get_mod() < (1 << 30))) { + for (int k = 0; k < n + m - 1; ++k) { + int s = max(0, k - m + 1); + int t = min(n, k + 1); + u64 sm = 0; + for (int i = s; i < t; ++i) { sm += u64(a[i].val) * (b[k - i].val); } + ans[k] = sm; + } + } else { + for (int k = 0; k < n + m - 1; ++k) { + int s = max(0, k - m + 1); + int t = min(n, k + 1); + u128 sm = 0; + for (int i = s; i < t; ++i) { sm += u64(a[i].val) * (b[k - i].val); } + ans[k] = T::raw(sm % T::get_mod()); + } + } + return ans; +} + +template +vc convolution_karatsuba(const vc &f, const vc &g) { + const int thresh = 30; + if (min(len(f), len(g)) <= thresh) return convolution_naive(f, g); + int n = max(len(f), len(g)); + int m = (n + 1) / 2; + vc f1, f2, g1, g2; + if (len(f) < m) f1 = f; + if (len(f) >= m) f1 = {f.begin(), f.begin() + m}; + if (len(f) >= m) f2 = {f.begin() + m, f.end()}; + if (len(g) < m) g1 = g; + if (len(g) >= m) g1 = {g.begin(), g.begin() + m}; + if (len(g) >= m) g2 = {g.begin() + m, g.end()}; + vc a = convolution_karatsuba(f1, g1); + vc b = convolution_karatsuba(f2, g2); + FOR(i, len(f2)) f1[i] += f2[i]; + FOR(i, len(g2)) g1[i] += g2[i]; + vc c = convolution_karatsuba(f1, g1); + vc F(len(f) + len(g) - 1); + assert(2 * m + len(b) <= len(F)); + FOR(i, len(a)) F[i] += a[i], c[i] -= a[i]; + FOR(i, len(b)) F[2 * m + i] += b[i], c[i] -= b[i]; + if (c.back() == T(0)) c.pop_back(); + FOR(i, len(c)) if (c[i] != T(0)) F[m + i] += c[i]; + return F; +} + +template +void ntt(vector &a, bool inverse) { + assert(mint::can_ntt()); + const int rank2 = mint::ntt_info().fi; + const int mod = mint::get_mod(); + static array root, iroot; + static array rate2, irate2; + static array rate3, irate3; + + static bool prepared = 0; + if (!prepared) { + prepared = 1; + root[rank2] = mint::ntt_info().se; + iroot[rank2] = mint(1) / root[rank2]; + FOR_R(i, rank2) { + root[i] = root[i + 1] * root[i + 1]; + iroot[i] = iroot[i + 1] * iroot[i + 1]; + } + mint prod = 1, iprod = 1; + for (int i = 0; i <= rank2 - 2; i++) { + rate2[i] = root[i + 2] * prod; + irate2[i] = iroot[i + 2] * iprod; + prod *= iroot[i + 2]; + iprod *= root[i + 2]; + } + prod = 1, iprod = 1; + for (int i = 0; i <= rank2 - 3; i++) { + rate3[i] = root[i + 3] * prod; + irate3[i] = iroot[i + 3] * iprod; + prod *= iroot[i + 3]; + iprod *= root[i + 3]; + } + } + + int n = int(a.size()); + int h = topbit(n); + assert(n == 1 << h); + if (!inverse) { + int len = 0; + while (len < h) { + if (h - len == 1) { + int p = 1 << (h - len - 1); + mint rot = 1; + FOR(s, 1 << len) { + int offset = s << (h - len); + FOR(i, p) { + auto l = a[i + offset]; + auto r = a[i + offset + p] * rot; + a[i + offset] = l + r; + a[i + offset + p] = l - r; + } + rot *= rate2[topbit(~s & -~s)]; + } + len++; + } else { + int p = 1 << (h - len - 2); + mint rot = 1, imag = root[2]; + for (int s = 0; s < (1 << len); s++) { + mint rot2 = rot * rot; + mint rot3 = rot2 * rot; + int offset = s << (h - len); + for (int i = 0; i < p; i++) { + u64 mod2 = u64(mod) * mod; + u64 a0 = a[i + offset].val; + u64 a1 = u64(a[i + offset + p].val) * rot.val; + u64 a2 = u64(a[i + offset + 2 * p].val) * rot2.val; + u64 a3 = u64(a[i + offset + 3 * p].val) * rot3.val; + u64 a1na3imag = (a1 + mod2 - a3) % mod * imag.val; + u64 na2 = mod2 - a2; + a[i + offset] = a0 + a2 + a1 + a3; + a[i + offset + 1 * p] = a0 + a2 + (2 * mod2 - (a1 + a3)); + a[i + offset + 2 * p] = a0 + na2 + a1na3imag; + a[i + offset + 3 * p] = a0 + na2 + (mod2 - a1na3imag); + } + rot *= rate3[topbit(~s & -~s)]; + } + len += 2; + } + } + } else { + mint coef = mint(1) / mint(len(a)); + FOR(i, len(a)) a[i] *= coef; + int len = h; + while (len) { + if (len == 1) { + int p = 1 << (h - len); + mint irot = 1; + FOR(s, 1 << (len - 1)) { + int offset = s << (h - len + 1); + FOR(i, p) { + u64 l = a[i + offset].val; + u64 r = a[i + offset + p].val; + a[i + offset] = l + r; + a[i + offset + p] = (mod + l - r) * irot.val; + } + irot *= irate2[topbit(~s & -~s)]; + } + len--; + } else { + int p = 1 << (h - len); + mint irot = 1, iimag = iroot[2]; + FOR(s, (1 << (len - 2))) { + mint irot2 = irot * irot; + mint irot3 = irot2 * irot; + int offset = s << (h - len + 2); + for (int i = 0; i < p; i++) { + u64 a0 = a[i + offset + 0 * p].val; + u64 a1 = a[i + offset + 1 * p].val; + u64 a2 = a[i + offset + 2 * p].val; + u64 a3 = a[i + offset + 3 * p].val; + u64 x = (mod + a2 - a3) * iimag.val % mod; + a[i + offset] = a0 + a1 + a2 + a3; + a[i + offset + 1 * p] = (a0 + mod - a1 + x) * irot.val; + a[i + offset + 2 * p] = (a0 + a1 + 2 * mod - a2 - a3) * irot2.val; + a[i + offset + 3 * p] = (a0 + 2 * mod - a1 - x) * irot3.val; + } + irot *= irate3[topbit(~s & -~s)]; + } + len -= 2; + } + } + } +} + +template +vector convolution_ntt(vector a, vector b) { + if (a.empty() || b.empty()) return {}; + int n = int(a.size()), m = int(b.size()); + int sz = 1; + while (sz < n + m - 1) sz *= 2; + + // sz = 2^k のときの高速化。分割統治的なやつで損しまくるので。 + if ((n + m - 3) <= sz / 2) { + auto a_last = a.back(), b_last = b.back(); + a.pop_back(), b.pop_back(); + auto c = convolution(a, b); + c.resize(n + m - 1); + c[n + m - 2] = a_last * b_last; + FOR(i, len(a)) c[i + len(b)] += a[i] * b_last; + FOR(i, len(b)) c[i + len(a)] += b[i] * a_last; + return c; + } + + a.resize(sz), b.resize(sz); + bool same = a == b; + ntt(a, 0); + if (same) { + b = a; + } else { + ntt(b, 0); + } + FOR(i, sz) a[i] *= b[i]; + ntt(a, 1); + a.resize(n + m - 1); + return a; +} + +template +vc convolution(const vc &a, const vc &b) { + int n = len(a), m = len(b); + if (!n || !m) return {}; + if (min(n, m) <= 50) return convolution_karatsuba(a, b); + return convolution_ntt(a, b); +} + +// Input: f(0), ..., f(n-1) and c. Return: f(c) +template +T lagrange_interpolate_iota(vc &f, T c) { + int n = len(f); + if (int(c.val) < n) return f[c.val]; + auto a = f; + FOR(i, n) { + a[i] = a[i] * fact_inv(i) * fact_inv(n - 1 - i); + if ((n - 1 - i) & 1) a[i] = -a[i]; + } + vc lp(n + 1), rp(n + 1); + lp[0] = rp[n] = 1; + FOR(i, n) lp[i + 1] = lp[i] * (c - i); + FOR_R(i, n) rp[i] = rp[i + 1] * (c - i); + T ANS = 0; + FOR(i, n) ANS += a[i] * lp[i] * rp[i + 1]; + return ANS; +} + +// Input: f(0), ..., f(n-1) and c, m +// Return: f(c), f(c+1), ..., f(c+m-1) +// Complexity: M(n, n + m) +template +vc lagrange_interpolate_iota(vc &f, mint c, int m) { + if (m <= 60) { + vc ANS(m); + FOR(i, m) ANS[i] = lagrange_interpolate_iota(f, c + mint(i)); + return ANS; + } + ll n = len(f); + auto a = f; + FOR(i, n) { + a[i] = a[i] * fact_inv(i) * fact_inv(n - 1 - i); + if ((n - 1 - i) & 1) a[i] = -a[i]; + } + // x = c, c+1, ... に対して a0/x + a1/(x-1) + ... を求めておく + vc b(n + m - 1); + FOR(i, n + m - 1) b[i] = mint(1) / (c + mint(i - n + 1)); + a = convolution(a, b); + + Sliding_Window_Aggregation> swag; + vc ANS(m); + ll L = 0, R = 0; + FOR(i, m) { + while (L < i) { swag.pop(), ++L; } + while (R - L < n) { swag.push(c + mint((R++) - n + 1)); } + auto coef = swag.prod(); + if (coef == 0) { + ANS[i] = f[(c + i).val]; + } else { + ANS[i] = a[i + n - 1] * coef; + } + } + return ANS; +} + +template +vc> prefix_product_of_poly_matrix(vc>> &A, ll k) { + int n = len(A); + + using MAT = vc>; + auto shift = [&](vc &G, T x) -> vc { + int d = len(G); + vvv(T, H, d, n, n); + FOR(i, n) FOR(j, n) { + vc g(d); + FOR(l, d) g[l] = G[l][i][j]; + auto h = lagrange_interpolate_iota(g, x, d); + FOR(l, d) H[l][i][j] = h[l]; + } + return H; + }; + + auto evaluate = [&](vc &f, T x) -> T { + T res = 0; + T p = 1; + FOR(i, len(f)) { + res += f[i] * p; + p *= x; + } + return res; + }; + + ll deg = 1; + FOR(i, n) FOR(j, n) deg = max(deg, len(A[i][j]) - 1); + + vc G(deg + 1); + ll v = 1; + while (deg * v * v < k) v *= 2; + T iv = T(1) / T(v); + + FOR(i, len(G)) { + T x = T(v) * T(i); + vv(T, mat, n, n); + FOR(j, n) FOR(k, n) mat[j][k] = evaluate(A[j][k], x); + G[i] = mat; + } + + for (ll w = 1; w != v; w *= 2) { + T W = w; + auto G1 = shift(G, W * iv); + auto G2 = shift(G, (W * T(deg) * T(v) + T(v)) * iv); + auto G3 = shift(G, (W * T(deg) * T(v) + T(v) + W) * iv); + FOR(i, w * deg + 1) { + G[i] = matrix_mul(G1[i], G[i]); + G2[i] = matrix_mul(G3[i], G2[i]); + } + copy(G2.begin(), G2.end() - 1, back_inserter(G)); + } + + vv(T, res, n, n); + FOR(i, n) res[i][i] = 1; + ll i = 0; + while (i + v <= k) res = matrix_mul(G[i / v], res), i += v; + while (i < k) { + vv(T, mat, n, n); + FOR(j, n) FOR(k, n) mat[j][k] = evaluate(A[j][k], i); + res = matrix_mul(mat, res); + ++i; + } + return res; +} + +template +T prefix_product_of_poly(vc &f, ll k) { + vc>> A(1); + A[0].resize(1); + A[0][0] = f; + auto res = prefix_product_of_poly_matrix(A, k); + return res[0][0]; +} + +using mint = modint998; + +void solve() { + int N; + scanf("%d", &N); + vc f = {1, 1}; + mint b = prefix_product_of_poly(f, N); + printf("%d\n", b.val); +} + +signed main() { + int T; + scanf("%d", &T); + for (int t = 0; t < T; ++t) solve(); + return 0; +} diff --git a/math/factorial/sol/naive.cpp b/math/factorial/sol/naive.cpp new file mode 100644 index 000000000..8cee4078a --- /dev/null +++ b/math/factorial/sol/naive.cpp @@ -0,0 +1,35 @@ +#include + +using namespace std; + +constexpr int mod = 998244353; + +void solve() { + int N; + scanf("%d", &N); + if (N < mod / 2) { + long long x = 1; + for (int n = 1; n <= N; ++n) x = x * n % mod; + printf("%lld\n", x); + } else { + long long x = mod - 1; // (mod-1)! + long long y = 1; + for (int n = N + 1; n <= mod - 1; ++n) { y = y * n % mod; } + // x / y + // xy^{mod-2} + long long n = mod - 2; + while (n) { + if (n & 1) x = x * y % mod; + y = y * y % mod; + n /= 2; + } + printf("%lld\n", x); + } +} + +signed main() { + int T; + scanf("%d", &T); + for (int t = 0; t < T; ++t) solve(); + return 0; +} diff --git a/math/factorial/task.md b/math/factorial/task.md new file mode 100644 index 000000000..cbfdaa7ab --- /dev/null +++ b/math/factorial/task.md @@ -0,0 +1,41 @@ +## @{keyword.statement} + +@{lang.en} + +This problem has $T$ cases. + +Given a non-negative integer $N$. Print $N!$ modulo $@{param.MOD}$. + +@{lang.ja} + +この問題は $T$ ケースあります. +非負整数 $N$ が与えられます.$N!$ を $@{param.MOD}$ で割った余りを出力してください. + +@{lang.end} + +## @{keyword.constraints} + +- $1\leq T\leq 5$ +- $0 \leq N < @{param.MOD}$ + +## @{keyword.input} + +``` +$T$ +$N_1$ +$\vdots$ +$N_T$ +``` + +## @{keyword.output} +@{lang.en} +Print $N!$ modulo $@{param.MOD}$. + +@{lang.ja} + +$N!$ を $@{param.MOD}$ で割った余りを出力してください. +@{lang.end} + +## @{keyword.sample} + +@{example.example_00} diff --git a/math/factorial/verifier.cpp b/math/factorial/verifier.cpp new file mode 100644 index 000000000..cbd5d79bf --- /dev/null +++ b/math/factorial/verifier.cpp @@ -0,0 +1,15 @@ +#include +#include "params.h" +#include "testlib.h" + +int main() { + registerValidation(); + int T = inf.readInt(1, T_MAX); + inf.readChar('\n'); + for (int t = 0; t < T; ++t) { + inf.readLong(0, MOD - 1); + inf.readChar('\n'); + } + inf.readEof(); + return 0; +}