感觉自己树上贪心太差 (NOIP 血的教训),所以来练几道。
为了方便叙述,我们将 $m$个点称为放火点,有炸药的点叫做炸药点。
首先考虑一下能不能 $dp$
记 $dp[i][j]$表示第 $i$个点时间为 $j$所需要的最小放火点数,所求即为第一个 $dp[1][j]\leq m$。
这状态都爆了,这怎么转移呀,不能光枚举儿子节点呀。那枚举所有节点也会 $gg$呀,貌似是个 $O(n^3)$
然后另一个比较明显的思路是可以二分答案 $now$表示当前时间,然后算最少需要多少个放火点。
一个贪心的想法是肯定是把放火点放得离根越近越好。
但是有时候迫不得已的时候必须要把当前点设为放火点,比如说子树最远的炸药点距离 $=now$的时候。但子树之间的放火点也有可能相互影响。
所以我们设 $mn[u]$表示 $u$的子树中离 $u$最近的放火点,$mx[u]$表示离 $u$最远的未被点着的炸药点。
然后转移就比较简单了直接看代码吧,主要是要找到这些状态比较麻烦。
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<vector>
#include<ctype.h>
#include<queue>
#include<map>
#define Re register
#define fo(i, a, b) for (Re int i = (a); i <= (b); ++i)
#define fd(i, a, b) for (Re int i = (a); i >= (b); --i)
#define edge(i, u) for (Re int i = head[u], v = e[i].v; i; i = e[i].nxt, v = e[i].v)
#define pb push_back
#define F first
#define S second
#define ll long long
#define inf 1000000007
#define mp std::make_pair
#define eps 1e-4
#define mod 989381
#define lowbit(x) (x & -x)
#define N 2000005
#define clr(arr) memset(arr, 0, sizeof arr)
#define bset std::bitset<N>
inline int read ()
{
Re int x = 0; Re bool flag = 0; Re char ch = getchar();
while (!isdigit(ch) && ch != '-') ch = getchar();
if (ch == '-') flag = 1, ch = getchar();
while (isdigit(ch)) x = (x << 3) + (x << 1) + (ch ^ 48), ch = getchar();
if (flag) x = -x; return x;
}
int n, m, a[N], now;
int head[N], tot;
struct edge {
int nxt, v;
} e[N << 1];
inline void addedge (int u, int v)
{
e[++tot] = (edge) {head[u], v};
head[u] = tot;
}
int mx[N], mn[N], in[N], ans;
inline void dfs (int u, int fa)
{
mx[u] = -inf; mn[u] = inf;
edge (i, u)
{
if (v == fa) continue;
dfs(v, u);
mx[u] = std::max(mx[v] + 1, mx[u]);
mn[u] = std::min(mn[v] + 1, mn[u]);
}
if (a[u] && mn[u] > now) mx[u] = std::max(mx[u], 0);
if (mx[u] + mn[u] <= now) mx[u] = -inf;
if (mx[u] == now) ++ans, mx[u] = -inf, mn[u] = 0;
}
inline bool check ()
{
ans = 0;
dfs(1, 0);
if (mx[1] >= 0) ++ans;
return m >= ans;
}
int main ()
{
n = read(); m = read();
fo (i, 1, n) a[i] = read();
fo (i, 2, n)
{
int u = read(), v = read();
addedge(u, v);
addedge(v, u);
}
int l = 0, r = n;
while (l < r)
{
now = l + r >> 1;
if (!check()) l = now + 1; else r = now;
}
printf("%d", l);
return 0;
}
0 条评论