TIOJ 2268:紗霧與正宗

TIOJ 2268:紗霧與正宗


題目大意:給一個大小為 $n$ 的陣列 $a_1\sim a_n$,可對陣列元素進行 $k$ 輪替換,每一輪給 $b_i$,可將 $a$ 中某一個元素改成 $b_i$,也可以不改,問最後最大的最大區連續和可以是多少。

解法:首先觀察到選固定區間 $[l, r]$ 可得的最大值等於 $a_l\sim a_r,b_1\sim b_k$ 中前 $r-l+1$ 大的加總,可以用持久化線段樹 $O(\log(n+k))$ 算。

接下來定義一個函數 $\text{que}(l,r)$,代表取 $[l,r]$ 這個區間的答案,若 $l>r$ 則回傳 $-\text{inf}-l+r$,定義一個矩陣 $A$,$A_{i,j}=\text{que}(i,j)$,發現它有完全單調的性質,會得到最大值發生的位置 $R(i)$ 滿足 $R(1)\leq R(2)\leq\dots\leq R(n)$,最後只要取 $\text{que}(i,R(i))$ 的最大值就好了。

如果用轉移點單調算,會是 $\log^2$ 的,但是這題的 $n=10^6$,所以要想辦法優化。有一個叫 $\text{SMAWK}$ 演算法的東西,可以讓複雜度少一個 $\log$,大致上來講就是當 $n\geq m$ 時,算奇數行的答案,然後偶數行的用上一行與下一行的位置算;當 $n<m$ 時,可以用一個 $\text{stack}$ 存哪些列要留下來。

總時間複雜度 $O((n+k)\log(n+k))$。

$\text{Code:}$

#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx2,popcnt,sse4,abm")
#include <bits/stdc++.h>
using namespace std;

using ll = long long;

const ll INF = 1e18;
const int N = 1e6 + 5;
const int T = 3e7 + 5;

int n, m, k;
ll ans;
int a[N], b[N], lis[2 * N];

struct Nd {
    ll sum;
    int cnt;
    Nd() = default;
} nd[8 * N];

int cnt[2 * N];
void build(int pos, int l, int r) {
    if (l == r) {
        nd[pos].cnt = cnt[l];
        nd[pos].sum = 1ll * cnt[l] * lis[l];
        return;
    }
    int mid = (l + r) / 2;
    build(pos << 1, l, mid);
    build(pos << 1 | 1, mid + 1, r);
    nd[pos].cnt = nd[pos << 1].cnt + nd[pos << 1 | 1].cnt;
    nd[pos].sum = nd[pos << 1].sum + nd[pos << 1 | 1].sum;
}

struct Node {
    ll sum;
    int cnt, ls, rs;
    Node() = default;
} node[T];

int sz, root[N];
void upd(int &pos, int l, int r, int p) {
    node[++sz] = node[pos];
    pos = sz;
    if (l == r) {
        node[pos].cnt++;
        node[pos].sum += lis[p];
        return;
    }
    int mid = (l + r) / 2;
    if (p <= mid) upd(node[pos].ls, l, mid, p);
    else upd(node[pos].rs, mid + 1, r, p);
    node[pos].sum = node[node[pos].ls].sum + node[node[pos].rs].sum;
    node[pos].cnt = node[node[pos].ls].cnt + node[node[pos].rs].cnt;
}
ll que(int pos, int pL, int pR, int l, int r, int k) {
    if (l == r) return 1ll * lis[l] * k;
    int mid = (l + r) / 2;
    int rcnt = nd[pos << 1 | 1].cnt + node[node[pR].rs].cnt - node[node[pL].rs].cnt;
    if (rcnt <= k) {
        ll rsum = nd[pos << 1 | 1].sum + node[node[pR].rs].sum - node[node[pL].rs].sum;
        return (k > rcnt ? que(pos << 1, node[pL].ls, node[pR].ls, l, mid, k - rcnt) : 0) + rsum;
    }
    return que(pos << 1 | 1, node[pL].rs, node[pR].rs, mid + 1, r, k);
}
inline ll que(int l, int r) {
    if (l > r) return -INF - l + r;
    return que(1, root[l - 1], root[r], 1, m, r - l + 1);
}

const int lim = 2;

int ansR[N];
void smawk() {
    auto rec = [&](auto rec, vector<int> &row, vector<int> &col)->void {
        int n = row.size(), m = col.size();
        if (min(n, m) <= lim) {
            int last = 0;
            for (int i = 0; i < n; i++) {
                int best = -1;
                ll mx = -2 * INF;
                for (int j = last; j < m; j++) {
                    ll val = que(row[i], col[j]);
                    if (val > mx) {
                        mx = val;
                        best = j;
                    }
                }
                ansR[row[i]] = col[best];
                last = best;
            }
            return;
        }
        if (n >= m) {
            vector<int> odd;
            for (int i = 0; i < n; i += 2) odd.emplace_back(row[i]);
            rec(rec, odd, col);
            int l = 0, r = 0;
            for (int i = 1; i < n; i += 2) {
                int lp = ansR[row[i - 1]];
                int rp = (i == n - 1 ? col[m - 1] : ansR[row[i + 1]]);
                while (col[l] < lp) l++;
                while (col[r] < rp) r++;
                ll mx = -2 * INF;
                for (int j = l; j <= r; j++) {
                    ll val = que(row[i], col[j]);
                    if (val > mx) {
                        mx = val;
                        ansR[row[i]] = col[j];
                    }
                }
            }
            return;
        }
        vector<int> red;
        for (int j = 0; j < m; j++) {
            while (red.size()) {
                int k = red.size() - 1;
                if (que(row[k], red.back()) >= que(row[k], col[j])) break;
                red.pop_back();
            }
            if (red.size() < n) red.emplace_back(col[j]);
        }
        rec(rec, row, red);
    };
    vector<int> row(n), col(n);
    iota(row.begin(), row.end(), 1);
    iota(col.begin(), col.end(), 1);
    rec(rec, row, col);
}

int main() {
    ios::sync_with_stdio(false), cin.tie(0);
    int tt;
    cin >> tt;
    while (tt--) {
        cin >> n >> k;
        for (int i = 1; i <= n; i++) cin >> a[i], lis[i] = a[i];
        for (int i = 1; i <= k; i++) cin >> b[i], lis[n + i] = b[i];
        sort(lis + 1, lis + n + k + 1);
        m = unique(lis + 1, lis + n + k + 1) - lis - 1;
        fill(cnt + 1, cnt + m + 1, 0);
        for (int i = 1; i <= k; i++) {
            b[i] = lower_bound(lis + 1, lis + m + 1, b[i]) - lis;
            cnt[b[i]]++;
        }
        build(1, 1, m);
        for (int i = 1; i <= n; i++) {
            root[i] = root[i - 1];
            a[i] = lower_bound(lis + 1, lis + m + 1, a[i]) - lis;
            upd(root[i], 1, m, a[i]);
        }
        smawk();
        ans = -INF;
        for (int i = 1; i <= n; i++) ans = max(ans, que(i, ansR[i]));
        cout << ans << '\n';
    }
}

我的分享就到這裡結束了,如果喜歡我的 $\text{Blog}$,歡迎追蹤!

留言

熱門文章