1904D1 - Set To Max (Easy Version)/1904D2 - Set To Max (Hard Version)
Editorial 1:
< ```c++
include <bits/stdc++.h>
include <ext/pb_ds/assoc_container.hpp>
include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds; using namespace std;
define pb push_back
define ff first
define ss second
typedef long long ll; typedef long double ld; typedef pair<int, int> pii; typedef pair<ll, ll> pll; typedef pair<ld, ld> pld;
const int INF = 1e9; const ll LLINF = 1e18; const int MOD = 1e9 + 7;
template using sset = tree<K, null_type, less, rb_tree_tag, tree_order_statistics_node_update>;
inline ll ceil0(ll a, ll b) { return a / b + ((a ^ b) > 0 && a % b); }
void setIO() { ios_base::sync_with_stdio(0); cin.tie(0); }
const int MAXN = 2e5; const int MAXQ = 2e5;
int seg[4*MAXN + 5]; int tag[4*MAXN + 5]; int tim; vector g[MAXN + 5]; int in[MAXN + 5], out[MAXN + 5];
void push_down(int cur){ if(!tag[cur]) return; for(int i = cur*2 + 1; i <= cur*2 + 2; i++){ seg[i] += tag[cur]; tag[i] += tag[cur]; } tag[cur] = 0; }
void update(int l, int r, int v, int ul = 0, int ur = tim — 1, int cur = 0){ if(l <= ul && ur <= r){ seg[cur] += v; tag[cur] += v; return; } push_down(cur); int mid = (ul + ur)/2; if(l <= mid) update(l, r, v, ul, mid, cur*2 + 1); if(r > mid) update(l, r, v, mid + 1, ur, cur*2 + 2); seg[cur] = max(seg[cur*2 + 1], seg[cur*2 + 2]); }
int query(int l, int r, int ul = 0, int ur = tim — 1, int cur = 0){ if(l <= ul && ur <= r) return seg[cur]; push_down(cur); int mid = (ul + ur)/2; if(r <= mid) return query(l, r, ul, mid, cur*2 + 1); if(l > mid) return query(l, r, mid + 1, ur, cur*2 + 2); return max(query(l, r, ul, mid, cur*2 + 1), query(l, r, mid + 1, ur, cur*2 + 2)); }
void dfs1(int x, int p = 0){ in[x] = tim++; for(int i : g[x]){ if(i == p) continue; dfs1(i, x); } out[x] = tim — 1; }
vector<pair<int, vector>> que[MAXN + 5]; int nxt[MAXN + 5]; int ans[MAXQ + 5]; int n, q;
void dfs2(int x, int p = 0){ for(auto &i : que[x]){ vector skip; bool found = false; for(int j : i.ss){ if(j == x){ found = true; break; } if(in[j] <= in[x] && in[x] <= out[j]){ skip.pb({0, in[nxt[j]] — 1}); skip.pb({out[nxt[j]] + 1, tim — 1}); } else { skip.pb({in[j], out[j]}); } } if(found) continue; sort(skip.begin(), skip.end()); int prv = 0; for(pii j : skip){ if(prv < j.ff) ans[i.ff] = max(ans[i.ff], query(prv, j.ff — 1)); prv = max(prv, j.ss + 1); } if(prv <= tim — 1) ans[i.ff] = max(ans[i.ff], query(prv, tim — 1)); } update(0, tim — 1, 1); for(int i : g[x]){ if(i == p) continue; update(in[i], out[i], -2); nxt[x] = i; dfs2(i, x); update(in[i], out[i], 2); } update(0, tim — 1, -1); }
int main(){ setIO(); cin >> n >> q; for(int i = 0; i < n — 1; i++){ int a, b; cin >> a >> b; g[a].pb(b); g[b].pb(a); } tim = 0; dfs1(1); for(int i = 0; i < q; i++){ int x, k; cin >> x >> k; vector v(k); for(int j = 0; j < k; j++) cin >> v[j]; que[x].pb({i, v}); } for(int i = 2; i <= n; i++) update(in[i], out[i], 1); dfs2(1); for(int i = 0; i < q; i++){ cout << ans[i] << endl; } } ```
Editorial 2:
In a tree, one of the farthest nodes from some node $$$x$$$ is one of the two endpoints of the diameter.
Let's try to find the diameter of the connected subgraph node $$$x$$$ is in after the nodes $$$a_{1 \dots n}$$$ are removed.
Consider an euler tour of the tree and order the nodes by their inorder traversal. When $$$k$$$ nodes are removed, the remaining nodes form $$$O(k)$$$ contiguous intervals in the tour.
Let's build a segtree/sparse table where each node stores the diameter (as a pair of nodes) for the nodes with $$$in$$$ values in the range $$$[l, r]$$$. To merge two diameters, we can enumerate all $$$4 \choose 2$$$ ways to pick the new diameter and take the best one.
To answer a query, we can first generate a list of banned intervals (just like solution 1) and use that list to generate the list of unbanned intervals. Then we can query our segtree for the diameter of each of ranges. Finally, we can combine the answers of the seperate queries to obtain the diameter of the connected subgraph. We know the farthest node from node $$$x$$$ is one of the two endpoints, so it suffices to just manually check the distance of those two nodes.
Final complexity is $$$O(n \log^2 n + \sum k \log n)$$$.
#include <bits/stdc++.h>
#define sz(x) ((int)(x.size()))
#define all(x) x.begin(), x.end()
#define pb push_back
#define eb emplace_back
const int MX = 2e5 +10, int_max = 0x3f3f3f3f;
using namespace std;
//lca template start
vector<int> dep, sz, par, head, tin, tout, tour;
vector<vector<int>> adj;
int n, ind, q;
void dfs(int x, int p){
sz[x] = 1;
dep[x] = dep[p] + 1;
par[x] = p;
for(auto &i : adj[x]){
if(i == p) continue;
dfs(i, x);
sz[x] += sz[i];
if(adj[x][0] == p || sz[i] > sz[adj[x][0]]) swap(adj[x][0], i);
}
if(p != 0) adj[x].erase(find(all(adj[x]), p));
}
void dfs2(int x, int p){
tour[ind] = x;
tin[x] = ind++;
for(auto &i : adj[x]){
if(i == p) continue;
head[i] = (i == adj[x][0] ? head[x] : i);
dfs2(i, x);
}
tout[x] = ind;
}
int k_up(int u, int k){
if(dep[u] <= k) return -1;
while(k > dep[u] - dep[head[u]]){
k -= dep[u] - dep[head[u]] + 1;
u = par[head[u]];
}
return tour[tin[u] - k];
}
int lca(int a, int b){
while(head[a] != head[b]){
if(dep[head[a]] > dep[head[b]]) swap(a, b);
b = par[head[b]];
}
if(dep[a] > dep[b]) swap(a, b);
return a;
}
int dist(int a, int b){
return dep[a] + dep[b] - 2*dep[lca(a, b)];
}
//lca template end
//segtree template start
#define ff first
#define ss second
int dist(pair<int, int> a){
return dist(a.ff, a.ss);
}
pair<int, int> merge(pair<int, int> a, pair<int, int> b){
auto p = max(pair(dist(a), a), pair(dist(b), b));
for(auto x : {a.ff, a.ss}){
for(auto y : {b.ff, b.ss}){
if(x == 0 || y == 0) continue;
p = max(p, pair(dist(pair(x, y)), pair(x, y)));
}
}
return p.ss;
}
pair<int, int> mx[MX*4];
#define LC(k) (2*k)
#define RC(k) (2*k +1)
void update(int p, int v, int k, int L, int R){
if(L + 1 == R){
mx[k] = {tour[p], tour[p]};
return ;
}
int mid = (L + R)/2;
if(p < mid) update(p, v, LC(k), L, mid);
else update(p, v, RC(k), mid, R);
mx[k] = merge(mx[LC(k)], mx[RC(k)]);
}
void query(int qL, int qR, vector<pair<int, int>>& ret, int k, int L, int R){
if(qR <= L || R <= qL) return ;
if(qL <= L && R <= qR){
ret.push_back(mx[k]);
return ;
}
int mid = (L + R)/2;
query(qL, qR, ret, LC(k), L, mid);
query(qL, qR, ret, RC(k), mid, R);
}
//segtree template end
int query(vector<int> arr, int x){
vector<pair<int, int>> banned, ret;
for(int u : arr){
if(lca(u, x) == u){
u = k_up(x, dep[x] - dep[u] - 1);
banned.push_back({0, tin[u]});
banned.push_back({tout[u], n});
}else{
banned.push_back({tin[u], tout[u]});
}
}
sort(all(banned), [&](pair<int, int> a, pair<int, int> b){
return (a.ff < b.ff) || (a.ff == b.ff && a.ss > b.ss);
});
vector<pair<int, int>> tbanned; //remove nested intervals
int mx = 0;
for(auto [a, b] : banned){
if(b <= mx) continue;
else if(a != b){
tbanned.pb({a, b});
mx = b;
}
}
banned = tbanned;
int tim = 0;
for(auto [a, b] : banned){
if(tim < a)
query(tim, a, ret, 1, 0, n);
tim = b;
}
if(tim < n)
query(tim, n, ret, 1, 0, n);
pair<int, int> dia = pair(x, x);
for(auto p : ret) dia = merge(dia, p);
int ans = max(dist(x, dia.ff), dist(x, dia.ss));
return ans;
}
void solve(){
cin >> n >> q;
dep = sz = par = head = tin = tout = tour = vector<int>(n+1, 0);
adj = vector<vector<int>>(n+1);
for(int i = 1; i<n; i++){
int a, b;
cin >> a >> b;
adj[a].push_back(b);
adj[b].push_back(a);
}
dfs(1, 0);
head[1] = 1;
dfs2(1, 0);
for(int i = 1; i<=n; i++){
update(tin[i], dep[i], 1, 0, n);
}
for(int i = 1; i<=q; i++){
int x, k;
cin >> x >> k;
vector<int> arr(k);
for(int& y : arr) cin >> y;
cout << query(arr, x) << "\n";
}
}
signed main(){
cin.tie(0) -> sync_with_stdio(0);
int T = 1;
//cin >> T;
for(int i = 1; i<=T; i++){
//cout << "Case #" << i << ": ";
solve();
}
return 0;
}
Can we represent the conditions as a graph?
Lets rewrite the condition that node $$$a$$$ must be smaller than node $$$b$$$ as a directed edge from $$$a$$$ to $$$b$$$. Then, we can assign each node a value based on the topological sort of this new directed graph. If this directed graph had a cycle, it is clear that there is no way to order the nodes.
With this in mind, we can try to construct a graph that would have these properties. Once we have the graph, we can topological sort to find the answer.
For now, let's consider the problem if it only had type 1 requirements (type 2 requirements can be done very similarly).
Thus, the problem reduces to "given a path and a node, add a directed edge from the node to every node in that path." To do this, we can use binary lifting. For each node, create $$$k$$$ dummy nodes, the $$$i$$$th of which represents the minimum number from the path between node $$$a$$$ and the $$$2^i$$$th parent of $$$a$$$. Now, we can draw a directed edge from the the $$$i$$$th dummy node of $$$a$$$ to the $$$i-1$$$th dummy node of $$$a$$$ and the $$$i-1$$$th dummy node of the $$$2^{i-1}$$$th parent of $$$a$$$.
Now, to add an edge from any node to a vertical path of the tree, we can repeatedly add an edge from that node to the largest node we can. This will add $$$O(\log n)$$$ edges per requirement.
The final complexity is $$$O((n+m)\log n)$$$ time and $$$O((n+m)\log n)$$$.
#pragma GCC optimize("O3,unroll-loops")
#pragma GCC target("avx,avx2,fma")
#pragma GCC target("sse4,popcnt,abm,mmx,tune=native")
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
using namespace std;
#define pb push_back
#define ff first
#define ss second
typedef long long ll;
typedef long double ld;
typedef pair<int, int> pii;
typedef pair<ll, ll> pll;
typedef pair<ld, ld> pld;
const int INF = 1e9;
const ll LLINF = 1e18;
const int MOD = 1e9 + 7;
template<class K> using sset = tree<K, null_type, less<K>, rb_tree_tag, tree_order_statistics_node_update>;
inline ll ceil0(ll a, ll b) {
return a / b + ((a ^ b) > 0 && a % b);
}
void setIO() {
ios_base::sync_with_stdio(0); cin.tie(0);
}
const int MAXN = 2e5;
const int MAXQ = 2e5;
int seg[4*MAXN + 5];
int tag[4*MAXN + 5];
int tim;
vector<int> g[MAXN + 5];
int in[MAXN + 5], out[MAXN + 5];
void push_down(int cur){
if(!tag[cur]) return;
for(int i = cur*2 + 1; i <= cur*2 + 2; i++){
seg[i] += tag[cur];
tag[i] += tag[cur];
}
tag[cur] = 0;
}
void update(int l, int r, int v, int ul = 0, int ur = tim - 1, int cur = 0){
if(l <= ul && ur <= r){
seg[cur] += v;
tag[cur] += v;
return;
}
push_down(cur);
int mid = (ul + ur)/2;
if(l <= mid) update(l, r, v, ul, mid, cur*2 + 1);
if(r > mid) update(l, r, v, mid + 1, ur, cur*2 + 2);
seg[cur] = max(seg[cur*2 + 1], seg[cur*2 + 2]);
}
int query(int l, int r, int ul = 0, int ur = tim - 1, int cur = 0){
if(l <= ul && ur <= r) return seg[cur];
push_down(cur);
int mid = (ul + ur)/2;
if(r <= mid) return query(l, r, ul, mid, cur*2 + 1);
if(l > mid) return query(l, r, mid + 1, ur, cur*2 + 2);
return max(query(l, r, ul, mid, cur*2 + 1), query(l, r, mid + 1, ur, cur*2 + 2));
}
void dfs1(int x, int p = 0){
in[x] = tim++;
for(int i : g[x]){
if(i == p) continue;
dfs1(i, x);
}
out[x] = tim - 1;
}
vector<pair<int, vector<int>>> que[MAXN + 5];
int nxt[MAXN + 5];
int ans[MAXQ + 5];
int n, q;
void dfs2(int x, int p = 0){
for(auto &i : que[x]){
vector<pii> skip;
bool found = false;
for(int j : i.ss){
if(j == x){
found = true;
break;
}
if(in[j] <= in[x] && in[x] <= out[j]){
skip.pb({0, in[nxt[j]] - 1});
skip.pb({out[nxt[j]] + 1, tim - 1});
} else {
skip.pb({in[j], out[j]});
}
}
if(found) continue;
sort(skip.begin(), skip.end());
int prv = 0;
for(pii j : skip){
if(prv < j.ff) ans[i.ff] = max(ans[i.ff], query(prv, j.ff - 1));
prv = max(prv, j.ss + 1);
}
if(prv <= tim - 1) ans[i.ff] = max(ans[i.ff], query(prv, tim - 1));
}
update(0, tim - 1, 1);
for(int i : g[x]){
if(i == p) continue;
update(in[i], out[i], -2);
nxt[x] = i;
dfs2(i, x);
update(in[i], out[i], 2);
}
update(0, tim - 1, -1);
}
int main(){
setIO();
cin >> n >> q;
for(int i = 0; i < n - 1; i++){
int a, b;
cin >> a >> b;
g[a].pb(b);
g[b].pb(a);
}
tim = 0;
dfs1(1);
for(int i = 0; i < q; i++){
int x, k;
cin >> x >> k;
vector<int> v(k);
for(int j = 0; j < k; j++) cin >> v[j];
que[x].pb({i, v});
}
for(int i = 2; i <= n; i++) update(in[i], out[i], 1);
dfs2(1);
for(int i = 0; i < q; i++){
cout << ans[i] << endl;
}
}