JAVA

[Java] 다익스트라(Dijkstra) 알고리즘 구현

myeongju 2022. 7. 25. 15:45
반응형

다익스트라 개념을 잘 모른다면 아래 포스팅을 읽고 오자.

https://myeongju00.tistory.com/52

 

[Algorithm] Java - 다익스트라(Dijkstra) 알고리즘

다익스트라 알고리즘이란? BFS와 DP를 활용한 최단경로 탐색 알고리즘이다. 즉, 그래프에서 출발점에서 목표점까지의 최단거리를 구할 때 사용하는 알고리즘이다. 이 과정에서 도착 정점 뿐만 아

myeongju00.tistory.com

구현에는 인접 행렬 방식, 우선순위 큐 방식이 있다.

노드의 개수를 V라고 할 때 인접 행렬 방식은 O( V^2 ) 우선순위 큐 방식은 O( V log V ) 의 시간 복잡도를 가진다.

따라서 알고리즘 문제풀이 시 우선순위 큐 방식을 주로 사용한다.

 

인접 행렬 방식

1. 그래프를 담을 클래스 선언 후 객체를 초기화한 후 가중치를 넣는다.

다익스트라에서는 초기 값을 0 대신 무한대를 사용하기 때문에 인접 행렬의 값을 int형의 MAX 값으로 초기화한다.

만약, 문제에서 int 범위 이상의 값을 사용한다면, 더 큰 값으로 잡아야 한다.

그 다음, 입력받은 가중치를 인접 행렬에 넣어준다.

 static int[][] map;
 static int V;
 public static void main(String[] args) throws IOException {
   BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
   StringTokenizer st = new StringTokenizer(br.readLine());
 
   V = Integer.parseInt(st.nextToken()); // 정점 개수
   int E = Integer.parseInt(st.nextToken()); // 간선 개수
 
   map = new int[V][V];
   for (int i = 0; i < V; i++) {
     Arrays.fill(map[i], Integer.MAX_VALUE);
   }
 
   for (int i = 0; i < E; i++) {
     st = new StringTokenizer(br.readLine());
 
     int start = Integer.parseInt(st.nextToken()); //출발 노드
     int end = Integer.parseInt(st.nextToken()); //도착 노드
     int weight = Integer.parseInt(st.nextToken()); //가중치
 
     map[start][end] = weight;
     map[end][start] = weight;
   }
 
   Dijkstra(0);
 }

 

2. 다익스트라 알고리즘을 구현한다.

  1. 최단 거리를 저장할 배열 및 노드 방문 여부 배열을 만든다. : 최단거리는 최솟값으로 바꿔줄 예정이므로, 초기값은 무한대로 설정한다.
  2. 인접 행렬에 초기화된 값을 아용하여 최단거리 배열을 초기화 한다. : 방문하지 않았고, 초기 값이 무한대가 아닐 때 최단 거리 배열을 업데이트 한다.
  3. 모든 노드를 돌며 최단 거리 배열 값이 가장 작은 노드를 선택해서 최소 거리를 비교한다.
 private static void Dijkstra(int node) {
     int[] distance = new int[V];          // 최단 거리를 저장할 변수
     boolean[] visited = new boolean[V];     // 해당 노드를 방문했는지 체크할 변수
 
     // distance값 초기화.
     for (int i = 0; i < V; ++i) {
       distance[i] = Integer.MAX_VALUE;
     }
 
     // 시작노드값 설정
     distance[node] = 0;
     visited[node] = true;
 
     // 연결노드 distance갱신
     for (int i = 0; i < V; ++i) {
       if (!visited[i] && map[node][i] != Integer.MAX_VALUE) {
         distance[i] = map[node][i];
       }
     }
 
     for (int i = 0; i < V - 1; ++i) { // 노드가 n개 있을 때 다익스트라를 위해서 반복수는 n-1번이면 된다.
       int min = Integer.MAX_VALUE;
       int min_index = -1;
 
       // 노드 최소값 찾기
       for (int j = 0; j < V; ++j) {
         if (!visited[j]) {
           if (distance[j] < min) {
             min = distance[j];
             min_index = j;
           }
         }
       }
 
       // 다른 노드를 거쳐서 가는 것이 더 비용이 적은지 확인
       visited[min_index] = true;
       for (int j = 0; j < V; ++j) {
         if (!visited[j] && map[min_index][j] != Integer.MAX_VALUE) {
           if (distance[min_index] + map[min_index][j] < distance[j]) {
             distance[j] = distance[min_index] + map[min_index][j];
           }
         }
       }
       
     }
 }

 

