알고리즘&문제풀이

[Softeer] 21년 재직자 대회 본선 - 거리 합 구하기

CSE 2025. 12. 12. 17:00

문제

<배경>

현호는 사내 네트워크 분석 업무를 담당하게 되었다.

현재 사내 네트워크는 $N$개의 노드를 가지는 트리 형태의 네트워크인데,

이 말은 두 노드간의 연결이 정확히 $N-1$개 있어서 이 연결만으로 모든 노드간에 통신을 할 수 있다는 뜻이다.

 

각 노드에 $1$에서 $N$사이의 번호를 붙이면 $i$번째 연결은 $x_i$번 노드와 $y_i$번 노드를 양방향으로 연결하며, 통신에 걸리는 시간은 $t_i$이다.

$D(i,j)$ $i$번 노드와 $j$번 노드 사이의 거리를 나타내는데, $i$번 노드에서 여러 연결을 거쳐 $j$번 노드에 도달하기 위해 걸리는 최소 시간이다.

노드를 들를 때 추가적인 작업이 없는 이상적인 시간을 따진다.

 

현호는 네트워크 분석을 위해 어떤 노드를 기준으로 다른 모든 노드 사이와의 거리의 합을 알고 싶다.

즉, $\sum_{j=1}^N D(i,j)$을 알고 싶다.

 

<입력>

<출력>

$N$개의 줄에 걸쳐서, $i$번째 줄에는 $i$번 노드와 다른 모든 노드 사이의 거리의 합, $\sum_{j=1}^N D(i,j)$를 출력한다.

 

<예시>

아래 그림에서 $D(1,j)$을 구해보자.

아래처럼 계산이 된다.

 


풀이

<알고리즘 설명>

위 예시에서 $D(1,j)$의 합은 38이다.

모든 노드에 대해 이렇게 일일히 계산하면 문제의 제한을 넘어간다.

따라서 subtree size라는 속성을 하나 만들자.

이 값은 자신을 포함한 하위 노드의 개수를 의미한다.

예를 들면, 1번노드는 7, 2번 노드는 1, 3번 노드는 3인 셈이다.

 

이 값을 이용해 연산량을 줄여보자,

2번 노드의 입장에서 보면,

1번노드에서 5의 거리를 더 이동해야 자신에게 올 수 있다.

다른 노드 6개가 모두 5의 거리를 더 이동해야하므로 38에 5x6을 더하고,

반대로 D(1,2)는 5에서 0이 되므로 5x1을 빼줘야 한다.

 

즉 38 + 5x6 - 5x1 처럼 계산할 수 있게 된다.

 

3번 노드에 대해 한 번 더 계산해보자.

1,2,4,7번 노드에 대해서는 1번 노드가 기준이었을 때 대비 길이가 2씩 증가하고, (+4x2)

3,5,6번 노드에 대해서는 2씩 감소한다.(-3x2)

 

<코드>

일단 입력 처리를 먼저한다.

N = int(input())
node =[[] for _ in range(N+1)]
subtreesize = [0] * (N+1)
distsum = [0] * (N+1)

for i in range(N-1):
    x,y,t = map(int,input().split())
    node[x].append([y,t])
    node[y].append([x,t])

 

이제 두 개의 핵심 함수를 만들어야한다.

dfs1은 subtreesize를 계산하는 함수이다.

또 distsum을 계산하는데, 이 값은 자신의 하위노드들에 대해서만 자신과의 거리 합을 의미한다.

따라서 처음 루트 노드에 대해서만 맞는 값이 들어간다.

다른 노드의 값은 실제와 distsum의 값이 틀려도 상관없는게, 우리가 필요한 것은 루트 노드에서의

다른 노드까지의 거리 합이기 때문이다.

 

dfs2에서는

dfs1에서 구한 1번 노드의 distsum값에서 위에서 설명한 비용을 더하고 빼는 작업을 통해,

하위 노드들에게 올바른 값을 바로 계산할 수 있도록 한다.

 

current에서 child노드로 오면서 추가되는 비용 : weight * (N - subtreesize[child])

child의 하위노드들과는 current보다 더 가까워 지는 것이므로 빠지는 비용 : weight * (subtreezise[child])

 

이므로 weight * (N - 2*subtreesize[child])가 된다.

def dfs1(current, parent): 
    subtreesize[current] = 1 #자기자신
    for i in range(len(node[currnet])): #자신한테서 나가는 모든 자식노드에 대해 
        child = node[current][i][0] #자식 노드의 번호
        weight = node[current][i][1] #가중치
        if child != parent:
            dfs1(child, current)
            distsum[current] += distsum[child] + subtreesize[child]*weight
            subtreesize[current] += subtreesize[child] #자식들의 subtreesize 합치기
    return 

def dfs2(current, parent):
    for i in range(len(node[currnet])): #자신한테서 나가는 모든 노드에 대해 
        child = node[current][i][0] #자식 노드의 번호
        weight = node[current][i][1] #가중치
        if child != parent :
            distsum[child] = distsum[current] + weight * (N - 2*subtreesize[child])
            dfs2(child, current)
    return

 

최종적으로 재귀제한을 풀어주고 실행하면 정답이다.

import sys
sys.setrecursionlimit(10**6)

dfs1(1,1)
dfs2(1,1)
for i in range(1, N+1):
    print(distsum[i])