P3177 [HAOI2015]树上染色 树形dp

题目描述

有一棵点数为 nn 的树,树边有边权。给你一个在 00nn 之内的正整数 kk,你要在这棵树中选择 kk 个点,将其染成黑色,并将其他 的 nkn−k 个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间的距离的和的受益。问受益最大值是多少。

输入格式

第一行包含两个整数 n,kn,k
第二到 nn 行每行三个正整数 fr,to,disfr,to,dis,表示该树中存在一条长度为 disdis 的边 (fr,to)(fr,to)。输入保证所有点之间是联通的。

输出格式

输出一个正整数,表示收益的最大值。

样例输入

3 1
1 2 1
1 3 2

样例输出

3

说明/提示

对于 100%100\% 的数据,0n,k20000≤n,k≤2000

题解

树形 dpdp
分别考虑每条边对答案产生的贡献
f[x][i]f[x][i] 表示在以 xx 为根的子树中选择 ii 个黑节点对答案的贡献
每条边的贡献为 (( 边一侧黑点数 * 另一侧黑点数 + 边一侧白点数 * 另一侧白点数))*边权
枚举 sizsiz,得转移方程为 :f[x][j]=max(f[x][j],f[x][jl]+f[i>to][l]+val):f[x][j] = max(f[x][j], f[x][j - l] + f[i->to][l] + val)
其中 val=(l(kl)+(siz[i>to]l)(nksiz[i>to]+l))i>valval = (l * (k - l) + (siz[i->to] - l) * (n - k - siz[i->to] + l)) * i->val

code

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#define int long long
using namespace std;
const int N = 2e3 + 5;
int read() {
	int x = 0, f = 1; char ch;
	while(! isdigit(ch = getchar())) (ch == '-') && (f = -f);
	for(x = ch^48; isdigit(ch = getchar()); x = (x << 3) + (x << 1) + (ch ^ 48));
	return x * f;
}
template <class T> T Max(T a, T b) { return a > b ? a : b; }
template <class T> T Min(T a, T b) { return a < b ? a : b; }
struct Edge{
	int to, val;
	Edge *nxt;
	Edge(int to, int val, Edge *nxt) : to(to), val(val), nxt(nxt) {}
} *head[N];
void add(int x, int y, int z) {head[x] = new Edge{y, z, head[x]};}
int n, k, siz[N], f[N][N];
void dfs(int x, int fa) {
	siz[x] = 1; f[x][0] = f[x][1] = 0;
	for(Edge *i = head[x]; i; i = i->nxt) {
		if(i->to == fa) continue;
		dfs(i->to, x); siz[x] += siz[i->to];
		for(int j = min(siz[x], k); j >= 0; -- j) {
			for(int l = 0; l <= min(siz[i->to], j); ++ l) {
				if(f[x][j - l] == -1) continue;
				int val = (l * (k - l) + (siz[i->to] - l) * (n - k - siz[i->to] + l)) * i->val;
				f[x][j] = max(f[x][j], f[x][j - l] + f[i->to][l] + val);
			}
		}
	}
}
signed main() {
	n = read(); k = read();
	for(int i = 1, x, y, z; i < n; ++ i) x = read(), y = read(), z = read(), add(x, y, z), add(y, x, z); 
	memset(f, -1, sizeof(f)); dfs(1, 0); printf("%lld\n", f[1][k]);
	return 0;
}