Skip to content

Commit 4ec0819

Browse files
committed
Fix dropout problems:
1. For showtell, and capitonmodel, since the rnn are using official implementation, there is no dropout at the last layer. We manually add it before doing logit. 2. For Att2in models, the state passed to the next time step is also dropped out; now only the output is dropped out.
1 parent 93ddeb4 commit 4ec0819

File tree

4 files changed

+10
-12
lines changed

4 files changed

+10
-12
lines changed

misc/Att2inModel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,7 @@ def forward(self, xt, fc_feats, att_feats, p_att_feats, state):
6868
next_c = forget_gate * state[1][-1] + in_gate * in_transform
6969
next_h = out_gate * F.tanh(next_c)
7070

71-
next_h = self.dropout(next_h)
72-
73-
output = next_h
71+
output = self.dropout(next_h)
7472
state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
7573
return output, state
7674

misc/AttModel.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,7 @@ def forward(self, xt, fc_feats, att_feats, p_att_feats, state):
464464
next_c = forget_gate * state[1][-1] + in_gate * in_transform
465465
next_h = out_gate * F.tanh(next_c)
466466

467-
next_h = self.dropout(next_h)
468-
469-
output = next_h
467+
output = self.dropout(next_h)
470468
state = (next_h.unsqueeze(0), next_c.unsqueeze(0))
471469
return output, state
472470

misc/CaptionModel.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def __init__(self, opt):
3333
self.linear = nn.Linear(self.fc_feat_size, self.num_layers * self.rnn_size) # feature to rnn_size
3434
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
3535
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
36+
self.dropout = nn.Dropout(self.drop_prob_lm)
3637

3738
self.init_weights()
3839

@@ -78,7 +79,7 @@ def forward(self, fc_feats, att_feats, seq):
7879
xt = self.embed(it)
7980

8081
output, state = self.core(xt, fc_feats, att_feats, state)
81-
output = F.log_softmax(self.logit(output))
82+
output = F.log_softmax(self.logit(self.dropout(output)))
8283
outputs.append(output)
8384

8485
return torch.cat([_.unsqueeze(1) for _ in outputs], 1)
@@ -165,7 +166,7 @@ def sample_beam(self, fc_feats, att_feats, opt={}):
165166
state = new_state
166167

167168
output, state = self.core(xt, tmp_fc_feats, tmp_att_feats, state)
168-
logprobs = F.log_softmax(self.logit(output))
169+
logprobs = F.log_softmax(self.logit(self.dropout(output)))
169170

170171
self.done_beams[k] = sorted(self.done_beams[k], key=lambda x: -x['p'])
171172
seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
@@ -216,7 +217,7 @@ def sample(self, fc_feats, att_feats, opt={}):
216217
seqLogprobs.append(sampleLogprobs.view(-1))
217218

218219
output, state = self.core(xt, fc_feats, att_feats, state)
219-
logprobs = F.log_softmax(self.logit(output))
220+
logprobs = F.log_softmax(self.logit(self.dropout(output)))
220221

221222
return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1)
222223

misc/ShowTellModel.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self, opt):
2626
self.core = getattr(nn, self.rnn_type.upper())(self.input_encoding_size, self.rnn_size, self.num_layers, bias=False, dropout=self.drop_prob_lm)
2727
self.embed = nn.Embedding(self.vocab_size + 1, self.input_encoding_size)
2828
self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
29+
self.dropout = nn.Dropout(self.drop_prob_lm)
2930

3031
self.init_weights()
3132

@@ -73,7 +74,7 @@ def forward(self, fc_feats, att_feats, seq):
7374
xt = self.embed(it)
7475

7576
output, state = self.core(xt.unsqueeze(0), state)
76-
output = F.log_softmax(self.logit(output.squeeze(0)))
77+
output = F.log_softmax(self.logit(self.dropout(output.squeeze(0))))
7778
outputs.append(output)
7879

7980
return torch.cat([_.unsqueeze(1) for _ in outputs[1:]], 1).contiguous()
@@ -159,7 +160,7 @@ def sample_beam(self, fc_feats, att_feats, opt={}):
159160
state = new_state
160161

161162
output, state = self.core(xt.unsqueeze(0), state)
162-
logprobs = F.log_softmax(self.logit(output.squeeze(0)))
163+
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))))
163164

164165
self.done_beams[k] = sorted(self.done_beams[k], key=lambda x: -x['p'])
165166
seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
@@ -212,6 +213,6 @@ def sample(self, fc_feats, att_feats, opt={}):
212213
seqLogprobs.append(sampleLogprobs.view(-1))
213214

214215
output, state = self.core(xt.unsqueeze(0), state)
215-
logprobs = F.log_softmax(self.logit(output.squeeze(0)))
216+
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))))
216217

217218
return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1)

0 commit comments

Comments
 (0)