Updated:

1. 문제 풀이

2가지 방법을 사용해서 문제를 풀었다.

  1. 첫번째 풀이
    • 가중치가 있는 그래프에서 가장 먼 거리를 찾는 것이므로 dijkstra 알고리즘을 이용해 풀었다.
    • 각 노드마다 dijkstra알고리즘을 적용해야하기때문에 시간이 오래걸린다. (오래걸리지만 통과하였다.)
    • 시간복잡도: $O(N^2log N)$
  2. 두번째 풀이
    • 트리 구조를 이용해 풀었다.
    • 서브트리가 존재할 때, 해당 서브트리의 지름은 리프 노드들 간의 거리 중 하나이기 때문에 DFS를 이용해 재귀적 방법으로 풀었다.
    • 시간 복잡도: $O(NlogN)$, 새로운 노드를 DFS로 탐색할 때마다 자식 노드간 최장거리를 구하기 위해 정렬을 사용했기 때문에 $logN$이 추가 되었다.

내가 풀었던 방법외에도 $O(N)$ 풀이가 존재한다고 한다.

2. 코드

import sys
def input(): return sys.stdin.readline().rstrip()
sys.setrecursionlimit(100000)

def dfs(node):
    global result
    if not graph[node]: return 0
    ret = 0 # node가 부모노드인 서브 트리에서 node와 리프 노드 사이의 거리 중 최장 거리
    dist = [] # node가 부모노드인 서브 트리에서 node와 리프 노드 사이의 거리들을 담은 배열
    for nxt, c in graph[node]:
        if not visited[nxt]:
            visited[node] = 1
            nxt_d = dfs(nxt)
            dist.append(c+nxt_d)
            ret = max(ret, c+nxt_d)
    dist.sort() # 리프 노드들 사이의 거리 중 최장 거리가 해당 서브 트리의 지름이다. (트리 지름 후보)
    if len(dist) >= 2:
        result = max(result, dist[-2] + dist[-1])
    else:
        result = max(result, dist[-1])
    return ret

n = int(input())
graph = [[] for _ in range(n+1)]
for _ in range(n-1):
    a, b, c = map(int,input().split())
    graph[a].append((b,c))

result = 0
visited = [0] * (n+1)
dfs(1)
print(result)

Leave a comment