此文不对理论做相关阐述,仅涉及代码实现:
1.熵计算公式:
P为正例,Q为反例
Entropy(S) = -PLog2(P) - QLog2(Q);
2.信息增量计算:
Gain(S,Sv) = Entropy(S) - (|Sv|/|S|)ΣEntropy(Sv);
举例:
转化数据输入:
5 14
Outlook Sunny Sunny Overcast Rain Rain Rain Overcast Sunny Sunny Rain Sunny Overcast Overcast Rain
Temperature Hot Hot Hot Mild Cool Cool Cool Mild Cool Mild Mild Mild Hot Mild
Humidity High High High High Normal Normal Normal High Normal Normal Normal High Normal High
Wind Weak Strong Weak Weak Weak Strong Strong Weak Weak Weak Strong Strong Weak Strong
PlayTennis No No Yes Yes Yes No Yes No Yes Yes Yes Yes Yes No
Outlook Temperature Humidity Wind PlayTennis
1 package com.qunar.data.tree;
2
3 /**
4 * *********************************************************
5 * <p/>
6 * Author: XiJun.Gong
7 * Date: 2016-09-02 15:28
8 * Version: default 1.0.0
9 * Class description:
10 * <p>统计该类型出现的次数</p>
11 * <p/>
12 * *********************************************************
13 */
14 public class CountMap<T> {
15
16 private T key; //类型
17 private int value; //出现的次数
18
19 public CountMap() {
20 this(null, 0);
21 }
22
23 public CountMap(T key, int value) {
24 this.key = key;
25 this.value = value;
26 }
27
28 public T getKey() {
29 return key;
30 }
31
32 public void setKey(T key) {
33 this.key = key;
34 }
35
36 public int getValue() {
37 return value;
38 }
39
40 public void setValue(int value) {
41 this.value = value;
42 }
43 }
1 package com.qunar.data.tree;
2
3 import com.google.common.collect.ArrayListMultimap;
4 import com.google.common.collect.Maps;
5 import com.google.common.collect.Multimap;
6 import com.google.common.collect.Sets;
7
8 import java.util.*;
9
10 /**
11 * *********************************************************
12 * <p/>
13 * Author: XiJun.Gong
14 * Date: 2016-09-02 14:24
15 * Version: default 1.0.0
16 * Class description:
17 * <p>决策树</p>
18 * <p/>
19 * *********************************************************
20 */
21
22 public class DecisionTree<T, K> {
23
24 private static String positiveExampleType = "Yes";
25 private static String counterExampleType = "No";
26
27
28 public double pLog2(final double p) {
29 if (0 == p) return 0;
30 return p * (Math.log(p) / Math.log(2));
31 }
32
33 /**
34 * 熵计算
35 *
36 * @param positiveExample 正例个数
37 * @param counterExample 反例个数
38 * @return 熵值
39 */
40 public double entropy(final double positiveExample, final double counterExample) {
41
42 double total = positiveExample + counterExample;
43 double positiveP = positiveExample / total;
44 double counterP = counterExample / total;
45 return -1d * (pLog2(positiveP) + pLog2(counterP));
46 }
47
48 /**
49 * @param features 特征列表
50 * @param results 对应结果
51 * @return 将信息整合成新的格式
52 */
53 public Multimap<T, CountMap<K>> merge(final List<T> features, final List<T> results) {
54 //数据转化
55 Multimap<T, CountMap<K>> InfoMap = ArrayListMultimap.create();
56 Iterator result = results.iterator();
57 for (T feature : features) {
58 K res = (K) result.next();
59 boolean tag = false;
60 Collection<CountMap<K>> countMaps = InfoMap.get(feature);
61 for (CountMap countMap : countMaps) {
62 if (countMap.getKey().equals(res)) {
63 /*修改值*/
64 int num = countMap.getValue() + 1;
65 InfoMap.remove(feature, countMap);
66 InfoMap.put(feature, new CountMap<K>(res, num));
67 tag = true;
68 break;
69 }
70 }
71 if (!tag)
72 InfoMap.put(feature, new CountMap<K>(res, 1));
73 }
74
75 return InfoMap;
76 }
77
78 /**
79 * 信息增益
80 *
81 * @param infoMap 因素(Outlook,Temperature,Humidity,Wind)对应的结果
82 * @param dataTable 输入的数据表
83 * @param type 因素中的类型(Outlook{Sunny,Overcast,Rain})
84 * @param entropyS 总的熵值
85 * @param totalSize 总的样本数
86 * @return 信息增益
87 */
88 public double gain(Multimap<T, CountMap<K>> infoMap,
89 Map<K, List<T>> dataTable,
90 final String type,
91 double entropyS,
92 final int totalSize) {
93 //去重
94 Set<T> subTypes = Sets.newHashSet();
95 subTypes.addAll(dataTable.get(type));
96 /*计算*/
97 for (T subType : subTypes) {
98 Collection<CountMap<K>> countMaps = infoMap.get(subType);
99 double subSize = 0;
100 double positiveExample = 0;
101 double counterExample = 0;
102 for (CountMap<K> countMap : countMaps) {
103 subSize += countMap.getValue();
104 if (positiveExampleType.equals(countMap.getKey()))
105 positiveExample = countMap.getValue();
106 else
107 counterExample = countMap.getValue();
108 }
109 entropyS -= (subSize / totalSize) * entropy(positiveExample, counterExample);
110 }
111 return entropyS;
112 }
113
114 /**
115 * 计算
116 *
117 * @param dataTable 数据表
118 * @param types 因素列表{Outlook,Temperature,Humidity,Wind}
119 * @param resultType 结果(PlayTennis)
120 * @return 返回信息增益集合
121 */
122 public Map<String, Double> calculate(Map<K, List<T>> dataTable, List<K> types, K resultType) {
123
124 Map<String, Double> answer = Maps.newHashMap();
125 List<T> results = dataTable.get(resultType);
126 int totalSize = results.size();
127 int positiveExample = 0;
128 int counterExample = 0;
129 double entropyS = 0d;
130 for (T ExampleType : results) {
131 if (positiveExampleType.equals(ExampleType)) {
132 ++positiveExample;
133 continue;
134 }
135 ++counterExample;
136 }
137 /*计算总的熵*/
138 entropyS = entropy(positiveExample, counterExample);
139
140 Multimap<T, CountMap<K>> infoMap;
141 for (K type : types) {
142 infoMap = merge(dataTable.get(type), results);
143 double _gain = gain(infoMap, dataTable, (String) type, entropyS, totalSize);
144 answer.put((String) type, _gain);
145 }
146 return answer;
147 }
148
149 } 1package com.qunar.data.tree;
2
3 import com.google.common.collect.Lists;
4 import com.google.common.collect.Maps;
5
6 import java.util.*;
7
8 /**
9 * *********************************************************
10 * <p/>
11 * Author: XiJun.Gong
12 * Date: 2016-09-02 16:43
13 * Version: default 1.0.0
14 * Class description:
15 * <p/>
16 * *********************************************************
17 */
18 public class Main {
19
20 public static void main(String args[]) {
21
22 Scanner scanner = new Scanner(System.in);
23 while (scanner.hasNext()) {
24 DecisionTree<String, String> dt = new DecisionTree();
25 Map<String, List<String>> dataTable = Maps.newHashMap();
26 /*Map<String, List<String>> dataTable = Maps.newHashMap();*/
27 List<String> types = Lists.newArrayList();
28 String resultType;
29 int factorSize = scanner.nextInt();
30 int demoSize = scanner.nextInt();
31 String type;
32
33 for (int i = 0; i < factorSize; i++) {
34 List<String> demos = Lists.newArrayList();
35 type = scanner.next();
36 for (int j = 0; j < demoSize; j++) {
37 demos.add(scanner.next());
38 }
39 dataTable.put(type, demos);
40 }
41 for (int i = 1; i < factorSize; i++) {
42 types.add(scanner.next());
43 }
44 resultType = scanner.next();
45 Map<String, Double> ans = dt.calculate(dataTable, types, resultType);
46 List<Map.Entry<String, Double>> list = new ArrayList<Map.Entry<String, Double>>(ans.entrySet());
47 Collections.sort(list, new Comparator<Map.Entry<String, Double>>() {
48
49
50 @Override
51 public int compare(Map.Entry<String, Double> o1, Map.Entry<String, Double> o2) {
52 return (o2.getValue() > o1.getValue() ? 1 : -1);
53 }
54 });
55
56 for (Map.Entry<String, Double> iterator : list) {
57 System.out.println(iterator.getKey() + "= " + iterator.getValue());
58 }
59 }
60 }
61
62 }
63 /**
64 *使用举例:*
65 5 14
66 Outlook Sunny Sunny Overcast Rain Rain Rain Overcast Sunny Sunny Rain Sunny Overcast Overcast Rain
67 Temperature Hot Hot Hot Mild Cool Cool Cool Mild Cool Mild Mild Mild Hot Mild
68 Humidity High High High High Normal Normal Normal High Normal Normal Normal High Normal High
69 Wind Weak Strong Weak Weak Weak Strong Strong Weak Weak Weak Strong Strong Weak Strong
70 PlayTennis No No Yes Yes Yes No Yes No Yes Yes Yes Yes Yes No
71 Outlook Temperature Humidity Wind PlayTennis
72 */
结果:
Outlook= 0.2467498197744391
Humidity= 0.15183550136234136
Wind= 0.04812703040826927
Temperature= 0.029222565658954647