Skip to content

树套树学习笔记

检测到 KaTeX 加载失败,可能会导致文中的数学公式无法正常渲染。

树套树是处理区间问题或二维数点问题的一种常见的数据结构。

其实树套树的原理很简单,就是利用外层树的树高为 O(logn)O(\log n) 和内层树允许动态开点的性质。经过一系列处理可以使得时空复杂度均保持在 O(nlog2n)O(n \log^2 n) 的级别。

但树套树处理问题的局限性在于询问需要可以被分成 logn\log n 段区间分别处理后合并。

线段树套 STL

支持操作

  1. 修改某一位置上的数值;
  2. 查询 xx 在区间内的前驱(前驱定义为小于 xx,且最大的数)。

理论上还可以支持以下操作:

  1. 查询 xx 在区间内的后继(后继定义为大于 xx,且最小的数)。

代码

#include <iostream>
#include <limits>
#include <set>

using std::cin;
using std::cout;
const char endl = '\n';

const int N = 5e4 + 5;

int n, m, a[N];

struct node : std::multiset<int> {
    int l, r;

    node()
        : l(0), r(0) {}

    node(int _l, int _r)
        : l(_l), r(_r) {
        insert(std::numeric_limits<int>::min());
        insert(std::numeric_limits<int>::max());
    }
} tr[N << 2];

void build(int u, int l, int r) {
    tr[u] = node(l, r);

    for (int i = l; i <= r; i++) {
        tr[u].insert(a[i]);
    }

    if (l == r) return;

    int mid = l + r >> 1;

    build(u << 1, l, mid);
    build(u << 1 | 1, mid + 1, r);
}

void modify(int u, int p, int x) {
    tr[u].erase(tr[u].find(a[p]));
    tr[u].insert(x);

    if (tr[u].l == tr[u].r) return;

    int mid = tr[u].l + tr[u].r >> 1;

    if (p <= mid) modify(u << 1, p, x);
    else modify(u << 1 | 1, p, x);
}

int query(int u, int l, int r, int x) {
    if (l <= tr[u].l && tr[u].r <= r) {
        return *--tr[u].lower_bound(x);
    }

    int mid = tr[u].l + tr[u].r >> 1;
    int res = std::numeric_limits<int>::min();

    if (l <= mid) res = std::max(res, query(u << 1, l, r, x));
    if (r > mid) res = std::max(res, query(u << 1 | 1, l, r, x));

    return res;
}

int main() {
    std::ios::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> n >> m;

    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }

    build(1, 1, n);

    while (m--) {
        int op;

        cin >> op;

        if (op == 1) {
            int p, x;

            cin >> p >> x;

            modify(1, p, x);
            a[p] = x;
        } else {  // op == 2
            int l, r, x;

            cin >> l >> r >> x;

            cout << query(1, l, r, x) << endl;
        }
    }

    return 0;
}

线段树套平衡树

支持操作

  1. 查询 xx 在区间内的排名;
  2. 查询区间内排名为 kk 的值;
  3. 修改某一位置上的数值;
  4. 查询 xx 在区间内的前驱(前驱定义为小于 xx,且最大的数);
  5. 查询 xx 在区间内的后继(后继定义为大于 xx,且最小的数)。

代码

这份代码在洛谷上被卡 TLE 了两个测试点,在 LibreOJ 上测试要比 FHQ Treap 版本慢上 1000 多毫秒,还比 FHQ Treap 长不少。

#include <iostream>
#include <limits>

using std::cin;
using std::cout;
const char endl = '\n';

const int N = 5e4 + 5;

class Splay {
  private:
    struct node {
        int value;
        node *lchild, *rchild, *parent, **root;
        size_t size, count;

        node()
            : value(0), lchild(nullptr), rchild(nullptr), parent(nullptr), root(nullptr), size(0), count(0) {}

        node(const int &_value, node *_parent, node **_root)
            : value(_value), lchild(nullptr), rchild(nullptr), parent(_parent), root(_root), size(1), count(1) {}

        ~node() {
            if (lchild != nullptr) delete lchild;
            if (rchild != nullptr) delete rchild;
        }

        node *&child(unsigned int x) {
            return !x ? lchild : rchild;
        }

        unsigned relation() const {
            return this == parent->lchild ? 0 : 1;
        }

        size_t lsize() const {
            return lchild == nullptr ? 0 : lchild->size;
        }

        size_t rsize() const {
            return rchild == nullptr ? 0 : rchild->size;
        }

        void pushup() {
            size = lsize() + count + rsize();
        }

