CSES 2086:Subarray Squares

CSES 2086Subarray Squares


題目大意:給一個陣列 $X<n>$,要將此陣列分成連續的 $k$ 份,每份的數值是元素總和的平方,求 $k$ 份數值總和的最小值。

解法:無情斜率優化 …。

$\text{Code:}$

///
/// dp (k, i) = max (dp (k - 1, j) + (pre[i] - pre[j]) ^ 2)
///           = max (-2 * pre[j] * pre[i] + dp (k - 1, j) + pre[j] ^ 2 + pre[i] ^ 2)
///
#pragma GCC optimize("O3")
#include <bits/stdc++.h>
using namespace std;
 
#define int long long
 
const int SIZE = 3005;
 
struct Line {
    int m, b, id;
};
 
int n;
int pre[SIZE];
int dp[SIZE][SIZE];
deque<Line> last, dq;
 
void get_dq (int k) {
    for (int i = k; i <= n; i++) {
        int m = 2 * pre[i];
        dp[k][i] = pre[i] * pre[i];
        while (last.size() > 1 && last[1].id < i) {
            auto l = last.front();
            last.pop_front();
            if (l.m * pre[i] + l.b > last.front().m * pre[i] + last.front().b) {
                last.push_front (l);
                break;
            }
        }
        dp[k][i] -= last.front().m * pre[i] + last.front().b;
        int num = -dp[k][i] - pre[i] * pre[i];
        while (dq.size() > 1) {
            auto l = dq.back();
            dq.pop_back();
            if (m * (dq.back().b - l.b) + dq.back().m * (l.b - num) < l.m * (dq.back().b - num)) {
                dq.push_back (l);
                break;
            }
        }
        dq.push_back ({m, num, i});
    }
}
 
void solve() {
    int k;
    cin >> n >> k;
    for (int i = 1; i <= n; i++) {
        cin >> pre[i];
        pre[i] += pre[i - 1];
    }
 
    for (int i = 1; i <= n; i++) {
        dp[1][i] = pre[i] * pre[i];
        dq.push_back ({2 * pre[i], -dp[1][i] - pre[i]*pre[i], i});
    }
 
    for (int i = 2; i <= k; i++) {
        last = dq, dq = deque<Line>();
        get_dq (i);
    }
    cout << dp[k][n] << '\n';
}
 
int32_t main() {
    ios::sync_with_stdio (false), cin.tie (0);
    solve();
}

Upd:可以 aliens 優化。

#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,popcnt,sse4,abm")
#include <bits/stdc++.h>
using namespace std;
 
#ifdef WAIMAI
#define debug(HEHE...) cout << "[" << #HEHE << "] : ", dout(HEHE)
void dout() {cout << '\n';}
template<typename T, typename...U>
void dout (T t, U...u) {cout << t << (sizeof... (u) ? ", " : ""), dout (u...);}
#else
#define debug(...) 7122
#endif
 
#define ll long long
#define Waimai ios::sync_with_stdio(false), cin.tie(0)
#define FOR(x,a,b) for (int x = a, I = b; x <= I; x++)
#define pb emplace_back
#define F first
#define S second
 
// dp[i] = min [ dp[j] + (pre[i] - pre[j])^2 ] + penalty
//       = min [ -2*pre[j] * pre[i] + dp[j]+pre[j]^2 ] + pre[i]^2 + penalty
 
const ll INF = 4e18;
const int SIZE = 3005;
 
int n, k;
int pre[SIZE];
pair<ll, int> dp[SIZE];
 
pair<ll, int> cal(int i, int x) {
    return {-2ll*pre[i]*x + dp[i].F + 1ll*pre[i]*pre[i], dp[i].S + 1};
}
bool cmp(pair<ll, int> p1, pair<ll, int> p2) {
    auto [v1, c1] = p1;
    auto [v2, c2] = p2;
    return v1 > v2 || (v1 == v2 && c1 >= c2);
}
bool del(int i, int j, int k) {
    int mi = -pre[i]; ll ki = dp[i].F + 1ll*pre[i]*pre[i];
    int mj = -pre[j]; ll kj = dp[j].F + 1ll*pre[j]*pre[j];
    int mk = -pre[k]; ll kk = dp[k].F + 1ll*pre[k]*pre[k];
    __int128 a = (__int128) (kk - ki) * (mi - mj);
    __int128 b = (__int128) (kj - ki) * (mi - mk);
    return a < b || (a == b && dp[j].S >= dp[k].S);
}
 
bool ok(ll pen) {
    deque<int> st;
    st.pb(0);
    FOR (i, 1, n) {
        while (st.size() >= 2 && cmp(cal(st[0], pre[i]), cal(st[1], pre[i]))) st.pop_front();
        dp[i] = cal(st[0], pre[i]), dp[i].F += 1ll*pre[i]*pre[i] + pen;
        while (st.size() >= 2 && del(st.end()[-2], st.end()[-1], i)) st.pop_back();
        st.pb(i);
    }
    return dp[n].S <= k;
}
 
void solve() {
    cin >> n >> k;
    FOR (i, 1, n) cin >> pre[i], pre[i] += pre[i - 1];
    ll l = 0, r = 1e17;
    while (l < r) {
        ll mid = (l + r) / 2;
        if (ok(mid)) r = mid;
        else l = mid + 1;
    }
    ok(l);
    cout << dp[n].F - l*k << '\n';
}
 
int main() {
    Waimai;
    solve();
}

我的分享就到這裡結束了,如果喜歡我的 $\text{Blog}$,歡迎追蹤!

留言

熱門文章