#include<iostream> #include<cstdio> #include<cstring> #define int long long
usingnamespace std;
constint N = 5e5 + 10; constint mod = 998244353, G = 3; int n, k; int f[N]; int tr[N]; int invG;
intqpow(int a, int b){ int res = 1; while (b) { if (b & 1) res = res * a % mod; a = a * a % mod; b >>= 1; } return res; }
intadd(int a, int b){ return a + b >= mod ? a + b - mod : a + b; }
intsub(int a, int b){ return a - b < 0 ? a - b + mod : a - b; }
voiddft(int *f, int n, int op){ for (int i = 0; i < n; ++i) { tr[i] = (tr[i >> 1] >> 1) | ((i & 1) ? (n >> 1) : 0); } for (int i = 0; i < n; ++i) { if (i < tr[i]) swap(f[i], f[tr[i]]); } for (int p = 2; p <= n; p <<= 1) { int len = p / 2; int tG = qpow(op ? G : invG, (mod - 1) / p); for (int l = 0; l < n; l += p) { int buf = 1; for (int k = l; k < l + len; ++k) { int tmp = buf * f[k + len] % mod; f[k + len] = sub(f[k], tmp); f[k] = add(f[k], tmp); buf = buf * tG % mod; } } } int invn = qpow(n, mod - 2); if (!op) { for (int i = 0; i < n; ++i) f[i] = f[i] * invn % mod; } }
int res[N]; int fac[N], ifac[N], inv[N]; int a[N], b[N]; int cnt1, cnt2;
intC(int n, int m){ if (n < m) return0; return fac[n] * ifac[m] % mod * ifac[n - m] % mod; } int g[N];
signedmain(){ scanf("%lld%lld", &n, &k); inv[1] = fac[0] = fac[1] = ifac[0] = ifac[1] = 1; for (int i = 2; i <= n; ++i) { fac[i] = fac[i - 1] * i % mod; inv[i] = (mod - mod / i) * inv[mod % i] % mod; ifac[i] = ifac[i - 1] * inv[i] % mod; } for (int i = 0; i <= (n - 1) / k; ++i) { a[i] = C((n - 1) / k + 1 - i, i); } for (int i = 0; i < (n - 1) / k; ++i) { b[i] = C((n - 1) / k - i, i); } cnt1 = (n - 1) % k + 1; cnt2 = k - cnt1; invG = qpow(G, mod - 2); int m = n; for (n = 1; n <= m + 1; n <<= 1); dft(a, n, 1), dft(b, n, 1); for (int i = 0; i <= n; ++i) { g[i] = qpow(a[i], cnt1 * 2) * qpow(b[i], cnt2 * 2) % mod; } dft(b, n, 0); dft(g, n, 0); int ans = 0; for (int i = 0; i <= m; ++i) { if (i & 1) ans = sub(ans, g[i] * fac[m - i]) % mod; else ans = add(ans, g[i] * fac[m - i]) % mod; } printf("%lld", ans); return0; }