Pytorch-自动微分模块

49739c720cb4452c9336253d032fc756.gif

🥇接下来我们进入到Pytorch的自动微分模块torch.autograd~

自动微分模块是PyTorch中用于实现张量自动求导的模块。PyTorch通过torch.autograd模块提供了自动微分的功能,这对于深度学习和优化问题至关重要,因为它可以自动计算梯度,无需手动编写求导代码。torch.autograd模块的一些关键组成部分:

  1. 函数的反向传播torch.autograd.function 包含了一系列用于定义自定义操作的函数,这些操作可以在反向传播时自动计算梯度。
  2. 计算图的反向传播torch.autograd.functional 提供了一种构建计算图并自动进行反向传播的方式,这类似于其他框架中的符号式自动微分。
  3. 数值梯度检查torch.autograd.gradcheck 用于检查数值梯度与自动微分得到的梯度是否一致,这是确保正确性的一个有用工具。
  4. 错误检测模式torch.autograd.anomaly_mode 在自动求导时检测错误产生路径,有助于调试。
  5. 梯度模式设置torch.autograd.grad_mode 允许用户设置是否需要梯度,例如在模型评估时通常不需要计算梯度。
  6. 求导方法:PyTorch提供backward()torch.autograd.grad()两种求梯度的方法。backward()会将梯度填充到叶子节点的.grad字段,而torch.autograd.grad()直接返回梯度结果。
  7. requires_grad属性:在创建张量时,可以通过设置requires_grad=True来指定该张量是否需要进行梯度计算。这样在执行操作时,PyTorch会自动跟踪这些张量的计算过程,以便后续进行梯度计算。

梯度基本计算

def func1():
    x = torch.tensor(10, requires_grad=True, dtype=torch.float64)
    f = x ** 2 +10
    # 自动微分求导
    f.backward()   # 反向求导
    # backward 函数计算的梯度值会存储在张量的 grad 变量中
    print(x.grad)
def func2():
    x = torch.tensor([10, 20, 30, 40], requires_grad=True, dtype=torch.float64)
    # 变量经过中间计算
    f1 = x ** 2 + 10
    
    # f2 = f1.mean()  # 平均损失,相当于每个值/4
    f2 = f1.sum()  # 求和损失,相当于每个值*1
    f2.backward()
    print(x.grad)
def func3():
    x1 = torch.tensor(10, requires_grad=True, dtype=torch.float64)
    x2 = torch.tensor(20, requires_grad=True, dtype=torch.float64)
    y = x1 ** 2 + x2 ** 2 + x1 * x2
    y = y.sum()
    y.backward()
    print(x1.grad, x2.grad)

def func4():
    x1 = torch.tensor([10, 20], requires_grad=True, dtype=torch.float64)
    x2 = torch.tensor([30, 40], requires_grad=True, dtype=torch.float64)

    y = x1 ** 2 + x2 ** 2 + x1 * x2
    y = y.sum()
    y.backward()
    print(x1.grad,x2.grad)

func1func2,它们分别处理标量张量和向量张量的梯度计算。

  • func1中,首先创建了一个标量张量x,并设置requires_grad=True以启用自动微分。然后计算f = x ** 2 + 10,接着使用f.backward()进行反向求导。最后,通过打印x.grad输出梯度值。
  • func2中,首先创建了一个向量张量x,并设置requires_grad=True以启用自动微分。然后计算f1 = x ** 2 + 10,接着使用f1.sum()对向量张量进行求和操作,得到一个标量张量f2。最后,使用f2.backward()进行反向求导。
  • func3func4分别求多个标量和向量的情况,与上面相似。

控制梯度计算

我们可以通过一些方法使 requires_grad=True 的张量在某些时候计算时不进行梯度计算。 

  1. 第一种方式是使用torch.no_grad()上下文管理器,在这个上下文中进行的所有操作都不会计算梯度。
  2. 第二种方式是通过装饰器@torch.no_grad()来装饰一个函数,使得这个函数中的所有操作都不会计算梯度。
  3. 第三种方式是通过torch.set_grad_enabled(False)来全局关闭梯度计算功能,之后的所有操作都不会计算梯度,直到下一次再次调用此方法torch.set_grad_enabled(True)开启梯度计算功能。
x = torch.tensor(10, requires_grad=True, dtype=torch.float64)
print(x.requires_grad)

# 第一种方式: 对代码进行装饰
with torch.no_grad():
    y = x ** 2
print(y.requires_grad)

# 第二种方式: 对函数进行装饰
@torch.no_grad()
def my_func(x):
    return x ** 2
print(my_func(x).requires_grad)


# 第三种方式
torch.set_grad_enabled(False)
y = x ** 2
print(y.requires_grad)

默认张量的 grad 属性会累计历史梯度值,如果需要重复计算每次的梯度,就需要手动清除。

x = torch.tensor([10, 20, 30, 40], requires_grad=True, dtype=torch.float64)

for _ in range(3):

    f1 = x ** 2 + 20
    f2 = f1.mean()

    if x.grad is not None:
        x.grad.data.zero_()   # 本身来改动

    f2.backward()
    print(x.grad)

