您的当前位置:首页>全部文章>文章详情

【Python】使用Python实现小批量梯度下降算法的代码逻辑

CrazyPanda发表于:2024-01-22 20:03:43浏览:336次TAG:

让theta=模型参数和max_iters=时期数。对于itr=1,2,3,...,max_iters:对于mini_batch(X_mini,y_mini):

批量X_mini的前向传递:

1、对小批量进行预测

2、使用参数的当前值计算预测误差(J(theta))

后传:计算梯度(theta)=J(theta)wrt theta的偏导数

更新参数:theta=theta–learning_rate*gradient(theta)

Python实现梯度下降算法的代码流程

第一步:导入依赖项,为线性回归生成数据,并可视化生成的数据。以8000个数据示例,每个示例都有2个属性特征。这些数据样本进一步分为训练集(X_train,y_train)和测试集(X_test,y_test),分别有7200和800个样本。

import numpy as np
import matplotlib.pyplot as plt
 
mean=np.array([5.0,6.0])
cov=np.array([[1.0,0.95],[0.95,1.2]])
data=np.random.multivariate_normal(mean,cov,8000)
 
plt.scatter(data[:500,0],data[:500,1],marker='.')
plt.show()
data=np.hstack((np.ones((data.shape[0],1)),data))
split_factor=0.90
split=int(split_factor*data.shape[0])
X_train=data[:split,:-1]
y_train=data[:split,-1].reshape((-1,1))
X_test=data[split:,:-1]
y_test=data[split:,-1].reshape((-1,1))
 
print(& quot Number of examples in training set= % d & quot % (X_train.shape[0]))
print(& quot Number of examples in testing set= % d & quot % (X_test.shape[0]))
opO7p4hNxXmtXwQhPjooZjdt8AgU.jpg

训练集中的示例数=7200测试集中的示例数=800

第二步:

使用小批量梯度下降实现线性回归的代码。gradientDescent()是主要的驱动函数,其他函数是辅助函数:

进行预测——hypothesis()

计算梯度——gradient()

计算误差——cost()

创建小批量——create_mini_batches()

驱动程序函数初始化参数,计算模型的最佳参数集,并返回这些参数以及一个列表,其中包含参数更新时的错误历史记录。

def hypothesis(X,theta):
    return np.dot(X,theta)
 
def gradient(X,y,theta):
    h=hypothesis(X,theta)
    grad=np.dot(X.transpose(),(h-y))
    return grad
 
def cost(X,y,theta):
    h=hypothesis(X,theta)
    J=np.dot((h-y).transpose(),(h-y))
    J/=2
    return J[0]
 
def create_mini_batches(X,y,batch_size):
    mini_batches=[]
    data=np.hstack((X,y))
    np.random.shuffle(data)
    n_minibatches=data.shape[0]//batch_size
    i=0
    for i in range(n_minibatches+1):
        mini_batch=data[i*batch_size:(i+1)*batch_size,:]
        X_mini=mini_batch[:,:-1]
        Y_mini=mini_batch[:,-1].reshape((-1,1))
        mini_batches.append((X_mini,Y_mini))
    if data.shape[0]%batch_size!=0:
       mini_batch=data[i*batch_size:data.shape[0]]
       X_mini=mini_batch[:,:-1]
       Y_mini=mini_batch[:,-1].reshape((-1,1))
       mini_batches.append((X_mini,Y_mini))
    return mini_batches
 
def gradientDescent(X,y,learning_rate=0.001,batch_size=32):
    theta=np.zeros((X.shape[1],1))
    error_list=[]
    max_iters=3
    for itr in range(max_iters):
        mini_batches=create_mini_batches(X,y,batch_size)
        for mini_batch in mini_batches:
            X_mini,y_mini=mini_batch
            theta=theta-learning_rate*gradient(X_mini,y_mini,theta)
            error_list.append(cost(X_mini,y_mini,theta))
    return theta,error_list

