0%

Hdu 6060 - RXD and dividing (dfs)

题目链接:
http://acm.hdu.edu.cn/showproblem.php?pid=6060

题目大意:
一颗以结点11为根的树,对2n2-n的结点作一个划分,但至多使用k个集合,对每个分块产生的集合中加入根节点1,然后计算集合内结点互相可达最少需要的权值的和(非最小生成树,可能需要经过其他结点),求权值和的最大可能值

分析:
若要使权值和最大,考虑每一条边的贡献次数,如果从父亲向下的一条边,子树大小小于k,设为x,下面至多被分为x个分块,该边被贡献x次,否则至多分成k个分块,该边贡献k次,所以每条边最大情况下应该被计算min(x,k)min(x,k)次,x为该边下方结点的子树大小

代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#include<bits/stdc++.h>

using namespace std;

const int N = 1000006;
const int M = N * 2;
const long long INF = 1e11 + 10;

int n, k;
long long dis[N];
pair<long long, int> son[N];
long long ans;
int fa[N];
int use[N];
int sz[N];

struct Edge{
int v, nxt;
long long w;
}e[M];
int h[N];
int cnt;

struct Node{
int u, w;

Node(){}
Node(int x, int y) : u(x), w(y){}
bool operator <(const Node &x)const{
return w > x.w;
}
};

void init(){
memset(h, -1, sizeof(h));
cnt = 0;
}

void add(int x, int y, long long z){
e[cnt].v = y;
e[cnt].w = z;
e[cnt].nxt = h[x];
h[x] = cnt++;
}

void dij(){
for(int i = 2; i <= n; i++) dis[i] = INF;
dis[1] = 0;

priority_queue<Node> Q;
Q.push(Node(1, dis[1]));
while(!Q.empty()){
Node temp = Q.top(); Q.pop();
int u = temp.u;
if(temp.w > dis[u]) continue;
for(int k = h[u]; ~k; k = e[k].nxt){
int v = e[k].v;
if(dis[v] > dis[u] + e[k].w){
dis[v] = dis[u] + e[k].w;
Q.push(Node(v, dis[v]));
}
}
}
}

int q[N];
void bfs(){
memset(fa, -1, sizeof(fa));
fa[1] = 0;

int be = 0, ed = 0;
q[ed++] = 1;
while(be < ed){
int u = q[be++];
for(int k = h[u]; ~k; k = e[k].nxt){
int v = e[k].v;
if(fa[v] != -1) continue;
fa[v] = u;
q[ed++] = v;
}
}

for(int i = n - 1; i >= 0; i--){
sz[q[i]]++;
sz[fa[q[i]]] += sz[q[i]];
}
}

int find(int x){
if(fa[x] == 1) return fa[x] = x;
else return x == fa[x] ? x : fa[x] = find(fa[x]);
}

void solve(int u){
for(int kk = h[u]; ~kk; kk = e[kk].nxt){
int v = e[kk].v;
if(fa[v] != u) continue;
if(sz[v] > k){
ans -= (sz[v] - k) * (dis[v] - dis[fa[v]]);
solve(v);
}
}
}

int main(){
while(~scanf("%d%d", &n, &k)){
init();
ans = 0;
memset(use, 0, sizeof(use));
memset(sz, 0, sizeof(sz));

int u, v;
long long w;
for(int i = 1; i < n; i++){
scanf("%d%d%lld", &u, &v, &w);
add(u, v, w);
add(v, u, w);
}

dij();

for(int i = 2; i <= n; i++)
ans += dis[i];

bfs();

solve(1);

printf("%lld\n", ans);
}

return 0;
}