ABC175-D Moving Piece 解説
\(O(N \log K )\) で解きます。また、本問題の「与えられたグラフが巡回置換に分解できる」という特性をほとんど用いていないので、 \(P\) が順列でない場合でも少し書き換えるだけで解くことができます。
問題
リンクはこちら
解法
ダブリングで前処理をしたあといわゆる桁DPをします。ダブリング、桁DPについてあまり知らないという方はまずはこちらの記事をご覧ください。(TODO:だれかの記事を貼らせてもらう)
以下に解法に至るまでのなんとなくの思考過程を記します。
まず、各頂点から行先は一つしかないので、始点と行動回数を決めてしまえばその後の行動と獲得できる点数は一意に定まります。よって、原理的には
- 現在どのマスにいるか
- 何点取っているか
- あと何回まで移動できるか
という状態を保持すればDPによって答えを求めることができるはずです。しかし、点数はDPテーブルの値として直接持つことにしても、移動回数を添え字として保持すると状態数が爆発してしまいます。
そこで、代わりに移動回数の上位ビットから走査してゆけば「 \(K\) より小さい/同じ/大きい」という三状態に絞ることができます。この部分に桁DPの手技を用いることができます。遷移を書く際には、各頂点から \(2 ^ b\) 回移動した際の行先、得点という情報が必要になるので、これらはあらかじめダブリングによって求めておけばOKです。
ダブリング
特に凝ったことはしません。
- \(vertex[b][i] :=\)頂点 \(i\) から \(2 ^ b\) 回移動したときにどの頂点にいるか
- \(score[b][i] :=\)頂点 \(i\) から \(2 ^ b\) 回移動したときに何点もらえるか
という二つの表を用意します。 \(score[0][i] = C _ {P _ i}\) となるのがやや気持ち悪いので、公式editorial同様に一つずらして考え、 \(score[0][i] = C _ i\) となるように \(C\) を書き換えてからダブリングを行っています。
vertex = list() score = list() vertex.append(P) score.append(C) m = 31 # bit数 for b in range(1, m+1): p_bth = [0] * N c_bth = [0] * N for i in range(n): p_bth[i] = vertex[b-1][vertex[b-1][i]] c_bth[i] = score[b-1][i] + score[b-1][vertex[b-1][i]] vertex.append(p_bth) score.append(c_bth)
計算量は \(O(Nm) = O(N \log K)\) です。
桁DP
\(dp[b][i][j] :=\) 下から \(b\) bit目まで見て、頂点 \(i\) に状態 \(j\) で存在するときのスコアの最大値
というテーブルを、 \(b\) の降順に埋めていきます。ここで状態 \(j\) は、
- \(j = 0 :\) 移動回数が \(K\) より小さくなることが確定している
- \(j = 1 :\) 確定していない(\(b + 1\) bit目まで \(K\) に等しい)
を表すとします。 \(K\) の \(b\) bit目を \(K _ b\) とすると、 \(K _ b\) と \(j\) の値によって以下の3通りの遷移が考えられます。
\(j = 0\)
既に移動回数が \(K\) より小さくなることが確定しているので、移動するかしないか自由に決めることができます。遷移先の状態も必ず \(j = 0\) です。具体的には
$$ \begin{aligned} dp[b-1][vertex[b][i]][0] &\leftarrow dp[b][i][0] + score[b][i] \\ dp[b-1][i][0] &\leftarrow dp[b][i][0] \end{aligned} $$
となります。「 \(\leftarrow\) 」はここでは代入ではなく \(\mathrm{chmax}\) だと思ってください。
\(K _ b = 1,\, j = 1\)
ここで移動した場合、 \(b\) bit目まで \(K\) に等しくなるため、遷移先の状態は \(j = 1\) です。移動しなかった場合はこの時点で \(K\) より小さくなることが確定するので、遷移先は \(j = 0\) になります。
$$ \begin{aligned} dp[b-1][vertex[b][i]][1] &\leftarrow dp[b][i][1] + score[b][i] \\ dp[b-1][i][0] &\leftarrow dp[b][i][1] \end{aligned} $$
\(K _ b = 0,\, j = 1\)
ここで移動すると回数が \(K\) を超えてしまうため、移動せずその場にとどまるしかありません。この場合も \(b\) bit目まで \(K\) に等しくなるため、遷移先の状態は \(j = 1\) です。
$$ dp[b-1][i][1] \leftarrow dp[b][i][1] $$
以上の遷移をコードに落とし込むと以下のようになります。ここでは、 \(j\) の値による場合分けは行わず、直接全パターンを書き下しています。
m = 31 # bit数 for b in range(m, -1, -1): for i in range(N): if (K >> b) & 1: dp[b-1][vertex[b][i]][0] = max(dp[b-1][vertex[b][i]][0], dp[b][i][0] + score[b][i]) dp[b-1][vertex[b][i]][1] = max(dp[b-1][vertex[b][i]][1], dp[b][i][1] + score[b][i]) dp[b-1][i][0] = max(dp[b-1][i][0], dp[b][i][0], dp[b][i][1]) else: dp[b-1][vertex[b][i]][0] = max(dp[b-1][vertex[b][i]][0], dp[b][i][0] + score[b][i]) dp[b-1][i][0] = max(dp[b-1][i][0], dp[b][i][0]) dp[b-1][i][1] = max(dp[b-1][i][1], dp[b][i][1])
ただし、以上のコードをそのまま走らせると終端で \(dp[-1]\) にアクセスしてしまいます。
対策としては、添え字 \(b\) を一つ分ずらして考えるのもよいですが、更新に一行前のデータしか必要がないことを利用して、二本の配列 \(prv,\, nxt\) を用意して順繰りに更新していくように書き換えました。これで空間計算量が削減されるだけではなく、速度も向上することが見込まれます。
m = 31 # bit数 for b in range(m, -1, -1): for i in range(N): if (K >> b) & 1: nxt[vertex[b][i]][0] = max(nxt[vertex[b][i]][0], prv[i][0] + score[b][i]) nxt[vertex[b][i]][1] = max(nxt[vertex[b][i]][1], prv[i][1] + score[b][i]) nxt[i][0] = max(nxt[i][0], prv[i][0], prv[i][1]) else: nxt[vertex[b][i]][0] = max(nxt[vertex[b][i]][0], prv[i][0] + score[b][i]) nxt[i][0] = max(nxt[i][0], prv[i][0]) nxt[i][1] = max(nxt[i][1], prv[i][1]) prv, nxt = nxt, prv
通常このようにDP更新を行う際は毎回 \(nxt\) をからっぽ(初期状態)に戻す、などの操作が必要になりますが、今回 \(nxt[b][i][j]\) に残っている値は「ひとつ前のbitを考えているときに移動しなかった場合」の値です(厳密にはやや異なりますが)。更新操作が \(\mathrm{chmax}\) のみなので、毎回の初期化は必要ないと判断し省きました。
計算量は \(O(Nm) = O(N \log K )\) です。
出力
最終的な答えは \(prv\) の中の最大値ですが、これが \(0\) だった場合、問題の「1回以上移動する」という制約に違反している可能性が考えられます。全マスの得点が負であれば、一度も移動をしないことが一番得となるからです。
ここで、 \(P\) が順列であることから、もし得点が正のマスが一つでもあれば、そのマスの一つ手前から始めて1回移動しゲームをやめることで、少なくとも正の得点を取ることができます。よって、DPの答えが0となるのはすべてのマスの得点が0以下である場合に限られることがわかります。この時、最もスコアの大きなマスに1回移動してゲームを終了するのが一番マシです。つまり、答えは \(\mathrm{min}(C)\) です。
以上をまとめるとこのようなコードになります。
N, K = map(int, input().split()) P = list(map(int, input().split())) _c = list(map(int, input().split())) # Cを書き換える C = [0] * N for i in range(N): P[i] -= 1 C[i] = _c[P[i]] m = 31 # bit数 # ダブリング vertex = list() score = list() vertex.append(P) score.append(C) for b in range(1, m+1): p_bth = [0] * N c_bth = [0] * N for i in range(N): p_bth[i] = vertex[b-1][vertex[b-1][i]] c_bth[i] = score[b-1][i] + score[b-1][vertex[b-1][i]] vertex.append(p_bth) score.append(c_bth) # 桁DP MIN = -(1 << 63) prv = [[MIN, 0] for _ in range(N)] nxt = [[MIN, MIN] for _ in range(N)] for b in range(m, -1, -1): for i in range(N): if (K >> b) & 1: nxt[vertex[b][i]][0] = max(nxt[vertex[b][i]][0], prv[i][0] + score[b][i]) nxt[vertex[b][i]][1] = max(nxt[vertex[b][i]][1], prv[i][1] + score[b][i]) nxt[i][0] = max(nxt[i][0], prv[i][0], prv[i][1]) else: nxt[vertex[b][i]][0] = max(nxt[vertex[b][i]][0], prv[i][0] + score[b][i]) nxt[i][0] = max(nxt[i][0], prv[i][0]) nxt[i][1] = max(nxt[i][1], prv[i][1]) prv, nxt = nxt, prv ans = max(max(x) for x in prv) if ans == 0: ans = max(C) print(ans)