3. 전체 코드

 import java.io.BufferedReader;
 import java.io.IOException;
 import java.io.InputStreamReader;
 import java.util.Arrays;
 import java.util.StringTokenizer;
 
 public class DijkstraStudy {
     static int[][] map;
     static int V;
     public static void main(String[] args) throws IOException {
         BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
         StringTokenizer st = new StringTokenizer(br.readLine());
 
         V = Integer.parseInt(st.nextToken()); // 정점 개수
         int E = Integer.parseInt(st.nextToken()); // 간선 개수
 
         map = new int[V][V];
         for (int i = 0; i < V; i++) {
             Arrays.fill(map[i], Integer.MAX_VALUE);
         }
 
         for (int i = 0; i < E; i++) {
             st = new StringTokenizer(br.readLine());
 
             int start = Integer.parseInt(st.nextToken()); //출발 노드
             int end = Integer.parseInt(st.nextToken()); //도착 노드
             int weight = Integer.parseInt(st.nextToken()); //가중치
 
             map[start][end] = weight;
             map[end][start] = weight;
         }
 
         Dijkstra(0);
     }
 
     private static void Dijkstra(int node) {
         int[] distance = new int[V];          // 최단 거리를 저장할 변수
         boolean[] visited = new boolean[V];     // 해당 노드를 방문했는지 체크할 변수
 
         // distance값 초기화.
         Arrays.fill(distance, Integer.MAX_VALUE);
 
         // 시작노드값 초기화.
         distance[node] = 0;
         visited[node] = true;
 
         // 결과값 출력
         for (int i = 0; i < V; ++i) {
             if (distance[i] == Integer.MAX_VALUE) System.out.print("∞ ");
             else System.out.print(distance[i] + " ");
         }
         System.out.println();
 
         // 연결노드 distance갱신
         for (int i = 0; i < V; ++i) {
             if (!visited[i] && map[node][i] != Integer.MAX_VALUE) {
                 distance[i] = map[node][i];
             }
         }
 
         // 결과값 출력
         for (int i = 0; i < V; ++i) {
             if (distance[i] == Integer.MAX_VALUE) System.out.print("∞ ");
             else System.out.print(distance[i] + " ");
         }
         System.out.println();
 
         for (int i = 0; i < V - 1; ++i) {
             // 노드가 n개 있을 때 다익스트라를 위해서 반복수는 n-1번이면 된다.
             int min = Integer.MAX_VALUE;
             int min_index = -1;
 
             // 노드 최소값 찾기
             for (int j = 0; j < V; ++j) {
                 if (!visited[j]) {
                     if (distance[j] < min) {
                         min = distance[j];
                         min_index = j;
                     }
                 }
             }
 
             // 다른 노드를 거쳐서 가는 것이 더 비용이 적은지 확인
             visited[min_index] = true;
             for (int j = 0; j < V; ++j) {
                 if (!visited[j] && map[min_index][j] != Integer.MAX_VALUE) {
                     if (distance[min_index] + map[min_index][j] < distance[j]) {
                         distance[j] = distance[min_index] + map[min_index][j];
                     }
                 }
             }
 
             // 결과값 출력
             for (int j = 0; j < V; ++j) {
                 if (distance[j] == Integer.MAX_VALUE) System.out.print("∞ ");
                 else System.out.print(distance[j] + " ");
             }
             System.out.println();
         }
     }
 }

 

