传送门

题目描述

i=1nj=1m(nmodi)×(mmodj),ij\sum_{i=1}^{n} \sum_{j=1}^{m} (n \bmod i) \times (m \bmod j), i \neq j

mod19940417\bmod 19940417 的值

输入格式

输入只有一行两个整数 nnmm

输出格式

答案 mod19940417\bmod 19940417

输入输出样例

输入 #1

1
3 4

输出 #1

1
1

输入 #2

1
123456 654321

输出 #2

1
116430

说明/提示

数据规模与约定

  • 对于 10%10\% 的数据,保证 n,m103n,m \leq 10^3
  • 对于 30%30\% 的数据,保证 n,m106n,m \leq 10^6
  • 另有 30%30\% 的数据,保证 n100n \leq 100
  • 对于 100%100\% 的数据,保证 1n,m1091 \leq n,m \leq 10^9

题解

首先看到柿子总是想到转换。

i=1nj=1m(nmodi)×(mmodj),ij\sum_{i=1}^{n} \sum_{j=1}^{m} (n \bmod i) \times (m \bmod j), i \neq j

=i=1nj=1m(nnii)×(mmjj),ij= \sum_{i=1}^{n} \sum_{j=1}^{m} (n - \lfloor \frac{n}{i} \rfloor * i) \times (m - \lfloor \frac{m}{j} \rfloor * j), i \neq j

=n2m2i=1n(m2nii)j=1m(n2mjj)+i=1nj=1m(niimjj),ij= n^2m^2 - \sum_{i=1}^{n} (m^2 \lfloor \frac{n}{i} \rfloor * i) - \sum_{j=1}^{m} (n^2 \lfloor \frac{m}{j} \rfloor * j) + \sum_{i=1}^{n} \sum_{j=1}^{m} (\lfloor \frac{n}{i} \rfloor * i * \lfloor \frac{m}{j} \rfloor * j) , i \ne j

=n2m2m2i=1n(nii)n2j=1m(mjj)+i=1n(nii)×j=1m(mjj)i=1min(n,m)(nimii2)= n^2m^2 - m^2 \sum_{i=1}^{n} (\lfloor \frac{n}{i} \rfloor * i) - n^2 \sum_{j=1}^{m} (\lfloor \frac{m}{j} \rfloor * j) + \sum_{i=1}^{n} (\lfloor \frac{n}{i} \rfloor * i) \times \sum_{j=1}^{m} (\lfloor \frac{m}{j} \rfloor * j) - \sum_{i=1}^{min(n,m)} (\lfloor \frac{n}{i} \rfloor \lfloor \frac{m}{i} \rfloor i^2)

a=i=1n(nii),b=j=1m(mjj)a = \sum_{i=1}^{n} (\lfloor \frac{n}{i} \rfloor * i), b = \sum_{j=1}^{m} (\lfloor \frac{m}{j} \rfloor * j)

则原式

=n2m2m2an2b+abi=1min(n,m)(nimii2)= n^2 m^2 - m^2 a - n^2 b + ab - \sum_{i=1}^{min(n,m)} (\lfloor \frac{n}{i} \rfloor \lfloor \frac{m}{i} \rfloor i^2)

运用数论分块可以在O(n)O(\sqrt{n})的时间内求出aabb

加上平方和公式(i=1ni2=n(n+1)(2n+1)6\sum_{i=1}^{n} i^2 = \frac{n(n+1)(2n+1)}{6})可以求出i=1min(n,m)(nimii2)\sum_{i=1}^{min(n,m)} (\lfloor \frac{n}{i} \rfloor \lfloor \frac{m}{i} \rfloor i^2)

没了……

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#include <cstdio>
#define MOD 19940417
#define min(x,y)((x)<(y)?(x):(y))
using namespace std;
long long inv6 = 3323403, ans, n, m;
long long get(long long to, long long y)
{
long long l, r, sum = 0;
for (l = 1; l <= to; l = r + 1) r = min(y / (y / l), to),
sum = (sum + (l + r) * (r - l + 1) / 2 % MOD * (y / l) % MOD) % MOD;
return sum;
}
int main()
{
scanf("%lld%lld", &n, &m);
long long a = get(n, n), b = get(m, m), t = min(n, m);
ans = (n * n % MOD * m % MOD * m % MOD + a * b % MOD) % MOD;
ans = (ans - m * m % MOD * a % MOD - n * n % MOD * b % MOD) % MOD;
ans = (ans + n * get(t, m) % MOD - n * m % MOD * t % MOD) % MOD;
ans = (ans + m * get(t, n) % MOD) % MOD;
for (long long l = 1, r; l <= t; l = r + 1) {
r = min(n / (n / l), m / (m / l));
long long t1 = r * (r + 1) % MOD * (2 * r + 1) % MOD * inv6 % MOD;
long long t2 = l * (l - 1) % MOD * (2 * l - 1) % MOD * inv6 % MOD;
ans = (ans - (n / l) * (m / l) % MOD * (t1 - t2) % MOD) % MOD;
}
printf("%lld", (ans + MOD) % MOD);
}