        void rotate() {
            node *old = parent;
            unsigned x = relation();

            if (old->parent != nullptr) {
                old->parent->child(old->relation()) = this;
            }
            parent = old->parent;

            if (child(x ^ 1) != nullptr) {
                child(x ^ 1)->parent = old;
            }
            old->child(x) = child(x ^ 1);

            child(x ^ 1) = old;
            old->parent = this;

            old->pushup();
            pushup();
        }

        void splay(node *target = nullptr) {
            while (parent != target) {
                if (parent->parent == target) {
                    rotate();
                } else if (relation() == parent->relation()) {
                    parent->rotate();
                    rotate();
                } else {
                    rotate();
                    rotate();
                }
            }

            if (target == nullptr) *root = this;
        }

        node *predecessor() {
            node *pred = lchild;

            while (pred->rchild != nullptr) {
                pred = pred->rchild;
            }

            return pred;
        }

        node *successor() {
            node *succ = rchild;

            while (succ->lchild != nullptr) {
                succ = succ->lchild;
            }

            return succ;
        }
    } * root;

    node *_insert(const int &value) {
        node **target = &root, *parent = nullptr;

        while (*target != nullptr && (*target)->value != value) {
            parent = *target;
            parent->size++;

            if (value < parent->value) {
                target = &parent->lchild;
            } else {
                target = &parent->rchild;
            }
        }

        if (*target == nullptr) {
            *target = new node(value, parent, &root);
        } else {
            (*target)->count++;
            (*target)->size++;
        }

        (*target)->splay();

        return root;
    }

    node *find(const int &value) {
        node *node = root;

        while (node != nullptr && value != node->value) {
            if (value < node->value) {
                node = node->lchild;
            } else {
                node = node->rchild;
            }
        }

        if (node != nullptr) {
            node->splay();
        }

        return node;
    }

    void erase(node *u) {
        if (u == nullptr) return;

        if (u->count > 1) {
            u->splay();
            u->count--;
            u->size--;

            return;
        }

        node *pred = u->predecessor(),
             *succ = u->successor();

        pred->splay();
        succ->splay(pred);

        delete succ->lchild;
        succ->lchild = nullptr;

        succ->pushup();
        pred->pushup();
    }

  public:
    Splay()
        : root(nullptr) {
        insert(std::numeric_limits<int>::min());
        insert(std::numeric_limits<int>::max());
    }

    ~Splay() {
        delete root;
    }

    void insert(const int &value) {
        _insert(value);
    }

    void erase(const int &value) {
        erase(find(value));
    }

    unsigned rank(const int &value) {
        node *node = find(value);

        if (node == nullptr) {
            node = _insert(value);
            int res = node->lsize();
            erase(node);

            return res;
        }

        return node->lsize();
    }

    const int &predecessor(const int &value) {
        node *node = find(value);

        if (node == nullptr) {
            node = _insert(value);
            const int &result = node->predecessor()->value;
            erase(node);
            return result;
        }

        return node->predecessor()->value;
    }

    const int &successor(const int &value) {
        node *node = find(value);

        if (node == nullptr) {
            node = _insert(value);
            const int &result = node->successor()->value;
            erase(node);
            return result;
        }

        return node->successor()->value;
    }
};

struct node : Splay {
    int l, r;
    node *lchild, *rchild;

    node()
        : l(0), r(0), lchild(nullptr), rchild(nullptr) {}

    node(const int &_l, const int &_r)
        : l(_l), r(_r), lchild(nullptr), rchild(nullptr) {}

    ~node() {
        if (lchild != nullptr) delete lchild;
        if (rchild != nullptr) delete rchild;
    }
} * root;

int n, m, a[N];

void build(node *&u, int l, int r) {
    u = new node(l, r);

    for (int i = l; i <= r; i++) {
        u->insert(a[i]);
    }

    if (l == r) return;

    int mid = (l + r) >> 1;

    build(u->lchild, l, mid);
    build(u->rchild, mid + 1, r);
}

int query_rank(node *u, int l, int r, int x) {
    if (l <= u->l && u->r <= r) {
        return u->rank(x) - 1;
    }

    int mid = (u->l + u->r) >> 1;
    int res = 0;

    if (l <= mid) res += query_rank(u->lchild, l, r, x);
    if (r > mid) res += query_rank(u->rchild, l, r, x);

    return res;
}

int query_kth(int _l, int _r, int k) {
    int l = -1e8, r = 1e8, res = -1;

    while (l <= r) {
        int mid = (l + r) >> 1;

        if (query_rank(root, _l, _r, mid) + 1 <= k) {
            l = mid + 1;
            res = mid;
        } else {
            r = mid - 1;
        }
    }

    return res;
}