입력 값

6 9
0 1 7
0 2 9
0 5 14
1 2 10
1 3 15
2 3 11
2 5 2
3 4 6
4 5 9

과정을 처음부터 노드마다 출력해보면 결과는 다음과 같다.

 

우선순위 큐 방식 (연결리스트 사용)

1. 우선순위에 넣을 노드 클래스를 선언한다.

노드까지의 가중치와 노드의 인덱스를 객체로 넣는다.

가중치를 기준으로 Comparable을 선언하여 우선순위 큐의 판단 기준을 제공한다.

 static class Node implements Comparable<Node>{
     private int index;
     private int weight;
 
     public Node(int index, int weight) {
       this.index = index;
       this.weight = weight;
     }
 
     @Override
     public int compareTo(Node o) {
       return Integer.compare(this.weight, o.weight);
     }
 }

 

2. 연결리스트 생성

그래프를 표현할 때 보통은 2차원 배열로 선언을 많이 하는데, 2차원 배열보다는 List컬렉션을 통해서 구현하는 것을 추천한다. 왜냐하면, 모든 노드가 간선으로 연결된 것이 아니라면 2차원 배열은 간선이 존재하지 않는 경우의 값도 저장하지만, 2차원 List로 표현을 하면 간선이 존재하는 노드들끼리의 연결만 표현할 수 있기 때문이다.

  1. graph를 생성해서 정점의 수만큼 ArrayList를 넣어준다.
  2. 입력받은 시작 노드를 가져와서 도착 노드와 가중치를 Node 로 넣어준다. : 방향이 없기 때문에 양쪽 다 넣어준다.
 // 그래프를 표현 할 List
 static List<List<Node>> graph = new ArrayList<>();
 static int V;
 
 public static void main(String[] args) throws IOException {
   BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
   StringTokenizer st = new StringTokenizer(br.readLine());
 
   V = Integer.parseInt(st.nextToken()); // 정점 개수
   int E = Integer.parseInt(st.nextToken()); // 간선 개수
 
   for (int i = 0; i < V; i++) {
     graph.add(new ArrayList<>());
   }
 
   for (int i = 0; i < E; i++) {
     st = new StringTokenizer(br.readLine());
 
     int start = Integer.parseInt(st.nextToken()); //출발 노드
     int end = Integer.parseInt(st.nextToken()); //도착 노드
     int weight = Integer.parseInt(st.nextToken()); //가중치
 
     graph.get(start).add(new Node(end, weight));
     graph.get(end).add(new Node(start, weight));
   }
 }

 

3. 다익스트라 알고리즘 구현

  private static void Dijkstra(int index) {
    PriorityQueue<Node> pq = new PriorityQueue<>();
    int[] distance = new int[V];          // 최단 거리를 저장할 변수
 
    // distance값 초기화.
    Arrays.fill(distance, Integer.MAX_VALUE);
 
    // 시작노드값 초기화.
    distance[index] = 0;
    pq.offer(new Node(index, 0));
 
    while (!pq.isEmpty()) {
      // 큐에서 노드 꺼내기
      Node node = pq.poll();
      int nodeIndex = node.index; //꺼낸 노드의 인덱스
      int weight = node.weight;
 
      /**
              * 큐는 최단거리를 기준으로 오름차순 정렬되고 있습니다.
              * 만약 현재 꺼낸 노드의 거리가 최단거리테이블의 값보다 크다면 해당 노드는 이전에 방문된 노드이므로,
              * 해당노드와 연결 된 노드를 탐색하지 않고 큐에서 다음 노드를 꺼냅니다.*/
 
      if(weight > distance[nodeIndex]) {
        continue;
      }
 
      for (Node linkedNode : graph.get(nodeIndex)) {
        if(weight + linkedNode.weight < distance[linkedNode.index]) {
          //최단 테이블 갱신
          distance[linkedNode.index] = weight + linkedNode.weight;
          // 갱신 된 노드를 우선순위 큐에 넣어
          pq.offer(new Node(linkedNode.index, distance[linkedNode.index]));
        }
      }
 
      // 결과값 출력
      for (int i = 0; i < V; ++i) {
        if (distance[i] == Integer.MAX_VALUE) System.out.print("∞ ");
        else System.out.print(distance[i] + " ");
      }
      System.out.println();
    }
  }

 

