项目中有时候需要用到对数据进行关联分析,比如分析一个小商店中顾客购买习惯.
1 package com.data.algorithm;
2
3 import com.google.common.base.Splitter;
4 import com.google.common.collect.Lists;
5 import com.google.common.collect.Maps;
6 import org.slf4j.Logger;
7 import org.slf4j.LoggerFactory;
8
9 import java.io.BufferedReader;
10 import java.io.FileInputStream;
11 import java.io.IOException;
12 import java.io.InputStreamReader;
13 import java.util.*;
14
15 /**
16 * *********************************************************
17 * <p/>
18 * Author: XiJun.Gong
19 * Date: 2017-01-20 15:06
20 * Version: default 1.0.0
21 * Class description:
22 * <p/>
23 * *********************************************************
24 */
25
26 class EOC {
27
28 private static final Logger logger = LoggerFactory.getLogger(EOC.class);
29 private Map<String, Integer> fmap; //forward map
30 private Map<Integer, String> bmap; //backward map
31 private List<Map<String, Integer>> elements = null;
32
33 private Integer maxDimension;
34
35 public EOC(final String pathFile, String separatSeq) {
36
37 BufferedReader bufferedReader = null;
38 try {
39 this.fmap = Maps.newHashMap();
40 this.bmap = Maps.newHashMap();
41 this.elements = Lists.newArrayList();
42 maxDimension = 0;
43 bufferedReader = new BufferedReader(
44 new InputStreamReader(
45 new FileInputStream(pathFile), "UTF-8"));
46 String _line = null;
47 Integer keyValue = null, mapIndex = 0;
48 while ((_line = bufferedReader.readLine()) != null) {
49 Map<String, Integer> lineMap = Maps.newHashMap();
50 if (_line.trim().length() > 1) {
51 if (separatSeq.trim().length() < 1) {
52 separatSeq = ",";
53 }
54 for (String word : Splitter.on(separatSeq).split(_line)) {
55 word = word.trim();
56 if (null == (keyValue = fmap.get(word))) {
57 keyValue = mapIndex++;
58 }
59 fmap.put(word, keyValue);
60 bmap.put(keyValue, word);
61 lineMap.put(word, keyValue);
62 }
63 if (maxDimension < lineMap.size())
64 maxDimension = lineMap.size();
65 elements.add(lineMap);
66 }
67 }
68 } catch (Exception e) {
69 logger.error("读取文件出错 , 错误原因:{}", e);
70 } finally {
71 if (bufferedReader != null) {
72 try {
73 bufferedReader.close();
74 } catch (IOException e) {
75 logger.error("bufferedReader , 错误原因:{}", e);
76 }
77 }
78 }
79 }
80
81 public Integer getMaxDimension() {
82 return maxDimension;
83 }
84
85 public float getRateOfSet(Collection<Integer> elementChild) {
86 float rateCnt = 0f;
87 int allSize = 1;
88 for (Map<String, Integer> eMap : elements) {
89 boolean flag = true;
90 for (Integer element : elementChild) {
91 if (null == eMap.get(bmap.get(element))) {
92 flag = false;
93 break;
94 }
95 }
96 if (flag) rateCnt += 1;
97 }
98 return rateCnt / ((allSize = elements.size()) > 1 ? (float) allSize : 1.0f);
99 }
100
101 public Set<Integer> getElements() {
102
103 return new HashSet<Integer>(fmap.values());
104 }
105
106 public Integer queryByKey(String key) {
107 return fmap.get(key);
108 }
109
110 public String queryByValue(Integer value) {
111 return bmap.get(value);
112 }
113 }
114
115 public class Apriori {
116 private static final Logger logger = LoggerFactory.getLogger(Apriori.class);
117 private EOC eoc = null;
118 private Integer maxDimension;
119 private final float exp = 1e-4f;
120
121 public Apriori(final String pathFile, String separatSeq, Integer maxDimension) {
122 this(pathFile, separatSeq);
123 this.maxDimension = maxDimension;
124 }
125
126 public Apriori(final String pathFile, String separatSeq) {
127 this.eoc = new EOC(pathFile, separatSeq);
128 this.maxDimension = this.eoc.getMaxDimension();
129 }
130
131 public void work(float confidenceLevel) {
132 List<Set<Integer>> listElement = null;
133 ArrayList<Set<Integer>> middleWareElement = null;
134 Map<Set<Integer>, Float> maps = null;
135 listElement = Lists.newArrayList();
136 for (Integer element : this.eoc.getElements()) {
137 Set<Integer> set = new HashSet<Integer>();
138 set.add(element);
139 listElement.add(set);
140 }
141 maps = Maps.newHashMap();
142 middleWareElement = Lists.newArrayList();
143 for (int i = 1; i < this.maxDimension; i++) {
144 for (Set<Integer> tmpSet : listElement) {
145 float rate = eoc.getRateOfSet(tmpSet);
146 if (confidenceLevel - exp <= rate)
147 maps.put(tmpSet, rate);
148 }
149 System.out.println("+++++++++++第 " + i + " 维度关联数据+++++++++++");
150 output(maps);
151 listElement.clear();
152 middleWareElement.addAll(maps.keySet());
153 maps.clear();
154 for (int j = 0; j < middleWareElement.size(); j++) {
155 Set<Integer> tmpSet = middleWareElement.get(j);
156 for (int k = j + 1; k < middleWareElement.size(); k++) {
157 Set<Integer> setChild = middleWareElement.get(k);
158 for (Integer label : setChild) {
159 if (!tmpSet.contains(label)) {
160 Set<Integer> newElement = new HashSet<Integer>(tmpSet);
161 newElement.add(label);
162 if (!listElement.contains(newElement)) {
163 listElement.add(newElement);
164 break;
165 }
166 }
167 }
168 }
169 }
170 middleWareElement.clear();
171 }
172 }
173
174 public void output(Map<Set<Integer>, Float> maps) {
175 for (Map.Entry<Set<Integer>, Float> iter : maps.entrySet()) {
176 for (Integer integer : iter.getKey()) {
177 System.out.print(eoc.queryByValue(integer) + " ");
178 }
179 System.out.println(iter.getValue()*100+"%");
180 }
181 }
182 }
1 package com.data.algorithm;
2
3
4 /**
5 * *********************************************************
6 * <p/>
7 * Author: XiJun.Gong
8 * Date: 2017-01-17 17:57
9 * Version: default 1.0.0
10 * Class description:
11 * <p/>
12 * *********************************************************
13 */
14 public class Main {
15 public static void main(String args[]) {
16 Apriori apriori = new Apriori("/home/com/src/main/java/com/qunar/data/algorithm/demo.data", ",");
17 apriori.work(0.5f);
18 }
19 }
1 +++++++++++第 1 维度关联数据+++++++++++
2 苹果 50.0%
3 西红柿 75.0%
4 香蕉 75.0%
5 矿泉水 75.0%
6 +++++++++++第 2 维度关联数据+++++++++++
7 苹果 西红柿 50.0%
8 西红柿 香蕉 50.0%
9 西红柿 矿泉水 50.0%
10 香蕉 矿泉水 75.0%
11 +++++++++++第 3 维度关联数据+++++++++++
12 西红柿 香蕉 矿泉水 50.0%