参考代码:https://www.cnblogs.com/bsdr/p/5405082.html
"""
我爱薯条
"""
import numpy as np
class Perceptron():
def __init__(self,lr):
self.lr=lr #学习率
def sign(self,func):
if func > 0:
return 1
else:
return -1
def fit(self,x_train,y_train): #x_train是特征,y_train是label
bias=0
is_wrong = False
weight=np.zeros(len(x_train[0]))
while not is_wrong:
wrong_count = 0
for i in range(len(x_train)):
x=x_train[i]
y=y_train[i]
if y*self.sign(np.dot(x,weight)+bias)<=0:
weight=weight+self.lr*np.dot(y[0],x)
bias=bias+self.lr*y
wrong_count += 1
print(weight, bias)
if wrong_count == 0:
is_wrong = True
print("Done~")
return weight,bias
pp=Perceptron(1)
a1=np.array([[3,3],[4,3],[1,1]])
y1=np.array([[1],[1],[-1]])
pp.fit(a1,y1)
[3. 3.] [1]
[2. 2.] [0]
[1. 1.] [-1]
[0. 0.] [-2]
[3. 3.] [-1]
[2. 2.] [-2]
[1. 1.] [-3]
Done~