TIOJ 2229:G. 矩陣相乘
TIOJ 2229:G. 矩陣相乘
題目大意:有一個大小 $n\times n$ 的矩陣 $A$ 和 矩陣 $B$,$A \times B = C$,且是在模 $P$ 的運算下,$C$ 裡的非 $0$ 元素小於等於 $2\times n$ 個,求所有非 $0$ 元素位置與值。
解法:$A\times B = C$,乘一次需要 $O(n^3)$,會太久,可以知道 $(A\times B)\times V = A\times(B\times V) = C\times V$,$V$ 為一個 $n\times 1$ 的矩陣,這樣就可將時間壓到 $O(n^2)$。那要如何從 $V$ 得知 $C$ 哪裡有非 $0$ 元素呢?假設 $V_i\ != 0$,則 $C_i$ 一定有非 $0$ 的元素,但如果 $V_i = 0$,$C_i$ 的元素不一定全部都是 $0$,有可能是與 $V$ 運算過後剛好被 $P$ 整除,所以要讓 $V$ 為隨機的數,且要做多次,才可降低出錯機率。當我們知道 $C$ 有哪些行有非 $0$ 元素,假設現在在第 $i$ 行,那我們就把 $A_i$ 加入新的 $A$ 裡,$B$ 則是要分成左半部與右半部,之後遞迴 $(A', B_l), (A', B_r)$。只是如果每次都直接傳入矩陣會 $\text{MLE}$,所以做法要改成傳入 $A$ 有哪些行要被用到與 $B$ 有哪些列要被用到。
$\text{Code:}$
#pragma GCC optimize("O3")
#include <bits/stdc++.h>
using namespace std;
#define Waimai ios::sync_with_stdio(false),cin.tie(0)
#define FOR(x,a,b) for(int x=a;x<=b;x++)
#define pb emplace_back
#define F first
#define S second
#define rnd(l,r) uniform_int_distribution<int> (l, r) (seed);
mt19937 seed (time (NULL));
const int SIZE = 2805;
const int times = 5;
int subtask, N, mod;
int A[SIZE][SIZE], B[SIZE][SIZE];
vector<tuple<int, int, int>> ans;
void dfs (vector<int> idA, vector<int> idB) {
int n = idA.size() - 1, m = idB.size() - 1;
if (m == 1) {
FOR (i, 1, n) {
long long sum = 0;
FOR (j, 1, N) sum += 1ll * A[idA[i]][j] * B[j][idB[1]];
sum %= mod;
if (sum) {
ans.pb (idA[i], idB[1], sum);
}
}
return;
}
vector<int> v[times + 1];
FOR (i, 1, times) {
v[i].resize (m + 1);
FOR (j, 1, m) v[i][j] = rnd (0, mod - 1);
vector<int> tmp (N + 1);
FOR (j, 1, N) {
long long sum = 0;
FOR (k, 1, m) sum += 1ll * B[j][idB[k]] * v[i][k];
tmp[j] = sum % mod;
}
v[i] = tmp;
tmp.resize (n + 1);
FOR (j, 1, n) {
long long sum = 0;
FOR (k, 1, N) sum += 1ll * A[idA[j]][k] * v[i][k];
tmp[j] = sum % mod;
}
v[i] = tmp;
}
vector<int> newidA (1, 0), lidB (1, 0), ridB (1, 0);
FOR (i, 1, n) {
bool ok = 0;
FOR (j, 1, times) ok |= v[j][i] != 0;
if (!ok) {
continue;
}
newidA.pb (idA[i]);
}
if (newidA.back()) {
FOR (j, 1, m / 2) lidB.pb (idB[j]);
FOR (j, m / 2 + 1, m) ridB.pb (idB[j]);
dfs (newidA, lidB);
dfs (newidA, ridB);
}
}
void solve() {
cin >> subtask >> N >> mod;
FOR (i, 1, N) FOR (j, 1, N) cin >> A[i][j];
FOR (i, 1, N) FOR (j, 1, N) cin >> B[i][j];
vector<int> idA (N + 1), idB (N + 1);
iota (idA.begin(), idA.end(), 0);
iota (idB.begin(), idB.end(), 0);
dfs (idA, idB);
sort (ans.begin(), ans.end());
for (auto [x, y, t] : ans) {
cout << x << ' ' << y << ' ' << t << '\n';
}
}
int main() {
Waimai;
solve();
}
我的分享就到這裡結束了,如果喜歡我的 $\text{Blog}$,歡迎追蹤!
留言
張貼留言