[이론] 다익스트라

자세한 설명은 kks 블로그를 보고 이해했습니다. 기억해야하는 부분만 정리해보겠습니다.

조건

  1. 간선은 항상 0 이상이다.(음수 간선의 경우 벨만 포드 알고리즘를 사용합니다.)

알고리즘

  1. 시작 노드를 고른다.
  2. 시작 노드 ~ 각 노드까지의 최단 거리를 저장하는 배열 dist는 생성한다.
  3. 시작노드가 0일 때 dist[0] =0, 나머지는 무한이다.
  4. 시작노드에서 아직 방문하지 않은 정점들의 dist[i]를 갱신한다.
  5. 시작노드에서 아직 방문하지 않은 정점 중 가장 거리(dist[curr] + d[curr][next])가 가장 가까운 곳을 방문한다.
  6. 해당 정점에서 인접하고 아직 방문하지 않은 정점들의 거리를 갱신한다.

특징

  • 한번 방문한 상태의 노드의 dist[i]는 갱신되지 않는다. 즉, 이미 방문을 했다면 그 노드까지의 최단 거리를 이미 구한 것이다.
    1. (가정) 아직 방문하지 않은 정점들 중 가장 dist 값이 작은 정점이 u고, 사실 dist[u]는 최단거리보다는 아직 크다.(아직 방문하지 않은 노드 중에서. 방문한 노드까지 포함해서 최소라면 가정 자체를 세울 수 없다. 그 자체로 최소값의 정의이기 때문이다.)
    2. 시작노드와 u 사이에 어떤 노드 v를 거쳐가는 더 빠른 경로가 존재한다.
    3. 즉, dist[v] > dist[v] + d[v][u]가 존재한다.
    4. 위 조건을 만족하는 v는 이미 방문 상태여야 한다. 방문하지 않은 v를 통해서 u까지 더 빨라지는 경로가 있다면, 방문하지 않은 정점 중 가장 u의 dist 값이 작다는 것에 모순이 생긴다.
    5. 그런데 v가 이미 방문상태였다면 ` dist[v] > dist[v] + d[v][u]`에 의해 이미 갱신되었어야 한다.
    6. 따라서 한 번 방문한 dist[u]는 매번 최단거리가 저장되어 있다.

구현

아직 방문하지 않은 정점들 중에서 dist 값이 가장 작은 정점을 찾아서 방문해야하는데, 이런 정점을 찾기 위해서 O(N-1)이 걸리고 N개의 노드에 대해서 반복해야하기 때문에 O(N^2)이 걸립니다. 이를 빠르게 하기 위해서 우선순위 큐를 사용합니다. 우선순위 큐에는 pair<dist[v],v>를 넣습니다. pair는 first를 먼저 기준으로 정렬하기 때문에 dist[v]가 가장 작은 값의 노드 번호 v를 사용합니다. 만약 top의 v가 이미 방문한 곳이라면 무시하고 그 다음의 top을 사용하면 됩니다.

시간 복잡도 :O(ElogV)

PQ에 pair $v^2$개가 들어와도 PQ는 O(logN)시간에 값을 처리하기 때문에, top 값을 처리하기 위해서 $O(logN^2) = O(2logN) = O(logN)$에 처리할 수 있습니다. 그리고 최대 간선E의 갯수만큼 top을 꺼내는 과정을 반복하기 때문에 O(ElogV)의 시간복잡도를 가집니다.

소스코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <bits/stdc++.h>
using namespace std;

using ll = long long;
using pii = pair<int, int>;

const int MAX = 1001;

