我是菜鸡,我是菜鸡,我是菜鸡。。。。重要的事说三遍
算是第一次做树形dp的题吧,不太难。。
园林构成一棵树,root为1,Hi从root出发,有k个园林必须玩,每个园林游玩后会得到权值w[i],最多玩M个园林。
经过的园林必须玩,问可得到的最大权值和。
题目链接:http://hihocoder.com/problemset/problem/1104
经典的树形dp,dp[i][j],表示以i为根的子树选j个结点可得的最大权值和。
这样dp[root][num] = g[son_num][num-1] + w[root];
g[son_num][num-1]类似于对root的儿子结点的子树做背包,g[i][j] = max( g[i-1][j-p] + dp[child_id][p] );
题解不错:http://hihocoder.com/discuss/question/2743
对于必须玩的k个点,他的处理方法我没有看懂。。。。discussion里的不错。。。游玩结点i,那么其祖先结点fa[i]必要也游玩,将这些点设为must,
将must构成的子树缩成一个点,rebuild a new tree,就成了上面说的最普通的树形dp了。
#include<bits/stdc++.h> using namespace std;
const int maxn = + ;
int n, k, m; vector<int> G[maxn], G1[maxn];
int fa[maxn], w[maxn], must[maxn], vis[maxn];
int f[maxn][maxn]; void init(){
for( int i = ; i < maxn; ++i ){
G[i].clear();
G1[i].clear();
}
memset( fa, -, sizeof(fa) );
memset( f, -, sizeof(f) );
memset( must, , sizeof(must) );
} void dfs( int u, int pa ){
fa[u] = pa;
for( int i = ; i < G[u].size(); ++i ){
int v = G[u][i];
if( v == pa )
continue;
dfs( v, u );
}
} void dfs1( int u, int pa ){
if( pa != - ){
if( vis[pa] && !vis[u] )
G1[].push_back(u);
else if( !vis[pa] && !vis[u] )
G1[pa].push_back(u);
} for( int i = ; i < G[u].size(); ++i ){
int v = G[u][i];
if( v == pa )
continue;
dfs1( v, u );
}
} int dp( int root, int num ){
if( num == )
return ;
if( f[root][num] != - )
return f[root][num]; int g[maxn][maxn], son_num = G1[root].size();
memset( g, , sizeof(g) );
for( int i = ; i <= son_num; ++i ){
int child = G1[root][i-];
for( int j = ; j < num; ++j ){
for( int p = ; p <= j; ++p ){
g[i][j] = max( g[i][j], g[i-][j-p] + dp( child, p ) );
}
}
} f[root][num] = g[son_num][num-] + w[root];
return f[root][num];
} void print( int num ){
for( int i = ; i < G1[num].size(); ++i )
cout << G1[num][i] << " ";
cout << endl;
} void solve(){
dfs(, -); //must节点压缩
int ans = , cnt = ;
memset( vis, , sizeof(vis) );
for( int i = ; i <= n; ++i ){
if(must[i]){
int u = i;
while(u != - && !vis[u]){
vis[u] = ;
cnt++, ans += w[u];
u = fa[u];
}
}
}
//cout << "ans: " << ans << endl; if( cnt > m ){
printf("-1\n");
return;
} //rebuild tree
dfs1( , - ); w[] = ans;
//cout << "m-cnt: " << m - cnt << endl;
ans = dp( , m-cnt+ );
//cout << "f: " << f[8][3] << endl; printf("%d\n", ans);
} int main(){
//freopen("1.in", "r", stdin);
init();
scanf("%d%d%d", &n, &k, &m);
for( int i = ; i <= n; ++i ){
scanf("%d", &w[i]);
} int t;
for( int i = ; i <= k; ++i ){
scanf("%d", &t);
must[t] = ;
} int a, b;
for( int i = ; i <= n-; ++i ){
scanf("%d%d", &a, &b);
G[a].push_back(b), G[b].push_back(a);
} solve(); return ;
}
当然,也可以不需要g数组,直接当一维背包来做
dp[root][x] = max( dp[root][x], dp[root][x-y] + dp[child_id][y] );
#include<bits/stdc++.h> using namespace std;
const int maxn = + ;
int n, k, m; vector<int> G[maxn], G1[maxn];
int fa[maxn], w[maxn], must[maxn], vis[maxn];
int f[maxn][maxn]; void init(){
for( int i = ; i < maxn; ++i ){
G[i].clear();
G1[i].clear();
}
memset( fa, -, sizeof(fa) );
memset( f, -, sizeof(f) );
memset( must, , sizeof(must) );
} void dfs( int u, int pa ){
fa[u] = pa;
for( int i = ; i < G[u].size(); ++i ){
int v = G[u][i];
if( v == pa )
continue;
dfs( v, u );
}
} void dfs1( int u, int pa ){
if( pa != - ){
if( vis[pa] && !vis[u] )
G1[].push_back(u);
else if( !vis[pa] && !vis[u] )
G1[pa].push_back(u);
} for( int i = ; i < G[u].size(); ++i ){
int v = G[u][i];
if( v == pa )
continue;
dfs1( v, u );
}
} void dp(int root, int pa){
f[root][] = w[root];
for( int i = ; i < G1[root].size(); ++i ){
int v = G1[root][i];
if( v == pa )
continue;
dp( v, root );
for( int x = m; x >= ; --x ){
for( int y = ; y < x; ++y ){
f[root][x] = max( f[root][x], f[root][x-y] + f[v][y] );
}
}
}
} void solve(){
dfs(, -); //must½ÚµãѹËõ
int ans = , cnt = ;
memset( vis, , sizeof(vis) );
for( int i = ; i <= n; ++i ){
if(must[i]){
int u = i;
while(u != - && !vis[u]){
vis[u] = ;
cnt++, ans += w[u];
u = fa[u];
}
}
}
//cout << "ans: " << ans << endl; if( cnt > m ){
printf("-1\n");
return;
} //rebuild tree
dfs1( , - ); w[] = ans;
dp(, -);
printf("%d\n", f[][m-cnt+]);
} int main(){
//freopen("1.in", "r", stdin);
init();
scanf("%d%d%d", &n, &k, &m);
for( int i = ; i <= n; ++i ){
scanf("%d", &w[i]);
} int t;
for( int i = ; i <= k; ++i ){
scanf("%d", &t);
must[t] = ;
} int a, b;
for( int i = ; i <= n-; ++i ){
scanf("%d%d", &a, &b);
G[a].push_back(b), G[b].push_back(a);
} solve(); return ;
}