void modify(node *u, int p, int x) {
    u->erase(a[p]);
    u->insert(x);

    if (u->l == u->r) return;

    int mid = (u->l + u->r) >> 1;

    if (p <= mid) modify(u->lchild, p, x);
    else modify(u->rchild, p, x);
}

int query_pre(node *u, int l, int r, int x) {
    if (l <= u->l && u->r <= r) {
        return u->predecessor(x);
    }

    int mid = (u->l + u->r) >> 1;
    int res = std::numeric_limits<int>::min();

    if (l <= mid) res = std::max(res, query_pre(u->lchild, l, r, x));
    if (r > mid) res = std::max(res, query_pre(u->rchild, l, r, x));

    return res;
}

int query_suc(node *u, int l, int r, int x) {
    if (l <= u->l && u->r <= r) {
        return u->successor(x);
    }

    int mid = (u->l + u->r) >> 1;
    int res = std::numeric_limits<int>::max();

    if (l <= mid) res = std::min(res, query_suc(u->lchild, l, r, x));
    if (r > mid) res = std::min(res, query_suc(u->rchild, l, r, x));

    return res;
}

int main() {
    std::ios::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> n >> m;

    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }

    build(root, 1, n);

    while (m--) {
        int op;

        cin >> op;

        if (op == 1) {
            int l, r, x;

            cin >> l >> r >> x;

            cout << query_rank(root, l, r, x) + 1 << endl;
        } else if (op == 2) {
            int l, r, k;

            cin >> l >> r >> k;

            cout << query_kth(l, r, k) << endl;
        } else if (op == 3) {
            int p, x;

            cin >> p >> x;

            modify(root, p, x);
            a[p] = x;
        } else if (op == 4) {
            int l, r, x;

            cin >> l >> r >> x;

            cout << query_pre(root, l, r, x) << endl;
        } else {  // op == 5
            int l, r, x;

            cin >> l >> r >> x;

            cout << query_suc(root, l, r, x) << endl;
        }
    }

    delete root;

    return 0;
}

这份代码在洛谷上开了 O2 优化之后是可以以 2.00s 的运行时间刚好卡过去 P3380 【模板】二逼平衡树(树套树) 的:R81169175

#include <iostream>
#include <chrono>
#include <limits>
#include <random>

using std::cin;
using std::cout;
const char endl = '\n';

const int N = 5e4 + 5;

class Treap {
  private:
    struct node {
        node *lchild, *rchild;
        int size, value, key;

        node()
            : lchild(nullptr), rchild(nullptr), size(0), value(0), key(rand()) {}

        node(int _value)
            : lchild(nullptr), rchild(nullptr), size(1), value(_value), key(rand()) {}

        ~node() {
            delete lchild;
            delete rchild;
        }

        inline void pushup() {
            size = 1;
            if (lchild != nullptr) size += lchild->size;
            if (rchild != nullptr) size += rchild->size;
        }
    } * root;

    inline int getNodeSize(node *node) {
        return node == nullptr ? 0 : node->size;
    }

    std::pair<node *, node *> split(node *p, int k) {
        if (p == nullptr) return std::make_pair(nullptr, nullptr);

        if (k <= getNodeSize(p->lchild)) {
            auto o = split(p->lchild, k);
            p->lchild = o.second;
            p->pushup();
            o.second = p;

            return o;
        }

        auto o = split(p->rchild, k - getNodeSize(p->lchild) - 1);
        p->rchild = o.first;
        p->pushup();
        o.first = p;

        return o;
    }

    std::pair<node *, node *> splitByValue(node *p, int value) {
        if (p == nullptr) return std::make_pair(nullptr, nullptr);

        if (p->value < value) {
            auto o = splitByValue(p->rchild, value);
            p->rchild = o.first;
            p->pushup();
            o.first = p;

            return o;
        }

        auto o = splitByValue(p->lchild, value);
        p->lchild = o.second;
        p->pushup();
        o.second = p;

        return o;
    }

    node *merge(node *x, node *y) {
        if (x == nullptr) return y;
        if (y == nullptr) return x;

        if (x->key > y->key) {
            x->rchild = merge(x->rchild, y);
            x->pushup();
            return x;
        }

        y->lchild = merge(x, y->lchild);
        y->pushup();
        return y;
    }

  public:
    Treap()
        : root(nullptr) {}

    ~Treap() {
        delete root;
    }

    inline void insert(int value) {
        auto o = splitByValue(root, value);
        o.first = merge(o.first, new node(value));
        root = merge(o.first, o.second);
    }

