In this HackerEarth Mojtabas Trees and Arpas Queries March HourStorm problem solution Mojtaba has two trees, each of them has n vertices. Arpa has q queries, each in type v, u, x, y. Let s be the set of vertices in the path from v to u in the first tree and p be the set of vertices in the path from x to y in the second tree. Mojtaba has to calculate the size of s intersects p for each query. Help him!


HackerEarth Mojtabas Trees and Arpas Queries <March HourStorm> problem solution


HackerEarth Mojtabas Trees and Arpas Queries March HourStorm problem solution.

#include <bits/stdc++.h>
using namespace std; 

const int maxN = 300 * 1000 + 100; 
const int maxL = 20; 

typedef pair<int,int> pii; 

vector<int> c[2][maxN];
vector<pii> que[maxN];

int st[maxN], en[maxN];
int seg[4*maxN]; 

int h[2][maxN], par[2][maxN][maxL];

void dfs_lca(int t, int u) { 
    for(int k = 1; k < maxL; k++)
        par[t][u][k] = par[t][par[t][u][k-1]][k-1];

    for( auto x : c[t][u] )
        if( x != par[t][u][0] ) { 
            par[t][x][0] = u; 
            h[t][x] = h[t][u] + 1;
            dfs_lca(t, x);
        }
}

void dfs_time(int u, int p) {
    static int ind = 0;
    st[u] = ind++;
    for( auto x : c[0][u] ) 
        if( x != p ) 
            dfs_time(x, u);
    en[u] = ind;
}

int get_lca(int t, int u, int v) { 
    if( h[t][u] < h[t][v] ) swap(u, v);
    
    int diff = h[t][u] - h[t][v];
    for(int k = 0; k < maxL; k++)
        if( (diff>>k) & 1 ) 
            u = par[t][u][k];

    if( u == v ) return u; 

    for(int k = maxL - 1; k >= 0; k--)
        if( par[t][u][k] != par[t][v][k] ) { 
            u = par[t][u][k];
            v = par[t][v][k];
        }

    return par[t][u][0];
}

void query(int i, int u, int v, int x, int y) { 
    que[x].push_back( pii(i, u) );
    que[x].push_back( pii(-i, v) );
    que[y].push_back( pii(-i, u) );
    que[y].push_back( pii(i, v) );
}

int n;
int ans[maxN];

void seg_add(int ql, int qr, int qv, int xl=0, int xr=n, int ind=1) { 
    if( xr <= ql || qr <= xl ) return;
    if( ql <= xl && xr <= qr ) {
        seg[ind] += qv;
        return;
    }
    int xm = (xl+xr)/2;
    seg_add(ql, qr, qv, xl, xm, ind * 2);
    seg_add(ql, qr, qv, xm, xr, ind * 2 + 1);
}

int seg_get(int qp, int xl=0, int xr=n, int ind=1) { 
    if( xr - xl == 1 ) 
        return seg[ind];

    int xm = (xl+xr)/2;
    if( qp < xm ) 
        return seg[ind] + seg_get(qp, xl, xm, ind*2);
    return seg[ind] + seg_get(qp, xm, xr, ind*2+1);
}

void dfs_solve(int u, int p) {
    seg_add(st[u], en[u], 1);


    for(auto q: que[u]) { 
        int id = abs(q.first);
        int v = seg_get(st[q.second]);
        if( q.first < 0 ) 
            ans[id] -= v;
        else
            ans[id] += v;
    }

    for( auto x : c[1][u] ) 
        if( x != p ) 
            dfs_solve(x, u);

    seg_add(st[u], en[u], -1);
}

