1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
| #include <bits/stdc++.h> using namespace std;
const int N = 1e5 + 10;
int n, m, d, a[N], b[N], fa[N], mp[N], ans, c[N]; vector<int> g[N];
void change(int u, int c) { if (a[u] <= c) { b[u] = a[u]; a[u] = c; } else if (b[u] <= c) b[u] = c; }
void dfs1(int u, int pre) { fa[u] = pre; for (int v : g[u]) { if (v == pre) continue; dfs1(v, u); if (a[v]) change(u, a[v] + 1); else if (mp[v]) change(u, 1); } }
void dfs2(int u) { for (int v : g[u]) { if (v == fa[u]) continue; if (c[u]) c[v] = c[u] + 1; if (mp[u]) c[v] = max(c[v], 1); if (a[u] && (a[u] != a[v] + 1 || (!a[v] && !mp[v]))) c[v] = max(c[v], a[u] + 1); else if (a[u] == a[v] + 1 && b[u]) c[v] = max(c[v], b[u] + 1); dfs2(v); } }
int main() { scanf("%d%d%d", &n, &m, &d); for (int i = 1; i <= m; i++) { int u; scanf("%d", &u); mp[u] = 1; } for (int i = 1; i < n; i++) { int u, v; scanf("%d%d", &u, &v); g[u].push_back(v); g[v].push_back(u); } dfs1(1, 0); dfs2(1); for (int i = 1; i <= n; i++) if (max(c[i], a[i]) <= d) ans++; printf("%d", ans); return 0; }
|