堆有两个特性:
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
/**
*
* 功能描述: 最小堆的实现
*
* @version 2.0.0
* @author zhiminchen
*/
public class MinHeap<E extends Comparable<E>> {
// 用于存储数据E对象
private E[] data;
// 当前数组已经存储了多少个数据
private int size;
public MinHeap(int capacity) {
data = (E[]) new Comparable[capacity];
size = 0;
}
public MinHeap() {
this(16);
}
/**
* 传入一个array数组,构造堆
*
* @param array
*/
public MinHeap(E[] array) {
data = (E[]) new Comparable[array.length];
// 将array数据复制到data数组中
for (int i = 0; i < data.length; i++) {
data[i] = array[i];
}
size = array.length;
// 对堆中非叶子结点做siftDown操作, 这样堆的特性就维护好了
for (int i = parent(array.length - 1); i >= 0; i--) {
siftDown(i);
}
}
/**
*
* 功能描述: 堆是否为空的方法
*
* @return boolean
* @version 2.0.0
* @author zhiminchen
*/
public boolean isEmpty() {
return size == 0;
}
/**
*
* 功能描述: 返回当前堆的大小
*
* @return int
* @version 2.0.0
* @author zhiminchen
*/
public int size() {
return size;
}
/**
*
* 功能描述: 得到index的父亲结点
*
* @param index
* @return int
* @version 2.0.0
* @author zhiminchen
*/
private int parent(int index) {
if (index == 0)
throw new IllegalArgumentException("index-0 doesn't have parent.");
return (index - 1) / 2;
}
/**
*
* 功能描述: 得到index的左孩子结点
*
* @param index
* @return int
* @version 2.0.0
* @author zhiminchen
*/
private int left(int index) {
return 2 * index + 1;
}
/**
*
* 功能描述: 得到index下的右孩子结点
*
* @param index
* @return int
* @version 2.0.0
* @author zhiminchen
*/
private int right(int index) {
return 2 * index + 2;
}
/**
*
* 功能描述: 往堆中增加元素element
*
* @param element void
* @version 2.0.0
* @author zhiminchen
*/
public void add(E element) {
if (size >= data.length) {
resize(data.length * 2);
}
data[size] = element;
siftUp(size);
size++;
}
/**
*
* 功能描述: 删除堆顶的元素,并返回
*
* @return E
* @version 2.0.0
* @author zhiminchen
*/
public E remove() {
if (size <= 0) {
return null;
}
E ret = data[0];
// 将数组的最后一个元素放到队列的头
data[0] = data[size - 1];
data[size - 1] = null;
// 对size进行维护
size--;
// 对零位的数据进行下沉操作, 以保持堆的特性
siftDown(0);
return ret;
}
/**
*
* 功能描述: 查看堆中最小的元素
*
* @return E
* @version 2.0.0
* @author zhiminchen
*/
public E findMin() {
if (size == 0)
throw new IllegalArgumentException("Can not findMin when heap is empty.");
return data[0];
}
/**
*
* 功能描述: 对index的元素进行下沉操作,以保持推的特性
*
* @param index void
* @version 2.0.0
* @author zhiminchen
*/
private void siftDown(int index) {
int left = left(index);
// 递归退出条件 left>=size说明当前节点已经是叶子节点了
if (left >= size) {
return;
}
// 假设index的左节点是最小的节点
E min = data[left];
// 如果index节点有右孩子并且右孩子比左孩子还小, index结点需要跟最小的节点进行交换
if (left + 1 < size && min.compareTo(data[left + 1]) > 0) {
min = data[left + 1];
left = left + 1;
}
// 当index节点的数据比最小的节点还大的时候,则进行交换,否则结束
if (data[index].compareTo(min) > 0) {
swap(index, left);
siftDown(left);
}
}
/**
*
* 功能描述: 对index元素进行上浮操作, 以维持队的特性
*
* @param size2 void
* @version 2.0.0
* @author zhiminchen
*/
private void siftUp(int index) {
// 递归退出的条件
if (index <= 0) {
return;
}
// index节点比父亲节点小, 则进行上浮
if (data[index].compareTo(data[parent(index)]) < 0) {
swap(index, parent(index));
// 递归调用,继续上浮
siftUp(parent(index));
}
}
/**
*
* 功能描述: 交换索引为a, b的数据
*
* @param a
* @param b void
* @version 2.0.0
* @author zhiminchen
*/
private void swap(int a,
int b) {
E temp = data[a];
data[a] = data[b];
data[b] = temp;
}
/**
*
* 功能描述: 对堆进行扩容
*
* @param capacity void
* @version 2.0.0
* @author zhiminchen
*/
private void resize(int capacity) {
E[] temp = (E[]) new Object[capacity];
for (int i = 0; i < data.length; i++) {
temp[i] = data[i];
}
data = temp;
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(String.format("heap size : %d capacity : %d ", size, data.length));
sb.append("[");
for (int i = 0; i < size; i++) {
sb.append(data[i]);
if (i != size - 1) {
sb.append(", ");
}
}
sb.append("]");
return sb.toString();
}
public static void main(String[] args) {
Random random = new Random();
int n = 100;
Integer array[] = new Integer[n];
for (int i = 0; i < 100; i++) {
array[i] = random.nextInt(Integer.MAX_VALUE);
}
MinHeap<Integer> heap = new MinHeap<Integer>(array);
System.out.println(heap);
List<Integer> list = new ArrayList<Integer>();
while (!heap.isEmpty()) {
list.add(heap.remove());
}
for (int i = 1; i < list.size(); i++) {
if (list.get(i - 1) > list.get(i)) {
throw new RuntimeException();
}
}
}
}