int main() { 
    ios::sync_with_stdio(false);
    cin.tie(0);

    int q;
    cin >> n >> q; 

    for(int t = 0; t < 2; t++) {
        c[t][0].push_back(1);
        for(int i = 0; i + 1 < n; i++) { 
            int u, v;
            cin >> u >> v;
            c[t][u].push_back(v);
            c[t][v].push_back(u);
        }
    }

    n++;

    dfs_lca(0, 0);
    dfs_lca(1, 0);

    dfs_time(0, -1);

    for(int i = 1; i <= q; i++) {
        int u, v, x, y;
        cin >> u >> v >> x >> y;

        int w = get_lca(0, u, v);
        int z = get_lca(1, x, y);


        query(i, u, par[0][w][0], x, par[1][z][0]);
        query(i, v, w, x, par[1][z][0]);
        query(i, v, w, y, z);
        query(i, u, par[0][w][0], y, z);
    }

    dfs_solve(0, -1);

    for(int i = 1; i <= q; i++)
        cout << ans[i] << '\n';

    return 0;
}

Second solution

#include<bits/stdc++.h>
using namespace std;

const int maxn = 3e5 + 17, lg = 19;

struct Q{
    int x, y, i, z;
};
int n, q, par[2][lg][maxn], st[maxn], ft[maxn], h[2][maxn], ans[maxn], iman[maxn];
vector<int> g[2][maxn];
vector<Q> assign[maxn];
void make_par(int id, int v = 0){
    for(auto u : g[id][v])
        if(u != par[id][0][v]){
            par[id][0][u] = v;
            h[id][u] = h[id][v] + 1;
            make_par(id, u);
        }
}
void get_st(int v = 0){
    static int time = 0;
    st[v] = time++;
    for(auto u : g[1][v])
        if(u != par[1][0][v])
            get_st(u);
    ft[v] = time;
}
int lca(int id, int v, int u){
    if(h[id][v] > h[id][u])
        swap(v, u);
    for(int i = 0; i < lg; i++)
        if(h[id][u] - h[id][v] >> i & 1)
            u = par[id][i][u];
    for(int i = lg - 1; i >= 0; i--)
        if(par[id][i][v] != par[id][i][u])
            v = par[id][i][v], u = par[id][i][u];
    return v == u ? v : par[id][0][v];
}
int hamid(int p){
    int ans = 0;
    for(p++; p; p ^= p & -p)  ans += iman[p];
    return ans;
}
void majid(int p, int v){
    for(p++; p < maxn; p += p & -p)  iman[p] += v;
}
void majid(int l, int r, int v){
    majid(l, v), majid(r, -v);
}
void dfs(int v = 0){
    majid(st[v], ft[v], +1);
    for(auto q : assign[v]){
        int p = lca(1, q.x, q.y);
        ans[q.i] += q.z * (hamid(st[q.x]) + hamid(st[q.y]) - hamid(st[p]) - (p ? hamid(st[ par[1][0][p] ]) : 0));
    }
    for(auto u : g[0][v])
        if(u != par[0][0][v])
            dfs(u);
    majid(st[v], ft[v], -1);
}
int main(){
    ios::sync_with_stdio(0), cin.tie(0);
    cin >> n >> q;
    for(int k = 0; k < 2; k++)
        for(int i = 1; i < n; i++){
            int v, u;
            cin >> v >> u;
            v--, u--;
            g[k][v].push_back(u);
            g[k][u].push_back(v);
        }
    for(int j = 0; j < 2; j++){
        make_par(j);
        for(int k = 1; k < lg; k++)
            for(int v = 0; v < n; v++)
                par[j][k][v] = par[j][k - 1][ par[j][k - 1][v] ];
    }
    for(int i = 0; i < q; i++){
        int v, u, x, y, p;
        cin >> v >> u >> x >> y;
        v--, u--, x--, y--;
        p = lca(0, v, u);
        assign[v].push_back({x, y, i, +1});
        assign[u].push_back({x, y, i, +1});
        assign[p].push_back({x, y, i, -1});
        if(p)
            assign[ par[0][0][p] ].push_back({x, y, i, -1});
    }
    get_st();
    dfs();
    for(int i = 0; i < q; i++)
        cout << ans[i] << '\n';
}