#squeeze 函数:从数组的形状中删除单维度条目,即把shape中为1的维度去掉
#unsqueeze() 是squeeze()的反向操作,增加一个维度,该维度维数为1,可以指定添加的维度。例如unsqueeze(a,1)表示在1这个维度进行添加
import torch a=torch.rand(2,3,1) print(torch.unsqueeze(a,2).size())#torch.Size([2, 3, 1, 1]) print(a.size()) #torch.Size([2, 3, 1]) print(a.squeeze().size()) #torch.Size([2, 3]) print(a.squeeze(0).size()) #torch.Size([2, 3, 1]) print(a.squeeze(-1).size()) #torch.Size([2, 3]) print(a.size()) #torch.Size([2, 3, 1]) print(a.squeeze(-2).size()) #torch.Size([2, 3, 1]) print(a.squeeze(-3).size()) #torch.Size([2, 3, 1]) print(a.squeeze(1).size()) #torch.Size([2, 3, 1]) print(a.squeeze(2).size()) #torch.Size([2, 3]) print(a.squeeze(3).size()) #RuntimeError: Dimension out of range (expected to be in range of [-3, 2], but got 3) print(a.unsqueeze().size()) #TypeError: unsqueeze() missing 1 required positional arguments: "dim" print(a.unsqueeze(-3).size()) #torch.Size([2, 1, 3, 1]) print(a.unsqueeze(-2).size()) #torch.Size([2, 3, 1, 1]) print(a.unsqueeze(-1).size()) #torch.Size([2, 3, 1, 1]) print(a.unsqueeze(0).size()) #torch.Size([1, 2, 3, 1]) print(a.unsqueeze(1).size()) #torch.Size([2, 1, 3, 1]) print(a.unsqueeze(2).size()) #torch.Size([2, 3, 1, 1]) print(a.unsqueeze(3).size()) #torch.Size([2, 3, 1, 1]) print(torch.unsqueeze(a,3)) b=torch.rand(2,1,3,1) print(b.squeeze().size()) #torch.Size([2, 3])
补充:pytorch中unsqueeze()、squeeze()、expand()、repeat()、view()、和cat()函数的总结
学习Bert模型的时候,需要使用到pytorch来进行tensor的操作,由于对pytorch和tensor不熟悉,就把pytorch中常用的、有关tensor操作的unsqueeze()、squeeze()、expand()、view()、cat()和repeat()等函数做一个总结,加深记忆。
1、unsqueeze()和squeeze()
torch.unsqueeze(input, dim,out=None) → Tensor
unsqueeze()的作用是用来增加给定tensor的维度的,unsqueeze(dim)就是在维度序号为dim的地方给tensor增加一维。例如:维度为torch.Size([768])的tensor要怎样才能变为torch.Size([1, 768, 1])呢?就可以用到unsqueeze(),直接上代码:
a=torch.randn(768) print(a.shape) # torch.Size([768]) a=a.unsqueeze(0) print(a.shape) #torch.Size([1, 768]) a = a.unsqueeze(2) print(a.shape) #torch.Size([1, 768, 1])
也可以直接使用链式编程:
a=torch.randn(768) print(a.shape) # torch.Size([768]) a=a.unsqueeze(1).unsqueeze(0) print(a.shape) #torch.Size([1, 768, 1])
tensor经过unsqueeze()处理之后,总数据量不变;维度的扩展类似于list不变直接在外面加几层[]括号。
torch.squeeze(input, dim=None, out=None) → Tensor
squeeze()的作用就是压缩维度,直接把维度为1的维给去掉。形式上表现为,去掉一层[]括号。
同时,输出的张量与原张量共享内存,如果改变其中的一个,另一个也会改变。
a=torch.randn(2,1,768) print(a) print(a.shape) #torch.Size([2, 1, 768]) a=a.squeeze() print(a) print(a.shape) #torch.Size([2, 768])
图片中的维度信息就不一样,红框中的括号层数不同。
注意的是:squeeze()只能压缩维度为1的维;其他大小的维不起作用。
a=torch.randn(2,768) print(a.shape) #torch.Size([2, 768]) a=a.squeeze() print(a.shape) #torch.Size([2, 768])
2、expand()
这个函数的作用就是对指定的维度进行数值大小的改变。只能改变维大小为1的维,否则就会报错。不改变的维可以传入-1或者原来的数值。
torch.Tensor.expand(*sizes) → Tensor
返回张量的一个新视图,可以将张量的单个维度扩大为更大的尺寸。
a=torch.randn(1,1,3,768) print(a) print(a.shape) #torch.Size([1, 1, 3, 768]) b=a.expand(2,-1,-1,-1) print(b) print(b.shape) #torch.Size([2, 1, 3, 768]) c=a.expand(2,1,3,768) print(c.shape) #torch.Size([2, 1, 3, 768])
可以看到b和c的维度是一样的
第0维由1变为2,可以看到就直接把原来的tensor在该维度上复制了一下。
3、repeat()
repeat(*sizes)
沿着指定的维度,对原来的tensor进行数据复制。这个函数和expand()还是有点区别的。expand()只能对维度为1的维进行扩大,而repeat()对所有的维度可以随意操作。
a=torch.randn(2,1,768) print(a) print(a.shape) #torch.Size([2, 1, 768]) b=a.repeat(1,2,1) print(b) print(b.shape) #torch.Size([2, 2, 768]) c=a.repeat(3,3,3) print(c) print(c.shape) #torch.Size([6, 3, 2304])
b表示对a的对应维度进行乘以1,乘以2,乘以1的操作,所以b:torch.Size([2, 1, 768])
c表示对a的对应维度进行乘以3,乘以3,乘以3的操作,所以c:torch.Size([6, 3, 2304])
a:
b
c
4、view()
tensor.view()这个函数有点类似reshape的功能,简单的理解就是:先把一个tensor转换成一个一维的tensor,然后再组合成指定维度的tensor。例如:
word_embedding=torch.randn(16,3,768) print(word_embedding.shape) new_word_embedding=word_embedding.view(8,6,768) print(new_word_embedding.shape)
当然这里指定的维度的乘积一定要和原来的tensor的维度乘积相等,不然会报错的。16*3*768=8*6*768
另外当我们需要改变一个tensor的维度的时候,知道关键的维度,有不想手动的去计算其他的维度值,就可以使用view(-1),pytorch就会自动帮你计算出来。
word_embedding=torch.randn(16,3,768) print(word_embedding.shape) new_word_embedding=word_embedding.view(-1) print(new_word_embedding.shape) new_word_embedding=word_embedding.view(1,-1) print(new_word_embedding.shape) new_word_embedding=word_embedding.view(-1,768) print(new_word_embedding.shape)
结果如下:使用-1以后,就会自动得到其他维度维。
需要特别注意的是:view(-1,-1)这样的用法就会出错。也就是说view()函数中只能出现单个-1。
5、cat()
cat(seq,dim,out=None),表示把两个或者多个tensor拼接起来。
其中 seq表示要连接的两个序列,以元组的形式给出,例如:seq=(a,b), a,b 为两个可以连接的序列
dim 表示以哪个维度连接,dim=0, 横向连接 dim=1,纵向连接
a=torch.randn(4,3) b=torch.randn(4,3) c=torch.cat((a,b),dim=0)#横向拼接,增加行 torch.Size([8, 3]) print(c.shape) d=torch.cat((a,b),dim=1)#纵向拼接,增加列 torch.Size([4, 6]) print(d.shape)
还有一种写法:cat(list,dim,out=None),其中list中的元素为tensor。
tensors=[] for i in range(10): tensors.append(torch.randn(4,3)) a=torch.cat(tensors,dim=0) #torch.Size([40, 3]) print(a.shape) b=torch.cat(tensors,dim=1) #torch.Size([4, 30]) print(b.shape)
结果:
torch.Size([40, 3]) torch.Size([4, 30])
以上为个人经验,希望能给大家一个参考,也希望大家多多支持。如有错误或未考虑完全的地方,望不吝赐教。
免责声明:本站资源来自互联网收集,仅供用于学习和交流,请遵循相关法律法规,本站一切资源不代表本站立场,如有侵权、后门、不妥请联系本站删除!
稳了!魔兽国服回归的3条重磅消息!官宣时间再确认!
昨天有一位朋友在大神群里分享,自己亚服账号被封号之后居然弹出了国服的封号信息对话框。
这里面让他访问的是一个国服的战网网址,com.cn和后面的zh都非常明白地表明这就是国服战网。
而他在复制这个网址并且进行登录之后,确实是网易的网址,也就是我们熟悉的停服之后国服发布的暴雪游戏产品运营到期开放退款的说明。这是一件比较奇怪的事情,因为以前都没有出现这样的情况,现在突然提示跳转到国服战网的网址,是不是说明了简体中文客户端已经开始进行更新了呢?
更新日志
- 小骆驼-《草原狼2(蓝光CD)》[原抓WAV+CUE]
- 群星《欢迎来到我身边 电影原声专辑》[320K/MP3][105.02MB]
- 群星《欢迎来到我身边 电影原声专辑》[FLAC/分轨][480.9MB]
- 雷婷《梦里蓝天HQⅡ》 2023头版限量编号低速原抓[WAV+CUE][463M]
- 群星《2024好听新歌42》AI调整音效【WAV分轨】
- 王思雨-《思念陪着鸿雁飞》WAV
- 王思雨《喜马拉雅HQ》头版限量编号[WAV+CUE]
- 李健《无时无刻》[WAV+CUE][590M]
- 陈奕迅《酝酿》[WAV分轨][502M]
- 卓依婷《化蝶》2CD[WAV+CUE][1.1G]
- 群星《吉他王(黑胶CD)》[WAV+CUE]
- 齐秦《穿乐(穿越)》[WAV+CUE]
- 发烧珍品《数位CD音响测试-动向效果(九)》【WAV+CUE】
- 邝美云《邝美云精装歌集》[DSF][1.6G]
- 吕方《爱一回伤一回》[WAV+CUE][454M]