问题描述
二维平面上,给定N(大约200000)个点,这些点的x和y的取值范围都是[1,3000]之间的整数,给定M(大约200000)个查询,每个查询输入一个点P(px,py)。对于每个查询,求N个点到点P的距离之和。
输入
第一行两个整数N和M
接下来N行表示N个点的x和y坐标 接下来M行表示M个查询的x和y坐标3 35 55 10 10 51 15 5 10 10
输出
M个正整数,每个查询输出一个正整数
思路
二维前缀和
定义二维数组xsum[3001][3001],xsum[x,y]表示区域[0,x],[0,y]上x坐标之和。这样一来,任意区域[fx,tx],[fy,ty]的x坐标之和就可以表示为xsum[fy,ty]-xsum[fx,ty]-xsum[tx,fy]+xsum[fx,tx]。
对于每个查询的点P,只需要处理它的左上方、左下方、右上方、右下方四个区域中的距离之和(方向使用二维空间直角坐标系)。
例如,点P左上方的点数为K,点P左上方的x坐标之和、y坐标之和分别为xs,ys,则点P左上方的点到P的距离为:K*px-xs+ys-K*py
代码
#include#include using namespace std;const int maxn = 200007;const int ma = 3003;int a[ma][ma];int xsum[ma][ma], ysum[ma][ma];int c[ma][ma];//前开后闭区间int getxsum(int fx, int fy, int tx, int ty) {//前开后闭区间 return xsum[tx][ty] - xsum[fx][ty] - xsum[tx][fy] + xsum[fx][fy];}int getysum(int fx, int fy, int tx, int ty) { return ysum[tx][ty] - ysum[fx][ty] - ysum[tx][fy] + ysum[fx][fy];}int getcnt(int fx, int fy, int tx, int ty) { return c[tx][ty] - c[tx][fy] - c[fx][ty] + c[fx][fy];} int main() { freopen("in.txt", "r", stdin); int N, M; cin >> N >> M; memset(a, 0, sizeof(a)); memset(xsum, 0, sizeof(xsum)); memset(ysum, 0, sizeof(ysum)); memset(c, 0, sizeof(c)); for (int i = 0; i < N; i++) { int x, y; cin >> x >> y; a[x][y]++; } for (int x = 1; x < ma; x++) { for (int y = 1; y < ma; y++) { xsum[x][y] = a[x][y] * x + xsum[x][y - 1]+xsum[x-1][y]-xsum[x-1][y-1]; ysum[x][y] = a[x][y] * y + ysum[x][y - 1]+ysum[x-1][y]-ysum[x-1][y-1]; c[x][y] = a[x][y]-c[x-1][y-1]+c[x-1][y]+c[x][y-1]; } } for (int i = 0; i < M; i++) { int x, y; cin >> x >> y; int s = 0; int cnt = getcnt(0, 0, x , y ); s += cnt * x + cnt * y - getxsum(0, 0, x , y) - getysum(0, 0, x, y ); cnt = getcnt(x , 0, ma-1, y ); s += getxsum(x, 0, ma-1, y) - cnt*x + cnt * y - getysum(x , 0, ma-1, y ); cnt = getcnt(0, y , x, ma-1); s += getysum(0, y, x, ma-1) - y * cnt + cnt * x - getxsum(0, y, x, ma-1); cnt = getcnt(x , y , ma-1, ma-1); s += getxsum(x , y , ma-1, ma-1) + getysum(x , y , ma-1, ma-1) - cnt * x - cnt * y; cout << s << endl; } return 0;}
总结
多维前缀和的计算方式可以认为是容斥原理。对于三维前缀和,加减号可以直接通过二进制表示的奇偶性来表示。
下面来段代码测试一下这个思路
import java.util.Random;public class Main {class Point { int x, y, z; Point(int x, int y, int z) { this.x = x; this.y = y; this.z = z; } @Override public String toString() { return String.format("(%d,%d,%d)", x, y, z); }}final int N = 100;//100*100*100的三维空间内final int POINT_COUNT = 4000;//点的个数final int QUESTION_COUNT = 4000;//问题的个数Random r = new Random(0);//随机一个点,点的各个维度取值范围要在1到N-1之间Point randomPoint() { return new Point(r.nextInt(N - 2) + 1, r.nextInt(N - 2) + 1, r.nextInt(N - 2) + 1);}//生成问题Point[] generateProblem() { Point[] a = new Point[POINT_COUNT]; for (int i = 0; i < a.length; i++) { a[i] = randomPoint(); } return a;}//绝对正确的方法class Stupid { Point[] a; Stupid(Point[] a) { this.a = a; } int solve(Point fp, Point tp) { int s = 0; for (Point i : a) { if (i.x >= fp.x && i.y >= fp.y && i.z >= fp.z && i.x <= tp.x && i.y <= tp.y && i.z <= tp.z) { s++; } } return s; }}//快速方法class Fast { int[][][] a, c; Fast(Point[] p) { //a[i,j,k]表示i,j,k处的点的个数 a = new int[N][N][N]; //c[i,j,k]表示000到ijk处的点的总数 c = new int[N][N][N]; for (Point i : p) { a[i.x][i.y][i.z]++; } for (int i = 1; i < N; i++) { for (int j = 1; j < N; j++) { for (int k = 1; k < N; k++) { c[i][j][k] = c[i - 1][j][k] + c[i][j - 1][k] + c[i][j][k - 1] - c[i - 1][j - 1][k] - c[i - 1][j][k - 1] - c[i][j - 1][k - 1] + c[i - 1][j - 1][k - 1] - a[i][j][k]; } } } } int solve(Point fp, Point tp) { int s = 0; for (int i = 0; i < 8; i++) { //i各个bit int one = i & 1, two = (i >> 1) & 1, three = (i >> 2) & 1; int x = tp.x, y = tp.y, z = tp.z; //符号位 int sgn = (one ^ two ^ three) == 0 ? -1 : 1; if (one != 0) x = fp.x - 1; if (two != 0) y = fp.y - 1; if (three != 0) z = fp.z - 1; s += c[x][y][z] * sgn; } return s; }}Main() { Point[] p = generateProblem(); Stupid stupid = new Stupid(p); Fast fast = new Fast(p); for (int i = 0; i < QUESTION_COUNT; i++) { //生成一对点作为查询区间,起始点的各个坐标必须小于终结点的各个坐标 Point fp = randomPoint(), tp = randomPoint(); if (fp.x > tp.x) { int temp = fp.x; fp.x = tp.x; tp.x = temp; } if (fp.y > tp.y) { int temp = fp.y; fp.y = tp.y; tp.y = temp; } if (fp.z > tp.z) { int temp = fp.z; fp.z = tp.z; tp.z = temp; } int realAns = stupid.solve(fp, tp); int mine = fast.solve(fp, tp); if (realAns != mine) { throw new RuntimeException("error on from=" + fp + ",to=" + tp + " " + realAns + " " + mine); } }}public static void main(String[] args) { new Main();}}