들어가며

  • 이전에는 AVL트리의 벨런스를 맞추는데 사용되는 연산 전반과 트리에 노드를 삽입하는 연산까지 구현하였다.
  • 이번에는 AVL 트리에서 노드를 삭제하는 연산과 그 결과로 인해 벨런스가 맞지 않을 경우 어떻게 동작하는지에 대해서 구현해보겠다.
  • 직전까지 구현한 코드는 아래와 같다.
import java.util.*;

class Node{
    int data;
    int height;
    Node left;
    Node right;
    Node(int data){
        this.data = data;
        this.left = null;
        this.right = null;
        this.height = 0;
    }
}

class AVL_Tree {
    Node root;

    AVL_Tree(){
        this.root = null;
    }

    private int getHeight(Node node){
        if (node == null){
            return -1;
        }

        return node.height;
    }

    private int getBalance(Node node){
        if (node == null){
            return 0;
        }

        return getHeight(node.left) - getHeight(node.right);
    }

    private Node rightRotate(Node node){
        Node lNode = node.left;

        node.left = lNode.right;
        lNode.right = node;

        node.height = Math.max(getHeight(node.left), getHeight(node.right)) + 1;
        lNode.height = Math.max(getHeight(lNode.left), getHeight(lNode.right)) + 1;

        return lNode;
    }

    private Node leftRotate(Node node){
        Node rNode = node.right;

        node.right = rNode.left;
        rNode.left = node;

        node.height = Math.max(getHeight(node.left), getHeight(node.right)) + 1;
        rNode.height = Math.max(getHeight(rNode.left), getHeight(rNode.right)) + 1;

        return rNode;
    }

    private Node LR_Rotate(Node node){
        node.left = leftRotate(node.left);
        return rightRotate(node);
    }

    private Node RL_Rotate(Node node){
        node.right = rightRotate(node.right);
        return leftRotate(node);
    }

    public void addNode(int key){
        if (searchNode(this.root, key)){
            return;
        }

        this.root = add(this.root, key);
    }

    private Node add(Node cur, int key){
        if (cur == null){
            return new Node(key);
        }

        if (cur.data < key){
            cur.right = add(cur.right, key);
        } else{
            cur.left = add(cur.left, key);
        }

        cur.height = Math.max(getHeight(cur.left), getHeight(cur.right)) + 1;

        int balance = getBalance(cur);

        if (balance > 1 && cur.data < key){
            return LR_Rotate(cur);
        }

        if (balance > 1 && cur.data > key){
            return rightRotate(cur);
        }

        if (balance < -1 && cur.data < key){
            return leftRotate(cur);
        }

        if (balance < -1 && cur.data > key){
            return RL_Rotate(cur);
        }

        return cur;
    }

    public boolean searchNode(Node cur, int key){
        
    }

    public void removeNode(int key){
        
    }

    private Node remove(Node cur, int key){
        
    }

    public void levelTraversal(){
        Deque<Node> queue = new ArrayDeque<>();

        queue.addLast(this.root);

        while (!queue.isEmpty()){
            Node cur = queue.pollFirst();

            System.out.print(cur.data + " ");

            if (cur.left != null) queue.addLast(cur.left);
            if (cur.right != null) queue.addLast(cur.right);
        }

        System.out.println();
    }

    public void traversal(Node cur){
        if (cur == null){
            return;
        }

        traversal(cur.left);
        System.out.print(cur.data + " ");
        traversal(cur.right);
    }
}

 

 

 

트리의 삭제 연산

  • 삭제 연산 역시 삽입 연산을 이전에 재귀로 구현한 것 처럼 동일하게 재귀로 구현할 것이다. 그를 위해 노드가 삭제된 트리를 갱신하기 위한 추가적인 함수가 필요하다.
  • 이는 아래와 같이 구현된다. 
public void removeNode(int key){
    if (!searchNode(this.root, key)){
        return;
    }

    this.root = remove(this.root, key);
}
  • 노드의 삭제 연산은 이진 탐색 트리에서의 삭제 연산과 동일하다.먼저 삭제할 노드의 위치를 찾고 해당 노드의 자식 유무를 확인하고 자식 유무에 따라 삭제 처리를 다르게 구현한다.
