31 December 2016

介绍

线段树,也叫区间树,英文名叫 segment tree,该数据结构主要是用来实现区间查询。

假设,我们要求 从 1 到 10 的和,那么我们会从 1 加到 10 来得出结果是 55,要算 1 到 10000 的和也是类似,那么要算出 1 到 10000 个数里面任意给定范围的和,我们还是要从这个范围的起点开始算起。

现在,我们将 1 到 10000 这些数字的第 1000 个数修改成 n,那么我们就重新开始计算,从 1 开始算到 10000。

按照人类的思维,我们可以直接直接算出 一开始的第 1000 个数和 n 的差,然后和之前结果进行一次对比就行了。在计算机看来,就需要保留上一次计算的值。因此就有

sum[1, 10000]  = xxx;   第 1 个数到 第 10000 个数的和
sum[43, 999]   = xxx;   第 43 个数到 第 999 个数的和
sum[888, 1500] = xxx;   第 888 个数到 第 1500 个数的和
...

现在我们要修改 第 1000 个数的值为 n 的话,我们就要更改所有包含 1000 这个数的区间的和,例如 [1, 100000], [99-1500] 等类似包含 1000 的区间。

实现

在实现中,我们不能随便生成一个区间,区间的生成是由一开始的始末依次二分生成的。一图胜千言,下面是由区间 [1-10] 生成的线段树

一开始, 1,10 分成 1,56,10 两半,之后以此类推。

现在我们来看看区间是怎么计算的:

要计算 1-10 区间的和,就先要计算 1-5,再加上 6-10;
要计算 1-5  区间的和,就先要计算 1-3,再加上 4-5;
要计算 1-3  区间的和,就先要计算 1-2,再加上 3-3;
...

以此类推,我们计算到区间的 左区间和右区间相等的时候,才能往上迭代(图上的红色表示区间和)

代码:

int data[11] = {-1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10};  // 除去 data[0]
int tree[100];

int left(int x) {         // 左子树对应的标号
    return 2*x;
}

int right(int x) {        // 右子树对应的标号
    return 2*x + 1;
}


// 构建线段树
void build(int root = 1, int l = 1, int r = 10) {
    if (l == r) {         // 如果当前的区间只有一个节点
        tree[root] = data[l];
        return;
    }

    if (l > r) {          // 越界,左边不能大于右边
        return;
    }

    int m = (l + r) / 2;

    build(left(root),  l,   m);
    build(right(root), m+1, r);

    tree[root] = tree[left(root)] + tree[right(root)]; // 往上迭代
}

查询

在构建完线段树之后,我们要查询区间 [query_min, query_max] 的和

区间的开始是从 [1, 10] 往中间缩。

如果我们要查询的是 1-10 的和,那么应该直接返回 tree[1]
如果我们要查询的是 1-8 的和,那么查询的过程是:

发现查询区间 [1-8] 不能完全包含 [l-r] (即[1-10])
--> 分成左右区间 [1-5] [6-10]
--> 发现查询区间 [1-8] 完全包含 [1-5]
--> 发现查询区间 [1-8] 部分包含 [6-10],将 6-10 分成 [6-8], [9-10]
----> 发现查询区间 [1-8] 部分包含 [6-8],并且 [9-10] 不在查询区间内

返回 sum([1-5]) + sum([6-8]) + sum([9, 10])
    15           21           0            ==> 36

如果我们要查询的是 3-7 的和,那么查询的过程是:

发现查询区间 [3-7] 不能完全包含 [l-r] ([1-10])
--> 分成左右区间 [1-5] [6-10]

    处理 [1-5]
----> 查询区间 [3-7] 不能完全包含 [1-5],分成 [1-3] 和 [4-5]

      处理 [1-3]
------> 查询区间 [3-7] 不能完全包含 [1-3],分成 [1-2], [3-3], --> [3-3] 完全包含

      处理 [4-5]
------> 查询区间 [3-7] 完全包含 [4-5]

    处理 [6-10]
----> 查询区间 [3-7] 不能完全包含 [6-10],分成 [6-8] 和 [9-10]

      处理 [6-8]
------> 查询区间 [3-7] 不能完全包含 [6-8],分成 [6-7], [8-8], --> [6-7] 完全包含

      处理 [9-10]
------> [9-10] 超出范围,返回 0

结果是 sum([3-3]) + sum([4-5]) + sum(6-7)
int getSum(int query_min, int query_max, int root = 1, int l = 1, int r = 10) {
    if (query_min <= l && r <= query_max) {
        return tree[root];
    }

    if (query_min > r || query_max < l) { // l, r 不在查询区间内
        return 0;
    }

    int m = (l + r) / 2;

    int lsum = getSum(query_min, query_max, left(root),  l,   m);
    int rsum = getSum(query_min, query_max, right(root), m+1, r);

    return lsum + rsum;
}

更新

更新相当于重建

void update(int index, int v) {
    data[index] = v;
    build();
}

懒惰更新

实际上,我们可以将对区间的更新推迟到我们需要获取这个区间的值的时候而不是马上重建整棵树,这个叫懒惰更新,lazy update

updateInterval(updateL, updateR, root, l, r)

lazy update 遵循以下原则:

  1. 如果 当前的节点范围 (l, r) 不在 更新范围 (updateL, updateR) 内,不做处理
  2. 如果 当前的节点有待更新的数据 ,将待更新的数据更新到当前节点。将 当前节点的更新信息 “推迟” 给左右两个子节点
  3. 如果 当前的节点范围 完全在 更新范围 内,更新当前的节点。将 待更新的信息 “推迟” 给左右两个子节点
  4. 如果 当前的节点范围 有一部分在 更新范围 内,那么遍历左右两个子节点
void updateInterval(int updateL, int updateR, int diff, int root = 1, int l = 1, int r = 9) {
    // 不在更新范围内
    if (l > r || l > updateR || r < updateL) {
        return ;
    }

    // 当前的节点有待更新的数据
    if (lazy[root] != 0) {
        tree[root] += (r-l+1) * lazy[root];

        if (l != r) {      // 存在子节点
            lazy[left(root)]  += lazy[root];
            lazy[right(root)] += lazy[root];
        }

        lazy[root] = 0;
    }

    // 当前的节点范围 完全在 更新范围 内
    if (l >= updateL && r <= updateR) {
        tree[root] += (r-l+1) * diff;

        if (l != r) {
            lazy[left(root)]  += diff;
            lazy[right(root)] += diff;
        }
        return;
    }

    // 当前的节点范围 有一部分在 更新范围
    int m = (l + r) / 2;
    updateInterval(updateL, updateR, diff, left(root),  l,   m);
    updateInterval(updateL, updateR, diff, right(root), m+1, r);

    tree[root] = tree[left(root)] + tree[right(root)];
}
int getSum(int query_min, int query_max, int root, int l, int r) {
    // 当前的节点有待更新的数据
    if (lazy[root] != 0) {
        tree[root] += (r-l+1) * lazy[root];

        if (l != r) {      // 存在子节点
            lazy[left(root)]  += lazy[root];
            lazy[right(root)] += lazy[root];
        }

        lazy[root] = 0;
    }

    if (l > r || l > query_max || r < query_min) {
        return 0;
    }

    if (l >= query_min && r <= query_max) {
        return tree[root];
    }

    int m = (l + r) / 2;
    int lsum = getSumWrap(query_min, query_max, left(root),  l,   m);
    int rsum = getSumWrap(query_min, query_max, right(root), m+1, r);

    return lsum + rsum;
}