834. Sum of Distances in Tree
834. Sum of Distances in Tree
题目: https://leetcode.com/problems/sum-of-distances-in-tree/
难度: Hard
题意:
- 给定一棵树
- 问每个节点到其他节点的距离之和是多少
- N <= 10000
思路:
-
最简单的方法,假设我们要求节点0与其他节点的距离之和,只需要一遍dfs即可,其他节点同理。复杂度是o(n^2)
-
注意到,树中每个边都是一个桥(就是去掉这个边图就变成非连通了),树中的任意两个节点都是一个固定的路径。当我们计算不同节点的dfs,其实是有一部分路径是重复计算的(自己拿笔画一下)
-
来看看分析的思路。从问题出发,要问每个节点到其他节点的距离和,想象这个点A是树根(每个点都可以是树根),那么我们遍历一下所有连接这个节点的边。想象一下某个边e连接的是一个子树,点A跟这个子树所有的点都得经过边e,因此我们把问题下推,计算这个子树的树根与其他节点的距离之和,然后加上这个子树的数量就等于点A跟这个子树所有的点的距离和
-
有个点需要注意的,由于边是无向的,我们动态规划的方向也需要两个方向,怎么理解呢。比如说,假设A和B连着。以A为树根计算子问题的时候,顺序是A->B,以B为树根计算子问题的时候,顺序是B->A,相当于对边进行动态规划,规划两个方向
-
解法知道了,就需要编码。注意到,N <= 10000,如果递归的话,小心stack over flow。需要自己模拟栈
代码:
class Solution {
private class Edge {
private int x;
private int y;
private int num;
private int dist;
public Edge(int x, int y, int num, int dist) {
this.x = x;
this.y = y;
this.num = num;
this.dist = dist;
}
public int getX() {
return x;
}
public void setX(int x) {
this.x = x;
}
public int getY() {
return y;
}
public void setY(int y) {
this.y = y;
}
public int getNum() {
return num;
}
public void setNum(int num) {
this.num = num;
}
public int getDist() {
return dist;
}
public void setDist(int dist) {
this.dist = dist;
}
}
private void solve(Edge edge, List[] edgeList) {
if (edge.num != -1) {
return;
}
LinkedList<Edge> queue = new LinkedList<Edge>();
Stack<Edge> stack = new Stack<Edge>();
queue.addLast(edge);
stack.add(edge);
while (!queue.isEmpty()) {
Edge e = queue.pollFirst();
for (Edge next: (List<Edge>) edgeList[e.y]) {
if (next.y == e.x) {
continue;
}
if (next.num != -1) {
continue;
}
queue.addLast(next);
stack.add(next);
}
}
while (!stack.empty()) {
Edge e = stack.pop();
e.num = e.dist = 0;
for (Edge next: (List<Edge>) edgeList[e.y]) {
if (next.y == e.x) {
continue;
}
e.num += next.num;
e.dist += next.num + next.dist;
}
e.num ++;
}
}
public int[] sumOfDistancesInTree(int N, int[][] edges) {
List[] edgeList = new List[N];
for (int i = 0;i < edgeList.length;i++) {
edgeList[i] = new ArrayList<Edge>();
}
for (int i = 0;i < edges.length;i++) {
edgeList[edges[i][0]].add(new Edge(edges[i][0], edges[i][1], -1, -1));
edgeList[edges[i][1]].add(new Edge(edges[i][1], edges[i][0], -1, -1));
}
int[] result = new int[N];
for (int i = 0;i < N;i++) {
result[i] = 0;
for (Edge next: (List<Edge>) edgeList[i]) {
if (next.dist == -1) {
solve(next, edgeList);
}
result[i] += next.dist + next.num;
}
}
return result;
}
}