private Node remove(Node cur, int key){
    if (cur == null){
        return null;
    }

    if (cur.data < key){
        cur.right = remove(cur.right, key);
    } else if (cur.data > key){
        cur.left = remove(cur.left, key);
    } else{
        if (cur.left == null){
            return cur.right;
        } else if (cur.right == null){
            return cur.left;
        } else{
            Node predecessor = cur;
            Node successor = cur.left;

            while (successor.right != null){
                predecessor = successor;
                successor = successor.right;
            }

            if (predecessor == cur){
                cur.data = successor.data;
                cur.left = successor.left;
            } else{
                cur.data = successor.data;
                predecessor.right = successor.left;
            }
        }
    }
}
  • 언듯 난해해 보이지만 이전에 구현했던 반복문을 재귀화한것이다.
  • 삭제 처리를 한 후 다시 거슬러 올라가며 재귀로 호출되었던 모든 노드의 높이를 계산한다. 그리고 벨런스 역시 계산한다.
  • 이때 삽입 연산과 달라지는점이 있는데 삽입 연산은 삽입되는 위치로 인해 불균형이 발생하였기 때문에 값의 비교로 LR, LL. RL, RR의 연산을 판단할 수 있었지만 삭제 연산은 그럴 수 없다. 따라서 해당 노드의 좌측 서브트리의 높이와 우측 서브트리의 높이를 확인하여 불균형이 발생하였는지 체크해야 한다.
    • 예시를 들어 벨런스가 1보다 크다고 가정하자. 이는 좌측 서브트리에 문제가 생긴 것이다. 그렇다면 좌측 서브트리의 벨런스 상태를 확인하여 1보다 크다면 다시 그 좌측 서브트리에 문제가 있는 것이므로 LL 연산을 수행하면 된다.
    • 만약 좌측 서브트리의 벨런스 상태가 -1보다 작다면, 그 우측 서브트리에 문제가 있다는 것이므로 LR 연산을 수행해야 한다.
private Node remove(Node cur, int key){
    if (cur == null){
        return null;
    }

    if (cur.data < key){
        cur.right = remove(cur.right, key);
    } else if (cur.data > key){
        cur.left = remove(cur.left, key);
    } else{
        if (cur.left == null){
            return cur.right;
        } else if (cur.right == null){
            return cur.left;
        } else{
            Node predecessor = cur;
            Node successor = cur.left;

            while (successor.right != null){
                predecessor = successor;
                successor = successor.right;
            }

            if (predecessor == cur){
                cur.data = successor.data;
                cur.left = successor.left;
            } else{
                cur.data = successor.data;
                predecessor.right = successor.left;
            }
        }
    }
    
    cur.height = Math.max(getHeight(cur.left), getHeight(cur.right)) + 1;

    int balance = getBalance(cur);

    if (balance > 1 && getBalance(cur.left) < 0){
        return LR_Rotate(cur);
    }

    if (balance > 1 && getBalance(cur.left) > 0){
        return rightRotate(cur);
    }

    if (balance < -1 && getBalance(cur.right) < 0){
        return leftRotate(cur);
    }

    if (balance < -1 && getBalance(cur.right) > 0){
        return RL_Rotate(cur);
    }

    return cur;
}
  • 삭제 연산의 구현이 종료되었다. 유심히 보면 단지 이진 탐색 트리의 연산에 벨런스 체크를 수행하고 그에 따라 연산을 진행한다는 부분만 틀리고 나머진 모두 동일하다는 것을 알 수 있다.

 

 

노드의 탐색

  • 노드의 탐색은 딱히 달라진 것이 없으며 이전에는 반복문을 통해 구현한 것을 재귀로 구현한 것이 다르다.
public boolean searchNode(Node cur, int key){
    if (cur == null){
        return false;
    }

    if (cur.data == key){
        return true;
    }

    if (cur.data < key){
        return searchNode(cur.right, key);
    } else{
        return searchNode(cur.left, key);
    }
}
  • 물론 이것도 루트노드를 매개변수에 넣을 필요 없이 키 값으로만 찾을 수 있도록 인터페이스격인 메서드를 만들어도 된다. 그렇다면 결과는 다음과 같을 것이다.
public boolean searchNode(int key){
	return search(this.root, key);
}

private boolean search(Node cur, int key){
    if (cur == null){
        return false;
    }

    if (cur.data == key){
        return true;
    }

    if (cur.data < key){
        return searchNode(cur.right, key);
    } else{
        return searchNode(cur.left, key);
    }
}

 

 

 

정리

  • 이렇게 AVL트리의 모든 기능 구현을 마쳤다. 기존 이진 탐색 트리의 경우 데이터가 정렬된 상태로 들어왔을 때 최악의 탐색 효율이 나올 수 있지만 AVL트리는 내부 균형 유지 연산을 통해 정렬된 상태로 들어와도 빠른 탐색을 보장한다.
  • 전체적인 코드는 아래와 같다.
import java.util.*;

class Node{
    int data;
    int height;
    Node left;
    Node right;
    Node(int data){
        this.data = data;
        this.left = null;
        this.right = null;
        this.height = 0;
    }
}

class AVL_Tree {
    Node root;

    AVL_Tree(){
        this.root = null;
    }

    private int getHeight(Node node){
        if (node == null){
            return -1;
        }

        return node.height;
    }

