Code: Select all
#include <cstdio>
#include <cassert>
#include <algorithm>
using namespace std;
#define MAXN (1<<12)
int max3(int a, int b, int c) { return max(a, max(b, c)); }
int array[MAXN], N; // the input
// The segment tree is stored in this array using heap-like numbering:
// The root is node[1], and children of node[i] are node[2*i] and node[2*i+1].
// Each node has an associated interval [u, v] of array's indices. The left and
// right children of a node cover intervals [u, c] and [c+1, v], where c=(u+v) div 2.
// The root's interval is [0, N-1]. The numbers u and v for space efficiency
// are not explicitly stored, but passed as functions' arguments where needed.
//
// The four values in each node are defined as:
// node[i][0] = sum_{u<=k<=v} a[k]
// node[i][1] = max(0, max_{u<=i<=v} sum_{i<=k<=v} a[k])
// node[i][2] = max(0, max_{u<=j<=v} sum_{u<=k<=j} a[k])
// node[i][3] = max(0, max_{u<=i<=j<=v} sum_{i<=k<=j} a[k])
// (The max(0, ...) thing merely allows empty sums.)
int node[2*MAXN][4]; // Note: MAXN here must be a power of 2
void build(int root, int u, int v) {
if (u == v) { // leaf
node[root][0] = array[u];
node[root][1] = node[root][2] = node[root][3] = max(0, array[u]);
} else {
int c = (u+v)/2, L = 2*root, R = 2*root+1;
build(L, u, c);
build(R, c+1, v);
// Try to carefully digest these relations. The function query() below
// basically has to mirror them in a lot more complicated way.
node[root][0] = node[L][0] + node[R][0];
node[root][1] = max(node[R][1], node[R][0] + node[L][1]);
node[root][2] = max(node[L][2], node[L][0] + node[R][2]);
node[root][3] = max3(node[L][3], node[R][3], node[L][1] + node[R][2]);
}
}
static int query(int root, int ru, int rv, int qu, int qv, int kind) {
// [ru, rv] = interval of node root
// [qu, qv] = the query interval.
// kind = which query to perform
// See description of node[i][kind], substitute qu and qv instead of u and v.
qu = max(qu, ru);
qv = min(qv, rv);
if (qu > qv) return 0; // empty
int rc = (ru+rv)/2, L = 2*root, R = 2*root+1;
if (qu == ru && qv == rv) {
// [qu, qv] is exactly the node's interval (this also covers leaf cases)
// we have a precomputed answer right away.
return node[root][kind];
// else we have to divide and conquer... And thanks to a set of
// precomputed answers the whole thing will run in O(log N) time.
} else if (kind == 0) {
// sum query
return query(L,ru,rc, qu,qv, 0) + query(R,rc+1,rv, qu,qv, 0);
} else if (kind == 1) {
return max(
query(R,rc+1,rv, qu,qv, 1),
query(R,rc+1,rv, qu,qv, 0) + query(L,ru,rc, qu,qv, 1));
} else if (kind == 2) {
return max(
query(L,ru,rc, qu,qv, 2),
query(L,ru,rc, qu,qv, 0) + query(R,rc+1,rv, qu,qv, 2));
} else { // kind == 3
return max3(
query(L,ru,rc, qu,qv, 3),
query(R,rc+1,rv, qu,qv, 3),
query(L,ru,rc, qu,qv, 1) + query(R,rc+1,rv, qu,qv, 2));
}
}
// Returns max(0, max_{low <= i <= j <= high} sum_{i <= k <= j} array[k])
int query_fast(int low, int high) {
assert(0 <= low && low <= high && high < N);
return query(1, 0, N-1, low, high, 3);
}
int query_naive(int low, int high) {
int res = 0;
for (int u = low; u <= high; u++) {
int sum = 0;
for (int v = u; v <= high; v++) {
sum += array[v];
res = max(res, sum);
}
}
return res;
}
// Sample usage and unittest:
int main() {
// prepare input
srand(53387);
N = MAXN;
for (int i = 0; i < N; i++) array[i] = rand() % 1000 - 500;
// build the tree
build(1, 0, N-1);
for (int count = 0; count < 1000; count++) {
int a = rand() % N, b = rand() % N;
if (a > b) swap(a, b);
int w = query_fast(a, b), z = query_naive(a, b);
printf("a=%d b=%d naive=%d fast=%d\n", a, b, z, w);
assert(z == w);
}
printf("PASSED\n");
}
I hope it works.