book
归档: Haskell 
flag
mode_edit

我也是会写平衡树的人了

二叉查找树(BST)

二叉查找树具有很多很妙的性质:

  1. 对于每一个节点,满足左子树的所有值小于该节点的值小于右子树的所有节点的值

  2. 由 1. ,若我们将一棵二叉查找树进行中序遍历,那么可以得到一个有序序列

  3. 由 1. 2.,考虑如何将有序序列重建为一棵二叉查找树:我们应当首先在区间内找到一个节点作为根节点,然后对于该节点左右两侧的序列递归执行相同的建树操作。考虑「找到一个节点的」过程,实际上选取方式足足有 $n$ 种,而且他们产生的树一定不同,因此:对于同一序列,存在多种合法的二叉查找树。

  4. 由 3. ,考虑二叉查找树的形态,若每次选择区间中点,那么最后会得到一棵相对平衡的树;若每次选择区间左端点,则会得到一条长链。

  5. 二叉查找树可以快速确定元素 $x$ 是否在树内。考虑二叉查找树的查找过程:首先将 $x$ 与树根比较,如果 $x< \mathit{root}$,那么 $x$ 一定在左子树内;同样的,若 $x=root$,那么 $x$ 就是根节点;对于 $x>\mathit{root}$,那么 $x$ 在右子树内。递归在左右子树内应用这个过程,即可完成查找。

  6. 考虑 5. 中过程的时间复杂度,为 $O(\mathit{height})$,由 4. 若树为一条长链,那么时间复杂度为 $O(n)$;若树为一棵平衡的树,那么时间复杂度为 $O(\log n)$。

  7. 考虑二叉查找树的插入过程,若要将 $x$ 插入树中,首先和插入操作一样,比较 $x$ 和根的大小,如果 $x<\mathit{root}$ 那么将左子树替换为插入后的树,反之替换右子树。而边界条件为:向空树中插入一个值,得到一个节点。

  8. 对于 7. 中的插入过程,我们注意到树的形态是不受我们控制的,如果我们插入 $1,2,3,4,5,6\cdots n$,那么树就会退化成一条链。因此,我们需要一种能够维持树的平衡性的数据结构,这就是平衡树。

Treap

如果我们在二叉搜索树的每个节点上额外存储一个值,使得对于这个值,整棵二叉搜索树还满足堆性质:每个节点的关键值大于左右儿子的关键值。然后在新建节点时随机生成一个关键值,就可以使得整棵树期望平衡。

带有旋转操作的数据结构都难以可持久化,既然要用 Haskell 写,就只能写那些可以高效地可持久化的东西——无旋 Treap。

data Treap a = Node Int a Int (Treap a) (Treap a) | Nil deriving Show -- Size Value Key lc rc

getSize :: Treap a -> Int
getSize Nil = 0
getSize (Node sz _ _ _ _) = sz

无旋 Treap 的操作基于分裂和合并。分裂操作可以依据 $v$ 把整棵树分裂为左右两部分:左半边所有节点的值 $\leq v$,右半边所有节点的值 $>v$。而合并是分裂的逆过程,可以将两个值域不重合的 Treap 合并为一个大的 Treap。

分裂

考虑怎样才能以 $\mathit{value}$ 为界把一个 Treap 分裂为两块呢?

不妨设分裂后我们得到了 $lc$ 与 $rc$ 。首先拿着 $\mathit{value}$ 与根节点的 $v$ 进行比较, 如果 $v\leq \mathit{value}$ ,那么根节点连同左子树都应该是属于 $lc$ 的。这样,我们需要继续分裂右子树,右子树可以被分裂为两块:$rlc$ 与 $rrc$ 。由图可知,我们可以把 $rlc$ 当成 $lc$ 的右子树,$rrc$ 自成一派,成为 $rc$。

img

同理可以这样处理 $v>\mathit{value}$ 的情况,代码见下:

split :: (Num a, Ord a) => Treap a -> a -> (Treap a, Treap a)
split Nil _ = (Nil, Nil)
split (Node _ v key lc rc) value
  | v <= value = (Node (getSize l1 + getSize lc + 1) v key lc l1, r1)
  | otherwise = (l2, Node (getSize r2 + getSize rc + 1) v key r2 rc)
    where
      (l1, r1) = split rc value
      (l2, r2) = split lc value

