In this HackerEarth Simple Sum problem solution, You have been given an array of N integers A1,A2..AN. You have to find simple sum for this array. Simple Sum is defined as Sigma(i=1,i=N) Sigma(j=i,j=N) max(Ai,Ai+1,...,Aj) * (Ai | Aj). | denotes the bitwise OR operator.


HackerEarth Simple Sum problem solution


HackerEarth Simple Sum problem solution.

#include <bits/stdc++.h>
using namespace std;
#define ll long long unsigned
#define pb push_back
#define fr freopen("in.txt","r",stdin)
#define rep(i,n) for(int i=0;i<n;i++)
#define frep(i,n) for(int i=1;i<=n;i++)
#define maxval 100011
#define maxn 300011
#define pi pair<int,int>
#define f first
#define s second
#define MAXBITS 15
ll A[100011];
int dp[100011][20];
ll cnt[100011][20];
ll ans = 0;
int query(int i,int j) {
    int len = j-i+1;
    len = log2(len);
    int p = dp[i][len];
    int q = dp[j-(1<<len)+1][len];
    if(A[p]>A[q]) return p;
    return q;
}
void calc(int i,int j) {
    if(i==j) {
        ans+=A[i]*A[i];
        return;
    }
    if(i>j) return;
    int m = query(i,j);
    if(m-i<=j-m) {
        for(int k=i;k<=m;k++) {
            rep(p,MAXBITS) {
                if(A[k]&(1<<p)) {
                    ans+=A[m]*(1LL<<p)*(ll)(j-m+1);
                } else{
                    ans+=A[m]*(1LL<<p)*(ll)(cnt[j][p]-cnt[m-1][p]);
                }
            }
        }
    } else{
        for(int k=m;k<=j;k++) {
            rep(p,MAXBITS) {
                if(A[k]&(1<<p)) {
                    ans+=A[m]*(1LL<<p)*(ll)(m-i+1);
                } else{
                    ans+=A[m]*(1LL<<p)*(ll)(cnt[m][p]-cnt[i-1][p]);
                }
            }
        }
    }
    //ans = 0;
    calc(i,m-1);
    calc(m+1,j);
}
int main() {
    freopen("in10.txt","r",stdin);
    freopen("out10.txt","w",stdout);

    int N;
    cin >> N;
    frep(i,N) {
        cin >> A[i];
        dp[i][0] = i;
        rep(j,MAXBITS) {
            cnt[i][j] = cnt[i-1][j];
            if(A[i]&(1<<j)) {
                cnt[i][j]++;
            }
        }
    }
    int p,q;
    for(int s=1;s<20;s++) {
        frep(i,N-(1<<s)+1) {
            p = dp[i][s-1];
            q = dp[i+(1<<(s-1))][s-1];
            if(A[p]>A[q]) dp[i][s] = p;
            else dp[i][s] = q;
        }
    }
    calc(1,N);
    cout << ans;
}

Second solution

#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <cassert>
#include <algorithm>
using namespace std;

const int MAXN = 100000;
const int MAXM = 10000;
const int MAXBITS = 16;

int a[MAXN], cnt[MAXBITS][MAXN + 1];

struct MaxInfo
{
    int value, i;
    MaxInfo() {}
    MaxInfo(int value, int i) : value(value), i(i) {}
};

MaxInfo st[20][MAXN];

inline bool operator < (const MaxInfo& a, const MaxInfo &b)
{
    return a.value < b.value;
}

MaxInfo getMax(int l, int r)
{
    int len = r - l + 1;
    int x = (int)log2(len);
    return max(st[x][l], st[x][r - (1 << x) + 1]);
}

long long divideAndConquer(int l, int r)
{
    if (l > r) {
        return 0;
    }
    if (l == r) {
        return (long long) a[l] * a[l];
    }
    int mid = l + r >> 1;

    MaxInfo maxValue = getMax(l, r);
    long long ret = divideAndConquer(l, maxValue.i - 1) + divideAndConquer(maxValue.i + 1, r);
    if (maxValue.i < mid) {
        for (int i = l; i <= maxValue.i; ++ i) {
            for (int bit = 0; bit < MAXBITS; ++ bit) {
                if (a[i] >> bit & 1) {
                    ret += (long long)maxValue.value * (r - maxValue.i + 1) * (1LL << bit);
                } else {
                    ret += (long long)maxValue.value * (cnt[bit][r + 1] - cnt[bit][maxValue.i]) * (1LL << bit);
                }
            }
        }
    } else {
        for (int j = maxValue.i; j <= r; ++ j) {
            for (int bit = 0; bit < MAXBITS; ++ bit) {
                if (a[j] >> bit & 1) {
                    ret += (long long)maxValue.value * (maxValue.i - l + 1) * (1LL << bit);
                } else {
                    ret += (long long)maxValue.value * (cnt[bit][maxValue.i + 1] - cnt[bit][l]) * (1LL << bit);
                }
            }
        }
    }
    return ret;
}

int main()
{
    int n;
    assert(scanf("%d", &n) == 1 && 1 <= n && n <= MAXN);
    for (int i = 0; i < n; ++ i) {
        assert(scanf("%d", &a[i]) == 1);
// fprintf(stderr, "%d\n", a[i]);
        assert(1 <= a[i] && a[i] <= MAXM);
        st[0][i] = MaxInfo(a[i], i);
        for (int bit = 0; bit < MAXBITS; ++ bit) {
            cnt[bit][i + 1] = cnt[bit][i] + (a[i] >> bit & 1);
        }
    }
    for (int i = 0, len = 1; len < n; ++ i, len *= 2) {
        for (int j = 0; j + len * 2 <= n; ++ j) {
            st[i + 1][j] = max(st[i][j], st[i][j + len]);
        }
    }
    printf("%lld\n", divideAndConquer(0, n - 1));
    return 0;
}