在代碼中添加以下兩行可以解決:
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
補(bǔ)充:pytorch訓(xùn)練過(guò)程顯存一直增加的問(wèn)題
之前遇到了爆顯存的問(wèn)題,卡了很久,試了很多方法,總算解決了。
總結(jié)下自己試過(guò)的幾種方法:
**1. 使用torch.cuda.empty_cache()
在每一個(gè)訓(xùn)練epoch后都添加這一行代碼,可以讓訓(xùn)練從較低顯存的地方開(kāi)始,但并不適用爆顯存的問(wèn)題,隨著epoch的增加,最大顯存占用仍然會(huì)提示out of memory 。
2.使用torch.backends.cudnn.enabled = True 和 torch.backends.cudnn.benchmark = True
原理不太清楚,用法和1一樣。但是幾乎沒(méi)有效果,直接pass。
3.最重要的:查看自己的forward函數(shù)是否存在泄露。
常需要在forward函數(shù)里調(diào)用其他子函數(shù),這時(shí)候要特別注意:
input盡量不要寫在for循環(huán)里面?。。?/p>
子函數(shù)里如果有append()等函數(shù),一定少用,能不用就不用?。?!
子函數(shù)list一定少用,能不用就不用!??!
總之,子函數(shù)一般也不會(huì)太復(fù)雜,直接寫出來(lái),別各種for,嵌套,變量。!??!
補(bǔ)充:Pytorch顯存不斷增長(zhǎng)問(wèn)題的解決思路
這個(gè)問(wèn)題,我先后遇到過(guò)兩次,每次都異常艱辛的解決了。
在網(wǎng)上,關(guān)于這個(gè)問(wèn)題,你可以找到各種看似不同的解決方案,但是都沒(méi)能解決我的問(wèn)題。所以只能自己摸索,在摸索的過(guò)程中,有了一個(gè)排查問(wèn)題點(diǎn)的思路。
下面舉個(gè)例子說(shuō)一下我的思路。
大體思路
其實(shí)思路很簡(jiǎn)單,就是在代碼的運(yùn)行階段輸出顯存占用量,觀察在哪一塊存在顯存劇烈增加或者顯存異常變化的情況。
但是在這個(gè)過(guò)程中要分級(jí)確認(rèn)問(wèn)題點(diǎn),也即如果存在三個(gè)文件main.py、train.py、model.py。
在此種思路下,應(yīng)該先在main.py中確定問(wèn)題點(diǎn),然后,從main.py中進(jìn)入到train.py中,再次輸出顯存占用量,確定問(wèn)題點(diǎn)在哪。
隨后,再?gòu)膖rain.py中的問(wèn)題點(diǎn),進(jìn)入到model.py中,再次確認(rèn)。
如果還有更深層次的調(diào)用,可以繼續(xù)追溯下去。
具體例子
main.py
def train(model,epochs,data):
for e in range(epochs):
print("1:{}".format(torch.cuda.memory_allocated(0)))
train_epoch(model,data)
print("2:{}".format(torch.cuda.memory_allocated(0)))
eval(model,data)
print("3:{}".format(torch.cuda.memory_allocated(0)))
假設(shè)1與2之間顯存增加極為劇烈,說(shuō)明問(wèn)題出在train_epoch中,進(jìn)一步進(jìn)入到train.py中。
train.py
def train_epoch(model,data):
model.train()
optim=torch.optimizer()
for batch_data in data:
print("1:{}".format(torch.cuda.memory_allocated(0)))
output=model(batch_data)
print("2:{}".format(torch.cuda.memory_allocated(0)))
loss=loss(output,data.target)
print("3:{}".format(torch.cuda.memory_allocated(0)))
optim.zero_grad()
print("4:{}".format(torch.cuda.memory_allocated(0)))
loss.backward()
print("5:{}".format(torch.cuda.memory_allocated(0)))
utils.func(model)
print("6:{}".format(torch.cuda.memory_allocated(0)))
如果在1,2之間,5,6之間同時(shí)出現(xiàn)顯存增加異常的情況。此時(shí)需要使用控制變量法,例如我們先讓5,6之間的代碼失效,然后運(yùn)行,觀察是否仍然存在顯存爆炸。如果沒(méi)有,說(shuō)明問(wèn)題就出在5,6之間下一級(jí)的代碼中。進(jìn)入到下一級(jí)代碼,進(jìn)行調(diào)試:
utils.py
def func(model):
print("1:{}".format(torch.cuda.memory_allocated(0)))
a=f1(model)
print("2:{}".format(torch.cuda.memory_allocated(0)))
b=f2(a)
print("3:{}".format(torch.cuda.memory_allocated(0)))
c=f3(b)
print("4:{}".format(torch.cuda.memory_allocated(0)))
d=f4(c)
print("5:{}".format(torch.cuda.memory_allocated(0)))
此時(shí)我們?cè)僬故玖硪环N調(diào)試思路,先注釋第5行之后的代碼,觀察顯存是否存在先訓(xùn)爆炸,如果沒(méi)有,則注釋掉第7行之后的,直至確定哪一行的代碼出現(xiàn)導(dǎo)致了顯存爆炸。假設(shè)第9行起作用后,代碼出現(xiàn)顯存爆炸,說(shuō)明問(wèn)題出在第九行,顯存爆炸的問(wèn)題鎖定。
幾種導(dǎo)致顯存爆炸的情況
pytorch的hook機(jī)制可能導(dǎo)致,顯存爆炸,hook函數(shù)取出某一層的輸入輸出跟權(quán)重后,不可進(jìn)行存儲(chǔ),修改等操作,這會(huì)造成hook不能回收,進(jìn)而導(dǎo)致取出的輸入輸出權(quán)重都可能不被pytorch回收,所以模型的負(fù)擔(dān)越來(lái)也大,最終導(dǎo)致顯存爆炸。
這種情況是我第二次遇到顯存爆炸查出來(lái)的,非常讓人匪夷所思。在如下代碼中,p.sub_(torch.mm(k, torch.t(k)) / (alpha + torch.mm(r, k))),導(dǎo)致了顯存爆炸,這個(gè)問(wèn)題點(diǎn)就是通過(guò)上面的方法確定的。
這個(gè)P是一個(gè)矩陣,在使用p.sub_的方式更新P的時(shí)候,導(dǎo)致了顯存爆炸。
將這行代碼修改為p=p-(torch.mm(k, torch.t(k)) / (alpha + torch.mm(r, k))),顯存爆炸的問(wèn)題解決。
def pro_weight(p, x, w, alpha=1.0, cnn=True, stride=1):
if cnn:
_, _, H, W = x.shape
F, _, HH, WW = w.shape
S = stride # stride
Ho = int(1 + (H - HH) / S)
Wo = int(1 + (W - WW) / S)
for i in range(Ho):
for j in range(Wo):
# N*C*HH*WW, C*HH*WW = N*C*HH*WW, sum -> N*1
r = x[:, :, i * S: i * S + HH, j * S: j * S + WW].contiguous().view(1, -1)
# r = r[:, range(r.shape[1] - 1, -1, -1)]
k = torch.mm(p, torch.t(r))
p.sub_(torch.mm(k, torch.t(k)) / (alpha + torch.mm(r, k)))
w.grad.data = torch.mm(w.grad.data.view(F, -1), torch.t(p.data)).view_as(w)
else:
r = x
k = torch.mm(p, torch.t(r))
p.sub_(torch.mm(k, torch.t(k)) / (alpha + torch.mm(r, k)))
w.grad.data = torch.mm(w.grad.data, torch.t(p.data))
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教。
您可能感興趣的文章:- Python深度學(xué)習(xí)之使用Pytorch搭建ShuffleNetv2
- win10系統(tǒng)配置GPU版本Pytorch的詳細(xì)教程
- 淺談pytorch中的nn.Sequential(*net[3: 5])是啥意思
- pytorch visdom安裝開(kāi)啟及使用方法
- PyTorch CUDA環(huán)境配置及安裝的步驟(圖文教程)
- pytorch中的nn.ZeroPad2d()零填充函數(shù)實(shí)例詳解
- 使用pytorch實(shí)現(xiàn)線性回歸
- pytorch實(shí)現(xiàn)線性回歸以及多元回歸
- Pytorch 使用tensor特定條件判斷索引
- 在Windows下安裝配置CPU版的PyTorch的方法
- PyTorch兩種安裝方法
- PyTorch的Debug指南