In this HackerEarth Joseph and Tree problem solution, Joseph loves games about a tree! His friend Nick invented a game for him!

Initially, there is a rooted weighted tree with N vertices numbered 1 ... N. Nick guarantees that the tree is connected, and there is a unique path between any vertices! Also he gave us Q queries on it of the following type:
  1. v and k: Let S denote the sorted (nondecreasing order) array of shortest distances from v to any other vertex from subtree rooted v. Answer will be kth element of S. If such a number does not exist, i.e. the S has less than k elements, answer is 1. Note that v is not included in his own subtree.
All the indices in the queries are 1-based. The root of the tree is node 1.
But it turns out, Joseph has an exam tomorrow, and he doesn't have time for playing a game! And he asks your help!


HackerEarth Joseph and Tree problem solution


HackerEarth Joseph and Tree problem solution.

#include <bits/stdc++.h>

#define pb push_back
#define f first
#define s second
#define mp make_pair
#define sz(a) int((a).size())
#ifdef _WIN32
#  define I64 "%I64d"
#else
#  define I64 "%lld"
#endif
#define fname "."
#define pi pair < int, int >
#define pp pop_back

typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;

const int MAX_N = (int)1e5 + 123;
const double eps = 1e-6;
const int inf = (int)1e9 + 123;

using namespace std;

int n, q;
vector < pi > g[MAX_N];

struct tree {
    int sum, l, r;
    tree() : sum(0), l(-1), r(-1) {}
};

vector < tree > t;

int update(int x, int v, int tl = 0, int tr = n - 1) {
  int now = sz(t);
  t.pb(tree());
    if (v != -1)
        t[now] = t[v];
    if (tl == tr) {
        t[now].sum++;
        return now;
    }
    int tm = (tl + tr) / 2;
    if (x <= tm) {
        int son = update(x, (v == -1 ? -1 : t[v].l), tl, tm);
        t[now].l = son;
    }
    else {
        int son = update(x, (v == -1 ? -1 : t[v].r), tm + 1, tr);
        t[now].r = son;
    }
    t[now].sum = 0;
    if (t[now].l != -1)
        t[now].sum += t[t[now].l].sum;
    if (t[now].r != -1)
        t[now].sum += t[t[now].r].sum;
    return now;
}

int find_kth(int L, int R, int k, int tl = 0, int tr = n - 1) {
    if (tl == tr)
        return tl;
    int tm = (tl + tr) / 2;
    int left = 0;
    if (R != -1 && t[R].l != -1)
        left += t[t[R].l].sum;
    if (L != -1 && t[L].l != -1)
        left -= t[t[L].l].sum;

    if (k <= left)
        return find_kth((L == -1 ? -1 : t[L].l), (R == -1 ? -1 : t[R].l), k, tl, tm);
    k -= left;
    return find_kth((L == -1 ? -1 : t[L].r), (R == -1 ? -1 : t[R].r), k, tm + 1, tr);
}

vector < int > st;
ll dist[MAX_N];
int l[MAX_N], r[MAX_N];

void dfs(int v, int pr = -1, ll all = 0) {
    dist[v] = all;
    l[v] = sz(st);
    st.pb(v);
    for (auto to : g[v])
        if (to.f != pr)
            dfs(to.f, v, all + to.s);
    r[v] = sz(st) - 1;
}

int root[MAX_N];
vector < ll > uniq;

ll get(int v, int k) {
    int sz = r[v] - l[v];
    if (k > sz)
        return -1;
    return uniq[find_kth(root[l[v]], root[r[v]], k)] - uniq[dist[v]];
}

int main() {
    #ifdef Nick
    freopen(fname"in", "r", stdin);
    freopen(fname"out", "w", stdout);
    #endif
    scanf("%d", &n);
    for (int i = 1, u, v, w; i < n; i++) {
        scanf("%d%d%d", &u, &v, &w);
        g[u].pb(mp(v, w)), g[v].pb(mp(u, w));
    }
    
    dfs(1);

    for (int i = 1; i <= n; i++)
        uniq.pb(dist[i]);
    sort(uniq.begin(), uniq.end());
    uniq.resize(unique(uniq.begin(), uniq.end()) - uniq.begin());
    for (int i = 1; i <= n; i++)
        dist[i] = lower_bound(uniq.begin(), uniq.end(), dist[i]) - uniq.begin();

    for (int i = 0, last = -1; i < sz(st); i++) {
        last = root[i] = update(dist[st[i]], last);
    }

    int query;
    scanf("%d", &query);
    for (int i = 1, v, k; i <= query; i++) {
        scanf("%d%d", &v, &k);
        printf(I64"\n", get(v, k));
    }

    return 0;
}

Second solution

import java.io.*;
import java.util.*;

public class AugClashJosephAndTree {

    static int N;
    static ArrayList<Integer> adj[], weight[];
    static long depth[];
    static int dfsOrder[];
    static int timeIn[], timeOut[];
    static int time;
    static TreeMap<Long, Integer> dist;
    static long revMap[];
    static ArrayList<Node> nodes;
    static int root[];

