In this HackerEarth 3B - Bear and Special Dfs problem solution Bear recently studied the depth first traversal of a Tree in his Data Structure class. He learnt the following algorithm for doing a dfs on a tree.

void dfs(int v,int p){
    vis[v] = 1;
    for(int i=0;i<G[v].size();++i){
        if(G[v][i]!=p)
            dfs(G[v][i],v);
    }
}
Now he was given an array A, Tree G as an adjacency list and Q queries to answer on the following modified algorithm as an homework.

void dfs(int v,int p){
    vis[v] += 1;
    for(int j=1;j<=A[v];++j){
        for(int i=0;i<G[v].size();++i){
            if(G[v][i]!=p)
                dfs(G[v][i],v);
        }
    }
}
The Q queries can be of the following two types :
  1. v x : Update A[v] = x. Note that updates are permanent.
  2. v: Initialize vis[u] = 0 u related[1,N]. Run the above algorithm with call dfs(1,-1). Compute the sum Sigma(u related subtree of v, v) vis[u]. As the result may be quite large print it mod 10^9 + 7.

HackerEarth 3B - Bear and Special Dfs problem solution


HackerEarth 3B - Bear and Special Dfs problem solution.

#include<bits/stdc++.h>
#include<iostream>
using namespace std;
#define fre     freopen("0.in","r",stdin);freopen("0.out","w",stdout)
#define abs(x) ((x)>0?(x):-(x))
#define MOD 1000000007
#define LL signed long long int
#define pii pair<int,int>
#define scan(x) scanf("%d",&x)
#define print(x) printf("%d\n",x)
#define scanll(x) scanf("%lld",&x)
#define printll(x) printf("%lld\n",x)
#define rep(i,from,to) for(int i=(from);i <= (to); ++i)
vector<int> G[2*100000+5];
int L[100000+5],R[100000+5], A[100000+5];
LL tree[5*100000+5];
LL lazy[5*100000+5];
int T = 0;
void dfs(int v,int p){
    L[v] = ++T;
    for(int i=0;i<G[v].size();++i){
        int u = G[v][i];
        if(u==p)
            continue;
        dfs(u,v);
    }
    R[v] = T;
}

LL pow(LL base, LL exponent,LL modulus)
{
    LL result = 1;
    while (exponent > 0)
    {
        if (exponent % 2 == 1)
            result = (result * base) % modulus;
        exponent = exponent >> 1;
        base = (base * base) % modulus;
    }
    return result;
}

void build(int i,int si,int sj){
    if(si==sj){
        tree[i] = 1;
        lazy[i] = 1;
    }
    else{
        int mid = (si+sj)/2;
        build(2*i,si,mid);
        build(2*i+1,mid+1,sj);
        tree[i] = tree[2*i]+tree[2*i+1];
        lazy[i] = 1;
    }
}
void update(int i,int si,int sj,int qi,int qj,int x){
    if(si==qi and sj==qj){
        tree[i] = tree[i] * x % MOD;
        lazy[i] = lazy[i] * x % MOD;
    }
    else{
        int mid = (si+sj)/2;
        if(qj<=mid)
            update(2*i,si,mid,qi,qj,x);
        else if(mid+1<=qi)
            update(2*i+1,mid+1,sj,qi,qj,x);
        else
            update(2*i,si,mid,qi,mid,x),
            update(2*i+1,mid+1,sj,mid+1,qj,x);
        tree[i] = (tree[2*i] + tree[2*i+1]) * lazy[i] % MOD;
    }
}
LL query(int i,int si,int sj,int qi,int qj){
    if(qi==si and qj==sj){
        return tree[i];
    }
    else{
        int mid = (si+sj)/2;
        LL x = 0;
        if(qj<=mid)
            x = query(2*i,si,mid,qi,qj);
        else if(mid+1<=qi)
            x = query(2*i+1,mid+1,sj,qi,qj);
        else
            x = query(2*i,si,mid,qi,mid) + query(2*i+1,mid+1,sj,mid+1,qj);
        return x * lazy[i] % MOD;
    }
}
int N;
void update(int i,int j,int x){
    if(i<=j){
        update(1,1,N,i,j,x);
    }
}