合并

考虑该如何合并两个值域不重叠的 Treap 呢?

似乎拿着其中一个去和另一个的一棵子树合并即可。不妨设我们要合并 $lc$ 与 $rc$,且 $\max{lc} < \min{rc}$ 。那么我们可以拿着 $lc$ 的右子树与 $rc$ 合并,也可以拿着 $rc$ 的左子树和 $lc$ 合并,都满足「值域不重叠」 的性质。那么我们究竟该拿哪一个呢?这时候就要拿出堆性质了,这里我们实现大根堆,比较两个根节点的 $key$ 值,较小的一个就只能和另一个的子树合并了。

infixl 5 ><

(><) :: (Num a, Ord a) => Treap a -> Treap a -> Treap a
Nil >< u = u
u >< Nil = u
u@(Node _ v1 k1 llc lrc) >< v@(Node _ v2 k2 rlc rrc)
  | k1 <= k2 = Node (getSize rrc + getSize l + 1) v2 k2 l rrc
  | otherwise = Node (getSize llc + getSize r + 1) v1 k1 llc r
  where
    l = u >< rlc
    r = lrc >< v

基本操作

有了分裂和合并,我们的生活就改善了很多。如果要插入一个值 $v$ ,可以先把树以 $v$ 为界分裂成两部分 $lc, rc$,然后再按照 $lc,v,rc$的顺序合并起来。

insert :: (Num a, Ord a) => Treap a -> a -> Int -> Treap a
insert tr v num = l >< t >< r
  where
    (l, r) = split tr v
    t = Node 1 v num Nil Nil

删除稍微复杂一点,考虑到普通平衡树这个题里面,一个值可以出现多次,删除单个值 $v$ 是比较复杂的。我们可以这样操作:

  1. 首先分裂出来一棵树,这棵树里包含了原树里的所有 $v$ 。
  2. 将这棵树的左右子树合并,丢掉根节点。
  3. 将这棵树与之前的分裂出来的片段合并,得到结果

第一步操作比较复杂,可以这样进行:首先按照 $v$ 把整棵树分裂为 $(lc,rc)$,然后再以 $v-1$ 为界分裂 $lc$ 得到 $(_,lrc)$,这样,$lrc$ 就是我们想要的树了。代码见下:

erase :: (Num a, Ord a) => Treap a -> a -> Treap a
erase Nil _ = error "Cannot erase a empty Treap"
erase tr v = ll >< ntr >< r
  where
    (l, r) = split tr v
    (ll, Node _ _ _ rlc rrc) = split l (v - 1)
    ntr = rlc >< rrc

求 $v$ 的排名自然也不复杂,我们只需要以 $v-1$ 为界,得到 $\leq v-1$ 的值的数目即可。

求第 $k$ 大的操作也不复杂,只需要在比较左子树的 $size$ 的同时不断递归就可以了。

以上两个操作的代码见下:

rank :: (Num a, Ord a) => Treap a -> a -> Int
rank Nil _ = error "Cannot rank a empty Treap"
rank tr v = getSize l + 1
  where
    (l, _) = split tr (v - 1)

kth :: (Num a, Ord a) => Treap a -> Int -> a
kth Nil _ = error "Cannot kth a empty Treap"
kth (Node _ v _ lc rc) k
  | lsz + 1 == k = v
  | lsz < k = kth rc (k - lsz - 1)
  | otherwise = kth lc k
  where
    lsz = getSize lc

由于元素可重复,求前驱和后继的代码稍微有一点复杂,实际上是写成循环比较好。思路是这样的,在求前驱的时候,如果遇到一个比 $v$ 更小的元素,就先记下来,再向右边走,看看能不能碰到更大、但是小于 $v$ 的元素。求后继同理。这里放 C++ 的代码。

int prev(node* tr, int value) {
  int res;
  while (tr) {
    if (tr->value < value)
      res = tr->value, tr = tr->rc;
    else
      tr = tr->lc;
  }
  return res;
}

int succ(node* tr, int value) {
  int res;
  while (tr) {
    if (tr->value > value)
      res = tr->value, tr = tr->lc;
    else
      tr = tr->rc;
  }
  return res;
}

