torch.where() 用于將兩個(gè)broadcastable的tensor組合成新的tensor,類似于c++中的三元操作符“?:”
區(qū)別于python numpy中的where()直接可以找到特定條件元素的index
想要實(shí)現(xiàn)numpy中where()的功能,可以借助nonzero()
對(duì)應(yīng)numpy中的where()操作效果:
補(bǔ)充:Pytorch torch.Tensor.detach()方法的用法及修改指定模塊權(quán)重的方法
detach的中文意思是分離,官方解釋是返回一個(gè)新的Tensor,從當(dāng)前的計(jì)算圖中分離出來(lái)
需要注意的是,返回的Tensor和原Tensor共享相同的存儲(chǔ)空間,但是返回的 Tensor 永遠(yuǎn)不會(huì)需要梯度
import torch as t a = t.ones(10,) b = a.detach() print(b) tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
–假如A網(wǎng)絡(luò)輸出了一個(gè)Tensor類型的變量a, a要作為輸入傳入到B網(wǎng)絡(luò)中,如果我想通過(guò)損失函數(shù)反向傳播修改B網(wǎng)絡(luò)的參數(shù),但是不想修改A網(wǎng)絡(luò)的參數(shù),這個(gè)時(shí)候就可以使用detcah()方法
a = A(input) a = detach() b = B(a) loss = criterion(b, target) loss.backward()
import torch as t x = t.ones(1, requires_grad=True) x.requires_grad #True y = t.ones(1, requires_grad=True) y.requires_grad #True x = x.detach() #分離之后 x.requires_grad #False y = x+y #tensor([2.]) y.requires_grad #我還是True y.retain_grad() #y不是葉子張量,要加上這一行 z = t.pow(y, 2) z.backward() #反向傳播 y.grad #tensor([4.]) x.grad #None
以上代碼就說(shuō)明了反向傳播到y(tǒng)就結(jié)束了,沒(méi)有到達(dá)x,所以x的grad屬性為None
–假如A網(wǎng)絡(luò)輸出了一個(gè)Tensor類型的變量a, a要作為輸入傳入到B網(wǎng)絡(luò)中,如果我想通過(guò)損失函數(shù)反向傳播修改A網(wǎng)絡(luò)的參數(shù),但是不想修改B網(wǎng)絡(luò)的參數(shù),這個(gè)時(shí)候又應(yīng)該怎么辦了?
這時(shí)可以使用Tensor.requires_grad屬性,只需要將requires_grad修改為False即可.
for param in B.parameters(): param.requires_grad = False a = A(input) b = B(a) loss = criterion(b, target) loss.backward()
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教。
標(biāo)簽:股票 駐馬店 呼和浩特 江蘇 湖州 衡水 中山 畢節(jié)
巨人網(wǎng)絡(luò)通訊聲明:本文標(biāo)題《Pytorch 使用tensor特定條件判斷索引》,本文關(guān)鍵詞 Pytorch,使用,tensor,特定條件,;如發(fā)現(xiàn)本文內(nèi)容存在版權(quán)問(wèn)題,煩請(qǐng)?zhí)峁┫嚓P(guān)信息告之我們,我們將及時(shí)溝通與處理。本站內(nèi)容系統(tǒng)采集于網(wǎng)絡(luò),涉及言論、版權(quán)與本站無(wú)關(guān)。