x.grad不是x,因为x是一个tensor张量,而x.grad是x的梯度。在PyTorch中,张量的梯度是通过自动求导机制计算得到的,而不是直接等于张量本身。

梯度下降优化最优解

x = torch.tensor(10, requires_grad=True, dtype=torch.float64)

for _ in range(5000):

     
    f = x ** 2

    # 梯度清零
    if x.grad is not None:
        x.grad.data.zero_()

    # 反向传播计算梯度
    f.backward()

    # 更新参数
    x.data = x.data - 0.001 * x.grad

    print('%.10f' % x.data)

更新参数相当于通过学习率对当前数值进行迭代。

f.backward()是PyTorch中自动梯度计算的函数,用于计算张量`f`关于其所有可学习参数的梯度。在这个例子中,`f`是一个标量张量,它只有一个可学习参数`x`。当调用f.backward()`时,PyTorch会自动计算`f`关于`x`的梯度,并将结果存储在`x.grad`中。这样,我们就可以使用这个梯度来更新`x`的值,以便最小化损失函数`f`。

梯度计算注意

当对设置 requires_grad=True 的张量使用 numpy 函数进行转换时, 会出现如下报错:

Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.

此时, 需要先使用 detach 函数将张量进行分离, 再使用 numpy 函数。detach 之后会产生一个新的张量, 新的张量作为叶子结点,并且该张量和原来的张量共享数据, 但是分离后的张量不需要计算梯度。

import torch

def func1():

    x = torch.tensor([10, 20], requires_grad=True, dtype=torch.float64)

    # Can't call numpy() on Tensor that requires grad. Use tensor.detach().numpy() instead.
    # print(x.numpy())  # 错
    print(x.detach().numpy())  


def func2():

    x1 = torch.tensor([10, 20], requires_grad=True, dtype=torch.float64)

    # x2 作为叶子结点
    x2 = x1.detach()

    # 两个张量的值一样: 140421811165776 140421811165776
    print(id(x1.data), id(x2.data))
    x2.data = torch.tensor([100, 200])
    print(x1)
    print(x2)

    # x2 不会自动计算梯度: False
    print(x2.requires_grad)

7017d1cccb2c45cd845fefae64ed1947.gif

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/557433.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

视觉位置识别与多模态导航规划

前言 机器人感知决策是机器人移动的前提,机器人需要对周围环境实现理解,而周围环境通常由静态环境与动态环境构成。机器人在初始状态或者重启时需要确定当前所处的位置,然后根据用户的指令或意图,开展相应移动或抓取操作。通过视觉…

Mamba:使用选择性状态空间的线性时间序列建模

本文主要是关于mamba论文的详解~ 论文:Mamba: Linear-Time Sequence Modeling with Selective State Spaces 论文地址:https://arxiv.org/ftp/arxiv/papers/2312/2312.00752.pdf 代码:state-spaces/mamba (github.com) Demo:state-spaces (St…

Java 算法篇-深入了解 BF 与 KMP 算法

🔥博客主页: 【小扳_-CSDN博客】 ❤感谢大家点赞👍收藏⭐评论✍ 文章目录 1.0 BF 算法概述 1.1 BF 算法实际使用 2.0 KMP 算法概述 2.1 KMP 算法实际使用 2.2 相比于 BF 算法实现,KMP 算法的重要思想 2.3 为什么要这样设计&#x…

C++面向对象程序设计-北京大学-郭炜【课程笔记(六)】

C面向对象程序设计-北京大学-郭炜【课程笔记&#xff08;六&#xff09;】 1、可变长数组类的实现2、流插入运算符和流提取运算符的重载2.1、对形如cout << 5 ; 单个"<<"进行重载2.2、对形如cout << 5 << “this” ;连续多个"<<&…

地埋电缆故障检测方法有哪些?地埋电缆故障检测费用是多少?

地埋电缆故障检测方法主要涵盖脉冲反射法、桥接法、高压闪络法和声波定位法等多种方法。选择适当的方法取决于故障类型、电缆类型和实际现场条件。至于地埋电缆故障检测费用则受到多个因素的影响&#xff0c;包括故障类型、检测方法的复杂性、检测设备的先进程度以及所处地区的…

【强化学习的数学原理-赵世钰】课程笔记(十)Actor-Critic 方法

目录 一.最简单的 actor-critic&#xff08;QAC&#xff09;&#xff1a;The simplest actor-critic (QAC) 二.Advantage actor-critic (A2C) 三.Off-policy actor-critic 方法 四. Deterministic actor critic(DPG) Actor-Critic 方法把基于 value 的方法&#xff0c;特别…

删除顺序表中所有值为X的元素(顺序表,单链表)

目录 时间复杂度为O(1)(顺序表)&#xff1a;代码实现&#xff1a; 运行结果&#xff1a; 时间复杂度为O(n)(顺序表)&#xff1a;代码实现&#xff1a; 运行结果&#xff1a; 单链表&#xff1a;时间复杂度o&#xff08;n&#xff09;:代码实现&#xff1a; 时间复杂度为O(1…

调研-转换zpl为png

文章目录 前言ZPLZPL相关转换的网站一、labelary常用功能 二、labelzoom三、https://www.htmltozpl.com/docs/demo/html-to-zpl四、 开源仓库&#xff1a;JSZPL五、 开源仓库&#xff1a;BinaryKits.Zpl六 redhawk其他相关概述Lodop 处理zpl 前言 为了解决ZPL指令转换为png&am…

软件需求开发和管理过程性指导文件

1. 目的 2. 适用范围 3. 参考文件 4. 术语和缩写 5. 需求获取的方式 5.1. 与用户交谈向用户提问题 5.1.1. 访谈重点注意事项 5.1.2. 访谈指南 5.2. 参观用户的工作流程 5.3. 向用户群体发调查问卷 5.4. 已有软件系统调研 5.5. 资料收集 5.6. 原型系统调研 5.6.1. …

Cesium中实现镜头光晕

镜头光晕 镜头光晕 (Lens Flares) 是模拟相机镜头内的折射光线的效果&#xff0c;主要作用就是让太阳光/其他光源更加真实&#xff0c;和为您的场景多增添一些气氛。 Cesium 中实现 其实 Cesium 里面也是有实现一个镜头光晕效果的&#xff0c;添加方式如下&#xff0c;只是效…

Leetcode - 周赛393

目录 一&#xff0c;3114. 替换字符可以得到的最晚时间 二&#xff0c;3115. 素数的最大距离 三&#xff0c;3116. 单面值组合的第 K 小金额 四&#xff0c; 3117. 划分数组得到最小的值之和 一&#xff0c;3114. 替换字符可以得到的最晚时间 本题是一道模拟题&#xff0c;…

泛型的初步认识(1)

前言~&#x1f973;&#x1f389;&#x1f389;&#x1f389; hellohello~&#xff0c;大家好&#x1f495;&#x1f495;&#xff0c;这里是E绵绵呀✋✋ &#xff0c;如果觉得这篇文章还不错的话还请点赞❤️❤️收藏&#x1f49e; &#x1f49e; 关注&#x1f4a5;&#x…

标准版uni-app移动端页面添加/开发操作流程

页面简介 uni-app项目中&#xff0c;一个页面就是一个符合Vue SFC规范的.vue文件或.nvue文件。 .vue页面和.nvue页面&#xff0c;均全平台支持&#xff0c;差异在于当uni-app发行到App平台时&#xff0c;.vue文件会使用webview进行渲染&#xff0c;.nvue会使用原生进行渲染。…

HCIP的学习(10)

OSPF不规则区域划分 区域划分 非骨干与骨干区域直接相连骨干区域唯一 限制规则&#xff1a; 非骨干区域之间不允许直接相互发布区域间路由信息OSPF区域水平分割&#xff1a;从非骨干区域收到的路由信息&#xff0c;ABR设备能接收到不能使用&#xff08;从某区域传出的路由&…

全新升级轻舟知识付费系统引流变现至上利器

知识付费系统&#xff1a;引流变现至上利器 本系统参考各大主流知识付费系统&#xff0c;汇总取其精华&#xff0c;自主研发&#xff0c;正版授权系统。 我们给你搭建搭建一个独立运营的知识付费平台&#xff0c;搭建好之后&#xff0c;你可以自由的运营管理。网站里面的名称…

【机器学习】分类与预测算法评价的方式介绍

一、引言 1、机器学习分类与预测算法的重要性 在数据驱动的时代&#xff0c;机器学习已经成为了处理和分析大规模数据的关键工具。分类与预测作为机器学习的两大核心任务&#xff0c;广泛应用于各个领域&#xff0c;如金融、医疗、电商等。分类算法能够对数据进行有效归类&…

web前端网络相关知识

一、OSI 7层参考模型 1.物理层&#xff08;光纤、电缆等物理介质&#xff09; 传播比特流&#xff08;bit&#xff09; 01010101的形式 2.数据链路层&#xff08;交换机&#xff0c;mac地址&#xff09; 将比特流组合成字节&#xff0c;组合成帧&#xff0c;用mac地址访问&…

bugku-web-login2

这里提示是命令执行 抓包发现有五个报文 其中login.php中有base64加密语句 $sql"SELECT username,password FROM admin WHERE username".$username.""; if (!empty($row) && $row[password]md5($password)){ } 这里得到SQL语句的组成&#xff0c;…

CRMEB PRO安装系统配置清单

统在安装完成之后&#xff0c;需要对系统进行一系列的配置&#xff0c;才能正常使用全部的功能&#xff0c;以下是官方整理的配置清单

Xinstall带你进入一键通过URL打开App的新时代

在移动互联网时代&#xff0c;App已经成为我们日常生活中不可或缺的一部分。然而&#xff0c;在使用App的过程中&#xff0c;我们常常会遇到一些烦恼。比如&#xff0c;当我们通过一个网页链接想要打开对应的App时&#xff0c;往往需要先复制链接&#xff0c;然后在App中粘贴&a…
最新文章