AGC038-C LCMs 解説

公式解説が分かりづらかったので

問題

リンクはこちら

概要

$$ \sum _ {i=0} ^ {n-1} \sum _ {j=i+1} ^ n \mathrm{lcm}(A _ i, A _ j) $$

を求めなさい。

解法

 以下、計算は素数mod上で行うため、何かで割る操作は逆元を掛けているものだと思ってください。

 LCMは扱いづらいので、GCDを考える問題に変換します。

$$ \mathrm{lcm}(x, y)=\frac{xy}{\mathrm{gcd}(x, y)} $$

を利用すると、GCDに注目することで求める式は以下のように変形できます。

$$ \begin{aligned} \sum _ {i=0} ^ {n-1} \sum _ {j=i+1} ^ n \mathrm{lcm}(A _ i, A _ j) =& \sum _ {i=0} ^ {n-1} \sum _ {j=i+1} ^ n \frac{A _ i A _ j}{\mathrm{gcd}(A _ i, A _ j)} \\ \\ =& \sum _ {v=1} ^ {C} \frac{\sum _ {i < j, \mathrm{gcd}(A _ i, A _ j) = v} A _ i A _ j }{v} \\ &( C = \mathrm{max}(A) \leq 10 ^ 6 ) \end{aligned} $$

 \(f[v] := \sum _ {i < j, \mathrm{gcd}(A _ i, A _ j) = v} A _ i A _ j\) とおくと、全ての \(v\) について \(f[v]\) が求められれば良いです。 しかしこれを直接求めるには、すべてのペア \(A _ i,\, A _ j\) のGCDを求めて分類していく必要があり、結局愚直に計算するのと同じだけの時間がかかってしまいます。

 そこで、厄介なGCDの条件を少し緩めてみます。

$$ \mathrm{gcd}(A _ i,\, A _ j) = v \Rightarrow v | A _ i \land v | A _ j \ (a | b \Leftrightarrow aはbの約数) $$

であるので、 \(g[v] := \sum _ {i < j, v | A _ i \land v | A _ j} A _ i A _ j\) とおいて、問題を

  • \(g\) から \(f\) を求めることができるか
  • \(g\) を求めることができるか

の二点に分解して考えてみます。

\(g\) から \(f\) を求めることができるか

 あるペア \(x,\, y\) に注目します。 \(\mathrm{gcd}(x, y) = v\) とすると、 \(f\) では \(f[v]\) にのみ \(xy\) が加算されますが、 \(g\) では \(u | v\) であるすべての \(u\) について \(g[u]\) に \(xy\) が加算されています。 ここで、いったん逆に \(f\) から \(g\) を計算する方法を考えると、これは約数集合におけるゼータ変換をしていることが分かります。(約数集合におけるゼータ変換についてはこちらのブログに書かれています。)

\(f\) から \(g\) を求める操作は、まさに上記のブログ冒頭に書かれている操作です。

 よって、 \(g\) から \(f\) を求める際は、逆変換であるメビウス変換を行えばよいです。

計算量は \(O(C\log C)\) です。

# メビウス変換 in-placeなのでgがそのままfになる
for i in range(C-1, 0, -1):
    for j in range(i*2, C, i):
        g[i] = (g[i] - g[j]) % mod

\(g\)を求めることができるか

 \(g[v] := \sum _ {i < j, v | A _ i \land v | A _ j} A _ i A _ j\) は二つの要素 \(A _ i,\, A _ j\) を考える必要があり面倒なので、できれば一つの要素を考えるだけで済ませたいです。 そこで、新しく \(h[v] := \sum _ {v | A _ i} A _ i\) を考えます。これは比較的簡単に計算できるので、できればここから \(g[v]\) を求めたいです。

 実は、因数分解を考えると、 \(l[v] := \sum _ {v | A _ i} A _ i ^ 2\) を用いて

$$ g[v] = \frac{h[v] ^ 2 - l[v]}{2} $$

