In this HackerEarth Tri-State-Area<May Cir 19> problem solution we have given a weighted tree (T) and an integer M A XW, count the number of weighted graphs whose non-negative edges weight at most M A XW and T is an MST (minimum spanning tree) for that graph. Output the result modulo 987654319.


HackerEarth Tri-State-Area<May Cir 19> problem solution


HackerEarth Tri State Area May Cir 19 problem solution.

#include<bits/stdc++.h>
using namespace std;
const int N = 300005, Mod = 987654319;
int n, MAXW, tot = 1, V[N], U[N], W[N], P[N], T[N];
bool CMP(int i, int j) {return (W[i] < W[j]);}
int Find(int v)
{
    return (P[v] < 0 ? v : (P[v] = Find(P[v])));
}
inline int Power(int a, int b)
{
    int ret = 1;
    for (; b; b >>= 1, a = 1LL * a * a % Mod)
        if (b & 1) ret = 1LL * ret * a % Mod;
    return (ret);
}
int main()
{
    scanf("%d%d", &n, &MAXW);
    for (int i = 1; i < n; i++)
        scanf("%d%d%d", &V[i], &U[i], &W[i]), T[i] = i;
    memset(P, -1, sizeof(P));
    sort(T + 1, T + n, CMP);
    for (int i = 1; i < n; i++)
    {
        if (W[T[i]] > MAXW)
            tot = 0;
        int v = Find(V[T[i]]), u = Find(U[T[i]]);
        tot = tot * 1LL * Power(MAXW - W[T[i]] + 2, (1LL * P[u] * P[v] - 1) % (Mod - 1)) % Mod;
        P[v] += P[u]; P[u] = v;
    }
    return !printf("%d\n", tot);
}

Second solution

#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace __gnu_pbds;
using namespace std;
typedef long long ll;
const int maxn = 3e5 + 14, mod = 987654319;
const ll inf = 2e18;


template<typename T>
struct MOS{
    tree<pair<T, int>, null_type, less<pair<T, int>>, rb_tree_tag,tree_order_statistics_node_update> os;
    map<T, int> cnt;
    int size(){
        return os.size();
    }
    int get(const T &x){
        return os.order_of_key({x, 0});
    }
    void insert(const T &x, bool v = 1){
        if(v == 1)
            os.insert({x, cnt[x]++});
        else
            os.erase({x, --cnt[x]});
    }
    void clear(){
        cnt.clear();
        os.clear();
    }
};

int n, sz[maxn];
ll mxw;
vector<pair<int, ll> > g[maxn];
bool bl[maxn];
int get_sz(int v, int p = -1){
    sz[v] = 1;
    for(auto [u, w] : g[v])
        if(!bl[u] && u != p)
            sz[v] += get_sz(u, v);
    return sz[v];
}
MOS<ll> cnt;
void dfs_add(int v, bool add, ll pat, int p = -1){
    if(bl[v])
        return ;
    cnt.insert(pat, add);
    for(auto [u, w] : g[v])
        if(u != p)
            dfs_add(u, add, max(pat, w), v);
}
int po(ll a, int b){
    a %= mod; // remove cache bug!
    int ans = 1;
    for(; b; b >>= 1, a = (ll) a * a % mod)
        if(b & 1)
            ans = (ll) ans * a % mod;
    return ans;
}
int dfs(int v, bool eq, ll pat, int p = -1){
    if(bl[v])
        return 1;
    int ans = po(mxw - pat + 2, eq ? cnt.get(pat + 1) - cnt.get(pat) : cnt.get(pat));
    for(auto [u, w] : g[v])
        if(u != p)
            ans = (ll) ans * dfs(u, eq, max(pat, w), v) % mod;
    return ans;
}
int solve(int root = 0){
    int cen = root, all = get_sz(root), p = -1;
    bool br = 0;
    while(br ^= 1)
        for(auto [u, w] : g[cen])
            if(!bl[u] && u != p && sz[u] > all / 2){
                p = cen, cen = u, br = 0;
                break;
            }
    bl[cen] = 1;
    int ans = 1;
    cnt.clear();
    cnt.insert(-inf);
    for(auto [u, w] : g[cen]){
        ans = (ll) ans * dfs(u, 1, w) % mod;
        dfs_add(u, 1, w);
    }
    for(auto [u, w] : g[cen]){
        dfs_add(u, 0, w);
        ans = (ll) ans * dfs(u, 0, w) % mod;
        dfs_add(u, 1, w);
    }
    for(auto [u, w] : g[cen])
        if(!bl[u])
            ans = (ll) ans * solve(u) % mod;
    assert(ans > 0);
    return ans;
}
int main(){
    ios::sync_with_stdio(0), cin.tie(0);
    cin >> n >> mxw;
    int inT = 1;
    for(int i = 1; i < n; i++){
        int v, u;
        ll w;
        cin >> v >> u >> w;
        v--, u--;
        g[v].push_back({u, w});
        g[u].push_back({v, w});
        if(w > mxw)
            return cout << "0\n", 0;
        inT = (mxw - w + 2) % mod * inT % mod;
    }
    cerr << inT << '\n';
    assert(po(inT, mod - 2) > 0);
    cout << (ll) po(inT, mod - 2) * solve() % mod << '\n';
}