前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >朴素贝叶斯分类器(离散型)算法实现(一)

朴素贝叶斯分类器(离散型)算法实现(一)

作者头像
Gxjun
发布2018-03-27 10:40:41
1K0
发布2018-03-27 10:40:41
举报
文章被收录于专栏:ml

1. 贝叶斯定理:    

   (1)   P(A^B) = P(A|B)P(B) = P(B|A)P(A) 

 由(1)得

   P(A|B) = P(B|A)*P(A)/[p(B)]

贝叶斯在最基本题型:

假定一个场景,在一所高中男女比例为4:6, 留长头发的有男学生有女学生, 我们设定女生都留长发 , 而男生中有10%的留长发,90%留短发.那么如果我们看到远处一个长发背影?请问是一只男学生的概率?

  分析:

    P(男|长发) = P(长发|男)*P(男)/[p(长发)] 

        = (1/10)*(4/10)/[(6+4*(1/10))/10]

        =1/16 =0.0625

   P(女|长发) =P(长发|女)*P(女)/[p(长发)]

                  =1*(6/10)/[(6+4*(1/10))/10]

                 =30/32 =15/16

再举一个列子:

某个医院早上收了六个门诊病人,如下表。

  症状  职业   疾病   打喷嚏 护士   感冒    打喷嚏 农夫   过敏    头痛  建筑工人 脑震荡    头痛  建筑工人 感冒    打喷嚏 教师   感冒    头痛  教师   脑震荡

