弗洛伊德算法 $O(n^3)$ 哦? 可以过?!!!
正解 最近公共祖先
#include <cstdio>
#include <vector>
#include <functional>
using namespace std;
int main() {
int n, m;
scanf("%d %d", &n, &m);
vector<vector<char>> G(n, vector<char>(n, 0));
vector<int> w(n);
for (int i = 1; i < n; ++i) {
int x, y;
scanf("%d %d", &x, &y);
--x, --y;
G[x][y] = 1;
G[y][x] = 1;
++w[x], ++w[y];
}
const int inf = 1e9;
vector<vector<int>> f(n, vector<int>(n, inf));
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
if (G[i][j])
f[i][j] = w[i];
}
}
for (int k = 0; k < n; ++k) {
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
if (i == j || i == k || j == k)
continue;
if (f[i][j] > f[i][k] + f[k][j])
f[i][j] = f[i][k] + f[k][j];
}
}
}
for (int i = 0; i < m; ++i) {
int x, y;
scanf("%d %d", &x, &y);
--x, --y;
printf("%d\n", (f[x][y] == inf ? 0 : f[x][y]) + w[y]);
}
return 0;
}正解 最近公共祖先
#include <cstdio>
#include <vector>
#include <tuple>
#include <functional>
using namespace std;
int getBit(int x) {
int res = 0;
while (x) {
++res;
x >>= 1;
}
return res;
}
int main() {
int n, m;
scanf("%d %d", &n, &m);
int n_bit = getBit(n);
// 父节点, 路径权和
vector<vector<tuple<int, int>>> f(n, vector<tuple<int, int>>(n_bit, make_tuple(-1, 0)));
vector<vector<int>> G(n);
vector<int> w(n), h(n);
for (int i = 1; i < n; ++i) {
int x, y;
scanf("%d %d", &x, &y);
--x, --y;
G[y].push_back(x);
G[x].push_back(y);
++w[x], ++w[y];
}
function<void(int, int)> dfs = [&](int i, int pa) {
// printf("%d -> %d\n", i, pa);
for (int it : G[i]) {
if (it != pa) {
h[it] = h[i] + 1;
f[it][0] = make_tuple(i, w[it]);
dfs(it, i);
}
}
};
dfs(0, -1);
// dp
for (int b = 1; b < n_bit; ++b) {
for (int i = 0; i < n; ++i) {
auto it = f[i][b - 1];
if (get<0>(it) != -1) {
f[i][b] = make_tuple(
get<0>(f[get<0>(it)][b - 1]),
get<1>(it) + get<1>(f[get<0>(it)][b - 1])
);
}
}
}
function<int(int, int)> lac = [&](int a, int b) {
if (h[a] > h[b]) {
int tmp = b;
b = a;
a = tmp;
}
int res = 0;
// b 是更深的
int dh = h[b] - h[a], bit = 0;
while (dh) {
if (dh & 1) {
res += get<1>(f[b][bit]);
b = get<0>(f[b][bit]);
}
++bit;
dh >>= 1;
}
if (a == b)
return res + w[a]; // !
for (int bb = n_bit - 1; bb >= 0; --bb) {
if (get<0>(f[a][bb]) != get<0>(f[b][bb])) {
res += get<1>(f[a][bb]) + get<1>(f[b][bb]);
a = get<0>(f[a][bb]);
b = get<0>(f[b][bb]);
}
}
// printf("[%d] (%d, %d) - ", res, get<1>(f[a][0]), get<1>(f[b][0]));
return res + get<1>(f[a][0]) + get<1>(f[b][0]) + w[get<0>(f[a][0])];
};
while (m--) {
int u, v;
scanf("%d %d", &u, &v);
--u, --v;
printf("%d\n", lac(u, v));
}
return 0;
}
0 回复
0 转发
0 喜欢
2 阅读