调用gradientDescent()函数来计算模型参数(theta)并可视化误差函数的变化。

theta,error_list=gradientDescent(X_train,y_train)
print("Bias=",theta[0])
print("Coefficients=",theta[1:])
 
plt.plot(error_list)
plt.xlabel("Number of iterations")
plt.ylabel("Cost")
plt.show()

偏差=[0.81830471]系数=[[1.04586595]]

opO7p4hNxXmtXwQhPjooZjdt8AgU.jpg

第三步:对测试集进行预测并计算预测中的平均绝对误差。

y_pred=hypothesis(X_test,theta)
plt.scatter(X_test[:,1],y_test[:,],marker='.')
plt.plot(X_test[:,1],y_pred,color='orange')
plt.show()
 
error=np.sum(np.abs(y_test-y_pred)/y_test.shape[0])
print(& quot Mean absolute error=&quot,error)

opO7p4hNxXmtXwQhPjooZjdt8AgU.jpg

平均绝对误差=0.4366644295854125

橙色线代表最终假设函数:theta[0]+theta[1]*X_test[:,1]+theta[2]*X_test[:,2]=0


猜你喜欢

【Python】查看pandas版本的方法
如何查看pandas版本信息,需要具体代码示例Pandas是Python中一个十分受欢迎的数据处理库,广泛应用于数据分析、数据清洗和数据转换等领域。在使用pandas之前,我们通常需要了解当前所安装的pandas版本以确保我们使用的是最新版本或兼容的版本。本文将介绍如何查看pandas版本信息,并提供具体的代码示例。要查看pandas版本信息,我们可以使用pandas库中提供的__version__属性。下面是一段简单的示例代码:import pandas as 
发表于:2024-01-10 浏览:311 TAG:
【Python】使用Pandas重命名列名实现高效数据处理
高效数据处理:使用Pandas修改列名,需要具体代码示例数据处理是数据分析中一个非常重要的环节,而在数据处理过程中,经常需要对数据的列名进行修改。Pandas是一个强大的数据处理库,提供了丰富的方法和函数帮助我们快速高效地处理数据。本文将介绍如何使用Pandas修改列名,并提供具体的代码示例。在实际的数据分析中,原始数据的列名可能存在命名规范不统一、不易理解等问题,这就需要我们根据实际需求对列名进行修改。下面是一个示例数据集,包含三列数据:姓名、年龄和性别。import pandas&
发表于:2024-01-11 浏览:318 TAG:
【Python】Python 入门的60个基础练习
文章目录01-Hello World02-print 函数03-基本运算04-input05-输入输出基础练习06-字符串使用基础07-列表基础08-元组基础09-字典基础10-基本判断11-条件表达式、三元运算符12-判断练习:用户名和密码是否正确13-猜数:基础实现14-成绩分类 115-成绩分类 216-石头剪刀布17-改进的石头剪刀布18-猜数,直到猜对19-猜数,5 次机会20-while 循环,累加至 10021-while-break2
发表于:2023-11-28 浏览:1308 TAG:
【Python】学会应对Python中len函数常见问题和解决方法的技巧
快速掌握Python中len函数的常见问题和解决方法一、引言Python中的len函数是一个常用的内建函数,用来获取容器对象的长度或元素个数。尽管len函数使用简单,但在实际应用时,仍有一些常见问题和解决方法值得我们注意。本文将重点介绍len函数的常见问题和解决方法,并提供具体的代码示例,旨在帮助读者快速掌握和应用。二、常见问题及解决方法问题一:如何获取字符串的长度?解决方法:可以使用len函数获取字符串的长度。下面是一个具体的代码示例:string = "Hell
发表于:2024-01-15 浏览:326 TAG:
【Python】简单指南:逐步安装matplotlib并制作引人注目图表
一步步教你安装matplotlib,轻松绘制漂亮图表,需要具体代码示例在数据分析和数据可视化领域,matplotlib 是一个非常强大的 Python 库。它为我们提供了丰富的绘图功能,可以轻松地创建各种类型的图表。本文将向大家介绍如何安装 matplotlib,并提供一些具体的代码示例,以帮助大家更好地掌握该库。步骤一:安装 Python首先,我们需要安装 Python。在官方网站(https://www.python.org/downloads/)上可以下载最新的 Python 版本。根据你
发表于:2024-01-12 浏览:285 TAG:
【Python】深度剖析len函数的意义与用法
深入解析len函数的含义和用途在许多编程语言中,len函数常常用于获取字符串、列表、元组、字典等数据结构的长度。在本文中,我们将深入解析len函数的含义和用途,并提供具体的代码示例。一、len函数的含义len函数是Python标准库中内置的函数之一,用于返回给定数据结构的长度。具体来说,len函数可以用于返回字符串中字符的数量、列表中元素的数量,以及字典中键值对的数量等。二、len函数的用途获取字符串的长度字符串是一系列字符的集合,而len函数可以帮助我们快速获取字符串的长度。下面是一个示例代码
发表于:2024-01-02 浏览:338 TAG:
【Python】ChatGPT和Python的完美结合:打造多领域聊天机器人
ChatGPT和Python的完美结合:打造多领域聊天机器人引言:随着人工智能技术的快速发展,聊天机器人已经成为了现今社交媒体和客户服务领域中的重要工具。聊天机器人能够通过自然语言处理和机器学习技术,与用户进行自动化的对话,提供信息和解决问题。本文将介绍如何使用OpenAI的ChatGPT模型和Python语言,打造一个多领域的聊天机器人。一、了解ChatGPT模型ChatGPT是OpenAI基于GPT-3模型开发的一个聊天机器人模型。它可以接受一系列指令和问题,并生成连贯的回答。ChatGPT
发表于:2024-01-24 浏览:296 TAG:
【Python】如何用Python编写线性查找算法?
如何用Python编写线性查找算法?线性查找是最简单的搜索算法之一,也被称为顺序搜索。它的原理很简单,就是从头到尾遍历待查找的数据集合,逐个比较查找目标和数据集合中的元素。下面我们将介绍如何使用Python编写线性查找算法,并给出具体的代码示例。算法实现步骤:遍历待查找的数据集合,逐个比较目标和元素。若找到目标,返回元素的索引位置。若遍历完所有元素仍未找到目标,返回-1。代码示例:def linear_search(arr, target):  for i
发表于:2024-01-16 浏览:298 TAG:
【Python】Linux环境中基于Python脚本实现监控网站可用性的技术
Linux环境中基于Python脚本实现监控网站可用性的技术摘要:本文介绍了如何在Linux环境中使用Python脚本来监控网站的可用性。具体包括通过发送HTTP请求和解析响应来检测网站是否可访问,以及如何将该监控脚本配置为定时任务并发送报警邮件。1. 简介随着互联网的发展,网站可用性成为了一个至关重要的指标。如果网站无法正常访问,将会给用户带来巨大不便,甚至会影响到企业的品牌形象和业务运营。因此,及时监控网站的可用性是非常重要的。2. 准备工作为了能够使用Python监控网站可用性,我们首先需
发表于:2024-01-19 浏览:264 TAG:
【Python】使用Python中的len函数统计文本中的单词数量的示例
Python中的len函数应用实例:如何利用它统计文本中的单词数量在Python编程中,len函数是一个非常有用的函数,它用于返回一个对象的长度或元素的个数。在本文中,将介绍如何使用len函数来统计文本中的单词数量,并提供具体的代码示例。在开始编写代码之前,需要先了解一下如何定义一个单词。在本文中,我们将使用空格作为单词的分隔符,也就是说,任何两个空格之间的字符串都被认为是一个单词。下面是一个简单的代码示例,展示了如何使用len函数统计文本中的单词数量:def count_words(
发表于:2024-01-15 浏览:324 TAG: