Skip to content

834. Sum of Distances in Tree

834. Sum of Distances in Tree

题目: https://leetcode.com/problems/sum-of-distances-in-tree/

难度: Hard

题意:

  1. 给定一棵树
  2. 问每个节点到其他节点的距离之和是多少
  3. N <= 10000

思路:

  • 最简单的方法,假设我们要求节点0与其他节点的距离之和,只需要一遍dfs即可,其他节点同理。复杂度是o(n^2)

  • 注意到,树中每个边都是一个桥(就是去掉这个边图就变成非连通了),树中的任意两个节点都是一个固定的路径。当我们计算不同节点的dfs,其实是有一部分路径是重复计算的(自己拿笔画一下)

  • 来看看分析的思路。从问题出发,要问每个节点到其他节点的距离和,想象这个点A是树根(每个点都可以是树根),那么我们遍历一下所有连接这个节点的边。想象一下某个边e连接的是一个子树,点A跟这个子树所有的点都得经过边e,因此我们把问题下推,计算这个子树的树根与其他节点的距离之和,然后加上这个子树的数量就等于点A跟这个子树所有的点的距离和

  • 有个点需要注意的,由于边是无向的,我们动态规划的方向也需要两个方向,怎么理解呢。比如说,假设A和B连着。以A为树根计算子问题的时候,顺序是A->B,以B为树根计算子问题的时候,顺序是B->A,相当于对边进行动态规划,规划两个方向

  • 解法知道了,就需要编码。注意到,N <= 10000,如果递归的话,小心stack over flow。需要自己模拟栈

代码:

class Solution {
     private class Edge {
        private int x;
        private int y;
        private int num;
        private int dist;

        public Edge(int x, int y, int num, int dist) {
            this.x = x;
            this.y = y;
            this.num = num;
            this.dist = dist;
        }

        public int getX() {
            return x;
        }

        public void setX(int x) {
            this.x = x;
        }

        public int getY() {
            return y;
        }

        public void setY(int y) {
            this.y = y;
        }

        public int getNum() {
            return num;
        }

        public void setNum(int num) {
            this.num = num;
        }

        public int getDist() {
            return dist;
        }

        public void setDist(int dist) {
            this.dist = dist;
        }
    }

    private void solve(Edge edge, List[] edgeList) {
        if (edge.num != -1) {
            return;
        }

        LinkedList<Edge> queue = new LinkedList<Edge>();
        Stack<Edge> stack = new Stack<Edge>();

        queue.addLast(edge);
        stack.add(edge);

        while (!queue.isEmpty()) {
            Edge e = queue.pollFirst();
            for (Edge next: (List<Edge>) edgeList[e.y]) {
                if (next.y == e.x) {
                    continue;
                }
                if (next.num != -1) {
                    continue;
                }
                queue.addLast(next);
                stack.add(next);
            }
        }

        while (!stack.empty()) {
            Edge e = stack.pop();
            e.num = e.dist = 0;
            for (Edge next: (List<Edge>) edgeList[e.y]) {
                if (next.y == e.x) {
                    continue;
                }
                e.num += next.num;
                e.dist += next.num + next.dist;
            }
            e.num ++;
        }
    }

    public int[] sumOfDistancesInTree(int N, int[][] edges) {
        List[] edgeList = new List[N];
        for (int i = 0;i < edgeList.length;i++) {
            edgeList[i] = new ArrayList<Edge>();
        }

        for (int i = 0;i < edges.length;i++) {
            edgeList[edges[i][0]].add(new Edge(edges[i][0], edges[i][1], -1, -1));
            edgeList[edges[i][1]].add(new Edge(edges[i][1], edges[i][0], -1, -1));
        }

        int[] result = new int[N];
        for (int i = 0;i < N;i++) {
            result[i] = 0;
            for (Edge next: (List<Edge>) edgeList[i]) {
                if (next.dist == -1) {
                    solve(next, edgeList);
                }
                result[i] += next.dist + next.num;
            }
        }
        return result;
    }
}


回到顶部