    private int getBalance(Node node){
        if (node == null){
            return 0;
        }

        return getHeight(node.left) - getHeight(node.right);
    }

    private Node rightRotate(Node node){
        Node lNode = node.left;

        node.left = lNode.right;
        lNode.right = node;

        node.height = Math.max(getHeight(node.left), getHeight(node.right)) + 1;
        lNode.height = Math.max(getHeight(lNode.left), getHeight(lNode.right)) + 1;

        return lNode;
    }

    private Node leftRotate(Node node){
        Node rNode = node.right;

        node.right = rNode.left;
        rNode.left = node;

        node.height = Math.max(getHeight(node.left), getHeight(node.right)) + 1;
        rNode.height = Math.max(getHeight(rNode.left), getHeight(rNode.right)) + 1;

        return rNode;
    }

    private Node LR_Rotate(Node node){
        node.left = leftRotate(node.left);
        return rightRotate(node);
    }

    private Node RL_Rotate(Node node){
        node.right = rightRotate(node.right);
        return leftRotate(node);
    }

    public void addNode(int key){
        if (searchNode(key)){
            return;
        }

        this.root = add(this.root, key);
    }

    private Node add(Node cur, int key){
        if (cur == null){
            return new Node(key);
        }

        if (cur.data < key){
            cur.right = add(cur.right, key);
        } else{
            cur.left = add(cur.left, key);
        }

        cur.height = Math.max(getHeight(cur.left), getHeight(cur.right)) + 1;

        int balance = getBalance(cur);

        if (balance > 1 && cur.data < key){
            return LR_Rotate(cur);
        }

        if (balance > 1 && cur.data > key){
            return rightRotate(cur);
        }

        if (balance < -1 && cur.data < key){
            return leftRotate(cur);
        }

        if (balance < -1 && cur.data > key){
            return RL_Rotate(cur);
        }

        return cur;
    }

    public boolean searchNode(int key){
        return search(this.root, key);
    }

    private boolean search(Node cur, int key){
        if (cur == null){
            return false;
        }

        if (cur.data == key){
            return true;
        }

        if (cur.data < key){
            return search(cur.right, key);
        } else{
            return search(cur.left, key);
        }
    }

    public void removeNode(int key){
        if (!search(this.root, key)){
            return;
        }

        this.root = remove(this.root, key);
    }

    private Node remove(Node cur, int key){
        if (cur == null){
            return null;
        }

        if (cur.data < key){
            cur.right = remove(cur.right, key);
        } else if (cur.data > key){
            cur.left = remove(cur.left, key);
        } else{
            if (cur.left == null){
                return cur.right;
            } else if (cur.right == null){
                return cur.left;
            } else{
                Node predecessor = cur;
                Node successor = cur.left;

                while (successor.right != null){
                    predecessor = successor;
                    successor = successor.right;
                }

                if (predecessor == cur){
                    cur.data = successor.data;
                    cur.left = successor.left;
                } else{
                    cur.data = successor.data;
                    predecessor.right = successor.left;
                }
            }
        }

        cur.height = Math.max(getHeight(cur.left), getHeight(cur.right)) + 1;

        int balance = getBalance(cur);

        if (balance > 1 && getBalance(cur.left) < 0){
            return LR_Rotate(cur);
        }

        if (balance > 1 && getBalance(cur.left) > 0){
            return rightRotate(cur);
        }

        if (balance < -1 && getBalance(cur.right) < 0){
            return leftRotate(cur);
        }

        if (balance < -1 && getBalance(cur.right) > 0){
            return RL_Rotate(cur);
        }

        return cur;
    }

    public void levelTraversal(){
        Deque<Node> queue = new ArrayDeque<>();

        queue.addLast(this.root);

        while (!queue.isEmpty()){
            Node cur = queue.pollFirst();

            System.out.print(cur.data + " ");

            if (cur.left != null) queue.addLast(cur.left);
            if (cur.right != null) queue.addLast(cur.right);
        }

        System.out.println();
    }

    public void traversal(Node cur){
        if (cur == null){
            return;
        }

        traversal(cur.left);
        System.out.print(cur.data + " ");
        traversal(cur.right);
    }
}

public class StudyCode {
    public static void main(String[] args){
        AVL_Tree avlt = new AVL_Tree();

        avlt.addNode(10);
        avlt.addNode(20);
        avlt.levelTraversal();
        avlt.addNode(30);
        avlt.addNode(40);
        avlt.addNode(50);
        avlt.levelTraversal();

        System.out.println(avlt.searchNode(50));

        avlt.traversal(avlt.root);
        System.out.println();

        avlt.removeNode(40);
        avlt.removeNode(10);

        avlt.levelTraversal();

    }
}
복사했습니다!