# Use linear model to model this data. from sklearn.linear_model import LinearRegression import numpy as np lr=LinearRegression() lr.fit(pga.distance[:,np.newaxis],pga['accuracy']) # Another way is using pga[['distance']] theta0=lr.intercept_ theta1=lr.coef_ print(theta0) print(theta1) #calculating cost-function for each theta1 #計算平均累積誤差 def cost(x,y,theta0,theta1): J=0 for i in range(len(x)): mse=(x[i]*theta1+theta0-y[i])**2 J+=mse return J/(2*len(x)) theta0=100 theta1s = np.linspace(-3,2,197) costs=[] for theta1 in theta1s: costs.append(cost(pga['distance'],pga['accuracy'],theta0,theta1)) plt.plot(theta1s,costs) plt.show() print(pga.distance) #調(diào)整theta def partial_cost_theta0(x,y,theta0,theta1): #我們的模型是線性擬合函數(shù)時:y=theta1*x + theta0,,而不是sigmoid函數(shù),當非線性時我們可以用sigmoid #直接多整個x series操作,,省的一個一個計算,,最終求sum 再平均 h=theta1*x+theta0 diff=(h-y) partial=diff.sum()/len(diff) return partial partial0=partial_cost_theta0(pga.distance,pga.accuracy,1,1) def partial_cost_theta1(x,y,theta0,theta1): #我們的模型是線性擬合函數(shù):y=theta1*x + theta0,,而不是sigmoid函數(shù),當非線性時我們可以用sigmoid h=theta1*x+theta0 diff=(h-y)*x partial=diff.sum()/len(diff) return partial partial1=partial_cost_theta1(pga.distance,pga.accuracy,0,5) print(partial0) print(partial1) def gradient_descent(x,y,alpha=0.1,theta0=0,theta1=0): #設(shè)置默認參數(shù) #計算成本 #調(diào)整權(quán)值 #計算錯誤代價,,判斷是否收斂或者達到最大迭代次數(shù) most_iterations=1000 convergence_thres=0.000001 c=cost(x,y,theta0,theta1) costs=[c] cost_pre=c+convergence_thres+1.0 counter=0 while( (np.abs(c-cost_pre)>convergence_thres) & (counter<most_iterations) ): update0=alpha*partial_cost_theta0(x,y,theta0,theta1) update1=alpha*partial_cost_theta1(x,y,theta0,theta1) theta0-=update0 theta1-=update1 cost_pre=c c=cost(x,y,theta0,theta1) costs.append(c) counter+=1 return {'theta0': theta0, 'theta1': theta1, "costs": costs} print("Theta1 =", gradient_descent(pga.distance, pga.accuracy)['theta1']) costs=gradient_descent(pga.distance,pga.accuracy,alpha=.01)['cost'] print(gradient_descent(pga.distance, pga.accuracy,alpha=.01)['theta1']) plt.scatter(range(len(costs)),costs) plt.show() 預覽
數(shù)據(jù)集 : 復制下面數(shù)據(jù),,保存為: pga.csv distance,accuracy 290.3,59.5 302.1,54.7 287.1,62.4 282.7,65.4 299.1,52.8 300.2,51.1 300.9,58.3 279.5,73.9 287.8,67.6 284.7,67.2 296.7,60 283.3,59.4 284,72.2 292,62.1 282.6,66.5 287.9,60.9 279.2,67.3 291.7,64.8 289.9,58.1 289.8,61.7 298.8,56.4 280.8,60.5 294.9,57.5 287.5,61.8 282.7,56 277.7,72.5 270.5,71.7 285.2,66 315.1,55.2 281.9,67.6 293.3,58.2 286,59.9 285.6,58.2 289.9,65.7 277.5,59 293.6,56.8 301.1,65.4 300.8,63.4 287.4,67.3 281.8,72.6 277.4,63.1 279.1,66.5 287.4,66.4 280.9,62.3 287.8,57.2 261.4,69.2 272.6,69.4 291.3,65.3 294.2,52.8 285.5,49 287.9,61.1 282.2,65.6 301.3,58.2 276.2,61.7 281.6,68.1 275.5,61.2 309.7,53.1 287.7,56.4 291.6,56.9 284.1,65 299.6,57.5 282.7,60 271.5,72 292.1,58.2 295,59.4 274.9,69 273.6,68.7 299.9,60.1 279.9,74 289.9,66 283.6,59.8 310.3,52.4 291.7,65.6 284.2,63.2 295,53.5 298.6,55.1 297.4,60.4 299.7,67.7 284.4,69.7 286.4,72.4 285.9,66.9 297.6,54.3 272.5,62 277,66.2 287.6,60.9 280.4,69.4 280,63.7 295.4,52.8 274.4,68.8 286.5,73.1 287.7,65.2 291.5,65.9 279,69.4 299,65.2 290.1,69.1 288.9,67.9 288.8,68.2 283.2,61 293.2,58.4 285.3,67.3 284.1,65.7 281.4,67.7 286.1,61.4 284.9,62.3 284.8,68.1 296,62 282.9,71.8 280.9,67.8 291.2,62 292.8,62.2 291,61.9 285.7,62.4 283.9,62.9 298.4,61.5 285.1,65.3 286.1,60.1 283.1,65.4 289.4,58.3 284.6,70.7 296.6,62.3 295.9,64.9 295.2,62.8 293.9,54.5 275,65.5 286.8,69.5 291.1,64.4 284.8,62.5 283.7,59.5 295.4,66.9 291.8,62.7 274.9,72.3 302.9,61.2 272.1,80.4 274.9,74.9 296.3,59.4 286.2,58.8 294.2,63.3 284.1,66.5 299.2,62.4 275.4,71 273.2,70.9 281.6,65.9 295.7,55.3 287.1,56.8 287.7,66.9 296.7,53.7 282.2,64.2 291.7,65.6 281.6,73.4 311,56.2 278.6,64.7 288,65.7 276.7,72.1 292,62 286.4,69.9 292.7,65.7 294.2,62.9 278.6,59.6 283.1,69.2 284.1,66 278.6,73.6 291.1,60.4 294.6,59.4 274.3,70.5 274,57.1 283.8,62.7 272.7,66.9 303.2,58.3 282,70.4 281.9,61 287,59.9 293.5,63.8 283.6,56.3 296.9,55.3 290.9,58.2 303,58.1 292.8,61.1 281.1,65 293,61.1 284,66.5 279.8,66.7 292.9,65.4 284,66.9 282,64.5 280.6,64 287.7,63.4 287.7,63.4 298.3,59.5 299.6,53.4 291.3,62.5 295.2,61.4 288,62.4 297.8,59.5 286,62.6 285.3,66.2 286.9,63.4 275.1,73.7
|