位置: IT常识 - 正文

pytorch如何搭建一个最简单的模型,(pytorch如何搭建神经网络)

编辑:rootadmin
pytorch如何搭建一个最简单的模型, 一、搭建模型的步骤

推荐整理分享pytorch如何搭建一个最简单的模型,(pytorch如何搭建神经网络),希望有所帮助,仅作参考,欢迎阅读内容。

文章相关热门搜索词:pytorch搭建gan,pytorch环境搭建mac,pytorch如何搭建神经网络,pytorch怎么装,pytorch创建模型,pytorch 搭建简单网络,pytorch搭建gan,pytorch搭建gan,内容如对您有帮助,希望把文章链接给更多的朋友!

在 PyTorch 中,可以使用 torch.nn 模块来搭建深度学习模型。具体步骤如下:

定义一个继承自 torch.nn.Module 的类,这个类将作为我们自己定义的模型。

在类的构造函数 __init__() 中定义网络的各个层和参数。可以使用 torch.nn 模块中的各种层,如 Conv2d、BatchNorm2d、Linear 等。

在类中定义前向传播函数 forward(),实现模型的具体计算过程。

将模型部署到 GPU 上,可以使用 model.to(device) 将模型移动到指定的 GPU 设备上。

二、简单的例子pytorch如何搭建一个最简单的模型,(pytorch如何搭建神经网络)

下面是一个简单的例子,演示了如何使用 torch.nn 模块搭建一个简单的全连接神经网络:

import torch.nn as nnclass MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.fc1 = nn.Linear(784, 512) self.relu = nn.ReLU() self.fc2 = nn.Linear(512, 10) def forward(self, x): x = x.view(-1, 784) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x

MyNet 的神经网络类,它继承自 torch.nn.Module。在构造函数 __init__() 中定义了两个全连接层,一个 ReLU 激活函数,并将它们作为网络的成员变量。在前向传播函数 forward() 中,首先将输入的图像数据 x 压成一维向量,然后依次经过两个全连接层和一个 ReLU 激活函数,最终得到模型的输出结果。

在模型训练之前,需要将模型部署到 GPU 上,可以使用以下代码将模型移动到 GPU 上:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = MyNet().to(device)如何将loss函数添加到模型中去呢?

在 PyTorch 中,通常将损失函数作为单独的对象来定义,并在训练过程中手动计算和优化损失。为了将损失函数添加到模型中,需要在模型类中添加一个成员变量,然后在前向传播函数中计算损失。

下面是一个例子,演示了如何在模型中添加交叉熵损失函数:

import torch.nn as nnclass MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.fc1 = nn.Linear(784, 512) self.relu = nn.ReLU() self.fc2 = nn.Linear(512, 10) self.loss_fn = nn.CrossEntropyLoss() def forward(self, x, y): x = x.view(-1, 784) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) loss = self.loss_fn(x, y) return x, loss

在模型类 MyNet 的构造函数中添加了一个成员变量 self.loss_fn,它是交叉熵损失函数。在前向传播函数 forward() 中,传入两个参数 x 和 y,其中 x 是输入图像数据,y 是对应的标签。在函数中先执行正向传播计算,然后计算交叉熵损失,并将损失值作为输出返回。

实际训练代码

在实际训练过程中,首先将模型输出结果 x 和标签 y 传入前向传播函数 forward() 中计算损失,然后使用优化器更新模型的权重和偏置。代码如下:

model = MyNet()optimizer = torch.optim.SGD(model.parameters(), lr=0.01)for inputs, labels in data_loader: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs, loss = model(inputs, labels) loss.backward() optimizer.step()

在上面的代码中,使用随机梯度下降优化器 torch.optim.SGD 来更新模型的参数。在每个批次中,首先将输入数据和标签移动到 GPU 上,然后使用 optimizer.zero_grad() 将梯度清零。接着执行前向传播计算,并得到损失值 loss。最后使用 loss.backward() 计算梯度并执行反向传播,使用 optimizer.step() 更新模型参数。

2023.03.27更新 完整的代码

# -*-coding:utf-8-*-# !/usr/bin/env python# @Time : 2023/3/27 上午11:00# @Author : loveinfall uestc# @File : csdn_test_.py# @Description :import torchimport torch.nn as nnimport torch.utils.data as dataimport cv2####################### model ###########################class MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.fc1 = nn.Linear(784, 512) self.relu = nn.ReLU() self.fc2 = nn.Linear(512, 10) def forward(self, x): x = x.view(-1, 784) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x###################### end ############################################### loss 函数 #############################loss_fn = nn.CrossEntropyLoss()################## end #################################################### dataloader 需要自己构建 ############class image_folder(data.Dataset): def __init__(self): self.image_dirs = []#构造数据读取路径列表 self.label_dirs = [] def __getitem__(self,index): image = cv2.imread(self.image_dirs[index]) label = 'read data'#根据实际情况,写 return image,label def __len__(self): return 'len(data)'train_dataset = image_folder()data_loader = data.DataLoader( train_dataset, batch_size=3, shuffle=True, num_workers=2, pin_memory=True)#################### end ##################################################### train #######################@#####device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = MyNet().to(device)optimizer = torch.optim.SGD(model.parameters(), lr=0.01)for inputs, labels in data_loader: inputs = inputs.to(device) labels = labels.to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_fn(outputs,labels) loss.backward() optimizer.step()
本文链接地址:https://www.jiuchutong.com/zhishi/297834.html 转载请保留说明!

上一篇:【Vue】图片拉近、全屏背景实战经验总结(vue图片点击放大)