int main(){
    //fre;
    int Q,a,b,c,x,v;
    cin>>N>>Q;
    rep(i,1,N-1){
        scan(a);
        scan(b);
        G[a].push_back(b);
        G[b].push_back(a);
    }
    dfs(1,0);
    build(1,1,N);
    rep(i,1,N)scan(A[i]),update(L[i]+1,R[i],A[i]);
    while(Q--){
        scan(c);
        if(c==1){
            //update
            scan(v);
            scan(x);
            update(L[v]+1,R[v],pow(A[v],MOD-2,MOD));
            A[v] = x;
            update(L[v]+1,R[v],A[v]);
        }
        else{
            //query
            scan(v);
            printll(query(1,1,N,L[v],R[v]));
        }
    }
}

Second solution

#include <vector>
#include <list>
#include <map>
#include <set>
#include <queue>
#include <deque>
#include <stack>
#include <bitset>
#include <algorithm>
#include <functional>
#include <numeric>
#include <utility>
#include <sstream>
#include <iostream>
#include <iomanip>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <limits>
#include <string>
#include <cassert>

using namespace std;

#define all(X) (X).begin(), (X).end()

typedef long long LL;
typedef pair<int, int> PII;
typedef vector<int> VI;
typedef vector<VI> VVI;

const int mod = 1e9 + 7;
const int maxn = 100001;

int tree[maxn << 4], lazy[maxn << 4];

int p(int base, int power) {
    int res = 1;
    while (power) {
        if (power & 1) res = res * 1LL * base % mod;
        power >>= 1;
        base = base * 1LL * base % mod;
    }
    return res;
}

int mul(int a, int b) {
    return a * 1LL * b % mod;
}

void init(int l, int r, int idx) {
    if (l == r) {
        tree[idx] = lazy[idx] = 1;
        return;
    }
    init(l, (l + r) / 2, idx << 1);
    init((l + r) / 2 + 1, r, (idx << 1) | 1);
    lazy[idx] = 1;
    tree[idx] = tree[idx * 2] + tree[idx * 2 + 1];
}

void relax(int node) {
    if (lazy[node] > 1) {
        tree[node] = mul(tree[node], lazy[node]);
        lazy[node << 1] = mul(lazy[node << 1], lazy[node]);
        lazy[(node << 1) | 1] = mul(lazy[(node << 1) | 1], lazy[node]);
        lazy[node] = 1;
    }
}

void update(int s, int e, int l, int r, int idx, int val) {
    relax(idx);
    if (r < s || l > e || l > r) return;
    if (s >= l && e <= r) {
        lazy[idx] = mul(lazy[idx], val);
        relax(idx);
        return;
    }
    update(s, (s + e) / 2, l, r, idx * 2, val);
    update((s + e) / 2 + 1, e, l, r, idx * 2 + 1, val);
    tree[idx] = (tree[idx * 2] + tree[idx * 2 + 1]) % mod;
}

int read(int s, int e, int l, int r, int idx) {
    relax(idx);
    if (r < s || l > e || l > r) return 0;
    if (s >= l && e <= r) return tree[idx];
    int res = read(s, (s + e) / 2, l, r, idx * 2);
    res += read((s + e) / 2 + 1, e, l, r, idx * 2 + 1);
    tree[idx] = (tree[idx * 2] + tree[idx * 2 + 1]) % mod;
    if (res >= mod) res -= mod;
    return res;
}

int st[maxn], ed[maxn], dis;
VI e[maxn];

void dfs(int u, int par) {
    st[u] = ++dis;
    for (auto v : e[u]) {
        if (v != par) dfs(v, u);
    }
    ed[u] = dis;
}

int A[maxn];

int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    int N, Q, x, y, v, t;
    cin >> N >> Q;
    for (int i = 1; i < N; ++i) {
        cin >> x >> y;
        e[x].emplace_back(y);
        e[y].emplace_back(x);
    }
    init(1, N, 1);
    dfs(1, -1);
    for (int i = 1; i <= N; ++i) {
        cin >> A[i];
        update(1, N, st[i] + 1, ed[i], 1, A[i]);
    }
    while (Q--) {
        cin >> t >> v;
        if (t == 1) {
            cin >> x;
            int upVal = x * 1LL * p(A[v], mod - 2) % mod;
            A[v] = x;
            update(1, N, st[v] + 1, ed[v], 1, upVal);
        } else {
            cout << read(1, N, st[v], ed[v], 1) << "\n";
        }
    }
    return 0;
}