TIOJ 2229:G. 矩陣相乘

TIOJ 2229G. 矩陣相乘


題目大意:有一個大小 $n\times n$ 的矩陣 $A$ 和 矩陣 $B$,$A \times B = C$,且是在模 $P$ 的運算下,$C$ 裡的非 $0$ 元素小於等於 $2\times n$ 個,求所有非 $0$ 元素位置與值。

解法:$A\times B = C$,乘一次需要 $O(n^3)$,會太久,可以知道 $(A\times B)\times V = A\times(B\times V) = C\times V$,$V$ 為一個 $n\times 1$ 的矩陣,這樣就可將時間壓到 $O(n^2)$。那要如何從 $V$ 得知 $C$ 哪裡有非 $0$ 元素呢?假設 $V_i\ != 0$,則 $C_i$ 一定有非 $0$ 的元素,但如果 $V_i = 0$,$C_i$ 的元素不一定全部都是 $0$,有可能是與 $V$ 運算過後剛好被 $P$ 整除,所以要讓 $V$ 為隨機的數,且要做多次,才可降低出錯機率。當我們知道 $C$ 有哪些行有非 $0$ 元素,假設現在在第 $i$ 行,那我們就把 $A_i$ 加入新的 $A$ 裡,$B$ 則是要分成左半部與右半部,之後遞迴 $(A', B_l), (A', B_r)$。只是如果每次都直接傳入矩陣會 $\text{MLE}$,所以做法要改成傳入 $A$ 有哪些行要被用到與 $B$ 有哪些列要被用到。

$\text{Code:}$

#pragma GCC optimize("O3")
#include <bits/stdc++.h>
using namespace std;

#define Waimai ios::sync_with_stdio(false),cin.tie(0)
#define FOR(x,a,b) for(int x=a;x<=b;x++)
#define pb emplace_back
#define F first
#define S second

#define rnd(l,r) uniform_int_distribution<int> (l, r) (seed);
mt19937 seed (time (NULL));

const int SIZE = 2805;
const int times = 5;

int subtask, N, mod;
int A[SIZE][SIZE], B[SIZE][SIZE];
vector<tuple<int, int, int>> ans;

void dfs (vector<int> idA, vector<int> idB) {
    int n = idA.size() - 1, m = idB.size() - 1;
    if (m == 1) {
        FOR (i, 1, n) {
            long long sum = 0;
            FOR (j, 1, N) sum += 1ll * A[idA[i]][j] * B[j][idB[1]];
            sum %= mod;
            if (sum) {
                ans.pb (idA[i], idB[1], sum);
            }
        }
        return;
    }
    vector<int> v[times + 1];
    FOR (i, 1, times) {
        v[i].resize (m + 1);
        FOR (j, 1, m) v[i][j] = rnd (0, mod - 1);
        vector<int> tmp (N + 1);
        FOR (j, 1, N) {
            long long sum = 0;
            FOR (k, 1, m) sum += 1ll * B[j][idB[k]] * v[i][k];
            tmp[j] = sum % mod;
        }
        v[i] = tmp;
        tmp.resize (n + 1);
        FOR (j, 1, n) {
            long long sum = 0;
            FOR (k, 1, N) sum += 1ll * A[idA[j]][k] * v[i][k];
            tmp[j] = sum % mod;
        }
        v[i] = tmp;
    }
    vector<int> newidA (1, 0), lidB (1, 0), ridB (1, 0);
    FOR (i, 1, n) {
        bool ok = 0;
        FOR (j, 1, times) ok |= v[j][i] != 0;
        if (!ok) {
            continue;
        }
        newidA.pb (idA[i]);
    }
    if (newidA.back()) {
        FOR (j, 1, m / 2) lidB.pb (idB[j]);
        FOR (j, m / 2 + 1, m) ridB.pb (idB[j]);
        dfs (newidA, lidB);
        dfs (newidA, ridB);
    }
}

void solve() {
    cin >> subtask >> N >> mod;
    FOR (i, 1, N) FOR (j, 1, N) cin >> A[i][j];
    FOR (i, 1, N) FOR (j, 1, N) cin >> B[i][j];
    vector<int> idA (N + 1), idB (N + 1);
    iota (idA.begin(), idA.end(), 0);
    iota (idB.begin(), idB.end(), 0);
    dfs (idA, idB);
    sort (ans.begin(), ans.end());
    for (auto [x, y, t] : ans) {
        cout << x << ' ' << y << ' ' << t << '\n';
    }
}

int main() {
    Waimai;
    solve();
}

我的分享就到這裡結束了,如果喜歡我的 $\text{Blog}$,歡迎追蹤!

留言

熱門文章