代码(普通平衡树)

Haskell

{-# OPTIONS_GHC -O2 #-}
-- {-# LANGUAGE Strict #-}
module Main where
import Prelude hiding (succ)
import Data.Char (digitToInt, isSpace)
import Text.Printf (printf)
import qualified Data.Text as T
import qualified Data.Text.IO as I


data Treap a = Node Int a Int (Treap a) (Treap a) | Nil deriving Show -- Size Value Key lc rc

getSize :: Treap a -> Int
getSize Nil = 0
getSize (Node sz _ _ _ _) = sz

split :: (Num a, Ord a) => Treap a -> a -> (Treap a, Treap a)
split Nil _ = (Nil, Nil)
split (Node _ v key lc rc) value
  | v <= value = (Node (getSize l1 + getSize lc + 1) v key lc l1, r1)
  | otherwise = (l2, Node (getSize r2 + getSize rc + 1) v key r2 rc)
    where
      (l1, r1) = split rc value
      (l2, r2) = split lc value

infixl 5 ><

(><) :: (Num a, Ord a) => Treap a -> Treap a -> Treap a
Nil >< u = u
u >< Nil = u
u@(Node _ v1 k1 llc lrc) >< v@(Node _ v2 k2 rlc rrc)
  | k1 <= k2 = Node (getSize rrc + getSize l + 1) v2 k2 l rrc
  | otherwise = Node (getSize llc + getSize r + 1) v1 k1 llc r
  where
    l = u >< rlc
    r = lrc >< v


insert :: (Num a, Ord a) => Treap a -> a -> Int -> Treap a
insert tr v num = l >< t >< r
  where
    (l, r) = split tr v
    t = Node 1 v num Nil Nil

erase :: (Num a, Ord a) => Treap a -> a -> Treap a
erase Nil _ = error "Cannot erase a empty Treap"
erase tr v = ll >< ntr >< r
  where
    (l, r) = split tr v
    (ll, Node _ _ _ rlc rrc) = split l (v - 1)
    ntr = rlc >< rrc

rank :: (Num a, Ord a) => Treap a -> a -> Int
rank Nil _ = error "Cannot rank a empty Treap"
rank tr v = getSize l + 1
  where
    (l, _) = split tr (v - 1)

kth :: (Num a, Ord a) => Treap a -> Int -> a
kth Nil _ = error "Cannot kth a empty Treap"
kth (Node _ v _ lc rc) k
  | lsz + 1 == k = v
  | lsz < k = kth rc (k - lsz - 1)
  | otherwise = kth lc k
  where
    lsz = getSize lc

prev' :: Ord a => Treap a -> a -> a -> a
prev' Nil _ res = res
prev' (Node _ v _ lc rc) vl res
  | v < vl = prev' rc vl v
  | otherwise = prev' lc vl res

prev :: (Num a, Ord a) => Treap a -> a -> a
prev tr v = prev' tr v 0

succ' :: Ord a => Treap a -> a -> a -> a
succ' Nil _ res = res
succ' (Node _ v _ lc rc) vl res
  | v > vl = succ' lc vl v
  | otherwise = succ' rc vl res

succ :: (Num a, Ord a) => Treap a -> a -> a
succ tr v = succ' tr v 0

int :: String -> Int
int str = int' (filter (not . isSpace) str) 0
  where
    int' [] x = x
    int' ('-':xs) _ = -1 * (int' xs 0)
    int' (x:xs) p = int' xs $ p * 10 + digitToInt x

repM :: Monad m => Int -> a -> (a -> m a) -> m a
repM 0 x _ = return x
repM n x f = f x >>= \y -> repM (n - 1) y f

repM_ :: Monad m => Int -> a -> (a -> m a) -> m ()
repM_ n x f = repM n x f >> return ()

main :: IO ()
main = do
  n <- int <$> T.unpack <$> I.getLine
  repM_ n (Nil, 0) $ \(root, seed) -> do
    [op, num] <- map (int . T.unpack) . T.words <$> I.getLine
    let res = (1919 * seed * seed + 19260817 * seed + 2333) `mod` 1000000007
    case op of 1 -> return $ (insert root num res, res)
               2 -> return $ (erase root num, res)
               3 -> printf "%d\n" (rank root num) >> return (root, res)
               4 -> printf "%d\n" (kth root num) >> return (root, res)
               5 -> printf "%d\n" (prev root num) >> return (root, res)
               6 -> printf "%d\n" (succ root num) >> return (root, res)
               _ -> error "???"
  return ()

C++

#include <bits/stdc++.h>

using namespace std;

namespace mgt {

struct node {
  int value, key, size; // key is for heap
  node *lc, *rc;
  node(int value) {
    this->value = value;
    lc = rc = 0;
    size    = 1;
    key     = rand();
  }
};

using pnns = pair<node*, node*>;

const int maxn = (int)1e6 + 10;

node* newNode(int value) { return new node(value); }

void updateSize(node* tr) {
  if (tr)
    tr->size = 1;
  else
    return;
  if (tr->lc)
    tr->size += tr->lc->size;
  if (tr->rc)
    tr->size += tr->rc->size;
}

int getSize(node* tr) { return tr == 0 ? 0 : tr->size; }

pnns split(node* tr, int value) { // get 2 trees: forall node in fst, node.value
                                  // <= value, vice versa for snd
  if (tr == nullptr)
    return pnns(nullptr, nullptr);
  if (tr->value <= value) {
    auto x = split(tr->rc, value);
    tr->rc = x.first;
    updateSize(tr);
    return pnns(tr, x.second);
  } else {
    auto x = split(tr->lc, value);
    tr->lc = x.second;
    updateSize(tr);
    return pnns(x.first, tr);
  }
}

node* merge(node* u, node* v) { // make sure max(u.value) <= min(v.value)
  if (!(u && v))
    return u == 0 ? v : u;
  if (u->key <= v->key) {
    v->lc = merge(u, v->lc);
    updateSize(v);
    updateSize(u);
    return v;
  } else {
    u->rc = merge(u->rc, v);
    updateSize(u);
    updateSize(v);
    return u;
  }
}

node* insert(node* tr, int value) {
  auto x = split(tr, value);
  auto t = newNode(value);
  return merge(x.first, merge(t, x.second));
}

node* erase(node* tr, int value) {
  auto x = split(tr, value);
  auto y = split(x.first, value - 1);
  auto z = merge(y.second->lc, y.second->rc);
  return merge(y.first, merge(z, x.second));
}

int getRank(node* tr, int value) {
  auto tmp = split(tr, value - 1);
  int res  = getSize(tmp.first) + 1;
  merge(tmp.first, tmp.second);
  return res;
}

int find(node* tr, int rank) {
  if (getSize(tr->lc) + 1 == rank)
    return tr->value;
  if (getSize(tr->lc) < rank)
    return find(tr->rc, rank - getSize(tr->lc) - 1);
  return find(tr->lc, rank);
}

int prev(node* tr, int value) {
  int res;
  while (tr) {
    if (tr->value < value)
      res = tr->value, tr = tr->rc;
    else
      tr = tr->lc;
  }
  return res;
}

int succ(node* tr, int value) {
  int res;
  while (tr) {
    if (tr->value > value)
      res = tr->value, tr = tr->lc;
    else
      tr = tr->rc;
  }
  return res;
}

template <class T> inline T gn() {
  register int k = 0, f = 1;
  register char c = getchar();
  for (; !isdigit(c); c = getchar())
    if (c == '-')
      f = -1;
  for (; isdigit(c); c = getchar())
    k = k * 10 + c - '0';
  return k * f;
}
} // namespace mgt

using mgt::gn;

int main() {
  int n           = gn<int>();
  mgt::node* root = nullptr;
  for (int i = 1; i <= n; ++i) {
    int op = gn<int>();
    if (op == 1)
      root = mgt::insert(root, gn<int>());
    else if (op == 2)
      root = mgt::erase(root, gn<int>());
    else if (op == 3)
      printf("%d\n", mgt::getRank(root, gn<int>()));
    else if (op == 4)
      printf("%d\n", mgt::find(root, gn<int>()));
    else if (op == 5)
      printf("%d\n", mgt::prev(root, gn<int>()));
    else if (op == 6)
      printf("%d\n", mgt::succ(root, gn<int>()));
  }
}