现在又来了第七个病人,是一个打喷嚏的建筑工人。请问他患上感冒的概率有多大?(来源: http://www.ruanyifeng.com/blog/2013/12/naive_bayes_classifier.html)

Java代码实现:

代码语言:javascript
复制
 1 /**
 2  * *********************************************************
 3  * <p/>
 4  * Author:     XiJun.Gong
 5  * Date:       2016-08-31 20:36
 6  * Version:    default 1.0.0
 7  * Class description:
 8  * <p>特征库</p>
 9  * <p/>
10  * *********************************************************
11  */
12 
13 public class FeaturePoint {
14 
15     private String key;
16     private double p;
17 
18     public FeaturePoint(String key) {
19         this(key, 1);
20     }
21 
22     public FeaturePoint(String key, double p) {
23         this.key = key;
24         this.p = p;
25     }
26 
27     public String getKey() {
28         return key;
29     }
30 
31     public void setKey(String key) {
32         this.key = key;
33     }
34 
35     public double getP() {
36         return p;
37     }
38 
39     public void setP(double p) {
40         this.p = p;
41     }
42 }
代码语言:javascript
复制
 1 import com.google.common.collect.ArrayListMultimap;
 2 import com.google.common.collect.Multimap;
 3 
 4 import java.util.Collection;
 5 import java.util.List;
 6 
 7 /**
 8  * *********************************************************
 9  * <p/>
10  * Author:     XiJun.Gong
11  * Date:       2016-08-31 15:48
12  * Version:    default 1.0.0
13  * Class description:
14  * <p/>
15  * *********************************************************
16  */
17 
18 public class Bayes {
19     private static Multimap<String, FeaturePoint> map = ArrayListMultimap.create();
20 
21     /*喂数据*/
22     public void input(List<String> labels) {
23 
24         for (String key : labels) {
25             Collection<FeaturePoint> features = map.get(key);
26             for (String value : labels) {
27                 if (features == null || features.size() < 1) {
28                     map.put(key, new FeaturePoint(value));
29                     continue;
30                 }
31                 boolean tag = false;
32                 for (FeaturePoint feature : features) {
33                     if (feature.getKey().equals(value)) {
34                         Double num = feature.getP() + 1;
35                         map.remove(key, feature);
36                         map.put(key, new FeaturePoint(value, num));
37                         tag = true;
38                         break;
39                     }
40                 }
41                 if (!tag)
42                     map.put(key, new FeaturePoint(value));
43             }
44         }
45     }
46 
47     /*构造模型*/
48     public void excute(List<String> labels) {
49         //   excute(labels, null);
50     }
51 
52     /*构造模型*/
53     public Double excute(final List<String> labels, final String judge, Integer dataSize) {
54 
55         Double denominator = 1d;    //分母
56         Double numerator = 1d;      //分子
57         Double coughNum = 0d;
58        /*选择相关性分子*/
59         Collection<FeaturePoint> featurePoints = map.get(judge);
60         for (FeaturePoint featurePoint : featurePoints) {
61             if (judge.equals(featurePoint.getKey())) {
62                 coughNum = featurePoint.getP();
63                 denominator *= (featurePoint.getP() / dataSize);
64                 break;
65             }
66         }
67 
68         Integer size = featurePoints.size() - 1; //容量
69         for (String label : labels) {
70             for (FeaturePoint featurePoint : featurePoints) {
71                 if (label.equals(featurePoint.getKey())) {
72                     denominator *= (featurePoint.getP() / coughNum);
73                     for (FeaturePoint feature : map.get(label)) {
74                         if (label.equals(feature.getKey())) {
75                             numerator *= (feature.getP() / dataSize);
76                         }
77                     }
78                 }
79             }
80         }
81 
82         return denominator / numerator;
83     }
84 
85 }
代码语言:javascript
复制
 1 import com.google.common.collect.Lists;
 2 
 3 import java.util.List;
 4 import java.util.Scanner;
 5 
 6 /**
 7  * *********************************************************
 8  * <p/>
 9  * Author:     XiJun.Gong
10  * Date:       2016-09-01 14:58
11  * Version:    default 1.0.0
12  * Class description:
13  * <p/>
14  * *********************************************************
15  */
16 public class Main {
17 
18     public static void main(String args[]) {
19 
20         Scanner scanner = new Scanner(System.in);
21         Integer size = scanner.nextInt();
22         Integer row = scanner.nextInt();
23         Bayes bayes = new Bayes();
24         while (scanner.hasNext()) {
25 
26             for (int ro = 0; ro < row; ro++) {
27                 List<String> list = Lists.newArrayList();
28                 for (int i = 0; i < size; i++) {
29                     list.add(scanner.next());
30                 }
31                 bayes.input(list);
32             }
33             List<String> list = Lists.newArrayList();
34             for (int i = 0; i < size - 1; i++) {
35                 list.add(scanner.next());
36             }
37             String judge = scanner.next();
38             System.out.println(bayes.excute(list, judge,row));
39             ;
40         }
41 
42     }
43 }

pom.xml包

代码语言:javascript
复制
    <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>3.8.1</version>
            <scope>test</scope>
        </dependency>
        <dependency>
            <groupId>com.google.guava</groupId>
            <artifactId>guava</artifactId>
            <version>18.0</version>
        </dependency>

结果:

代码语言:javascript
复制
1 3 6
2 打喷嚏 护士   感冒 
3   打喷嚏 农夫   过敏 
4   头痛  建筑工人 脑震荡 
5   头痛  建筑工人 感冒 
6   打喷嚏 教师   感冒 
7   头痛  教师   脑震荡
8 打喷嚏  建筑工人 感冒
9 0.6666666666666666 
代码语言:javascript
复制
1 3 6
2   打喷嚏 护士   感冒 
3   打喷嚏 农夫   过敏 
4   头痛  建筑工人 脑震荡 
5   头痛  建筑工人 感冒 
6   打喷嚏 教师   感冒 
7   头痛  教师   脑震荡
8 打喷嚏 护士   感冒 
9 1.3333333333333333
代码语言:javascript
复制
 1 2 50
 2 男  长发
 3 男  短发
 4 男  短发
 5 男  短发
 6 男  短发
 7 男  短发
 8 男  短发
 9 男  短发
10 男  短发
11 男  短发
12 男  短发
13 男  短发
14 男  短发
15 男  短发
16 男  短发
17 男  短发
18 男  短发
19 男  短发
20 男  短发
21 男  长发
22 女  长发
23 女  长发
24 女  长发
25 女  长发
26 女  长发
27 女  长发
28 女  长发
29 女  长发
30 女  长发
31 女  长发
32 女  长发
33 女  长发
34 女  长发
35 女  长发
36 女  长发
37 女  长发
38 女  长发
39 女  长发
40 女  长发
41 女  长发
42 女  长发
43 女  长发
44 女  长发
45 女  长发
46 女  长发
47 女  长发
48 女  长发
49 女  长发
50 女  长发
51 女  长发
52             
53 长发 男
54 0.06250000000000001
代码语言:javascript
复制
 1 2 50
 2 男  长发
 3 男  短发
 4 男  短发
 5 男  短发
 6 男  短发
 7 男  短发
 8 男  短发
 9 男  短发
10 男  短发
11 男  短发
12 男  短发
13 男  短发
14 男  短发
15 男  短发
16 男  短发
17 男  短发
18 男  短发
19 男  短发
20 男  短发
21 男  长发
22 女  长发
23 女  长发
24 女  长发
25 女  长发
26 女  长发
27 女  长发
28 女  长发
29 女  长发
30 女  长发
31 女  长发
32 女  长发
33 女  长发
34 女  长发
35 女  长发
36 女  长发
37 女  长发
38 女  长发
39 女  长发
40 女  长发
41 女  长发
42 女  长发
43 女  长发
44 女  长发
45 女  长发
46 女  长发
47 女  长发
48 女  长发
49 女  长发
50 女  长发
51 女  长发
52 长发 女
53 0.9375

 利用贝叶斯进行分类?

代码语言:javascript
复制
  1 import com.google.common.collect.ArrayListMultimap;
  2 import com.google.common.collect.Lists;
  3 import com.google.common.collect.Multimap;
  4 
  5 import java.util.Collection;
  6 import java.util.List;
  7 
  8 /**
  9  * *********************************************************
 10  * <p/>
 11  * Author:     XiJun.Gong
 12  * Date:       2016-08-31 15:48
 13  * Version:    default 1.0.0
 14  * Class description:
 15  * <p/>
 16  * *********************************************************
 17  */
 18 
 19 public class Bayes {
 20     private Multimap<String, FeaturePoint> map = null;
 21     private List<String> featurePool = null;
 22 
 23     public Bayes() {
 24         map = ArrayListMultimap.create();
 25         featurePool = Lists.newArrayList();
 26     }
 27 
 28     public void add(String label) {
 29         featurePool.add(label);
 30     }
 31 
 32     /*喂数据*/
 33     public void input(List<String> labels) {
 34 
 35         for (String key : labels) {
 36             Collection<FeaturePoint> features = map.get(key);
 37             for (String value : labels) {
 38                 if (features == null || features.size() < 1) {
 39                     map.put(key, new FeaturePoint(value));
 40                     continue;
 41                 }
 42                 boolean tag = false;
 43                 for (FeaturePoint feature : features) {
 44                     if (feature.getKey().equals(value)) {
 45                         Double num = feature.getP() + 1;
 46                         map.remove(key, feature);
 47                         map.put(key, new FeaturePoint(value, num));
 48                         tag = true;
 49                         break;
 50                     }
 51                 }
 52                 if (!tag)
 53                     map.put(key, new FeaturePoint(value));
 54             }
 55         }
 56     }
 57 
 58     /*最符合那个分类*/
 59     public String excute(List<String> labels, Integer dataSize) {
 60 
 61         Double max = -999999999d;
 62         String max_obj = null;
 63         List<Double> ans = Lists.newArrayList();
 64         for (String label : featurePool) {
 65             Double p = excute(labels, label, dataSize);
 66             ans.add(p);
 67             if (max < p) {
 68                 max_obj = label;
 69                 max = p;
 70             }
 71         }
 72         return max_obj;
 73     }
 74 
 75     /*构造模型*/
 76     public Double excute(final List<String> labels, final String judge, Integer dataSize) {
 77 
 78         Double denominator = 1d;    //分母
 79         Double numerator = 1d;      //分子
 80         Double coughNum = 0d;
 81        /*选择相关性分子*/
 82         Collection<FeaturePoint> featurePoints = map.get(judge);
 83         for (FeaturePoint featurePoint : featurePoints) {
 84             if (judge.equals(featurePoint.getKey())) {
 85                 coughNum = featurePoint.getP();
 86                 denominator *= (featurePoint.getP() / dataSize);
 87                 break;
 88             }
 89         }
 90        /*O(n^3)*/
 91         Integer size = featurePoints.size() - 1; //容量
 92         for (String label : labels) {
 93             for (FeaturePoint featurePoint : featurePoints) {
 94                 if (label.equals(featurePoint.getKey())) {
 95                     denominator *= (featurePoint.getP() / coughNum);
 96                     for (FeaturePoint feature : map.get(label)) {
 97                         if (label.equals(feature.getKey())) {
 98                             numerator *= (feature.getP() / dataSize);
 99                         }
100                     }
101                 }
102             }
103         }
104 
105         return denominator / numerator;
106     }
107 
108 }
代码语言:javascript
复制
 1 import com.google.common.collect.Lists;
 2 
 3 import java.util.List;
 4 import java.util.Scanner;
 5 
 6 /**
 7  * *********************************************************
 8  * <p/>
 9  * Author:     XiJun.Gong
10  * Date:       2016-09-01 14:58
11  * Version:    default 1.0.0
12  * Class description:
13  * <p/>
14  * *********************************************************
15  */
16 public class Main {
17 
18     public static void main(String args[]) {
19 
20         Scanner scanner = new Scanner(System.in);
21         Integer size = scanner.nextInt();
22         Integer row = scanner.nextInt();
23         Integer category = scanner.nextInt();
24         while (scanner.hasNext()) {
25             Bayes bayes = new Bayes();
26             for (int ro = 0; ro < row; ro++) {
27                 List<String> list = Lists.newArrayList();
28                 for (int i = 0; i < size; i++) {
29                     list.add(scanner.next());
30                 }
31                 bayes.input(list);
32             }
33             List<String> list = Lists.newArrayList();
34             for (int i = 0; i < size - 1; i++) {
35                 list.add(scanner.next());
36             }
37             for (int i = 0; i < category; i++) {
38                 bayes.add(scanner.next());
39             }
40             System.out.println(bayes.excute(list, row));
41         }
42 
43     }
44 }

结果:

代码语言:javascript
复制
 1 2 50 2
 2 男  长发
 3 男  短发
 4 男  短发
 5 男  短发
 6 男  短发
 7 男  短发
 8 男  短发
 9 男  短发
10 男  短发
11 男  短发
12 男  短发
13 男  短发
14 男  短发
15 男  短发
16 男  短发
17 男  短发
18 男  短发
19 男  短发
20 男  短发
21 男  长发
22 女  长发
23 女  长发
24 女  长发
25 女  长发
26 女  长发
27 女  长发
28 女  长发
29 女  长发
30 女  长发
31 女  长发
32 女  长发
33 女  长发
34 女  长发
35 女  长发
36 女  长发
37 女  长发
38 女  长发
39 女  长发
40 女  长发
41 女  长发
42 女  长发
43 女  长发
44 女  长发
45 女  长发
46 女  长发
47 女  长发
48 女  长发
49 女  长发
50 女  长发
51 女  长发
52 长发
53 男 女
54 女
代码语言:javascript
复制
 1 2 50 2
 2 男  长发
 3 男  短发
 4 男  短发
 5 男  短发
 6 男  短发
 7 男  短发
 8 男  短发
 9 男  短发
10 男  短发
11 男  短发
12 男  短发
13 男  短发
14 男  短发
15 男  短发
16 男  短发
17 男  短发
18 男  短发
19 男  短发
20 男  短发
21 男  长发
22 女  长发
23 女  长发
24 女  长发
25 女  长发
26 女  长发
27 女  长发
28 女  长发
29 女  长发
30 女  长发
31 女  长发
32 女  长发
33 女  长发
34 女  长发
35 女  长发
36 女  长发
37 女  长发
38 女  长发
39 女  长发
40 女  长发
41 女  长发
42 女  长发
43 女  长发
44 女  长发
45 女  长发
46 女  长发
47 女  长发
48 女  长发
49 女  长发
50 女  长发
51 女  长发
52 短发
53 男 女
54 男
本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2016-09-01 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档