Kruskal 알고리즘의 정리

  • Kruskal 알고리즘은 다음과 같은 방식을 따른다.
    • 가중치에 대해서 간선들을 정렬한다.
    • 각 정점에 대한 Union-Find 자료구조를 초기화시킨다.
    • 정렬된 간선들을 하나씩 택하면서 find를 통해 동일 집합에 포함되는지를 판단한다. 이때, 동일 집합이 아니라면, 가중치의 합에 해당 간선의 가중치를 더하고, 두 정점을 합친다(Union)
  • Union-Find를 어떻게 구현하는지는 앞서 설명한 글이 있으니 그것을 참고한다.

 

 

 

Kruskal 알고리즘의 구현

// 알고리즘 - 최소 신장 트리
// 크루스칼 알고리즘

import java.util.Arrays;

public class Main { // O(E logE)
    static int[] parents;
    static int[] size;
    public static int kruskal(int[][] data, int v, int e) {
        
    }

    // 해당 노드의 부모를 변경하는 메서드
    public static void union(int a, int b){
        
    }

    // 해당 노드의 최종적인 부모를 찾는 메서드
    public static int find(int a){
        
    }

    public static void main(String[] args) {
        // Test code
        int v = 7;
        int e = 10;
        int[][] graph = {{1, 3, 1}, {1, 2, 9}, {1, 6, 8}, {2, 4, 13}, {2, 5, 2}, {2, 6, 7}, {3, 4, 12}, {4, 7, 17}, {5, 6, 5}, {5, 7, 20}};

        System.out.println(kruskal(graph, v, e));
    }
}
  • main에서 그래프가 주어질 때 Kruskal의 알고리즘을 통하여 최소 신장 트리를 구하고, 그때의 가중치의 합을 출력해야 한다.
  • 먼저 각 정점에 대하여 disjoint-set을 정의해야 한다. 또한 우린 weighted union을 사용할 것이기 때문에 트리의 노드의 개수를 저장하는 size배열 역시 만들어야 한다.
  • Kruskal에서 가장 먼저 해야할 일은 위의 disjoint-set과 size의 초기화, 그리고 가중치에 대해 $E$를 오름차순으로 정렬하는 것이다.
// 알고리즘 - 최소 신장 트리
// 크루스칼 알고리즘

import java.util.Arrays;

public class Main { // O(E logE)
    static int[] parents;
    static int[] size;
    
    public static int kruskal(int[][] data, int v, int e) {
        int weightSum = 0;
        
        Arrays.sort(data, (x, y) -> x[2] - y[2]);
        
        parents = new int[v + 1];
        size = new int[v + 1];
        
        for(int i = 1; i < v + 1; i++){
            size[i] = 1;
            parents[i] = i;
        }
    }
}
  • 초기 각 집합의 크기는 1이고, 부모 노드는 자기 자신이다.
  • 이후에는 각 간선을 선택하여 두 노드가 동일한 집합인지를 확인하고, 동일한 집합이 아니라면, 안전한 간선이기 때문에 가중치의 합에 해당 간선의 가중치를 합하고, 두 노드를 union한다. 선택된 간선이 정점의 개수 - 1개라면, MST를 찾은 것이므로 가중치의 합을 반환한다.
// 알고리즘 - 최소 신장 트리
// 크루스칼 알고리즘

import java.util.Arrays;

public class Main { // O(E logE)
    static int[] parents;
    static int[] size;
    
    public static int kruskal(int[][] data, int v, int e) {
        int weightSum = 0;
        int cnt = 0;
        
        Arrays.sort(data, (x, y) -> x[2] - y[2]);
        
        parents = new int[v + 1];
        size = new int[v + 1];
        
        for(int i = 1; i < v + 1; i++){
            size[i] = 1;
            parents[i] = i;
        }
        
        for(int i = 0; i < data.length; i++){
            if(cnt == v - 1){
                return weightSum;
            }
        
            if(find(data[i][0]) != find(data[i][1])){
               weightSum += data[i][2];
               union(data[i][0], data[i][1]);
               cnt++;
            }
        }
        
        return weightSum;
    }
}
  • 이제 남은건 find 메서드와 union메서드를 구현하는 것만 남았다.
  • 이 둘의 구현은 이전 글에서 의사 코드로나마 확인했기 때문에 바로 구현을 시작할 수 있다.
// 알고리즘 - 최소 신장 트리
// 크루스칼 알고리즘

import java.util.Arrays;

public class Main { // O(E logE)
    static int[] parents;
    static int[] size;
    
    public static int kruskal(int[][] data, int v, int e) {
        ...
    }
    
    public static int find(int x){
        if(x == parents[x]){
            return x;
        }
        
        return parents[x] = find(x);
    }
    
    public static void union(int u, int v){
        int parentOfU = find(u);
        int parentOfV = find(v);
        
        if(size[parentOfU] < size[parentOfV]){
            parents[sizeOfU] = parentOfV;
            size[parentOfV] += size[parentOfU];
        } else{
            parents[sizeOfV] = parentOfU;
            size[parentOfU] += size[parentOfV];
        }
    }
}

 

 

 

 

정리

  • 이렇게 Kruskal의 모든 분석과 구현이 끝났다. MST의 개념과 일반화된 알고리즘부터, Kruskal에 어떻게 적용되고, 어떻게 구현되는지까지 살펴보았다.
  • 다음 글에서는 MST의 다른 알고리즘인 Prim 알고리즘에 대하여 알아볼 것이다.
  • 전체적인 Kruskal 알고리즘의 구현 코드는 아래와 같다.
// 알고리즘 - 최소 신장 트리
// 크루스칼 알고리즘

import java.util.Arrays;

public class Main { // O(E logE)
    static int[] parents;
    static int[] size;
    
    public static int kruskal(int[][] data, int v, int e) {
        int weightSum = 0;
        int cnt = 0;
        
        Arrays.sort(data, (x, y) -> x[2] - y[2]);
        
        parents = new int[v + 1];
        size = new int[v + 1];
        
        for(int i = 1; i < v + 1; i++){
            size[i] = 1;
            parents[i] = i;
        }
        
        for(int i = 0; i < data.length; i++){
            if(cnt == v - 1){
                return weightSum;
            }
        
            if(find(data[i][0]) != find(data[i][1])){
               weightSum += data[i][2];
               union(data[i][0], data[i][1]);
               cnt++;
            }
        }
        
        return weightSum;
    }
    
    public static int find(int x){
        if(x == parents[x]){
            return x;
        }
        
        return parents[x] = find(x);
    }
    
    public static void union(int u, int v){
        int parentOfU = find(u);
        int parentOfV = find(v);
        
        if(size[parentOfU] < size[parentOfV]){
            parents[sizeOfU] = parentOfV;
            size[parentOfV] += size[parentOfU];
        } else{
            parents[sizeOfV] = parentOfU;
            size[parentOfU] += size[parentOfV];
        }
    }
}

 

복사했습니다!