题目
536. 保卫王国
算法标签: 树形 d p dp dp, 树上倍增, L C A LCA LCA
思路
在原问题基础上还有每个点的限制, 要求某些点必须选择, 某些节点不能选择
将问题简化, 如果是一维问题, 可以前后缀分解来做
将问题回到树形问题, 首先考虑只有一个点选或者不选, 可以将整个树分为两部分, 也是类似于前后缀分解的方法满足一个点的限制
对于两个点以上情况, 使用倍增求解
对于两个点来说选与不选是两种情况, 然后对于 l c a lca lca也有两种情况, 选择或者不选择
代码
#include <iostream>
#include <algorithm>
#include <cstring>
#include <vector>using namespace std;typedef long long LL;
const int N = 1e5 + 10, M = N << 1, K = 17;
const LL INF = 1e18;int n, m, p[N];
string type;
vector<int> head[N];
LL f[N][2], g[N][2], w[N][K][2][2];
int fa[N][K], depth[N];void add(int u, int v) {head[u].push_back(v);
}void dfs_f(int u, int father) {f[u][1] = p[u];for (int v: head[u]) {if (v == father) continue;dfs_f(v, u);f[u][0] += f[v][1];f[u][1] += min(f[v][0], f[v][1]);}
}void dfs_g(int u, int father) {for (int v: head[u]) {if (v == father) continue;g[v][0] = g[u][1] + f[u][1] - min(f[v][0], f[v][1]);g[v][1] = min(g[v][0], g[u][0] + f[u][0] - f[v][1]);dfs_g(v, u);}
}void dfs_fa(int u, int father) {fa[u][0] = father;for (int k = 1; k < K; ++k)fa[u][k] = fa[fa[u][k - 1]][k - 1];for (int v: head[u]) {if (v == father) continue;depth[v] = depth[u] + 1;dfs_fa(v, u);}
}void dfs_w(int u, int father) {for (int v: head[u]) {if (v == father) continue;w[v][0][0][0] = INF;w[v][0][0][1] = f[u][1] - min(f[v][0], f[v][1]);w[v][0][1][0] = f[u][0] - f[v][1];w[v][0][1][1] = f[u][1] - min(f[v][0], f[v][1]);for (int k = 1; k < K; ++k) {int anc = fa[v][k - 1];for (int x = 0; x < 2; ++x) {for (int y = 0; y < 2; ++y) {w[v][k][x][y] = INF;for (int z = 0; z < 2; ++z) {w[v][k][x][y] = min(w[v][k][x][y], w[v][k - 1][x][z] + w[anc][k - 1][z][y]);}}}}dfs_w(v, u);}
}int lca(int a, int b) {if (depth[a] < depth[b]) swap(a, b);for (int k = K - 1; k >= 0; --k)if (depth[fa[a][k]] >= depth[b])a = fa[a][k];if (a == b) return a;for (int k = K - 1; k >= 0; --k)if (fa[a][k] != fa[b][k])a = fa[a][k], b = fa[b][k];return fa[a][0];
}LL solve(int a, int x, int b, int y) {if (depth[a] < depth[b]) swap(a, b), swap(x, y);if (!x && !y && fa[a][0] == b) return -1;LL sa[2] = {INF, INF}, sb[2] = {INF, INF};sa[x] = f[a][x];sb[y] = f[b][y];for (int k = K - 1; k >= 0; --k) {if (depth[fa[a][k]] >= depth[b]) {LL na[2] = {INF, INF};for (int u = 0; u < 2; ++u) {for (int v = 0; v < 2; ++v) {na[v] = min(na[v], sa[u] + w[a][k][u][v]);}}memcpy(sa, na, sizeof na);a = fa[a][k];}}if (a == b) return sa[y] + g[b][y];for (int k = K - 1; k >= 0; --k) {if (fa[a][k] != fa[b][k]) {LL na[2] = {INF, INF}, nb[2] = {INF, INF};for (int u = 0; u < 2; ++u) {for (int v = 0; v < 2; ++v) {na[v] = min(na[v], sa[u] + w[a][k][u][v]);nb[v] = min(nb[v], sb[u] + w[b][k][u][v]);}}memcpy(sa, na, sizeof na);memcpy(sb, nb, sizeof nb);a = fa[a][k];b = fa[b][k];}}int l = fa[a][0];LL res0 = f[l][0] - f[a][1] - f[b][1] + sa[1] + sb[1] + g[l][0];LL res1 = f[l][1] - min(f[a][0], f[a][1]) - min(f[b][0], f[b][1])+ min(sa[0], sa[1]) + min(sb[0], sb[1]) + g[l][1];return min(res0, res1);
}int main() {ios::sync_with_stdio(0);cin.tie(0), cout.tie(0);cin >> n >> m >> type;for (int i = 1; i <= n; ++i) cin >> p[i];for (int i = 1; i < n; ++i) {int u, v;cin >> u >> v;add(u, v);add(v, u);}dfs_f(1, 0);dfs_g(1, 0);depth[1] = 1;dfs_fa(1, 0);// 初始化w数组for (int i = 1; i <= n; ++i)for (int k = 0; k < K; ++k)for (int x = 0; x < 2; ++x)for (int y = 0; y < 2; ++y)w[i][k][x][y] = INF;dfs_w(1, 0);while (m--) {int a, x, b, y;cin >> a >> x >> b >> y;cout << solve(a, x, b, y) << "\n";}return 0;
}