int n, m, s, d; // 노드 갯수, 간선 갯수, 시작 노드
vector<pii> adj[MAX];
int dist[MAX];
bool visited[MAX];
priority_queue<pii, vector<pii>, greater<pii> > pq;
int main()
{
    //freopen("input.txt", "r", stdin);
    ios_base::sync_with_stdio(0); cin.tie(0);
    // 초기화
    fill(dist, dist + MAX, INT_MAX);
    fill(visited, visited + MAX, false);
    // 입력 및 인접리스트 저장
    cin >> n >> m;
    int u, v, w;
    for (int i = 0; i < m; i++) {
        cin >> u >> v >> w;
        adj[u - 1].push_back(make_pair(v - 1, w));
    }
    cin >> s >> d;
    s--, d--;
    // 시작 노드
    dist[s] = 0;
    pq.push(make_pair(dist[s], s));
    while (!pq.empty()) {
        int curr = 0;
        // pq에서 방문하지 않은 top 찾기
        do {
            curr = pq.top().second;
            pq.pop();
        } while (!pq.empty() && visited[curr]);
        // 방문처리
        visited[curr] = true;

        // top에서 인접한 곳 중 더 짧아지는
        for (pii& p : adj[curr]) {
            int next = p.first;
            int weight = p.second;
            // 시작노드 - 노드 1 - 노드 2 - 노드 3 - 노드 1
            // 로 이동하는 경우 노드 3은 처음 방문이라서
            // curr에 할당되지만 노드3에 인접한 노드1은
            // 이미 pq에 해당 노드의 pair가 push 된 상태
            if (dist[next] > dist[curr] + weight) {
                dist[next] = dist[curr] + weight;
                pq.push(make_pair(dist[next], next));
            }
        }
    }
    // 결과 확인
    cout << dist[d] << "\n";

    return 0;
}

python O(N^2) 코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# 단방향 그래프인 경우
import sys

input = sys.stdin.readline
INF = int(1e9)

n, m = map(int, input().split())
start = int(input())
graph = [[] for i in range(n + 1)]
visited = [False] * (n + 1)
distance = [INF] * (n + 1)
for _ in range(m):
    a, b, c = map(int, input().split())
    # 노드 a에서 b로 가는 비용이 c
    graph[a].append((b, c))


# 방문하지 않은 노드 중에서, 가장 최단거리가 짧은 노드의 번호를 반환
def get_smallest_node():
    min_value = INF
    index = 0
    for i in range(1, n + 1):
        if distance[i] < min_value and not visited[i]:
            min_value = distance[i]
            index = i
    return index


def dijkstra(start):
    distance[start] = 0
    visited[start] = True
    # 노드 j[0]에서 j[1]로 가는 비용
    for i in graph[start]:
        distance[i[0]] = i[1]
    # 시작노드를 제외한 각 노드에 대해 n-1번 반복
    for i in range(n - 1):
        x = get_smallest_node()
        visited[x] = True
        # 각 노드에서의 최단거리 갱신
        for nx, weight in graph[x]:
            cost = distance[x] + weight
            if cost < distance[nx]:
                distance[nx] = cost


dijkstra(start)

for i in range(1, n + 1):
    if distance[i] == INF:
        print("INFINITY")
    else:
        print(distance[i])

python 최적 알고리즘

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# 단방향 그래프인 경우
import heapq
import sys

input = sys.stdin.readline
INF = int(1e9)

n, m = map(int, input().split())
start = int(input())
graph = [[] for i in range(n + 1)]
visited = [False] * (n + 1)
distance = [INF] * (n + 1)
for _ in range(m):
    a, b, c = map(int, input().split())
    # 노드 a에서 b로 가는 비용이 c
    graph[a].append((b, c))


def dijkstra(start):
    q = []
    heapq.heappush(q, (0, start))
    distance[start] = 0
    while q:
        x_weight, x = heapq.heappop(q)
        if distance[x] < x_weight:
            continue
        for nx, nx_weight in graph[x]:
            cost = x_weight + nx_weight
            if cost < distance[nx]:
                distance[nx] = cost
                heapq.heappush(q, (cost, nx))


dijkstra(start)

for i in range(1, n + 1):
    if distance[i] == INF:
        print("INFINITY")
    else:
        print(distance[i])

Success Notice: 수고하셨습니다. :+1:

Leave a comment