为了配合斯坦福cs224n课程,最近开始看tensorflow文档。 今天找了一个博客用tensorflow把逻辑回归又实现了一下。
博客地址:https://segmentfault.com/a/1190000009954640
但是因为anaconda的python是3.7的还无法安装tensorflow(我的环境一直有点奇怪,如果有高人看到麻烦评论以下,我建了一个3.6的环境,也安装了tensorflow,但是在anaconda里面一直找不到安装的包),就先用pycharm编译好了再发上来。 运行结果我做了比对,降低一个量级后,分类准确率确实有下降。
引入数据
'''
使用TensorFlow对逻辑回归进行复现,
数据集为coursera ex2data1.txt,通过成绩预测学生是否会被录取
'''
import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
df = pd.read_csv('ex2data1.txt', header=None)
train_data = df.values
# 前两列为输入X,第三列为输出
train_X = train_data[:, :-1]
train_y = train_data[:, -1:]
feature_num = len(train_X[0])
sample_num = len(train_X)
# print("Size of train_X: {}x{}".format(sample_num, feature_num))
# print("Size of train_y: {}x{}".format(len(train_y), len(train_y[0])))
使用 TensorFlow 定义两个变量用来存放我们的训练用数据。
placeholder为占位符
# 模型设计
X = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
需要训练的参数W和b
W = tf.Variable(tf.zeros([feature_num, 1]))
b = tf.Variable([-.9])
表达损失函数是分三步进行的:
先分别将求和内的两部分表示出来,
再将它们加和并和外面的常数m进行运算,
最后对这个向量进行求和,便得到了损失函数的值。
db = tf.matmul(X, tf.reshape(W, [-1, 1])) + b
hyp = tf.sigmoid(db)
cost0 = y * tf.log(hyp)
cost1 = (1 - y) * tf.log(1 - hyp)
cost = (cost0 + cost1) / -sample_num
loss = tf.reduce_sum(cost)
定义优化的方法,0.001是学习率
optimizer = tf.train.GradientDescentOptimizer(0.001)
train = optimizer.minimize(loss)
训练模型 定义variable初始化
每运行1000步就输出一次W和b
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
feed_dict =
for step in range(100000):
sess.run(train, feed_dict)
if step % 1000 == 0:
print(step, sess.run(W).flatten(), sess.run(b).flatten())
if __name__ == '__main__':
# logistic_regression(train_X, train_y)
w = [0.04858239, 0.04162483]
b = -5.248103
x1 = train_data[:, 0]
x2 = train_data[:, 1]
y = train_data[:, -1:]
for x1p, x2p, yp in zip(x1, x2, y):
if yp == 0:
plt.scatter(x1p, x2p, marker='x', c='r')
else:
plt.scatter(x1p, x2p, marker='o', c='g')
x = np.linspace(20, 100, 10)
y = []
for i in x:
y.append((i * -w[1] - b) / w[0])
plt.plot(x, y)
plt.show()
领取专属 10元无门槛券
私享最新 技术干货