と計算できます。導出は面倒なので書きません。ごめんね

よって、各 \(A _ i\) について、全ての約数 \(v\) に対し

$$ \begin{aligned} h[v] &\leftarrow h[v] + A _ i \\ l[v] &\leftarrow l[v] + A _ i ^ 2 \end{aligned} $$

を計算することで、 \(g\) を求めることができます。

計算量は、各 \(A _ i\) の約数の列挙が \(O(\sqrt C)\) なので、合わせて \(O(N\sqrt C + C)\) です。

より高速に \(h,\, l\) を求める

 制約が \(N \leq 2 \times 10 ^ 5, C \leq 10 ^ 6\) なので、上の計算量だとちょっときついかもしれません。 \(h,\, l\) の定義を見ると、これはゼータ変換によって簡単に計算できる形になっています。そこで、各 \(A _ i\) に対し

$$ \begin{aligned} h[A _ i] &\leftarrow h[A _ i] + A _ i \\ l[A _ i] &\leftarrow l[A _ i] + A _ i ^ 2 \end{aligned} $$

を行った後、 \(h,\, l\) をそれぞれゼータ変換することで、より高速に求めることができます。 \(g\) を計算するパートを含めても、計算量は \(O(N + ClogC)\) です。

h = [0] * C
l = [0] * C
for x in A:
    h[x] = (h[x] + x) % mod
    l[x] = (l[x] + x**2) % mod
for i in range(1, C):
    for j in range(i*2, C, i):
        h[i] = (h[i] + h[j]) % mod
        l[i] = (l[i] + l[j]) % mod

提出コード(Pypy3)

めちゃくちゃTLEしてしまいました……

さらなる高速化

  \(g\) を求める際、 \(l\) を引き2で割ることで補正を行いましたが、この補正を行わず、最後にまとめて行うことを考えます。 \(g\) の計算式を

$$ g[v] = h[v] ^ 2 $$

に置き換えて最後まで計算すると、最終的に出てくる値は、各ペアにつき2回ずつ加算され、さらに同じ要素同士のペアも加味されていることになります。

 まず、同じ要素同士のペア \(A _ i,\, A _ i\) について、\(\mathrm{lcm}(A _ i, A _ i)=A _ i\) です。よって、 \(\sum _ {i = 0} ^ N A _ i\) を引くことで、これらの寄与分を取り除くことができます。これで、 \(l\) を計算する必要はなくなりました。 残りの値は各ペアにつき2回加算されているので、単純に2で割ればOKです。これでいくらか定数倍が軽くなりました。

 また、逆元を逐一計算すると合計で \(O(C\log mod)\) かかってしまうため、あらかじめまとめて \(O(C)\) で計算しています。 計算量は全部合わせて \(O(N + C\log C)\) です。

import sys
readline = sys.stdin.buffer.readline
 
def make_modinv_list(n, mod=10**9+7):
    inv_list = [0]*(n+1)
    inv_list[1] = 1
    for i in range(2, n+1):
        inv_list[i] = (mod - mod//i * inv_list[mod%i] % mod)
    return inv_list
 
mod = 998244353
N = int(readline())
A = list(map(int, readline().split()))
C = max(A) + 1
inv_list = make_modinv_list(C, mod)
h = [0] * C
for x in A:
    h[x] = (h[x] + x) % mod
for i in range(1, C):
    for j in range(i*2, C, i):
        h[i] = (h[i] + h[j]) % mod
g = [0] * C
for i in range(1, C):
    g[i] = h[i] * h[i] % mod
for i in range(C-1, 0, -1):
    for j in range(i*2, C, i):
        g[i] = (g[i] - g[j]) % mod
ans = 0
for i in range(1, C):
    ans = (ans + g[i] * inv_list[i]) % mod
ans -= sum(A)
print(ans * inv_list[2] % mod)

提出コード(Pypy3)

無事ACです。