[Python] heapqを応用して挿入/削除/最小値取得を効率的に行う

概要

ならし計算量で
・挿入 O(logN)
・削除 O(logN)
・最小値取得 O(logN)
でそれぞれの操作を行えるデータ構造を書きました。
(ちゃんとした平衡二分木ではありません🙇‍♂️ )

説明、計算量、実装に間違いがある可能性が多大にあります。
verifyできる問題や指摘をお持ちの方がいれば @sentya7 までお知らせください。

あらすじ

先日ABC170にて
E - Smart Infants
という問題が出ました。

そもそも発想の工夫やバグらせない実装が必要な骨のある問題だったのですが、解法の主要な部分に値の挿入、削除、最小値取得をO(logN程度で行えるデータ構造を用いるというパートがあります。

そしてこれは解説においてたびたび「例えばC++におけるmultisetなどを用いることで高速に処理できます」 という表現で表されます。
しかし、pythonにはこのようなデータ構造は標準装備されておらず、定期的に思い出しては泣きを見る羽目になります。

もうこんな悔しさを味わいたくない... !
この一心から今回ライブラリとして整備することにしました。

実装例及びE問題の回答も載せてありますので、よければ自由に使ってください。そしてフィードバックがあればください。

この記事を読む前に

そもそも今回の動機を解決するのに一番早いのは、平衡二分探索木を用いることかもしれません。
C++のmultisetも中身で用いられているのはこの平衡二分探索木ですから、言語は違えど実装してあげればいいわけです。 よく用いられるものとしてはAVL木や赤黒木などの歴史あるデータ構造が存在し、解説及び実装例が豊富なため自分でつくる、もしくは借りて貼るという手があります。

平衡二分木ではなく今回のデータ構造を使う利点があるとすれば、実装が軽いため構造の理解やデバッグがしやすいなどがあると思います。豊富が機能が必要でなければ簡単なデータ構造で代用する、という意味でSegment TreeとBinary Indexed Treeの使い分けみたいなものでしょうか。
(ただしこの使い分けはBITが定数倍で早いというのを強い目的としている場合もあります。今回のものが平衡二分木より軽いかは、不明です...)

また以下の記事も大変参考になります(というより、参考にしています )。
【Python】平衡二分木が必要な時に代わりに何とかするテク【競プロ】 - Qiita

本題

満たしたい条件のうち、挿入及び最小値取得までであればheapqを用いてO(logN)で処理することができます。
あと必要なのは任意の要素の削除で、最小値をpop以外の削除操作が続くと辛くなります。

そこで解決策ですが、削除の要求が来ても要素を消さないことにします。いいんでしょうかそんなことして?

もちろん全く操作をしないわけではなく、削除された要素に対してこの要素がすでに削除されているというフラグを外で管理しておきます。

この操作のメリットは、実際のヒープ木を全く変更しなくていい、という点にあります。
heapqにおいては、最小値以外の任意の要素はすぐに見つからない為、配列を線形操作して要素を削除したのちヒープ木を整え直す、という作業が必要になります(要出典)。そのため最悪でO(N)程度の時間が必要になります。

一方で今回の手法であれば木を変形する必要がなく、外部テーブルへのメモもハッシュマップ(pythonであればdefaultdictなど)を用いることでO(1)で行うことができます。

f:id:socha77:20200617061314p:plain
注: 右下では木構造は変化していません。灰色の要素もまだ削除されていません。

確かにひとまず計算量を押さえられたかのように思えますが、本当にこれで動くのでしょうか?この後の運用方法を追いながら理解していきます。

heapqを用いる一番のモチベーションは最小値の効率的な取得です。逆に言えば、外から見えない内部構造はどうなっててもいいし、最小値の取得時に辻褄さえ合わせられればOK と考えることもできます。

① heapqから出てきた値が「削除された」テーブルにない時 (= 本当に最小値である時)

特に問題ありません。そのまま返します。

②「削除された」テーブルにある時 (= 本当の最小値ではない時)

この場合はこのまま返すわけにはいきません。なのでこの値は捨ててしまいましょう。ここでようやく保留していた削除操作の辻褄が合うということです。
あとはまだ削除されていない要素が現れるまでこれを繰り返すことで、外から見れば正しい最小値が帰ってきたことになります。 最小値のpopに限りO(logN)で処理できるという性質が有効に生かされています。(ヒープ木は「親の値<子の値である」以外は保証されていないため、根を覗いてインデックス番号から全体のk番目の値であることを特定する、ということはできません。)

f:id:socha77:20200617070249p:plain
注: 内部構造は二分木ですが、便宜上ただのキューとして表示しています

これにて無事に挿入、削除、最小値取得の機能を実装することができました。

計算量の話

ところで、これは本当に無事でしょうか?
特に最小値取得のパートは今回追加で実装したところなので、手放しで安心して使うのが少し怖い気もしてきます。

察しが良かったりコーナーケースを考えるのが好きな人であれば、以下のようなパターンが頭に浮かぶかもしれません。

f:id:socha77:20200617080914p:plain
全部popしないといけないパターン
ヒープ木の要素を実際には削除していないため、追加や削除のみを(最小値更新をしないまま)繰り返すと残骸が積み重なっていくことになります。 残骸を1つpopするのにO(logN)ですから、最悪パターンでは最小値取得全体としてO(NlogN)かかってしまいます。

そこでここでは、ならし計算量(償却計算量)を考えます。 計算量については理解が甘いため勘違いなどありましたら申し訳ありません。 計算量について、償却/期待/平均など - noshi91のメモ なども参考にしていただければと思います。

ならし計算量とは、大雑把にいうと「1回の操作ではなく、複数回の操作を前提にしてその平均で評価する」という考え方です。データ構造がそもそも大量のデータや操作を効率よく行うためにある、という立場に立っていると認識しています。

さて、先ほどあげた残骸山積みの最悪パターンですが、このような状態になるにはこれ以前におよそN回の削除操作(クエリ)が行われていることがわかります。削除操作を思い出すと、ハッシュマップに値をセットするだけでしたからO(1)、そのツケを今払っているということになります。
N回程度の削除操作のツケがO(NlogN)程度なので、削除操作の計算量は平均してO(logN程度、と考えることができます。
これで、挿入、削除、最小値取得をそれぞれO(logN)で行えるようになりました。

注意点として、heapの中に本来は削除されていて欲しい要素が残っているためNが予想より定数倍程度大きくなること、また挿入、削除クエリの数QNと比べて大きい時に計算量がNではなくてQに依存する可能性があることに気をつけてください。

実装

さて以上の説明を踏まえて、組み込みであるheapqに削除済みフラグを管理するテーブルを合わせて一つのクラスとしていきます。
「本体を更新せず後で〇〇するという情報を保持->特定の操作が起きた時に一括で更新」という操作がSegment Treeにおける遅延評価と似通っていると感じたので、勝手にこの構造を遅延ヒープと呼ぶことにします。(すでに名前がついていたらすいません)

実装するメソッドは

  • push
  • get
  • pop
  • remove
  • __len__

とします。 使い分けるタイミングがあるわかりませんがgetは最小値の値を参照するだけ、popは最小値を削除しつつ返します。
__len__ は配列が空かどうかで場合わけをしたい時があり、残骸の影響で内部のヒープの要素数と外から見たときの要素数がズレるので気をつけて管理する必要があります。
また、ついでですがdequeのように配列を渡すことでその要素で初期化できるようにしておきます。

それではやっていきましょう。

class LazyHeap():
    def __init__(self, init_arr=[]):
        self.heap = []
        self.lazy = defaultdict(int)
        self.len = 0
        for init_element in init_arr:
            heapq.heappush(self.heap, init_element)
            self.len += 1

    def __len__(self):
        return self.len

後で混乱しないように少しだけちゃんと定義をしておきます。 heapに入っている要素の中で削除フラグが立っているものを残骸、それ以外を実体とでもしておきます。 self.lazyに入るkey, valueは、value = heap内に残っている残骸(key)の数です。実体が出し入れされる時はいじらず、また残骸が無事に吐き出された時はちゃんと減らしてあげます。
またself.lenは実体の総数です。( len(self.heap) と異なることに注意してください・)

def push(self, k):
        heapq.heappush(self.heap, k)
        self.len += 1

そのままです。

    def remove(self, k):
        self.lazy[k] += 1
        self.len -= 1

先にremoveからやっておきます。 実体が1つ残骸に変わっているので、self.lenの変更もしておきましょう。

    def _clear(self):
        while True:
            cand = self.heap[0]
            if cand in self.lazy and self.lazy[cand] > 0:
                heapq.heappop(self.heap)
                self.lazy[cand] -= 1
 
            else:
                return

getするにせよpopするにせよ、先に残骸をできるだけ排除しておく必要があります。 実装状のコツとして、先頭の要素を参照する際にいちいちpopする必要はなく配列の先頭を見るだけで大丈夫です。popすると計算量にO(logN)がつくので注意しましょう。(というかやらかしました。)

   def pop(self):
        self._clear()
        return heapq.heappop(self.heap)
        
    def get(self):
        self._clear()
        return self.heap[0]

残骸を削除しておけば無事に最小値を取得できます。

全体(コピペ用)

class LazyHeap():
    def __init__(self, init_arr=[]):
        self.heap = []
        self.lazy = defaultdict(int)
        self.len = 0
        for init_element in init_arr:
            heapq.heappush(self.heap, init_element)
            self.len += 1
 
    def __len__(self):
        return self.len
 
    def push(self, k):
        heapq.heappush(self.heap, k)
        self.len += 1
 
    def pop(self):
        self._clear()
        return heapq.heappop(self.heap)
        
    def get(self):
        self._clear()
        return self.heap[0]
 
    def _clear(self):
        while True:
            cand = self.heap[0]
            if cand in self.lazy and self.lazy[cand] > 0:
                heapq.heappop(self.heap)
                self.lazy[cand] -= 1
 
            else:
                return
 
    def remove(self, k):
        self.lazy[k] += 1
        self.len -= 1
 

Smart Infants 回答例

def getN():
    return int(input())
def getNM():
    return map(int, input().split())
def getList():
    return list(map(int, input().split()))
def getZList():
    return [int(x) - 1 for x in input().split()]
from collections import defaultdict, deque
from sys import exit
import math
import copy
from bisect import bisect_left, bisect_right
import heapq
import sys
# sys.setrecursionlimit(1000000)
INF = 10 ** 17
MOD = 1000000007
 
class LazyHeap():
    def __init__(self, init_arr=[]):
        self.heap = []
        self.lazy = defaultdict(int)
        self.len = 0
        for init_element in init_arr:
            heapq.heappush(self.heap, init_element)
            self.len += 1
 
    def __len__(self):
        return self.len
 
    def push(self, k):
        heapq.heappush(self.heap, k)
        self.len += 1
 
    def pop(self):
        self._clear()
        return heapq.heappop(self.heap)
        
    def get(self):
        self._clear()
        return self.heap[0]
 
    def _clear(self):
        while True:
            cand = self.heap[0]
            if cand in self.lazy and self.lazy[cand] > 0:
                heapq.heappop(self.heap)
                self.lazy[cand] -= 1
 
            else:
                return
 
    def remove(self, k):
        self.lazy[k] += 1
        self.len -= 1
 
def solve():
    N_KINDER = 200200
    # N_KINDER = 2 * (10 ** 5) + 1
    n, q = getList()
    kinder = [LazyHeap() for i in range(N_KINDER)]
    saikyo = LazyHeap()
    saikyo_ref = []
    pos_ref = [-1 for i in range(n)]
    enji_ref = []
  
    for i in range(n):
        a, b = getList()
        b -= 1
        enji_ref.append((-a, b))
        kinder[b].push(-a)
      
    for kin in kinder:
        if kin:
            candidate = kin.get()
            saikyo.push(-candidate)
            saikyo_ref.append(candidate)
        else:
            saikyo_ref.append(0)
 
    for i in range(q):
        c, d = getZList()
        mv_rate = enji_ref[c][0]
        prev = enji_ref[c][1]
        nxt = d
        pos_ref[c] = d
        enji_ref[c] = (mv_rate, d)
        # 移動元についての処理
        kinder[prev].remove(mv_rate)
        
        kp, kn = kinder[prev], kinder[nxt]
        if not kp:
            # 移動して園児0人になった場合
            saikyo_ref[prev] = 0
            saikyo.remove(-mv_rate)
        elif saikyo_ref[prev] != kp.get():
            saikyo.remove(-mv_rate)
            saikyo.push(-kp.get())
            saikyo_ref[prev] = kp.get()
 
        # 移動先についての処理
        kn.push(mv_rate)
        if mv_rate < saikyo_ref[nxt]:
            saikyo.remove(-saikyo_ref[nxt])
            saikyo.push(-mv_rate)
            saikyo_ref[nxt] = mv_rate
     
        print(saikyo.get())
    # print(rate)
 
def main():
    n = getN()
    for _ in range(n):
        solve()
if __name__ == "__main__":
    solve()

本番ACしたかった...

ちなみにSmart Infantsはpypy3で落ちたコードがpython3で通りました。原因はまだわかっていません...

終わりに

ここまで長いことお読みいただきありがとうございました。 これで問題が平衡二分木想定だということだけで涙を飲むことはもうありません。(と言いつつこれは嘘で、データ構造が必要な問題は計算量改善を必要とするくらいの制限ということなので定数倍などで泣きを見る可能性は存分にあります。Pythonを使い続ける以上覚悟をするしかないですね...)

ここまで書いて今更何をという感じなのですがやっぱり平衡二分木を貼った方がいいのでは...?という気持ちにも若干あります まあ、せっかく作ったし考えを整理したので残しておきたいかなって...

冒頭に述べたことの繰り返しになりますが、思いつきで始めた実装であるため不備や勘違いが大いにありそうです。
読んで/使ってみて違和感や指摘がありましたら、またverify用の問題がありましたら気軽に教えていただけると大変助かります。

それでは、お疲れ様でした。