下一篇:什么是前后端分离(什么是前后端分离的方式)

  • word文档如何删除空白页(word文档如何删除水印)

    word文档如何删除空白页(word文档如何删除水印)

  • 荣耀x10有密码保险箱功能吗(荣耀x10有密码保护吗)

    荣耀x10有密码保险箱功能吗(荣耀x10有密码保护吗)

  • 一加七和一加七pro对比(一加七和一加七pro续航)

    一加七和一加七pro对比(一加七和一加七pro续航)

  • 笔记本8g和16g区别大吗(笔记本8g和16g哪个好)

    笔记本8g和16g区别大吗(笔记本8g和16g哪个好)

  • 淘宝直播购物车在哪(淘宝直播购物车链接怎么弄)

    淘宝直播购物车在哪(淘宝直播购物车链接怎么弄)

  • 苹果11建议更新吗(苹果11建议更新ios16.2吗)

    苹果11建议更新吗(苹果11建议更新ios16.2吗)

  • 戴尔安全模式按f几(戴尔安全模式按什么键)

    戴尔安全模式按f几(戴尔安全模式按什么键)

  • 淘宝退款时间怎么规定的(淘宝退款怎么改退款路径)

    淘宝退款时间怎么规定的(淘宝退款怎么改退款路径)

  • ipad pro11可以接鼠标吗(ipadpro11能用typec耳机吗)

    ipad pro11可以接鼠标吗(ipadpro11能用typec耳机吗)

  • 网络忙请稍后再拨是啥意思(网络忙请稍后再试 12306)

    网络忙请稍后再拨是啥意思(网络忙请稍后再试 12306)

  • idc的含义(idc是指什么)

    idc的含义(idc是指什么)

  • 昨天看过的抖音怎么找(昨天看过的抖音还能找到吗)

    昨天看过的抖音怎么找(昨天看过的抖音还能找到吗)

  • 华为lonal00什么型号(华为lonal00手机报价)

    华为lonal00什么型号(华为lonal00手机报价)

  • 为什么支付宝显示网络错误(为什么支付宝显示银行反馈此卡不可用)

    为什么支付宝显示网络错误(为什么支付宝显示银行反馈此卡不可用)

  • 手机上的指纹锁怎样设置(手机上的指纹锁怎么搞)

    手机上的指纹锁怎样设置(手机上的指纹锁怎么搞)

  • 机械表时间不准怎么调(机械表时间不准怎么办)

    机械表时间不准怎么调(机械表时间不准怎么办)

  • 腾讯大王卡拼多多免流吗(腾讯大王卡拼多多视频免流吗)

    腾讯大王卡拼多多免流吗(腾讯大王卡拼多多视频免流吗)

  • 电脑怎么锁定当前窗口(电脑怎么锁定当前系统)

    电脑怎么锁定当前窗口(电脑怎么锁定当前系统)

  • 如何修改网页源代码(如何修改网页源代码生效)

    如何修改网页源代码(如何修改网页源代码生效)

  • 苹果无线耳机怎么查序列号(苹果无线耳机怎么连接手机)

    苹果无线耳机怎么查序列号(苹果无线耳机怎么连接手机)

  • 计算机能直接执行的程序是(计算机能直接执行的指令包括哪些)

    计算机能直接执行的程序是(计算机能直接执行的指令包括哪些)

  • 硬盘温度过高(硬盘温度过高怎么解决)

    硬盘温度过高(硬盘温度过高怎么解决)

  • Windows Update出现错误代码80070103怎么办(windows update更新错误)

    Windows Update出现错误代码80070103怎么办(windows update更新错误)

  • vue查询数据el-table不更新数据(vue 查询)

    vue查询数据el-table不更新数据(vue 查询)

  • 未达起征点增值税申报表怎么填
  • 企业其他税负率计算公式?
  • 全国税务师官网报名
  • 公司账户转私人账户要多久时间
  • 印花税计税依据是什么
  • 法人股东分红要交企业所得税吗
  • 其他应收款重分类
  • 单位购牙膏牙刷卫生纸怎么做账
  • 发生额对照表
  • 支票罚金
  • 作废的支票银行怎么处理
  • 会计核算形式
  • 事业单位取暖费什么时候发
  • 超市商品售出可以退货吗
  • 向公司一般户的银行借款怎么做账?
  • 房地产经纪公司经营范围
  • 分公司利润如何分红
  • 中药材免税还能抵扣收购发票
  • 人民法院被收买了怎么办
  • 出租房电费怎么结算
  • 电脑开启语音按什么键
  • 电脑怎么安装双显卡
  • 电脑维修中常用的软件
  • 企业清算期限如何规定
  • mac截图如何保存到照片
  • win10更新21h1后很卡
  • laravel快速入门
  • win10高级功能
  • php ftp功能
  • 原材料运费可以计入制造费用吗
  • 企业所得税的概述
  • 如何搭建chatGPT
  • 请问简单的
  • 设备租赁费属于固定成本吗
  • 资产负债表要点
  • vue项目上线教程
  • 银行汇票未用退回情况说明
  • 制造费用期末怎么结转
  • 前端css要掌握到什么程度
  • vue3和ts
  • 开发票的销售收入,正规的做账怎么做
  • java异常编程题
  • 公司员工抽奖活动
  • 个人向公司借款协议书范本
  • dedecms主页修改
  • 织梦标签教程
  • 财务费用相关指标
  • 上月未结账本月不能结账
  • mysql怎么给字段添加中文备注
  • 银行存款日记账电子表格模板
  • 主营业务成本大于主营业务收入怎么办
  • 租厂房需要办环评注意事项
  • 无息的银行承兑汇票
  • 合并报表时抵消内部交易包含的未实现损益的影响包括
  • 现金日记账金额怎么填写
  • 会计凭证的粘贴顺序
  • 组织机构代码证图片
  • 摊余成本通俗
  • sql联合主键设置外键
  • sql多级汇总
  • linux自动化装机
  • linux怎样浏览文件中的内容
  • windows临时文件在哪里
  • win10 mobile 1709
  • win8.1网络设置
  • win7任务栏变小图标
  • 麒麟linux系统怎么安装软件
  • 我的第二个姐姐用英语怎么说
  • 基于个人同意处理个人信息的个人什么撤回其同意
  • bootstrap 下拉按钮
  • python读json文件和写json文件
  • 编写批处理
  • unity3d documentation
  • JavaScript window.document的属性、方法和事件小结
  • 青岛市税务局内设机构
  • 进口汽车零部件编码查询
  • sp海淘3档到国内什么快递
  • 电子税务局怎么删除办税员
  • 新时代新思想基层医疗宣讲
  • 开票没有0还是o
  • 免责声明:网站部分图片文字素材来源于网络,如有侵权,请及时告知,我们会第一时间删除,谢谢! 邮箱:opceo@qq.com

    鄂ICP备2023003026号

    网站地图: 企业信息 工商信息 财税知识 网络常识 编程技术

    友情链接: 武汉网站建设