TIOJ 2242:抽水機

TIOJ 2242:抽水機


題目大意:有一棵 $n$ 個節點的樹,每個點有水量 $a_i$,有 $q$ 次抽水的操作,每次會選一個節點,輸出這個節點的水量,並把水抽走,其他節點的水會往這個節點移動一格。

解法:首先先把樹輕重鏈剖分,變成很多條鏈,對每條鏈開一個 treap,維護每個節點的水量和是否有水。對鏈跟鏈之間建邊,每次抽水時直接暴力 dfs,若更新完後這個鏈的點都沒水了,那可以把父節點連向自己的邊刪掉,反之加上邊。

如果抽到有水的點可以直接抽,如果沒水的話,那要找到離目前連通塊最近的那個,可以開一個 set 維護現在還有哪些點有水,如果按照 dfn 排序 dfn 最小的會是目前連通塊最上面的點。如果抽水的點沒有任何祖先是有水的,那離他最近的是 dfn 最小的點,否則可以倍增找。

每次操作的複雜度是有水鏈的數量乘一個 log,操作完葉子會不見,我沒有特別分析樹剖後的複雜度,不過我相信是好的。

$\text{Code:}$

#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,popcnt,sse4,abm")
#include <bits/stdc++.h>
using namespace std;

#ifdef WAIMAI
#define debug(HEHE...) cout << "[" << #HEHE << "] : ", dout(HEHE)
void dout() {cout << '\n';}
template<typename T, typename...U>
void dout (T t, U...u) {cout << t << (sizeof... (u) ? ", " : ""), dout (u...);}
#else
#define debug(...) 7122
#endif

#define ll long long
#define Waimai ios::sync_with_stdio(false), cin.tie(0)
#define FOR(x,a,b) for (int x = a, I = b; x <= I; x++)
#define pb emplace_back
#define F first
#define S second

const int SIZE = 1e5 + 5;
const int H = __lg(SIZE);

int n, q;
int a[SIZE];
vector<int> adj[SIZE];

mt19937 rng(7122);
struct Treap {
    int sz = 1;
    unsigned prior = rng();
    ll val = 0;
    bool is = 1;
    Treap *ls = nullptr, *rs = nullptr;
    Treap(ll val, bool is) : val(val), is(is) {}
    void pull() {
        sz = (ls ? ls->sz : 0) + (rs ? rs->sz : 0) + 1;
    }
};
Treap *merge(Treap *a, Treap *b) {
    if (a == nullptr || b == nullptr) return a ? a : b;
    if (a->prior < b->prior) {
        a->rs = merge(a->rs, b);
        a->pull();
        return a;
    } else {
        b->ls = merge(a, b->ls);
        b->pull();
        return b;
    }
}
void split(Treap *t, Treap *&a, Treap *&b, int k) {
    if (t == nullptr) {
        a = b = nullptr;
        return;
    }
    int lsz = (t->ls ? t->ls->sz : 0);
    if (k <= lsz) {
        b = t;
        split(t->ls, a, b->ls, k);
        b->pull();
    } else {
        a = t;
        split(t->rs, a->rs, b, k - lsz - 1);
        a->pull();
    }
}
Treap *kth(Treap *t, int k) {
    int lsz = (t->ls ? t->ls->sz : 0);
    if (k == lsz + 1) return t;
    return k <= lsz ? kth(t->ls, k) : kth(t->rs, k - lsz - 1);
}

struct Chain {
    Treap *chain;
    vector<int> nodes;
    set<int> adj;
    int l, r;
    bool empty() {
        return l > r;
    }
} chain[SIZE];

int chain_sz;
int dfcnt, dfn[SIZE], vert[SIZE];
int pa[SIZE], weight[SIZE], h[SIZE];
int chain_id[SIZE], chain_pos[SIZE];

void dfs1(int pos) {
    dfn[pos] = ++dfcnt;
    vert[dfcnt] = pos;
    weight[pos] = 1;
    for (int np : adj[pos]) if (dfn[np] == 0) {
        pa[np] = pos;
        h[np] = h[pos] + 1;
        dfs1(np);
        weight[pos] += weight[np];
    }
}
void dfs2(int pos, int tid) {
    chain_id[pos] = tid;
    chain_pos[pos] = chain[tid].nodes.size();
    chain[tid].chain = merge(chain[tid].chain, new Treap(a[pos], 1));
    chain[tid].nodes.pb(pos);
    chain[tid].l = 1, chain[tid].r = chain_pos[pos];
    int son = 0;
    for (int np : adj[pos]) if (dfn[pos] < dfn[np] && weight[np] > weight[son]) son = np;
    if (son) dfs2(son, tid);
    for (int np : adj[pos]) if (dfn[pos] < dfn[np] && np != son) {
        chain[++chain_sz].nodes.pb(pos);
        chain[tid].adj.insert(chain_sz);
        dfs2(np, chain_sz);
    }
}

set<int> all;
int to[SIZE][H + 1];