4. 전체 코드

 import java.io.BufferedReader;
 import java.io.IOException;
 import java.io.InputStreamReader;
 import java.util.*;
 
 public class DijkstraStudy {
     // 그래프를 표현 할 List
     static List<List<Node>> graph = new ArrayList<>();
     static int V;
 
     public static void main(String[] args) throws IOException {
         BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
         StringTokenizer st = new StringTokenizer(br.readLine());
 
         V = Integer.parseInt(st.nextToken()); // 정점 개수
         int E = Integer.parseInt(st.nextToken()); // 간선 개수
 
         for (int i = 0; i < V; i++) {
             graph.add(new ArrayList<>());
         }
 
         for (int i = 0; i < E; i++) {
             st = new StringTokenizer(br.readLine());
 
             int start = Integer.parseInt(st.nextToken()); //출발 노드
             int end = Integer.parseInt(st.nextToken()); //도착 노드
             int weight = Integer.parseInt(st.nextToken()); //가중치
 
             graph.get(start).add(new Node(end, weight));
             graph.get(end).add(new Node(start, weight));
         }
 
         Dijkstra(0);
     }
 
     private static void Dijkstra(int index) {
         PriorityQueue<Node> pq = new PriorityQueue<>();
         int[] distance = new int[V];          // 최단 거리를 저장할 변수
 
         // distance값 초기화.
         Arrays.fill(distance, Integer.MAX_VALUE);
 
         // 시작노드값 초기화.
         distance[index] = 0;
         pq.offer(new Node(index, 0));
 
         while (!pq.isEmpty()) {
             // 큐에서 노드 꺼내기
             Node node = pq.poll();
             int nodeIndex = node.index; //꺼낸 노드의 인덱스
             int weight = node.weight;
             /**
              * 큐는 최단거리를 기준으로 오름차순 정렬되고 있습니다.
              * 만약 현재 꺼낸 노드의 거리가 최단거리테이블의 값보다 크다면 해당 노드는 이전에 방문된 노드이므로,
              * 해당노드와 연결 된 노드를 탐색하지 않고 큐에서 다음 노드를 꺼냅니다.*/
             if(weight > distance[nodeIndex]) {
                 continue;
             }
 
             for (Node linkedNode : graph.get(nodeIndex)) {
                 if(weight + linkedNode.weight < distance[linkedNode.index]) {
                     //최단 테이블 갱신
                     distance[linkedNode.index] = weight + linkedNode.weight;
                     // 갱신 된 노드를 우선순위 큐에 넣어
                     pq.offer(new Node(linkedNode.index, distance[linkedNode.index]));
                 }
             }
 
              //결과값 출력
             for (int i = 0; i < V; ++i) {
                 if (distance[i] == Integer.MAX_VALUE) System.out.print("∞ ");
                 else System.out.print(distance[i] + " ");
             }
             System.out.println();
         }
     }
 
     static class Node implements Comparable<Node>{
         private int index;
         private int weight;
 
         public Node(int index, int weight) {
             this.index = index;
             this.weight = weight;
         }
 
         @Override
         public int compareTo(Node o) {
             return Integer.compare(this.weight, o.weight);
         }
     }
 }

입력값은 인접 행렬 방식과 같고, 이번에도 과정을 다 출력해보았다.

반응형