    @SuppressWarnings("unchecked")
    public static void main(String[] args) {
        InputReader in = new InputReader(System.in);
        PrintWriter out = new PrintWriter(System.out);

        N = in.nextInt();
        check(1, N, (int) 1e5);

        adj = new ArrayList[N + 1];
        weight = new ArrayList[N + 1];

        for (int i = 1; i <= N; i++) {
            adj[i] = new ArrayList<Integer>();
            weight[i] = new ArrayList<Integer>();
        }

        for (int i = 1; i < N; i++) {
            int a = in.nextInt();
            int b = in.nextInt();
            int w = in.nextInt();
            check(1, a, N);
            check(1, b, N);
            check(1, w, (int) 1e9);
            adj[a].add(b);
            weight[a].add(w);
            adj[b].add(a);
            weight[b].add(w);
        }

        depth = new long[N + 1];
        timeIn = new int[N + 1];
        timeOut = new int[N + 1];
        dfsOrder = new int[N + 1];
        time = 0;
        dfs(1, -1, 0);

        dist = new TreeMap<Long, Integer>();
        for (int i = 1; i <= N; i++)
            dist.put(depth[i], 0);

        revMap = new long[dist.size() + 1];

        int cnt = 0;
        for (long x : dist.keySet()) {
            dist.put(x, ++cnt);
            revMap[cnt] = x;
        }

        root = new int[N + 1];
        root[0] = 0;

        nodes = new ArrayList<Node>();
        nodes.add(new Node());

        // dist.get(depth[i]) returns rank of depth[i]
        for (int i = 1; i <= N; i++) {
            root[i] = update(root[i - 1], 1, cnt, dist.get(depth[dfsOrder[i]]));
        }

        int Q = in.nextInt();
        check(1, Q, (int) 1e5);

        while (Q-- > 0) {
            int v = in.nextInt();
            int k = in.nextInt();
            check(1, v, N);
            check(1, k, (int) 1e9);
            out.println(solve(v, k, cnt));
        }

        out.close();
    }

    static long solve(int v, int k, int cnt) {
        int size = timeOut[v] - timeIn[v];
        if (size < k)
            return -1;
        int ans = findKth(1, cnt, root[timeIn[v]], root[timeOut[v]], k);
        return revMap[ans] - depth[v];
    }

    static int findKth(int start, int end, int leftRoot, int rightRoot, int k) {
        if (start == end)
            return start;
        int mid = (start + end) >> 1;
        int leftSum = sum(left(rightRoot)) - sum(left(leftRoot)); //number of nodes in [start,mid]
        if (leftSum >= k)
            return findKth(start, mid, left(leftRoot), left(rightRoot), k);
        else
            return findKth(mid + 1, end, right(leftRoot), right(rightRoot), k - leftSum);
    }

    static int update(int prevRoot, int start, int end, int x) {
        Node now = new Node();
        now.sum = sum(prevRoot);
        now.left = left(prevRoot);
        now.right = right(prevRoot);
        nodes.add(now);

        int idx = nodes.size() - 1;

        if (start == end) {
            now.sum++;
            return idx;
        }

        int mid = (start + end) >> 1;

        if (x <= mid) {
            int leftSon = left(prevRoot);
            now.left = update(leftSon, start, mid, x);
        }

        else {
            int rightSon = right(prevRoot);
            now.right = update(rightSon, mid + 1, end, x);
        }

        now.sum = sum(now.left) + sum(now.right);
        return idx;
    }

    static void dfs(int curr, int parent, long pathLength) {
        depth[curr] = pathLength;
        timeIn[curr] = ++time;
        dfsOrder[time] = curr;
        for (int i = 0; i < adj[curr].size(); i++) {
            int child = adj[curr].get(i);
            int edgeWeight = weight[curr].get(i);
            if (child != parent) {
                dfs(child, curr, pathLength + edgeWeight);
            }
        }
        timeOut[curr] = time;
    }

    static class Node {
        int sum, left, right;

        Node() {
            sum = 0;
            left = -1;
            right = -1;
        }
    }

    static int sum(int idx) {
        return idx == -1 ? 0 : nodes.get(idx).sum;
    }

    static int left(int idx) {
        return idx == -1 ? -1 : nodes.get(idx).left;
    }

    static int right(int idx) {
        return idx == -1 ? -1 : nodes.get(idx).right;
    }

    static void check(int start, int key, int end) {
        if (key < start || key > end)
            throw new RuntimeException();
    }

    static class InputReader {

        private final InputStream stream;
        private final byte[] buf = new byte[8192];
        private int curChar, snumChars;
        private SpaceCharFilter filter;

        public InputReader(InputStream stream) {
            this.stream = stream;
        }

        public int snext() {
            if (snumChars == -1)
                throw new InputMismatchException();
            if (curChar >= snumChars) {
                curChar = 0;
                try {
                    snumChars = stream.read(buf);
                } catch (IOException e) {
                    throw new InputMismatchException();
                }
                if (snumChars <= 0)
                    return -1;
            }
            return buf[curChar++];
        }

        public int nextInt() {
            int c = snext();
            while (isSpaceChar(c)) {
                c = snext();
            }
            int sgn = 1;
            if (c == '-') {
                sgn = -1;
                c = snext();
            }
            int res = 0;
            do {
                if (c < '0' || c > '9')
                    throw new InputMismatchException();
                res *= 10;
                res += c - '0';
                c = snext();
            } while (!isSpaceChar(c));
            return res * sgn;
        }

        public boolean isSpaceChar(int c) {
            if (filter != null)
                return filter.isSpaceChar(c);
            return c == ' ' || c == '\n' || c == '\r' || c == '\t' || c == -1;
        }

        public interface SpaceCharFilter {
            public boolean isSpaceChar(int ch);
        }
    }
}