如何使用Pytorch實(shí)現(xiàn)two-head(多輸出)模型
1. two-head模型定義
先放一張我要實(shí)現(xiàn)的模型結(jié)構(gòu)圖:
如上圖,就是一個(gè)two-head模型,也是一個(gè)但輸入多輸出模型。該模型的特點(diǎn)是輸入一個(gè)x和一個(gè)t,h0和h1中只有一個(gè)會(huì)輸出,所以可能這不算是一個(gè)典型的多輸出模型。
2.實(shí)現(xiàn)所遇到的困難 一開始的想法:
這不是很簡單嘛,做一個(gè)判斷不就完了,t=0時(shí)模型為前半段加h0,t=1時(shí)模型為前半段加h1。但實(shí)現(xiàn)的時(shí)候傻眼了,發(fā)現(xiàn)在真正前向傳播的時(shí)候t是一個(gè)tensor,有0有1,沒法兒進(jìn)行判斷。
靈機(jī)一動(dòng),又生一法:把這個(gè)模型變?yōu)槿齻€(gè)模型,前半段是一個(gè)模型(r),后面的h0和h1分別為另兩個(gè)模型。把數(shù)據(jù)集按t=0和1分開,分別訓(xùn)練兩個(gè)模型:r+h0和r+h1。
但是后來搜如何進(jìn)行模型串聯(lián),發(fā)現(xiàn)極為麻煩。
3.解決方案
后來在pytorch的官方社區(qū)中看到一個(gè)極為簡單的方法:
(1) 按照一般的多輸出模型進(jìn)行實(shí)現(xiàn),代碼如下:
def forward(self, x):
#三層的表示層
x = F.elu(self.fcR1(x))
x = F.elu(self.fcR2(x))
x = F.elu(self.fcR3(x))
#two-head,兩個(gè)head分別進(jìn)行輸出
y0 = F.elu(self.fcH01(x))
y0 = F.elu(self.fcH02(y0))
y0 = F.elu(self.fcH03(y0))
y1 = F.elu(self.fcH11(x))
y1 = F.elu(self.fcH12(y1))
y1 = F.elu(self.fcH13(y1))
return y0, y1
這樣就相當(dāng)實(shí)現(xiàn)了一個(gè)多輸出模型,一個(gè)x同時(shí)輸出y0和y1.
訓(xùn)練的時(shí)候分別訓(xùn)練,也即分別建loss,代碼如下:
f_out_y0, _ = net(x0)
_, f_out_y1 = net(x1)
#實(shí)例化損失函數(shù)
criterion0 = Loss()
criterion1 = Loss()
loss0 = criterion0(f_y0, f_out_y0, w0)
loss1 = criterion1(f_y1, f_out_y1, w1)
print(loss0.item(), loss1.item())
#對(duì)網(wǎng)絡(luò)參數(shù)進(jìn)行初始化
optimizer.zero_grad()
loss0.backward()
loss1.backward()
#對(duì)網(wǎng)絡(luò)的參數(shù)進(jìn)行更新
optimizer.step()
先把x按t=0和t=1分為x0和x1,然后分別送入進(jìn)行訓(xùn)練。這樣就實(shí)現(xiàn)了一個(gè)two-head模型。
4.后記
我自以為多輸出模型可以分為以下兩類:
多個(gè)輸出不同時(shí)獲得,如本文情況。
多個(gè)輸出同時(shí)獲得。
多輸出不同時(shí)獲得的解決方法上文已說明。多輸出同時(shí)獲得則可以通過把y0和y1拼接起來一起輸出來實(shí)現(xiàn)。
補(bǔ)充:PyTorch 多輸入多輸出模型構(gòu)建
本篇教程基于 PyTorch 1.5版本
直接上代碼!
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
import torch.utils.data as data_utils
class Net(nn.Module):
def __init__(self, n_input, n_hidden, n_output):
super(Net, self).__init__()
self.hidden1 = nn.Linear(n_input, n_hidden)
self.hidden2 = nn.Linear(n_hidden, n_hidden)
self.predict1 = nn.Linear(n_hidden*2, n_output)
self.predict2 = nn.Linear(n_hidden*2, n_output)
def forward(self, input1, input2): # 多輸入?。?!
out01 = self.hidden1(input1)
out02 = torch.relu(out01)
out03 = self.hidden2(out02)
out04 = torch.sigmoid(out03)
out11 = self.hidden1(input2)
out12 = torch.relu(out11)
out13 = self.hidden2(out12)
out14 = torch.sigmoid(out13)
out = torch.cat((out04, out14), dim=1) # 模型層拼合?。?!當(dāng)然你的模型中可能不需要~
out1 = self.predict1(out)
out2 = self.predict2(out)
return out1, out2 # 多輸出?。?!
net = Net(1, 20, 1)
x1 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 請(qǐng)不要關(guān)心這里,隨便弄一個(gè)數(shù)據(jù),為了說明問題而已
y1 = x1.pow(3)+0.1*torch.randn(x1.size())
x2 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y2 = x2.pow(3)+0.1*torch.randn(x2.size())
x1, y1 = (Variable(x1), Variable(y1))
x2, y2 = (Variable(x2), Variable(y2))
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for t in range(5000):
prediction1, prediction2 = net(x1, x2)
loss1 = loss_func(prediction1, y1)
loss2 = loss_func(prediction2, y2)
loss = loss1 + loss2 # 重點(diǎn)!
optimizer.zero_grad()
loss.backward()
optimizer.step()
if t % 100 == 0:
print('Loss1 = %.4f' % loss1.data,'Loss2 = %.4f' % loss2.data,)
至此搞定!
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
您可能感興趣的文章:- pytorch構(gòu)建多模型實(shí)例
- pytorch模型存儲(chǔ)的2種實(shí)現(xiàn)方法
- 如何使用Pytorch搭建模型
- 詳解Pytorch 使用Pytorch擬合多項(xiàng)式(多項(xiàng)式回歸)