bool in(int pos) {
    int tid = chain_id[pos];
    pos = chain_pos[pos];
    return chain[tid].l <= pos && pos <= chain[tid].r;
}
int jump(int pos, int k) {
    int cnt = 0;
    while (k) {
        if (k & 1) pos = to[pos][cnt];
        cnt++;
        k >>= 1;
    }
    return pos;
}
int find_pos(int pos) {
    if (in(pos)) return pos;
    int top = vert[*all.begin()];
    if (h[pos] <= h[top] || jump(pos, h[pos] - h[top]) != top) return pa[top];
    for (int i = H; i >= 0; i--) if (h[to[pos][i]] >= h[top] && in(to[pos][i]) == 0) pos = to[pos][i];
    return pos;
}
void add(int tid, int p) {
    int pos = chain[tid].nodes[p];
    all.insert(dfn[pos]);
}
void del(int tid, int p) {
    int pos = chain[tid].nodes[p];
    if (all.count(dfn[pos])) all.erase(dfn[pos]);
}

ll que(int pos) {
    if (in(pos) == 0) return 0;
    int tid = chain_id[pos];
    return kth(chain[tid].chain, chain_pos[pos])->val;
}
void dfs(int pos, int last) {
    int tid = chain_id[pos], p = chain_pos[pos];
    if (chain[tid].empty() == 0) {
        Treap *node, *tl, *tr;
        node = chain[tid].chain;
        split(node, tl, node, p - 1);
        split(node, node, tr, 1);
        delete node;
        ll nval = 0;
        bool nis = 0;
        if (tl != nullptr) {
            Treap *tmp;
            split(tl, tl, tmp, p - 2);
            nval += tmp->val, nis |= tmp->is;
            *tmp = Treap(0, 0);
            tl = merge(tmp, tl);
        }
        if (tr != nullptr) {
            Treap *tmp;
            split(tr, tmp, tr, 1);
            nval += tmp->val, nis |= tmp->is;
            *tmp = Treap(0, 0);
            tr = merge(tr, tmp);
        }
        if (chain[tid].l == p && chain[tid].r == p) {
            del(tid, p);
            chain[tid].l++;
        } else {
            if (chain[tid].l < p) del(tid, chain[tid].l++);
            else if (chain[tid].l > p) add(tid, --chain[tid].l);
            if (chain[tid].r > p) del(tid, chain[tid].r--);
            else if (chain[tid].r < p) add(tid, ++chain[tid].r);
        }
        chain[tid].chain = new Treap(nval, nis);
        chain[tid].chain = merge(tl, chain[tid].chain);
        chain[tid].chain = merge(chain[tid].chain, tr);
    }
    if (chain[tid].nodes[0] != 0 && last != chain_id[chain[tid].nodes[0]]) {
        int pa = chain[tid].nodes[0], pid = chain_id[pa];
        Treap *tmp = kth(chain[pid].chain, chain_pos[pa]), *top = kth(chain[tid].chain, 1);
        top->val += tmp->val, top->is |= tmp->is;
        if (top->is) {
            if (chain[tid].empty()) {
                chain[tid].l = chain[tid].r = 1;
                add(tid, 1);
            } else if (chain[tid].l > 1) {
                add(tid, --chain[tid].l);
            }
        }
        dfs(pa, tid);
    }
    set<int> s = chain[tid].adj;
    for (int sid : s) if (sid != last) {
        int p = chain_pos[chain[sid].nodes[0]];
        Treap *tmp = kth(chain[tid].chain, p), *top = kth(chain[sid].chain, 1);
        tmp->val += top->val, tmp->is |= top->is;
        if (tmp->is) {
            if (chain[tid].empty()) {
                chain[tid].l = chain[tid].r = p;
                add(tid, p);
            } else {
                if (chain[tid].l > p) add(tid, --chain[tid].l);
                if (chain[tid].r < p) add(tid, ++chain[tid].r);
            }
        }
        dfs(chain[sid].nodes[1], tid);
    }
    if (chain[tid].nodes[0] != 0) {
        int pid = chain_id[chain[tid].nodes[0]];
        if (chain[tid].empty() == 0) chain[pid].adj.insert(tid);
        else if (chain[pid].adj.count(tid)) chain[pid].adj.erase(tid);
    }
}

void solve() {
    cin >> n;
    FOR (i, 2, n) {
        int a, b;
        cin >> a >> b;
        adj[a].pb(b);
        adj[b].pb(a);
    }
    FOR (i, 1, n) cin >> a[i];

    {
        h[1] = 1, dfs1(1);
        chain[++chain_sz].nodes.pb(0), dfs2(1, 1);
        FOR (i, 1, n) all.insert(i), to[i][0] = pa[i];
        FOR (j, 1, H) FOR (i, 1, n) to[i][j] = to[to[i][j - 1]][j - 1];
    }

    cin >> q;
    while (q--) {
        int pos;
        cin >> pos;
        if (all.size() == 0) {
            cout << "0 ";
            continue;
        }
        pos = find_pos(pos);
        cout << que(pos) << ' ';
        dfs(pos, 0);
    }
}

int main() {
    Waimai;
    solve();
}

我的分享就到這裡結束了,如果喜歡我的 $\text{Blog}$,歡迎追蹤!

留言

熱門文章