In this HackerEarth Tree query problem solution, A tree is a simple graph in which every two vertices are connected by exactly one path. You are given a rooted tree with n vertices and a lamp is placed on each vertex of the tree. 

You are given  queries of the following two types:
  1. v: You switch the lamp placed on the vertex v, that is, either from On to Off or Off to On.
  2. v: Determine the number of vertices connected to the subtree of v if you only consider the lamps that are in On state. In other words, determine the number of vertices in the subtree of v, such as u, that can reach from u by using only the vertices that have lamps in the On state.

HackerEarth Tree query problem solution


HackerEarth Tree query problem solution.

#include<bits/stdc++.h>
using namespace std;

#define ll long long
#define pb push_back

const int maxn = 5e5 + 20;

vector<int> adj[maxn];

int st[maxn] , ft[maxn] , now = -1 , is[maxn];

void dfs(int v , int p = -1)
{
    st[v] = ++now;

    for(auto u : adj[v])
        if(u != p)
            dfs(u , v);
    
    ft[v] = now + 1;
}

int mn[maxn * 4] , t[maxn * 4] , Add[maxn * 4] , n;

void build(int s = 0 , int e = n , int v = 1)
{
    t[v] = e - s;
    if(e - s < 2)
        return;
    
    int m = (s + e) / 2;
    build(s , m , 2 * v);
    build(m , e , 2 * v + 1);
}

void add(int l , int r , int val , int s = 0 , int e = n , int v = 1)
{
    if(l <= s && e <= r)
    {
        mn[v] += val;
        Add[v] += val;
        return;
    }
    if(r <= s || e <= l)
        return;

    int m = (s + e) / 2;

    add(l , r , val , s , m , 2 * v);
    add(l , r , val , m , e , 2 * v + 1);

    mn[v] = min(mn[2 * v] , mn[2 * v + 1]);
    t[v] = 0;
    if(mn[v] == mn[2 * v])
        t[v] += t[2 * v];
    if(mn[v] == mn[2 * v + 1])
        t[v] += t[2 * v + 1];
    mn[v] += Add[v];
}

pair<int , int> get(int l , int r , int s = 0 , int e = n , int v = 1)
{
    if(l <= s && e <= r)
        return make_pair(mn[v] , t[v]);
    if(r <= s || e <= l)
        return make_pair(1e9 , 0);
    
    int m = (s + e) / 2;
    auto x = get(l , r , s , m , 2 * v);
    auto y = get(l , r , m , e , 2 * v + 1);

    pair<int , int> ans = {1e9 , 0};
    ans.first = min(x.first , y.first);
    if(ans.first == x.first)
        ans.second += x.second;
    if(ans.first == y.first)
        ans.second += y.second;

    ans.first += Add[v];
    return ans;
}

int main()
{
    int q;
    scanf("%d%d", &n, &q);

    for(int i = 0; i < n - 1; i++)
    {
        int a , b;
        scanf("%d%d", &a, &b);
        a-- , b--;

        adj[a].pb(b);
        adj[b].pb(a);
    }

    dfs(0);
    build();

    while(q--)
    {
        int type , v;
        cin >> type >> v;
        v--;

        if(type == 1)
        {
            add(st[v] , ft[v] , is[v]? -1 : 1);
            is[v] ^= 1;
        }
        else
        {
            if(is[v])
                printf("0\n");
            else
                printf("%d\n", get(st[v] , ft[v]).second); 
        }
    }
}

Second solution

#include <bits/stdc++.h>
using namespace std;
 
typedef long long ll;
const int maxn = 5e5 + 14;
int n, q;
struct node{
    int m,n;
}  s[maxn<<2], emp={1717171717,0};
node operator &(const node &a,const node &b){
    if(a.m<b.m)return a;
    if(b.m<a.m)return b;
    return {a.m,a.n+b.n};
}
vector<int>g[maxn];
int st[maxn],en[maxn],Time,lazy[maxn<<2];
void shift(int id){
    if(!lazy[id])return;
    lazy[id<<1]+=lazy[id],lazy[id<<1|1]+=lazy[id];
    s[id<<1].m+=lazy[id];
    s[id<<1|1].m+=lazy[id];  
    lazy[id]=0;
}
node get(int st,int en,int l=0,int r=n,int id=1){
    if(en<=l || r<=st)return emp;
    if(st<=l && r<=en)return s[id];
    shift(id);
    int mid=(l+r)>>1;
    return get(st,en,l,mid,id<<1)&get(st,en,mid,r,id<<1|1);
}
void add(int st,int en,int v=1,int l=0,int r=n,int id=1){
    if(en<=l || r<=st)return;
    if(st<=l && r<=en){
        s[id].m+=v;
        lazy[id]+=v;
        return;
    }
    shift(id);
    int mid=(l+r)>>1;
    add(st,en,v,l,mid,id<<1);
    add(st,en,v,mid,r,id<<1|1);
    s[id]=s[id<<1]&s[id<<1|1];
}
void build(int l=0,int r=n,int id=1){
    s[id].n=r-l;
    int mid=(l+r)>>1;
    if(r-l>1)
        build(l,mid,id<<1),build(mid,r,id<<1|1);
}
void sten(int v=0,int p=-1){
    st[v]=Time++;
    for(auto &u:g[v])
        if(u!=p)
            sten(u,v);
    en[v]=Time;
}
void addVer(int v, int x){
    add(st[v], en[v], x);
}
bool state[maxn];
int main(){
    ios::sync_with_stdio(0), cin.tie(0);
    cin >> n >> q;
    for(int i = 1; i < n; i++){
        int v, u;
        cin >> v >> u;
        v--, u--;
        g[v].push_back(u);
        g[u].push_back(v);
    }
    sten();
    build();
    while(q--){
        int t, v;
        cin >> t >> v;
        v--;
        if(t == 1)
            addVer(v, (state[v] ^= 1) ? +1 : -1);
        else
            cout << !state[v] * get(st[v], en[v]).n << '\n';
    }
}