    inline void erase(int value) {
        auto o = splitByValue(root, value);
        auto t = split(o.second, 1);

        if (t.first->value == value) {
            delete t.first;
        }

        root = merge(o.first, t.second);
    }

    inline int rank(int value) {
        auto x = splitByValue(root, value);
        int r = getNodeSize(x.first) + 1;
        root = merge(x.first, x.second);
        return r;
    }

    inline int kth(int k) {
        auto x = split(root, k - 1);
        auto y = split(x.second, 1);
        Treap::node *o = y.first;
        root = merge(x.first, merge(y.first, y.second));
        return o == nullptr ? 0 : o->value;
    }

    inline int pre(int x) {
        int k = rank(x) - 1;
        return k > 0
                 ? kth(k)
                 : std::numeric_limits<int>::min() + 1;
    }

    inline int suc(int x) {
        int k = rank(x + 1);
        return k > getNodeSize(root)
                 ? std::numeric_limits<int>::max()
                 : kth(k);
    }
};

struct node : Treap {
    int l, r;
    node *lchild, *rchild;

    node()
        : l(0), r(0), lchild(nullptr), rchild(nullptr) {}

    node(const int &_l, const int &_r)
        : l(_l), r(_r), lchild(nullptr), rchild(nullptr) {}

    ~node() {
        if (lchild != nullptr) delete lchild;
        if (rchild != nullptr) delete rchild;
    }
} * root;

int n, m, a[N];

void build(node *&u, int l, int r) {
    u = new node(l, r);

    for (int i = l; i <= r; i++) {
        u->insert(a[i]);
    }

    if (l == r) return;

    int mid = (l + r) >> 1;

    build(u->lchild, l, mid);
    build(u->rchild, mid + 1, r);
}

int query_rank(node *u, int l, int r, int x) {
    if (l <= u->l && u->r <= r) {
        return u->rank(x) - 1;
    }

    int mid = (u->l + u->r) >> 1;
    int res = 0;

    if (l <= mid) res += query_rank(u->lchild, l, r, x);
    if (r > mid) res += query_rank(u->rchild, l, r, x);

    return res;
}

int query_kth(int _l, int _r, int k) {
    int l = 0, r = 1e8, res = -1;

    while (l <= r) {
        int mid = (l + r + 1) >> 1;

        if (query_rank(root, _l, _r, mid) + 1 <= k) {
            l = mid + 1;
            res = mid;
        } else {
            r = mid - 1;
        }
    }

    return res;
}

void modify(node *u, int p, int x) {
    u->erase(a[p]);
    u->insert(x);

    if (u->l == u->r) return;

    int mid = (u->l + u->r) >> 1;

    if (p <= mid) modify(u->lchild, p, x);
    else modify(u->rchild, p, x);
}

int query_pre(node *u, int l, int r, int x) {
    if (l <= u->l && u->r <= r) {
        return u->pre(x);
    }

    int mid = (u->l + u->r) >> 1;
    int res = std::numeric_limits<int>::min() + 1;

    if (l <= mid) res = std::max(res, query_pre(u->lchild, l, r, x));
    if (r > mid) res = std::max(res, query_pre(u->rchild, l, r, x));

    return res;
}

int query_suc(node *u, int l, int r, int x) {
    if (l <= u->l && u->r <= r) {
        return u->suc(x);
    }

    int mid = (u->l + u->r) >> 1;
    int res = std::numeric_limits<int>::max();

    if (l <= mid) res = std::min(res, query_suc(u->lchild, l, r, x));
    if (r > mid) res = std::min(res, query_suc(u->rchild, l, r, x));

    return res;
}

int main() {
    std::ios::sync_with_stdio(false);
    cin.tie(nullptr);

    cin >> n >> m;

    for (int i = 1; i <= n; i++) {
        cin >> a[i];
    }

    build(root, 1, n);

    while (m--) {
        int op;

        cin >> op;

        if (op == 1) {
            int l, r, x;

            cin >> l >> r >> x;

            cout << query_rank(root, l, r, x) + 1 << endl;
        } else if (op == 2) {
            int l, r, k;

            cin >> l >> r >> k;

            cout << query_kth(l, r, k) << endl;
        } else if (op == 3) {
            int p, x;

            cin >> p >> x;

            modify(root, p, x);
            a[p] = x;
        } else if (op == 4) {
            int l, r, x;

            cin >> l >> r >> x;

            cout << query_pre(root, l, r, x) << endl;
        } else {  // op == 5
            int l, r, x;

            cin >> l >> r >> x;

            cout << query_suc(root, l, r, x) << endl;
        }
    }

    delete root;

    return 0;
}