diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..51378d65 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +vis/ +train2014 +val2014 +test2014 +models +data +__pycache__ +eval_results +.idea/ +ImageCaptioning.pytorch.iml +transfer_nlp.sh diff --git a/captioning/data/dataloader.py b/captioning/data/dataloader.py index 7f2ed030..237a4a7b 100644 --- a/captioning/data/dataloader.py +++ b/captioning/data/dataloader.py @@ -4,6 +4,7 @@ import json import h5py +import pickle from lmdbdict import lmdbdict from lmdbdict.methods import DUMPS_FUNC, LOADS_FUNC import os @@ -11,6 +12,10 @@ import numpy.random as npr import random from functools import partial +import tqdm + +from captioning.utils import misc as utils +from captioning.models import utils as model_utils import torch import torch.utils.data as data @@ -18,6 +23,8 @@ import multiprocessing import six +PAD_IX = 0 + class HybridLoader: """ If db_path is a director, then use normal file loading @@ -156,6 +163,10 @@ def __init__(self, opt): elif opt.train_only == 0: # restval self.split_ix['train'].append(ix) + if opt.max_images_per_split is not None: + for key in ['train', 'val', 'test']: + self.split_ix[key] = self.split_ix[key][:opt.max_images_per_split] + print('assigned %d images to split train' %len(self.split_ix['train'])) print('assigned %d images to split val' %len(self.split_ix['val'])) print('assigned %d images to split test' %len(self.split_ix['test'])) @@ -224,8 +235,8 @@ def collate_func(self, batch, split): # #sort by att_feat length # fc_batch, att_batch, label_batch, gts, infos = \ # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) - fc_batch, att_batch, label_batch, gts, infos = \ - zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True)) + # fc_batch, att_batch, label_batch, gts, infos = \ + # zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True)) data = {} data['fc_feats'] = np.stack(fc_batch) # merge att_feats @@ -301,28 +312,143 @@ def __getitem__(self, index): def __len__(self): return len(self.info['images']) +class NearestNeighborIndex: + def __getstate__(self): + # don't pickle the index + state = self.__dict__.copy() + del state['index'] + return state + + def __setstate__(self, newstate): + # recreate the index from the other pickled attributes + self.__dict__.update(newstate) + self._init_index() + + def _init_others(self, split, loader): + self.split = split + self.vocab = loader.dataset.get_vocab() + + fc_features = [] + self.loader_ixs = loader_ixs = [] + self.ids = ids = [] + self.file_paths = file_paths = [] + self.captions = captions = [] + loader.reset_iterator(split) + for batch_ix in tqdm.trange(len(loader.loaders[split]), ncols=80): + data = loader.get_batch(split) + infos = data['infos'] + loader_ixs.extend(d['ix'] for d in infos) + ids.extend(d['id'] for d in infos) + file_paths.extend(d['file_path'] for d in infos) + fc_features.append(data['fc_feats']) + batch_captions = [] + for batch_labels in data['labels']: + instance_captions = [] + for img_labels in batch_labels: + caption = [self.vocab[str(ix.item())] for ix in img_labels if ix != PAD_IX] + instance_captions.append(caption) + batch_captions.append(instance_captions) + captions.extend(batch_captions) + self.fc_features = torch.cat(fc_features, 0).numpy() + + def _init_index(self): + import faiss + index = faiss.IndexFlatL2(self.fc_features.shape[1]) + index.add(self.fc_features) + self.index = index + + def __init__(self, split, loader): + self._init_others(split, loader) + self._init_index() + + def get_neighbor_batch(self, loader, img_fc_feat, k_neighbors, include_self=False, self_indices=None, + neighbor_type='closest'): + # assert len(img_fc_feat.shape) == 1, "should be the features for a single image" + if len(img_fc_feat.shape) == 1: + # add a dimension + img_fc_feat = img_fc_feat[None] + n_images = img_fc_feat.shape[0] + per_image_dim = k_neighbors + (1 if include_self else 0) + if neighbor_type == 'closest': + D, I = self.index.search(img_fc_feat, k_neighbors) + n_images_, k_ = D.shape + assert n_images == n_images_ + # n_images x per_image_dim + indices = torch.tensor([[self.loader_ixs[i] for i in inds] for inds in I]).long() + if include_self: + indices = torch.cat((torch.tensor(self_indices).view(n_images, 1).long(), indices), dim=1) + assert indices.size() == (n_images, per_image_dim) + elif neighbor_type == 'batch': + # TODO: just use the batch without reloading from the loader + if self_indices is None: + raise ValueError("must pass self_indices") + indices = [ + torch.tensor(self_indices).flatten().roll(-i, dims=(0,)) + for i in range(len(self_indices)) + ] + assert len(indices) == n_images + # [0,1,2] -> [[0,1,2],[1,0,2],[2,1,0]] + indices = torch.stack(indices, 0) + if not include_self: + indices = indices[:,1:] + elif neighbor_type == 'random': + indices = torch.tensor(np.random.choice(self.loader_ixs, (n_images, k_neighbors))).long() + if include_self: + indices = torch.cat((torch.tensor(self_indices).view(n_images, 1).long(), indices), dim=1) + assert indices.size() == (n_images, per_image_dim) + else: + raise ValueError(f"invalid neighbor_type {neighbor_type}") + data = [loader.dataset[ix, 0, False] for ix in indices.view(-1)] + batch = loader.dataset.collate_func(data, self.split) + split_batch = {} + for k, v in batch.items(): + # each tensor will be n_images x k_neighbors or n_images x (k_neighbors+1) if include_self + split_batch[k] = model_utils.split_tensors_no_transpose_no_unbind(n_images, v) + return split_batch + class DataLoader: - def __init__(self, opt): + def __init__(self, opt, shuffle_override=None, wrap_override=None, + build_nearest_neighbor_indices_for_splits=None, index_serialization_root_path=None): self.opt = opt self.batch_size = self.opt.batch_size self.dataset = Dataset(opt) # Initialize loaders and iters self.loaders, self.iters = {}, {} + self.indices = {} for split in ['train', 'val', 'test']: - if split == 'train': - sampler = MySampler(self.dataset.split_ix[split], shuffle=True, wrap=True) + if shuffle_override is not None: + shuffle = shuffle_override + else: + shuffle = split == 'train' + if wrap_override is not None: + wrap = wrap_override else: - sampler = MySampler(self.dataset.split_ix[split], shuffle=False, wrap=False) + wrap = split == 'train' + sampler = MySampler(self.dataset.split_ix[split], shuffle=shuffle, wrap=wrap) self.loaders[split] = data.DataLoader(dataset=self.dataset, batch_size=self.batch_size, sampler=sampler, pin_memory=True, - num_workers=4, # 4 is usually enough + num_workers=opt.num_workers, # 4 is usually enough collate_fn=partial(self.dataset.collate_func, split=split), drop_last=False) self.iters[split] = iter(self.loaders[split]) + if build_nearest_neighbor_indices_for_splits and split in build_nearest_neighbor_indices_for_splits: + if index_serialization_root_path is not None: + index_fname = os.path.join(index_serialization_root_path, f'{split}_index.pkl') + if os.path.exists(index_fname): + with open(index_fname, 'rb') as f: + index = pickle.load(f) + else: + # TODO: make sure that the resulting iterator resetting doesn't cause an issue for training + index = NearestNeighborIndex(split, self) + os.makedirs(os.path.dirname(index_fname), exist_ok=True) + with open(index_fname, 'wb') as f: + pickle.dump(index, f) + self.indices[split] = index + def get_batch(self, split): try: data = next(self.iters[split]) @@ -422,4 +548,4 @@ def state_dict(self, prefetched_num=None): 'iter_counter': self.iter_counter - prefetched_num } - \ No newline at end of file + diff --git a/captioning/data/pth_loader.py b/captioning/data/pth_loader.py index e2fd39fe..90794d57 100644 --- a/captioning/data/pth_loader.py +++ b/captioning/data/pth_loader.py @@ -218,8 +218,8 @@ def collate_func(self, batch): # #sort by att_feat length # fc_batch, att_batch, label_batch, gts, infos = \ # zip(*sorted(zip(fc_batch, att_batch, np.vsplit(label_batch, batch_size), gts, infos), key=lambda x: len(x[1]), reverse=True)) - fc_batch, att_batch, label_batch, gts, infos = \ - zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True)) + # fc_batch, att_batch, label_batch, gts, infos = \ + # zip(*sorted(zip(fc_batch, att_batch, label_batch, gts, infos), key=lambda x: 0, reverse=True)) data = {} data['fc_feats'] = np.stack(fc_batch) # merge att_feats diff --git a/captioning/models/AttModel.py b/captioning/models/AttModel.py index 8c94754d..a86c6339 100644 --- a/captioning/models/AttModel.py +++ b/captioning/models/AttModel.py @@ -23,8 +23,11 @@ import torch.nn.functional as F from . import utils from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence +import einops from .CaptionModel import CaptionModel +from ..data.dataloader import NearestNeighborIndex +from ..modules.distractor_scorer import DistractorScorer bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] bad_endings += ['the'] @@ -34,6 +37,7 @@ def sort_pack_padded_sequence(input, lengths): tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True) inv_ix = indices.clone() inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix) + # inv_ix = torch.arange(0, len(indices)).type_as(indices)[indices] return tmp, inv_ix def pad_unsort_packed_sequence(input, inv_ix): @@ -41,13 +45,25 @@ def pad_unsort_packed_sequence(input, inv_ix): tmp = tmp[inv_ix] return tmp -def pack_wrapper(module, att_feats, att_masks): +def pack_wrapper_old(module, att_feats, att_masks): if att_masks is not None: packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1)) return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) else: return module(att_feats) +def pack_wrapper(module, att_feats, att_masks): + if att_masks is not None: + packed = pack_padded_sequence(att_feats, att_masks.data.long().sum(1), enforce_sorted=False, batch_first=True) + padded = pad_packed_sequence(PackedSequence( + data=module(packed.data), sorted_indices=packed.sorted_indices, + unsorted_indices=packed.unsorted_indices, batch_sizes=packed.batch_sizes + ), + batch_first=True)[0] + return padded + else: + return module(att_feats) + class AttModel(CaptionModel): def __init__(self, opt): super(AttModel, self).__init__() @@ -95,9 +111,27 @@ def __init__(self, opt): self.vocab = opt.vocab self.bad_endings_ix = [int(k) for k,v in self.vocab.items() if v in bad_endings] + if opt.pragmatic_distractor_scoring == 'mlp': + self.distractor_scorer = DistractorScorer(opt) + else: + self.distractor_scorer = None + + def distractor_log_probs(self, fc_feats_target, fc_feats_distr): + # fc_feats_target: batch_size x 1 x d + # fc_feats_distr: batch_size x 1 x d + if hasattr(self, 'distractor_scorer') and self.distractor_scorer is not None: + distractor_log_probs = self.distractor_scorer(fc_feats_target, fc_feats_distr) + else: + # log p(i' | i) for target image i and distractor i' + batch_size, _, num_distractors = fc_feats_distr.size() + distractor_log_probs = torch.full((batch_size, num_distractors), 1./num_distractors).to(fc_feats_distr.device).log() + # log p(i' | i) for target image i and distractor i' + # batch_size x n_distractors + return distractor_log_probs + def init_hidden(self, bsz): weight = self.logit.weight \ - if hasattr(self.logit, "weight") \ + if hasattr(self.logit, "weight") \ else self.logit[0].weight return (weight.new_zeros(self.num_layers, bsz, self.rnn_size), weight.new_zeros(self.num_layers, bsz, self.rnn_size)) @@ -115,12 +149,12 @@ def _prepare_feature(self, fc_feats, att_feats, att_masks): # embed fc and att feats fc_feats = self.fc_embed(fc_feats) - att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) + att_feats_wrapped = pack_wrapper(self.att_embed, att_feats, att_masks) # Project the attention feats first to reduce memory and computation comsumptions. - p_att_feats = self.ctx2att(att_feats) + p_att_feats = self.ctx2att(att_feats_wrapped) - return fc_feats, att_feats, p_att_feats, att_masks + return fc_feats, att_feats_wrapped, p_att_feats, att_masks def _forward(self, fc_feats, att_feats, seq, att_masks=None): batch_size = fc_feats.size(0) @@ -162,6 +196,65 @@ def _forward(self, fc_feats, att_feats, seq, att_masks=None): return outputs + def score_seqs(self, fc_feats, att_feats, att_masks, seq, add_bos=True): + """ + Get the log-probabilities for producing each caption in seq from each image + :param fc_feats: resnet features. Tensor, batch x d_fc + :param att_feats: r-cnn features. Tensor, batch x obj x d_obj + :param att_masks: r-cnn feature mask. Tensor, batch x obj + :param seq: captions. Tensor, batch x t + :param add_bos: add BOS to the beginning of the captions (represented as padding) + :return: log probabilities for the captions. Tensor, batch: + """ + if add_bos: + # padding represents BOS + seq = torch.cat([torch.zeros(seq.size(0), 1).long().to(seq), seq], 1) + with torch.no_grad(): + scores = self(fc_feats, att_feats, seq, att_masks) + mask = (seq[:,:-1] > 0) | (seq[:,1:] > 0) + # TODO: does this include the EOS score? + # seq_t: words input at each position, with 0 for pad + # seq_t: 0, w_0, w_1, w_2, ..., w_k, 0, ... + # mask : 1, 1 , 1 , 1 , 1 , 1 , 0, ... + # selected_scores: w_0 + ... + w_k + # return scores[:,:-1].gather(2, seq[:,1:].unsqueeze(2)).squeeze(2) + selected_scores = (scores[:,:-1].gather(2, seq[:,1:].unsqueeze(2)).squeeze(2) * mask) + return selected_scores.sum(1) + + def cross_product_scores(self, fc_feats, att_feats, att_masks, seq, add_bos=True): + """ + Get log-probabilities for all images crossed with all captions. + :param fc_feats: n_images x d_fc + :param att_feats: n_images x obj x d_obj + :param att_masks: n_images x obj + :param seq: n_captions x T + :param add_bos: add BOS to the beginning of captions + :return: n_captions x n_images + """ + n_captions = seq.size(0) + n_images = fc_feats.size(0) + assert att_feats.size(0) == n_images + if att_masks is not None: + assert att_masks.size(0) == n_images + fc_feats_t = fc_feats.unsqueeze(0).repeat_interleave(n_captions, dim=0) + att_feats_t = att_feats.unsqueeze(0).repeat_interleave(n_captions, dim=0) + + fc_feats_t = einops.rearrange(fc_feats_t, 'caps imgs d -> (caps imgs) d') + att_feats_t = einops.rearrange(att_feats_t, 'caps imgs obj d -> (caps imgs) obj d') + + if att_masks is not None: + att_masks_t = att_masks.unsqueeze(0).repeat_interleave(n_captions, dim=0) + att_masks_t = einops.rearrange(att_masks_t, 'caps imgs obj -> (caps imgs) obj') + else: + att_masks_t = None + seq_t = seq.unsqueeze(1).repeat_interleave(n_images, dim=1) + seq_t = einops.rearrange(seq_t, 'caps imgs d -> (caps imgs) d') + + seq_scores = self.score_seqs(fc_feats_t, att_feats_t, att_masks_t, seq_t, add_bos=add_bos) + seq_scores = einops.rearrange(seq_scores, '(caps imgs) -> caps imgs', caps=n_captions, imgs=n_images) + return seq_scores + + def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1): # 'it' contains a word index xt = self.embed(it) @@ -254,7 +347,148 @@ def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): # return the samples and their log likelihoods return seq, seqLogprobs - def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): + def _sample_contrastive_beam( + self, fc_feats, att_feats, att_masks=None, opt={}, data=None, nearest_neighbor_index=None, loader=None, + ): + assert nearest_neighbor_index is not None + beam_size = opt.get('beam_size', 10) + sample_n = opt.get('sample_n', 10) + # when sample_n == beam_size then each beam is a sample. + assert sample_n == 1 or sample_n == beam_size, 'when beam search, sample_n == 1 or beam search' + batch_size = fc_feats.size(0) + device = fc_feats.device + + candidate_type = opt['pragmatic_distractor_candidate_type'] + + if candidate_type in ['closest', 'random']: + candidate_distractors = opt['pragmatic_distractors'] + elif candidate_type == 'batch': + candidate_distractors = batch_size - 1 + else: + raise ValueError( + f"invalid --pragmatic_distractor_candidate_type {candidate_type}" + ) + + neighbor_batch = nearest_neighbor_index.get_neighbor_batch( + loader, fc_feats.cpu().numpy(), k_neighbors=candidate_distractors, + include_self=True, self_indices=[data['infos'][img_ix]['ix'] for img_ix in range(batch_size)], + ) + + # assert torch.allclose(neighbor_batch['fc_feats'][:,0].cpu(), fc_feats.cpu()) + # neighbor_batch['fc_feats']: batch_size x k_neighbors+1 x d + + fc_feats = neighbor_batch['fc_feats'].to(device) + att_feats = neighbor_batch['att_feats'].to(device) + att_masks = neighbor_batch['att_masks'].to(device) + infos = neighbor_batch['infos'] + + if opt['pragmatic_distractor_type'] == 'choose_within_closest': + # fc_feats_target: batch_size x 1 x d + # fc_feats_distr: batch_size x candidate_distractors x d + fc_feats_target, fc_feats_distr = fc_feats.split((1, candidate_distractors), dim=1) + att_feats_target, att_feats_distr = att_feats.split((1, candidate_distractors), dim=1) + if att_masks is not None: + att_masks_target, att_masks_distr = att_masks.split((1, candidate_distractors), dim=1) + + num_distractors = opt['pragmatic_distractors_to_choose'] + + # log p(i' | i) for target image i and distractor i' + # batch_size x candidate_distractors + distractor_log_probs = self.distractor_log_probs(fc_feats_target, fc_feats_distr) + distractor_lps, distractor_indices = distractor_log_probs.topk(num_distractors, dim=-1) + + new_infos = [] + for b in range(batch_size): + # append info for target + new_infos.append(infos[b*(1+candidate_distractors)]) + for ix in distractor_indices[b]: + # add ix.item()+1 because we topk'd over distractors only + new_infos.append(infos[b*(1+candidate_distractors) + ix.item() + 1]) + assert len(new_infos) == batch_size * (num_distractors+1) + + fc_feats_distr = fc_feats_distr.gather( + 1, distractor_indices.unsqueeze(-1).expand(-1,-1,fc_feats_distr.size(-1)) + ) + fc_feats = torch.cat((fc_feats_target, fc_feats_distr), 1) + att_feats_distr = att_feats_distr.gather( + 1, distractor_indices.unsqueeze(-1).unsqueeze(-1).expand(-1,-1,*att_feats_distr.size()[-2:]) + ) + att_feats = torch.cat((att_feats_target, att_feats_distr), 1) + if att_masks is not None: + att_masks_distr = att_masks_distr.gather( + 1, distractor_indices.unsqueeze(-1).expand(-1,-1,att_masks_distr.size(-1)) + ) + att_masks = torch.cat((att_masks_target, att_masks_distr), 1) + else: + num_distractors = candidate_distractors + per_image_dim = num_distractors + 1 + + self.neighbor_infos = [ + infos[ix:ix+per_image_dim] + for ix in range(0, batch_size*per_image_dim, per_image_dim) + ] + + def combine_first_two(tensor): + return tensor.view((tensor.size(0) * tensor.size(1),) + tensor.size()[2:]) + + def flat_view(tensor): + if tensor is None: + return None + assert tensor.size(0) == batch_size + assert tensor.size(1) == per_image_dim + return combine_first_two(tensor) + + def unflat_view(tensor): + if tensor is None: + return None + assert tensor.size(0) == batch_size * per_image_dim + return tensor.view((batch_size, per_image_dim) + tensor.size()[1:]) + + # p_fc_feats_a, p_att_feats_a, pp_att_feats_a, p_att_masks_a = + prepped_feats = self._prepare_feature( + flat_view(fc_feats), + flat_view(att_feats), + flat_view(att_masks) if att_masks is not None else None, + ) + + assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' + seq = fc_feats.new_full((batch_size*sample_n, self.seq_length), self.pad_idx, dtype=torch.long) + seqLogprobs = fc_feats.new_zeros(batch_size*sample_n, self.seq_length, self.vocab_size + 1) + # lets process every image independently for now, for simplicity + + self.done_beams = [[] for _ in range(batch_size)] + + state = self.init_hidden(batch_size*per_image_dim) + + # first step, feed bos + it = fc_feats.new_full([batch_size*per_image_dim], self.bos_idx, dtype=torch.long) + it_non_neighbor = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) + # batch_size*per_image_dim x V + logprobs, state = self.get_logprobs_state(it, *(prepped_feats + (state,))) + + # (batch_size*beam_size*per_image_view) x ... + repeated_feats = (combine_first_two(ten) for ten in utils.repeat_tensors( + beam_size, + [unflat_view(t) for t in prepped_feats] + )) + self.done_beams = self.contrastive_beam_search( + num_distractors, state, logprobs, *repeated_feats, opt=opt + ) + assert len(self.neighbor_infos) == batch_size + for k in range(batch_size): + if sample_n == beam_size: + for _n in range(sample_n): + seq_len = self.done_beams[k][_n]['seq'].shape[0] + seq[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['seq'] + seqLogprobs[k*sample_n+_n, :seq_len] = self.done_beams[k][_n]['logps'] + else: + seq_len = self.done_beams[k][0]['seq'].shape[0] + seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score + seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps'] + # return the samples and their log likelihoods + return seq, seqLogprobs + + def _sample(self, fc_feats, att_feats, att_masks=None, opt={}, data=None, loader=None): sample_method = opt.get('sample_method', 'greedy') beam_size = opt.get('beam_size', 1) @@ -267,6 +501,12 @@ def _sample(self, fc_feats, att_feats, att_masks=None, opt={}): remove_bad_endings = opt.get('remove_bad_endings', 0) if beam_size > 1 and sample_method in ['greedy', 'beam_search']: return self._sample_beam(fc_feats, att_feats, att_masks, opt) + if sample_method in ['contrastive_beam_search']: + nearest_neighbor_index = loader.indices[opt['pragmatic_distractor_split']] + return self._sample_contrastive_beam( + fc_feats, att_feats, att_masks, opt, + data=data, nearest_neighbor_index=nearest_neighbor_index, loader=loader, + ) if group_size > 1: return self._diverse_sample(fc_feats, att_feats, att_masks, opt) @@ -965,4 +1205,4 @@ def core(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks): def _prepare_feature(self, fc_feats, att_feats, att_masks): fc_feats = self.fc_embed(fc_feats) - return fc_feats, None, None, None \ No newline at end of file + return fc_feats, None, None, None diff --git a/captioning/models/CaptionModel.py b/captioning/models/CaptionModel.py index 221ecd1e..8dd88eb2 100644 --- a/captioning/models/CaptionModel.py +++ b/captioning/models/CaptionModel.py @@ -17,6 +17,90 @@ from ..utils import misc as utils from . import utils as model_utils +import einops + +# does one step of classical beam search +def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state, log_l1=None): + #INPUTS: + #N: batch size + #b: beam size + #T: num timesteps + #V: vocab size + #logprobs: probabilities augmented after diversity N*bxV + #beam_size: obvious + #t : time instant + #beam_seq : tensor contanining the beams. N x b x T + #beam_seq_logprobs: tensor contanining the beam logprobs. batch x beam x T x V + #beam_logprobs_sum: tensor contanining joint logprobs + # log_l1: N x (num_distractors+1) x b x V + #OUPUTS: + #beam_seq : tensor containing the word indices of the decoded captions Nxbxl + #beam_seq_logprobs : log-probability of each decision made, NxbxlxV + #beam_logprobs_sum : joint log-probability of each beam Nxb + + batch_size = beam_logprobs_sum.shape[0] + vocab_size = logprobs.shape[-1] + logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV + if t == 0: + assert logprobs.shape[1] == 1 + beam_logprobs_sum = beam_logprobs_sum[:, :1] + candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV + ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True) + ys, ix = ys[:,:beam_size], ix[:,:beam_size] + beam_ix = ix // vocab_size # Nxb which beam + selected_ix = ix % vocab_size # Nxb # which world + state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(-1) # N*b which in Nxb beams + + if t > 0: + # gather according to beam_ix + assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all() + beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) + + beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(beam_seq_logprobs)) + + if log_l1 is not None: + per_image_dim = log_l1.size(1) + log_l1_reshape = einops.rearrange( + log_l1, + "batch_size per_image_dim beam_size v -> (batch_size per_image_dim) (beam_size v)" + ) + ix_repeated = einops.rearrange( + ix.unsqueeze(1).expand(-1, per_image_dim, -1), + "batch_size per_image_dim beam_size -> (batch_size per_image_dim) beam_size" + ) + new_priors = einops.rearrange( + log_l1_reshape.gather(-1, ix_repeated), + "(batch_size per_image_dim) beam_size -> batch_size per_image_dim beam_size", + batch_size=batch_size, per_image_dim=per_image_dim + ) + # assert new_priors.size() == log_l1.size()[:-1] + # for n in range(batch_size): + # for b in range(beam_size): + # bix = beam_ix[n,b] + # six = selected_ix[n,b] + # assert torch.allclose(new_priors[n,:,b], log_l1[n,:,bix,six]) + + beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl + beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \ + logprobs.reshape(batch_size, -1).gather(1, ix) + assert (beam_logprobs_sum == ys).all() + _tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size) + beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size)) # NxbxV + assert (_tmp_beam_logprobs == beam_logprobs).all() + beam_seq_logprobs = torch.cat([ + beam_seq_logprobs, + beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2) + + new_state = [None for _ in state] + for _ix in range(len(new_state)): + # copy over state in previous beam q to new beam at vix + new_state[_ix] = state[_ix][:, state_ix] + state = new_state + if log_l1 is not None: + # log_priors: batch_size x per_image_dim x beam_size + return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state,new_priors + else: + return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state class CaptionModel(nn.Module): def __init__(self): @@ -55,60 +139,6 @@ def add_diversity(beam_seq_table, logprobs, t, divm, diversity_lambda, bdash): return logprobs, unaug_logprobs - # does one step of classical beam search - - def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): - #INPUTS: - #logprobs: probabilities augmented after diversity N*bxV - #beam_size: obvious - #t : time instant - #beam_seq : tensor contanining the beams - #beam_seq_logprobs: tensor contanining the beam logprobs - #beam_logprobs_sum: tensor contanining joint logprobs - #OUPUTS: - #beam_seq : tensor containing the word indices of the decoded captions Nxbxl - #beam_seq_logprobs : log-probability of each decision made, NxbxlxV - #beam_logprobs_sum : joint log-probability of each beam Nxb - - batch_size = beam_logprobs_sum.shape[0] - vocab_size = logprobs.shape[-1] - logprobs = logprobs.reshape(batch_size, -1, vocab_size) # NxbxV - if t == 0: - assert logprobs.shape[1] == 1 - beam_logprobs_sum = beam_logprobs_sum[:, :1] - candidate_logprobs = beam_logprobs_sum.unsqueeze(-1) + logprobs # beam_logprobs_sum Nxb logprobs is NxbxV - ys, ix = torch.sort(candidate_logprobs.reshape(candidate_logprobs.shape[0], -1), -1, True) - ys, ix = ys[:,:beam_size], ix[:,:beam_size] - beam_ix = ix // vocab_size # Nxb which beam - selected_ix = ix % vocab_size # Nxb # which world - state_ix = (beam_ix + torch.arange(batch_size).type_as(beam_ix).unsqueeze(-1) * logprobs.shape[1]).reshape(-1) # N*b which in Nxb beams - - - if t > 0: - # gather according to beam_ix - assert (beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) == beam_seq.reshape(-1, beam_seq.shape[-1])[state_ix].view_as(beam_seq)).all() - beam_seq = beam_seq.gather(1, beam_ix.unsqueeze(-1).expand_as(beam_seq)) - - beam_seq_logprobs = beam_seq_logprobs.gather(1, beam_ix.unsqueeze(-1).unsqueeze(-1).expand_as(beam_seq_logprobs)) - - beam_seq = torch.cat([beam_seq, selected_ix.unsqueeze(-1)], -1) # beam_seq Nxbxl - beam_logprobs_sum = beam_logprobs_sum.gather(1, beam_ix) + \ - logprobs.reshape(batch_size, -1).gather(1, ix) - assert (beam_logprobs_sum == ys).all() - _tmp_beam_logprobs = unaug_logprobs[state_ix].reshape(batch_size, -1, vocab_size) - beam_logprobs = unaug_logprobs.reshape(batch_size, -1, vocab_size).gather(1, beam_ix.unsqueeze(-1).expand(-1, -1, vocab_size)) # NxbxV - assert (_tmp_beam_logprobs == beam_logprobs).all() - beam_seq_logprobs = torch.cat([ - beam_seq_logprobs, - beam_logprobs.reshape(batch_size, -1, 1, vocab_size)], 2) - - new_state = [None for _ in state] - for _ix in range(len(new_state)): - # copy over state in previous beam q to new beam at vix - new_state[_ix] = state[_ix][:, state_ix] - state = new_state - return beam_seq,beam_seq_logprobs,beam_logprobs_sum,state - # Start diverse_beam_search opt = kwargs['opt'] temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs @@ -188,16 +218,20 @@ def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprob final_beam = { 'seq': beam_seq_table[divm][b, vix].clone(), 'logps': beam_seq_logprobs_table[divm][b, vix].clone(), + # not sure what this would be used for: seems to be sum over even unused tokens 'unaug_p': beam_seq_logprobs_table[divm][b, vix].sum().item(), - 'p': beam_logprobs_sum_table[divm][b, vix].item() + 'p': beam_logprobs_sum_table[divm][b, vix].item(), + # same as p but more descriptively named and we can backprop through it + 'log_prob': beam_logprobs_sum_table[divm][b, vix].clone(), } final_beam['p'] = length_penalty(t-divm+1, final_beam['p']) done_beams_table[b][divm].append(final_beam) beam_logprobs_sum_table[divm][b, is_end] -= 1000 # move the current group one step forward in time - - it = beam_seq_table[divm][:, :, t-divm].reshape(-1).to(logprobs.device) + + it = beam_seq_table[divm][:, :, t-divm] + it = it.reshape(-1).to(logprobs.device) logprobs_table[divm], state_table[divm] = self.get_logprobs_state(it, *(args[divm] + [state_table[divm]])) logprobs_table[divm] = F.log_softmax(logprobs_table[divm] / temperature, dim=-1) @@ -206,6 +240,177 @@ def beam_step(logprobs, unaug_logprobs, beam_size, t, beam_seq, beam_seq_logprob done_beams = [sum(_, []) for _ in done_beams_table] return done_beams + def contrastive_beam_search(self, num_distractors, init_state, init_logprobs, *args, **kwargs): + # init_logprobs: batch_size*(num_distractors+1) x (vocab_size+1) + # args: each tensor is (batch_size*beam_size*per_image_dim) x ... + + # state: tuple of tensors (2 x batch_size*beam_size*per_image_dim x d) + # init_state: (2 x batch_size*1*per_image_dim x d) [beam_size initially like 0; call this "this_state_beam_size" later] + + # does one step of classical beam search + + # Start diverse_beam_search + opt = kwargs['opt'] + temperature = opt.get('temperature', 1) # This should not affect beam search, but will affect dbs + beam_size = opt.get('beam_size', 10) + diversity_lambda = opt.get('diversity_lambda', 0.5) + decoding_constraint = opt.get('decoding_constraint', 0) + remove_bad_endings = opt.get('remove_bad_endings', 0) + suppress_UNK = opt.get('suppress_UNK', 0) + length_penalty = utils.penalty_builder(opt.get('length_penalty', '')) + + alpha = opt['pragmatic_incremental_alpha'] + + per_image_dim = (num_distractors+1) + batch_size = init_logprobs.shape[0] // per_image_dim + assert args[0].size(0) == batch_size * beam_size * per_image_dim + + V = init_logprobs.size(-1) + + device = init_logprobs.device + + # INITIALIZATIONS + beam_seq_table = torch.LongTensor(batch_size, beam_size, 0).to(device) + beam_seq_logprobs_table = torch.FloatTensor(batch_size, beam_size, 0, self.vocab_size + 1).to(device) + beam_logprobs_sum_table = torch.zeros(batch_size, beam_size).to(device) + + log_priors = torch.full((batch_size, per_image_dim, beam_size), 1.0 / per_image_dim).log().to(device) + + # logprobs # logprobs predicted in last time step, shape (beam_size*per_image_dim, vocab_size+1) + done_beams_table = [[] for _ in range(batch_size)] + state_table = [_.clone() for _ in init_state] + # logprobs_table = list(init_logprobs.reshape(batch_size * bdash, group_size, -1).chunk(group_size, 0)) + + assert init_logprobs.size() == (batch_size*per_image_dim, V) + logprobs_table = init_logprobs.clone().view(batch_size, per_image_dim, 1, V).expand(batch_size, per_image_dim, beam_size, V) + # END INIT + + # Chunk elements in the args + args = list(args) + if self.__class__.__name__ == 'AttEnsemble': + raise NotImplementedError() + + for t in range(self.seq_length): + # log_s0: batch_size*(beam_size if t > 0 else 1) x V + assert logprobs_table.size() == (batch_size, per_image_dim, beam_size, V) + # state + this_state_beam_size = 1 if t == 0 else beam_size + # add vocab dimension. batch_size x per_image_dim x beam_size x V + log_s0 = logprobs_table + log_priors_expanded = log_priors.unsqueeze(-1).expand_as(log_s0) + log_l0 = (log_s0 + log_priors_expanded).log_softmax(1) + log_s1 = (log_s0 + (log_l0 * alpha)).log_softmax(3) + + # TODO: should this use s1? + # batch_size x per_image_dim x beam_size x V + if opt['pragmatic_incremental_l1_uses'] == 's0': + s_to_use = log_s0 + elif opt['pragmatic_incremental_l1_uses'] == 's1': + s_to_use = log_s1 + else: + raise ValueError("invalid --pragmatic_incremental_l1_uses {}".format(opt['pragmatic_incremental_l1_uses'])) + + log_l1 = (log_l0 + s_to_use).log_softmax(1) + + logprobs = log_s1[:,0] + if t == 0: + if beam_size > 1: + assert torch.allclose(logprobs[:,0], logprobs[:,1]) + logprobs = logprobs[:,0] + logprobs = logprobs.contiguous().view(-1, V) + + # suppress previous word + if decoding_constraint and t > 0: + raise NotImplementedError() + # logprobs.scatter_(1, beam_seq_table[:, :, t-1].reshape(-1, 1).to(device), float('-inf')) + if remove_bad_endings and t > 0: + raise NotImplementedError() + # logprobs[torch.from_numpy(np.isin(beam_seq_table[:, :, t-1].cpu().numpy(), self.bad_endings_ix)).reshape(-1), 0] = float('-inf') + # suppress UNK tokens in the decoding + if suppress_UNK and hasattr(self, 'vocab') and self.vocab[str(logprobs.size(1)-1)] == 'UNK': + logprobs[:,logprobs.size(1)-1] = logprobs[:, logprobs.size(1)-1] - 1000 + # diversity is added here + # the function directly modifies the logprobs values and hence, we need to return + # the unaugmented ones for sorting the candidates in the end. # for historical + # reasons :-) + unaug_logprobs = logprobs.clone() + + state_table_rearranged = tuple( + einops.rearrange( + t, "x (batch_size this_state_beam_size per_image_dim) d -> x (batch_size this_state_beam_size) per_image_dim d", + batch_size=batch_size, + per_image_dim=per_image_dim, + this_state_beam_size=this_state_beam_size, + ) + for t in state_table + ) + + # infer new beams + beam_seq_table, \ + beam_seq_logprobs_table, \ + beam_logprobs_sum_table, \ + state_table_rearranged, \ + log_priors = beam_step(logprobs, + unaug_logprobs, + beam_size, + t, + beam_seq_table, + beam_seq_logprobs_table, + beam_logprobs_sum_table, + state_table_rearranged, log_l1=log_l1) + # don't use this_state_beam_size because in t0 the beam expands to the full beam_size (from 1) + state_table = tuple( + einops.rearrange( + t, "x (batch_size beam_size) per_image_dim d -> x (batch_size beam_size per_image_dim) d", + batch_size=batch_size, + beam_size=beam_size, + ) + for t in state_table_rearranged + ) + + # if time's up... or if end token is reached then copy beams + for b in range(batch_size): + is_end = beam_seq_table[b, :, t] == self.eos_idx + assert beam_seq_table.shape[-1] == t+1 + if t == self.seq_length - 1: + is_end.fill_(1) + for vix in range(beam_size): + if is_end[vix]: + final_beam = { + 'seq': beam_seq_table[b, vix].clone(), + 'logps': beam_seq_logprobs_table[b, vix].clone(), + # not sure what this would be used for: seems to be sum over even unused tokens + 'unaug_p': beam_seq_logprobs_table[b, vix].sum().item(), + 'p': beam_logprobs_sum_table[b, vix].item(), + # same as p but more descriptively named and we can backprop through it + 'log_prob': beam_logprobs_sum_table[b, vix].clone(), + } + final_beam['p'] = length_penalty(t+1, final_beam['p']) + done_beams_table[b].append(final_beam) + beam_logprobs_sum_table[b, is_end] -= 1000 + + # move the current group one step forward in time + + it = beam_seq_table[:, :, t] + it = it.unsqueeze(2).expand(-1, -1, per_image_dim) + it = it.reshape(-1).to(logprobs.device) + + logprobs_table, state_table = self.get_logprobs_state(it, *(args + [state_table])) + logprobs_table = F.log_softmax(logprobs_table / temperature, dim=-1) + + # TODO: see if we can make state and args consistent so that we don't need to do this + logprobs_table = einops.rearrange( + logprobs_table, + "(batch_size beam_size per_image_dim) v -> batch_size per_image_dim beam_size v", + batch_size=batch_size, + beam_size=beam_size, + per_image_dim=per_image_dim + ) + + # all beams are sorted by their log-probabilities + done_beams = [sorted(done_beams_table[b], key=lambda x: -x['p'])[:beam_size] for b in range(batch_size)] + return done_beams + def old_beam_search(self, init_state, init_logprobs, *args, **kwargs): # function computes the similarity score to be augmented diff --git a/captioning/models/utils.py b/captioning/models/utils.py index feb130bc..0ee6aeca 100644 --- a/captioning/models/utils.py +++ b/captioning/models/utils.py @@ -13,7 +13,6 @@ def repeat_tensors(n, x): x = [repeat_tensors(n, _) for _ in x] return x - def split_tensors(n, x): if torch.is_tensor(x): assert x.shape[0] % n == 0 @@ -22,4 +21,14 @@ def split_tensors(n, x): x = [split_tensors(n, _) for _ in x] elif x is None: x = [None] * n + return x + +def split_tensors_no_transpose_no_unbind(first_dim, x): + if torch.is_tensor(x): + assert x.shape[0] % first_dim == 0 + x = x.reshape(first_dim, x.shape[0] // first_dim, *x.shape[1:]) + elif type(x) is list or type(x) is tuple: + x = [split_tensors_no_transpose_no_unbind(first_dim, _) for _ in x] + elif x is None: + x = [None] * first_dim return x \ No newline at end of file diff --git a/captioning/modules/distractor_scorer.py b/captioning/modules/distractor_scorer.py new file mode 100644 index 00000000..7784c9c2 --- /dev/null +++ b/captioning/modules/distractor_scorer.py @@ -0,0 +1,21 @@ +import torch +from torch import nn + +class DistractorScorer(torch.nn.Module): + def __init__(self, opt): + super(DistractorScorer, self).__init__() + self.opt = opt + hidden_size = opt.pragmatic_distractor_scoring_hidden_size + self.scorer = nn.Sequential(nn.Linear(opt.fc_feat_size*2, hidden_size), + nn.ReLU(), + nn.Dropout(opt.drop_prob_lm), + nn.Linear(hidden_size, 1)) + + def forward(self, fc_feats_target, fc_feats_distr): + # fc_feats_target: batch_size x 1 x d + # fc_feats_distr: batch_size x n_distractors x d + cat_feats = torch.cat((fc_feats_target.expand_as(fc_feats_distr), fc_feats_distr), -1) + # batch_size x n_distractors x 1 + scores = self.scorer(cat_feats) + log_probs = scores.squeeze(-1).log_softmax(-1) + return log_probs \ No newline at end of file diff --git a/captioning/modules/loss_wrapper.py b/captioning/modules/loss_wrapper.py index c22926a4..bcad9876 100644 --- a/captioning/modules/loss_wrapper.py +++ b/captioning/modules/loss_wrapper.py @@ -2,6 +2,13 @@ from . import losses from ..utils.rewards import init_scorer, get_self_critical_reward +import einops + +def combine_first_two(tensor): + if tensor is None: + return None + return tensor.view((tensor.size(0) * tensor.size(1),) + tensor.size()[2:]) + class LossWrapper(torch.nn.Module): def __init__(self, model, opt): super(LossWrapper, self).__init__() @@ -15,7 +22,7 @@ def __init__(self, model, opt): self.struc_crit = losses.StructureLosses(opt) def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices, - sc_flag, struc_flag): + sc_flag, struc_flag, contrastive_flag, ids=None): opt = self.opt out = {} @@ -41,6 +48,111 @@ def forward(self, fc_feats, att_feats, labels, masks, att_masks, gts, gt_indices out['lm_loss'] = lm_loss out['struc_loss'] = struc_loss['loss'] out['reward'] = struc_loss['reward'] + elif contrastive_flag: + # fc_feats: batch_size x (n_distractors+1) x d + # att_feats: batch_size x (n_distractors+1) x n_obj x d + # labels: batch_size x (n_distractors+1) x num_captions x T + # att_masks: batch_size x (n_distractors+1) x n_obj + alpha = opt.pragmatic_incremental_alpha + batch_size, per_image_dim, _ = fc_feats.size() + batch_size_, per_image_dim_, num_captions, T = labels.size() + assert batch_size == batch_size_ and per_image_dim == per_image_dim_ + + if opt.pragmatic_distractor_candidate_type in ['closest', 'random']: + num_distractors = opt.pragmatic_distractors + elif opt.pragmatic_distractor_candidate_type == 'batch': + num_distractors = batch_size - 1 + else: + raise ValueError( + f"invalid --pragmatic_distractor_candidate_type {opt.pragmatic_distractor_candidate_type}" + ) + assert num_distractors+1 == per_image_dim + + fc_feats_target, fc_feats_distr = fc_feats.split((1, num_distractors), dim=1) + att_feats_target, att_feats_distr = att_feats.split((1, num_distractors), dim=1) + att_masks_target, att_masks_distr = att_masks.split((1, num_distractors), dim=1) + + labels_from_target = labels[:,0] + masks_from_target = masks[:,0] + + labels_replace_distractor_target = labels.clone() + masks_replace_distractor_target = masks.clone() + + labels_replace_distractor_target[:,1:] = labels_from_target.unsqueeze(1).expand_as(labels_replace_distractor_target[:,1:]) + masks_replace_distractor_target[:,1:] = masks_from_target.unsqueeze(1).expand_as(masks_replace_distractor_target[:,1:]) + + outs = self.model(combine_first_two(fc_feats), combine_first_two(att_feats), + combine_first_two(labels_replace_distractor_target[..., :-1]), combine_first_two(att_masks)) + # batch_size x (n_distractors+1) x num_captions x T-1 x V + outs = outs.view(batch_size, per_image_dim, num_captions, T-1, -1) + V = outs.size(-1) + + num_choices = 2 + + outs_target, outs_distractors = outs.split((1, num_distractors), dim=1) + outs_target = outs_target.expand_as(outs_distractors) + + # batch_size x (n_distractors) x num_captions x num_choices x T-1 x V + outs_comparative = torch.stack((outs_target, outs_distractors), 3) + assert outs_comparative.size(-2) == T-1 + + labels_from_target_expanded = labels_from_target.unsqueeze(1).unsqueeze(3).unsqueeze(4)\ + .expand(batch_size, num_distractors, num_captions, num_choices, 1, T) + + masks_from_target_expanded = masks_from_target.unsqueeze(1).unsqueeze(3) \ + .expand(batch_size, num_distractors, num_captions, num_choices, T) + + def select_label(tensor, t): + return tensor.gather(-1, labels_from_target_expanded[...,t]).squeeze(-1) + + # batch_size x (n_distractors) x num_captions x 2 + log_priors = torch.full(outs_comparative.size()[:-2], 1./outs_comparative.size(3)).log().to(outs_comparative) + + log_s1_sums = torch.zeros_like(log_priors) + # batch_size x (n_distractors) x num_captions x num_choices + word_counts = torch.zeros_like(log_s1_sums) + + for t in range(T-1): + # batch_size x (n_distractors) x num_captions x num_choices x V + log_s0 = outs_comparative[...,t,:] + log_priors_expanded = log_priors.unsqueeze(-1).expand_as(log_s0) + log_l0 = (log_s0 + log_priors_expanded).log_softmax(3) + log_s1 = (log_s0 + (log_l0 * alpha)).log_softmax(4) + log_s1_chosen = select_label(log_s1, t+1) + this_mask = masks_from_target_expanded[...,t+1] + log_s1_sums += (log_s1_chosen * this_mask) + word_counts += this_mask + if opt.pragmatic_incremental_l1_uses == 's0': + s_to_use = log_s0 + elif opt.pragmatic_incremental_l1_uses == 's1': + s_to_use = log_s1 + else: + raise ValueError("invalid --pragmatic_incremental_l1_uses {}".format(opt.pragmatic_incremental_l1_uses)) + log_l1 = (select_label(log_l0, t+1) + select_label(s_to_use, t+1)).log_softmax(3) + log_priors = log_l1 + + # log p(c | i, i') for caption c, target image i and distractor i' + # batch_size x n_distractors x num_captions + target_log_seq_s1 = log_s1_sums[...,0] + + # TODO: a model that incorporates object features too + # log p(i' | i) for target image i and distractor i' + # batch_size x n_distractors + distractor_log_probs = self.model.distractor_log_probs(fc_feats_target, fc_feats_distr) + + # log p(c, i' | i)) + # batch_size x n_distractors x num_captions + joint_log_s1 = target_log_seq_s1 + distractor_log_probs.unsqueeze(-1).expand_as(target_log_seq_s1) + + # batch_size x num_captions + if opt.contrastive_em == 'hard': + obj_log_s1 = joint_log_s1.max(1).values + elif opt.contrastive_em == 'soft': + obj_log_s1 = joint_log_s1.sum(1) + + # TODO: this is a bit weird: scaling to be similar to the loss used in non-contrastive training, but + # will hopefully prevent having to mess with the LRs too much + loss = -(obj_log_s1.sum()) / word_counts[:,0,:,0].sum() elif not sc_flag: loss = self.crit(self.model(fc_feats, att_feats, labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:]) else: diff --git a/captioning/utils/eval_multi.py b/captioning/utils/eval_multi.py index 83907410..1c55571b 100644 --- a/captioning/utils/eval_multi.py +++ b/captioning/utils/eval_multi.py @@ -14,10 +14,12 @@ import os import sys from . import misc as utils -from eval_utils import getCOCO +from .eval_utils import getCOCO from .div_utils import compute_div_n, compute_global_div_n +SPICE_THREADS=4 + import sys try: sys.path.append("coco-caption") @@ -85,7 +87,7 @@ def eval_oracle(dataset, preds_n, model_id, split): json.dump(preds, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API... cocoRes = coco.loadRes(cache_path) - cocoEval = COCOEvalCap(coco, cocoRes) + cocoEval = COCOEvalCap(coco, cocoRes, spice_threads=SPICE_THREADS) cocoEval.params['image_id'] = cocoRes.getImgIds() cocoEval.evaluate() diff --git a/captioning/utils/eval_utils.py b/captioning/utils/eval_utils.py index c4bc7f44..72dab6b0 100644 --- a/captioning/utils/eval_utils.py +++ b/captioning/utils/eval_utils.py @@ -5,7 +5,12 @@ import torch import torch.nn as nn import torch.nn.functional as F +import einops +from torch.nn.utils.rnn import pad_sequence + +import itertools +import tqdm import numpy as np import json from json import encoder @@ -17,16 +22,23 @@ from . import misc as utils # load coco-caption if available +from ..data.dataloader import DataLoader +from ..models import AttModel + +SPICE_THREADS=4 + try: sys.path.append("coco-caption") from pycocotools.coco import COCO from pycocoevalcap.eval import COCOEvalCap -except: +except Exception as e: + print(e) print('Warning: coco-caption not available') bad_endings = ['a','an','the','in','for','at','of','with','before','after','on','upon','near','to','is','are','am'] bad_endings += ['the'] +PAD_ID = 0 def count_bad(sen): sen = sen.split(' ') @@ -82,7 +94,8 @@ def language_eval(dataset, preds, preds_n, eval_kwargs, split): json.dump(preds_filt, open(cache_path, 'w')) # serialize to temporary json file. Sigh, COCO API... cocoRes = coco.loadRes(cache_path) - cocoEval = COCOEvalCap(coco, cocoRes) + cocoEval = COCOEvalCap(coco, cocoRes, spice_threads=SPICE_THREADS, + scorers_to_run=['bleu', 'meteor', 'rouge', 'cider', 'wmd']) cocoEval.params['image_id'] = cocoRes.getImgIds() cocoEval.evaluate() @@ -93,30 +106,30 @@ def language_eval(dataset, preds, preds_n, eval_kwargs, split): out['entropy'] = mean_entropy imgToEval = cocoEval.imgToEval - for k in list(imgToEval.values())[0]['SPICE'].keys(): - if k != 'All': - out['SPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEval.values()]) - out['SPICE_'+k] = (out['SPICE_'+k][out['SPICE_'+k]==out['SPICE_'+k]]).mean() + # for k in list(imgToEval.values())[0]['SPICE'].keys(): + # if k != 'All': + # out['SPICE_'+k] = np.array([v['SPICE'][k]['f'] for v in imgToEval.values()]) + # out['SPICE_'+k] = (out['SPICE_'+k][out['SPICE_'+k]==out['SPICE_'+k]]).mean() for p in preds_filt: image_id, caption = p['image_id'], p['caption'] imgToEval[image_id]['caption'] = caption - if len(preds_n) > 0: - from . import eval_multi - cache_path_n = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '_n.json') - allspice = eval_multi.eval_allspice(dataset, preds_n, model_id, split) - out.update(allspice['overall']) - div_stats = eval_multi.eval_div_stats(dataset, preds_n, model_id, split) - out.update(div_stats['overall']) - if eval_oracle: - oracle = eval_multi.eval_oracle(dataset, preds_n, model_id, split) - out.update(oracle['overall']) - else: - oracle = None - self_cider = eval_multi.eval_self_cider(dataset, preds_n, model_id, split) - out.update(self_cider['overall']) - with open(cache_path_n, 'w') as outfile: - json.dump({'allspice': allspice, 'div_stats': div_stats, 'oracle': oracle, 'self_cider': self_cider}, outfile) + # if len(preds_n) > 0: + # from . import eval_multi + # cache_path_n = os.path.join('eval_results/', '.cache_'+ model_id + '_' + split + '_n.json') + # allspice = eval_multi.eval_allspice(dataset, preds_n, model_id, split) + # out.update(allspice['overall']) + # div_stats = eval_multi.eval_div_stats(dataset, preds_n, model_id, split) + # out.update(div_stats['overall']) + # if eval_oracle: + # oracle = eval_multi.eval_oracle(dataset, preds_n, model_id, split) + # out.update(oracle['overall']) + # else: + # oracle = None + # self_cider = eval_multi.eval_self_cider(dataset, preds_n, model_id, split) + # out.update(self_cider['overall']) + # with open(cache_path_n, 'w') as outfile: + # json.dump({'allspice': allspice, 'div_stats': div_stats, 'oracle': oracle, 'self_cider': self_cider}, outfile) out['bad_count_rate'] = sum([count_bad(_['caption']) for _ in preds_filt]) / float(len(preds_filt)) outfile_path = os.path.join('eval_results/', model_id + '_' + split + '.json') @@ -125,8 +138,215 @@ def language_eval(dataset, preds, preds_n, eval_kwargs, split): return out +def generate_pragmatic(model: AttModel, loader: DataLoader, fc_feats, att_feats, att_masks, data, eval_kwargs): + # generate candidate utterances + keep_all_scores = eval_kwargs.get('pragmatic_serialize_all_scores', 0) + input_data = fc_feats, att_feats, att_masks, data + n_imgs = fc_feats.size(0) + n_predictions, seqs, log_probs = generate_caption_candidates( + model, input_data, eval_kwargs, loader=loader, + ) + # seqs: n_images x n_captions x T + # log_probs: n_images x n_captions + n_imgs_, n_captions = log_probs.size() + assert n_imgs == n_imgs_, (n_imgs, n_imgs_) + s0_weight = eval_kwargs['pragmatic_s0_weight'] + if not (0.0 <= s0_weight <= 1.0): + raise ValueError(f"s0_weight {s0_weight} not in [0, 1]") + all_s0_scores_s = [] + all_s1_scores_s = [] + all_s0s1_scores_s = [] + target_s0_scores_s = [] + target_s1_scores_s = [] + target_s0s1_scores_s = [] + best_seqs = [] + best_scores = [] + device = fc_feats.device + image_context_paths_s = [] + image_context_ids_s = [] + candidates_s = [] + k_neighbors = eval_kwargs['pragmatic_distractors'] + nearest_neighbor_index = loader.indices[eval_kwargs['pragmatic_distractor_split']] + all_neighbors = nearest_neighbor_index.get_neighbor_batch( + loader, fc_feats.cpu().numpy(), k_neighbors=k_neighbors, + include_self=True, self_indices=[data['infos'][img_ix]['ix'] for img_ix in range(n_imgs)], + neighbor_type=eval_kwargs['pragmatic_distractor_candidate_type'] + ) + for img_ix in range(n_imgs): + neighbor_infos = all_neighbors['infos'][img_ix*(k_neighbors+1):(img_ix+1)*(k_neighbors+1)] + assert len(neighbor_infos) == k_neighbors+1 + # n_captions x (1 + pragmatic_distractors) + # [:,0]: scores for the captions for the target image + s0_scores = model.cross_product_scores( + all_neighbors['fc_feats'][img_ix].to(device), + all_neighbors['att_feats'][img_ix].to(device), + all_neighbors['att_masks'][img_ix].to(device) if all_neighbors['att_masks'] is not None else None, + seqs[img_ix] + ) + candidates_s.append(utils.decode_sequence(model.vocab, seqs[img_ix])) + # n_captions + l1_scores = s0_scores.log_softmax(dim=1) + s1_scores = l1_scores.log_softmax(dim=0) + s0s1_scores = s0_scores * s0_weight + s1_scores * (1.0 - s0_weight) + + target_s0s1_scores = s0s1_scores[:,0] + + image_context_paths_s.append([d['file_path'] for d in neighbor_infos]) + image_context_ids_s.append([d['id'] for d in neighbor_infos]) + + all_s0_scores_s.append(s0_scores.detach().cpu().numpy()) + all_s1_scores_s.append(s1_scores.detach().cpu().numpy()) + all_s0s1_scores_s.append(s0s1_scores.detach().cpu().numpy()) + target_s0_scores_s.append(s0_scores[:,0].detach().cpu().numpy()) + target_s1_scores_s.append(s1_scores[:,0].detach().cpu().numpy()) + target_s0s1_scores_s.append(target_s0s1_scores.detach().cpu().numpy()) + + best_score, best_ix = target_s0s1_scores.max(-1) + best_scores.append(best_score) + best_seqs.append(seqs[img_ix][best_ix]) + seq = pad_sequence(best_seqs, batch_first=True, padding_value=PAD_ID) + scores = torch.stack(best_scores, -1) + entropy = torch.zeros_like(scores) + perplexity = torch.zeros_like(scores) + extras = { + 'target_s0_scores': target_s0_scores_s, + 'target_s1_scores': target_s1_scores_s, + 'target_s0s1_scores': target_s0s1_scores_s, + 'chosen_target_s0s1_scores': scores.detach().cpu().numpy(), + 'candidates': candidates_s, + 'context_paths': image_context_paths_s, + 'context_ids': image_context_ids_s, + } + if keep_all_scores: + extras.update({ + 'all_s0_scores': all_s0_scores_s, + 'all_s1_scores': all_s1_scores_s, + 'all_s0s1_scores': all_s0s1_scores_s, + }) + return seq, entropy, perplexity, extras + +def search_distractors(s0_cap_by_img_score_mat, num_distractors_to_choose, s0_weight): + if not (0.0 <= s0_weight <= 1.0): + raise ValueError(f"s0_weight {s0_weight} must be in [0, 1]") + n_cap, n_img = s0_cap_by_img_score_mat.size() + best_distractors = None + best_score = None + best_cap = None + for distractors in itertools.combinations(range(1, n_img), num_distractors_to_choose): + img_indices = list(itertools.chain((0,), distractors)) + sub_mat = s0_cap_by_img_score_mat[:,img_indices] + l1 = sub_mat.log_softmax(1) + s1 = l1.log_softmax(0) + target_s0_scores = s0_cap_by_img_score_mat[:,0] + target_s1_scores = s1[:,0] + target_s0s1_scores = s0_weight * target_s0_scores + (1 - s0_weight) * target_s1_scores + this_best_score, this_best_cap = target_s0s1_scores.max(-1) + if best_score is None or this_best_score > best_score: + best_distractors = distractors + best_score = this_best_score + best_cap = this_best_cap + return best_cap, best_distractors, best_score + +def pragmatic_choose_from_candidates(instance_candidates: dict, eval_kwargs): + device = eval_kwargs.get('device', 'cuda') + # instance_candidates: candidate captions and scores for a single image + prediction = {} + for key in ['image_id', 'candidates', 'perplexity', 'entropy', 'context_paths', 'context_ids']: + prediction[key] = instance_candidates[key] + candidate_captions = instance_candidates['candidates'] + nonempty_indices, nonempty_captions = zip(*[(ix, cap) for ix, cap in enumerate(candidate_captions) if cap]) + nonempty_indices = torch.tensor(nonempty_indices).long() + if eval_kwargs['pragmatic_inference']: + assert not eval_kwargs['mbr_inference'] + num_distractors = eval_kwargs['pragmatic_distractors'] + s0_weight = eval_kwargs['pragmatic_s0_weight'] + # n_captions x n_images + + # target image has index 0; indices 1-end are for distractor images + s0_scores = torch.tensor(instance_candidates['all_s0_scores']) + if num_distractors >= s0_scores.size(1): + raise ValueError(f"not enough distractors in serialized candidates. {num_distractors} required; {s0_scores.size(1) - 1} available") + else: + s0_scores = s0_scores[:,:num_distractors+1] + s0_scores = s0_scores[nonempty_indices] + s0_scores = s0_scores.to(device) + distractor_type = eval_kwargs.get('pragmatic_distractor_type', 'closest') + if distractor_type == 'closest': + # use all distractors + num_to_choose = num_distractors + elif distractor_type == 'choose_within_closest': + num_to_choose = eval_kwargs.get('pragmatic_distractors_to_choose', 1) + else: + raise NotImplementedError(f"invalid --pragmatic_distractor_type {distractor_type}") + best_cap, best_distractors, best_scores = search_distractors(s0_scores, num_to_choose, s0_weight) + + caption = nonempty_captions[best_cap] + else: + raise NotImplementedError() + prediction['caption'] = caption + return prediction + +def mbr_choose_from_candidates(instance_candidates: dict, eval_kwargs, sent_rep_model, sent_rep_tokenizer): + device = eval_kwargs.get('device', 'cuda') + mbr_type = eval_kwargs.get('mbr_type', 'bert_cosine_sim') + if mbr_type != 'bert_cosine_sim': + raise NotImplementedError(f"--mbr_type: {mbr_type}") + s0_weight = eval_kwargs['mbr_s0_weight'] + if not (0.0 <= s0_weight <= 1.0): + raise ValueError(f"--mbr_s0_weight {s0_weight} must be in [0, 1]") + + prediction = {} + for key in ['image_id', 'candidates', 'perplexity', 'entropy', 'context_paths', 'context_ids']: + prediction[key] = instance_candidates[key] + candidate_captions = instance_candidates['candidates'] + nonempty_indices, nonempty_captions = zip(*[(ix, cap) for ix, cap in enumerate(candidate_captions) if cap]) + nonempty_captions = list(nonempty_captions) + nonempty_indices = torch.tensor(nonempty_indices).long() + + s0_scores = torch.tensor(instance_candidates['target_s0_scores']) + s0_scores = s0_scores[nonempty_indices] + inputs = sent_rep_tokenizer(nonempty_captions, return_tensors='pt', padding=True) + inputs = {k: v.to(device) for k, v in inputs.items()} + outputs = sent_rep_model(**inputs) + h = outputs['last_hidden_state'][:,0] + h_rescaled = h / h.norm(2, dim=-1, keepdim=True) + cosine_sim = torch.einsum("xh,yh->xy", (h_rescaled, h_rescaled)) + mbr_scores = cosine_sim.mean(-1).log_softmax(-1) + mbr_scores = mbr_scores.to(s0_scores.device) + joint_scores = s0_weight * s0_scores + (1 - s0_weight) * mbr_scores + best_score, best_cap = joint_scores.max(-1) + caption = nonempty_captions[best_cap] + prediction['caption'] = caption + return prediction + + +def eval_split_from_serialized(path, eval_kwargs={}): + device = eval_kwargs.get('device', 'cuda') + pragmatic_inference = eval_kwargs.get('pragmatic_inference', 0) + mbr_inference = eval_kwargs.get('mbr_inference', 0) + + if mbr_inference and pragmatic_inference: + raise ValueError("can't do both --pragmatic_inference and --mbr_inference") + + if mbr_inference: + from transformers import AutoTokenizer, AutoModel + tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased") + model = AutoModel.from_pretrained("bert-large-uncased") + model = model.to(device) + + candidates = torch.load(path) + predictions = [] + for instance_candidates in tqdm.tqdm(candidates, ncols=80): + if mbr_inference: + prediction = mbr_choose_from_candidates(instance_candidates, eval_kwargs, model, tokenizer) + else: + prediction = pragmatic_choose_from_candidates(instance_candidates, eval_kwargs) + predictions.append(prediction) + return predictions + def eval_split(model, crit, loader, eval_kwargs={}): verbose = eval_kwargs.get('verbose', True) + verbose_captions = eval_kwargs.get('verbose_captions', 0) verbose_beam = eval_kwargs.get('verbose_beam', 0) verbose_loss = eval_kwargs.get('verbose_loss', 1) num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1)) @@ -139,6 +359,9 @@ def eval_split(model, crit, loader, eval_kwargs={}): os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # Use this nasty way to make other code clean since it's a global configuration device = eval_kwargs.get('device', 'cuda') + pragmatic_inference = eval_kwargs.get('pragmatic_inference', 0) + contrastive = eval_kwargs['sample_method'] == 'contrastive_beam_search' + # Make sure in the evaluation mode model.eval() @@ -150,6 +373,7 @@ def eval_split(model, crit, loader, eval_kwargs={}): loss_evals = 1e-8 predictions = [] n_predictions = [] # when sample_n > 1 + verbose_predictions = [] while True: data = loader.get_batch(split) n = n + len(data['infos']) @@ -164,14 +388,20 @@ def eval_split(model, crit, loader, eval_kwargs={}): loss_sum = loss_sum + loss loss_evals = loss_evals + 1 - # forward the model to also get generated samples for each image with torch.no_grad(): tmp_eval_kwargs = eval_kwargs.copy() - tmp_eval_kwargs.update({'sample_n': 1}) - seq, seq_logprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') - seq = seq.data - entropy = - (F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1) - perplexity = - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1) + if pragmatic_inference: + seq, entropy, perplexity, extras = generate_pragmatic(model, loader, fc_feats, att_feats, att_masks, data, tmp_eval_kwargs) + seq = seq.data + else: + tmp_eval_kwargs.update({'sample_n': 1}) + # forward the model to also get generated samples for each image + seq, seq_logprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample', + loader=loader, data=data) + seq = seq.data + extras = {} + entropy = - (F.softmax(seq_logprobs, dim=2) * seq_logprobs).sum(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1) + perplexity = - seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(2).sum(1) / ((seq>0).to(seq_logprobs).sum(1)+1) # Print beam search if beam_size > 1 and verbose_beam: @@ -184,18 +414,33 @@ def eval_split(model, crit, loader, eval_kwargs={}): entry = {'image_id': data['infos'][k]['id'], 'caption': sent, 'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()} if eval_kwargs.get('dump_path', 0) == 1: entry['file_name'] = data['infos'][k]['file_path'] + verbose_entry = entry.copy() + if contrastive: + # DataParallel wrapper doesn't make attrs accessible + if isinstance(model, torch.nn.DataParallel): + underlying_model = model.module + else: + underlying_model = model + neighbor_infos = underlying_model.neighbor_infos[k] + verbose_entry['context_paths'] = [d['file_path'] for d in neighbor_infos] + verbose_entry['context_ids'] = [d['id'] for d in neighbor_infos] + if extras: + for key, value in extras.items(): + assert len(value) == len(sents) + verbose_entry[key] = value[k] predictions.append(entry) + verbose_predictions.append(verbose_entry) if eval_kwargs.get('dump_images', 0) == 1: # dump the raw image to vis/ folder cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + '" vis/imgs/img' + str(len(predictions)) + '.jpg' # bit gross print(cmd) os.system(cmd) - if verbose: + if verbose_captions: print('image %s: %s' %(entry['image_id'], entry['caption'])) if sample_n > 1: - eval_split_n(model, n_predictions, [fc_feats, att_feats, att_masks, data], eval_kwargs) + eval_split_n(model, n_predictions, [fc_feats, att_feats, att_masks, data], eval_kwargs, loader=loader) # ix0 = data['bounds']['it_pos_now'] ix1 = data['bounds']['it_max'] @@ -207,7 +452,7 @@ def eval_split(model, crit, loader, eval_kwargs={}): predictions.pop() if verbose: - print('evaluating validation preformance... %d/%d (%f)' %(n, ix1, loss)) + print('evaluating validation performance... %d/%d (%f)' %(n, ix1, loss)) if num_images >= 0 and n >= num_images: break @@ -217,7 +462,13 @@ def eval_split(model, crit, loader, eval_kwargs={}): n_predictions = sorted(n_predictions, key=lambda x: x['perplexity']) if not os.path.isdir('eval_results'): os.mkdir('eval_results') - torch.save((predictions, n_predictions), os.path.join('eval_results/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth')) + pred_fn = os.path.join('eval_results/', '.saved_pred_'+ eval_kwargs['id'] + '_' + split + '.pth') + print(f'saving to {pred_fn}') + torch.save((predictions, n_predictions), pred_fn) + verbose_pred_fname = os.path.join('eval_results/', f"pred_verbose_{eval_kwargs['id']}_{split}.pth") + if eval_kwargs.get('save_verbose_predictions', 0): + print(f"saving verbose predictions to {verbose_pred_fname}") + torch.save(verbose_predictions, verbose_pred_fname) if lang_eval == 1: lang_stats = language_eval(dataset, predictions, n_predictions, eval_kwargs, split) @@ -226,56 +477,92 @@ def eval_split(model, crit, loader, eval_kwargs={}): return loss_sum/loss_evals, predictions, lang_stats +def eval_split_n(model, n_predictions, input_data, eval_kwargs={}, loader=None): + new_predictions, seqs, log_probs = generate_caption_candidates(model, input_data, eval_kwargs, loader=loader) + n_predictions.extend(new_predictions) + # Only run when sample_n > 0 -def eval_split_n(model, n_predictions, input_data, eval_kwargs={}): - verbose = eval_kwargs.get('verbose', True) - beam_size = eval_kwargs.get('beam_size', 1) +def generate_caption_candidates(model, input_data, eval_kwargs={}, loader=None): + n_predictions = [] + verbose_captions = eval_kwargs.get('verbose_captions', 0) + # beam_size = eval_kwargs.get('beam_size', 1) sample_n = eval_kwargs.get('sample_n', 1) sample_n_method = eval_kwargs.get('sample_n_method', 'sample') fc_feats, att_feats, att_masks, data = input_data + n_imgs = fc_feats.size(0) + tmp_eval_kwargs = eval_kwargs.copy() - if sample_n_method == 'bs': + if sample_n_method in ['bs', 'contrastive_bs']: # case 1 sample_n == beam size + contrastive = sample_n_method == 'contrastive_bs' + if contrastive: + tmp_eval_kwargs['sample_method'] = 'contrastive_beam_search' tmp_eval_kwargs.update({'sample_n': 1, 'beam_size': sample_n, 'group_size': 1}) # randomness from softmax with torch.no_grad(): - model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') + model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample', data=data, loader=loader) + seqs = [] + log_probs = [] for k in range(fc_feats.shape[0]): - _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(sample_n)])) - for sent in _sents: - entry = {'image_id': data['infos'][k]['id'], 'caption': sent} + beams = [model.done_beams[k][_]['seq'] for _ in range(sample_n)] + stacked_beams = pad_sequence(beams, batch_first=True, padding_value=0) + seqs.extend(beams) + _log_prob = torch.stack([model.done_beams[k][i]['log_prob'] for i in range(sample_n)]).flatten() + log_probs.append(_log_prob) + _sents = utils.decode_sequence(model.vocab, stacked_beams) + for sent_ix, sent in enumerate(_sents): + entry = {'image_id': data['infos'][k]['id'], 'caption': sent, 'log_prob': _log_prob[sent_ix].item()} + if contrastive: + neighbor_infos = model.neighbor_infos[sent_ix] + entry['context_paths'] = [d['file_path'] for d in neighbor_infos] + entry['context_ids'] = [d['id'] for d in neighbor_infos] n_predictions.append(entry) + seqs = pad_sequence(seqs, batch_first=True, padding_value=0) + log_probs = torch.cat(log_probs, 0) # case 2 sample / gumbel / topk sampling/ nucleus sampling elif sample_n_method == 'sample' or \ sample_n_method == 'gumbel' or \ sample_n_method.startswith('top'): tmp_eval_kwargs.update({'sample_n': sample_n, 'sample_method': sample_n_method, 'beam_size': 1}) # randomness from sample with torch.no_grad(): - _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') + _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample', + loader=loader, data=data) _sents = utils.decode_sequence(model.vocab, _seq) - _perplexity = - _sampleLogprobs.gather(2, _seq.unsqueeze(2)).squeeze(2).sum(1) / ((_seq>0).to(_sampleLogprobs).sum(1)+1) + _log_prob = _sampleLogprobs.gather(2, _seq.unsqueeze(2)).squeeze(2).sum(1) + _perplexity = - _log_prob / ((_seq>0).to(_sampleLogprobs).sum(1)+1) for k, sent in enumerate(_sents): - entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent, 'perplexity': _perplexity[k].item()} + entry = {'image_id': data['infos'][k // sample_n]['id'], 'seq': _seq[k], 'caption': sent, 'perplexity': _perplexity[k].item(), 'log_prob': _log_prob[k].item()} n_predictions.append(entry) + seqs = _seq + log_probs = _log_prob elif sample_n_method == 'dbs': # Use diverse beam search + raise NotImplementedError("set seqs to be the returned candidates (a batch_size*sample_n x T array) and log_probs to bbe log probabilities (batch_size*sample_n)") tmp_eval_kwargs.update({'beam_size': sample_n * beam_size, 'group_size': sample_n}) # randomness from softmax with torch.no_grad(): - model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') + model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample', loader=loader, data=data) for k in range(loader.batch_size): _sents = utils.decode_sequence(model.vocab, torch.stack([model.done_beams[k][_]['seq'] for _ in range(0, sample_n*beam_size, beam_size)])) for sent in _sents: entry = {'image_id': data['infos'][k]['id'], 'caption': sent} n_predictions.append(entry) - else: + elif sample_n_method in ['dgreedy', 'dsample', 'dtopk', 'dtopp']: + raise NotImplementedError("set log_probs to bbe log probabilities (batch_size*sample_n)") tmp_eval_kwargs.update({'sample_method': sample_n_method[1:], 'group_size': sample_n, 'beam_size':1}) # randomness from softmax with torch.no_grad(): - _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample') + _seq, _sampleLogprobs = model(fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample', + loader=loader, data=data) _sents = utils.decode_sequence(model.vocab, _seq) for k, sent in enumerate(_sents): entry = {'image_id': data['infos'][k // sample_n]['id'], 'caption': sent} n_predictions.append(entry) - if verbose: + seqs = _seq + else: + raise ValueError(f"invalid sample_n_method {sample_n_method}") + if verbose_captions: for entry in sorted(n_predictions[-fc_feats.shape[0] * sample_n:], key=lambda x: x['image_id']): - print('image %s: %s' %(entry['image_id'], entry['caption'])) \ No newline at end of file + print('image %s: %s' %(entry['image_id'], entry['caption'])) + seqs = einops.rearrange(seqs, "(n_imgs n_caps) T -> n_imgs n_caps T", n_imgs=n_imgs, n_caps=sample_n) + log_probs = einops.rearrange(log_probs, "(n_imgs n_caps) -> n_imgs n_caps", n_imgs=n_imgs, n_caps=sample_n) + return n_predictions, seqs, log_probs diff --git a/captioning/utils/misc.py b/captioning/utils/misc.py index 8bd3193d..f6c2c6fe 100644 --- a/captioning/utils/misc.py +++ b/captioning/utils/misc.py @@ -8,12 +8,15 @@ import numpy as np import torch.optim as optim import os +import sys +import subprocess import torch.nn.functional as F import six from six.moves import cPickle + bad_endings = ['with','in','on','of','a','at','to','for','an','this','his','her','that'] bad_endings += ['the'] @@ -247,4 +250,9 @@ def get_std_opt(model, optim_func='adam', factor=1, warmup=2000): optim_func = dict(adam=torch.optim.Adam, adamw=torch.optim.AdamW)[optim_func] return NoamOpt(model.d_model, factor, warmup, - optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) \ No newline at end of file + optim_func(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9)) + +def dump_git_status(out_file=sys.stdout, exclude_file_patterns=['*.ipynb', '*.th', '*.sh', '*.txt', '*.json']): + subprocess.call('git rev-parse HEAD', shell=True, stdout=out_file) + exclude_string = ' '.join("':(exclude){}'".format(f) for f in exclude_file_patterns) + subprocess.call('git --no-pager diff -- . {}'.format(exclude_string), shell=True, stdout=out_file) diff --git a/captioning/utils/opts.py b/captioning/utils/opts.py index b97adddc..79d63f6f 100644 --- a/captioning/utils/opts.py +++ b/captioning/utils/opts.py @@ -164,7 +164,6 @@ def parse_opt(): parser.add_argument('--train_only', type=int, default=0, help='if true then use 80k, else use 110k') - # Reward parser.add_argument('--cider_reward_weight', type=float, default=1, help='The reward weight from cider') @@ -193,6 +192,10 @@ def parse_opt(): parser.add_argument('--train_beam_size', type=int, default=1, help='') + # contrastive loss + parser.add_argument('--contrastive_after', type=int, default=-1) + parser.add_argument('--contrastive_em', choices=['soft', 'hard'], default='hard') + # Used for self critical parser.add_argument('--sc_sample_method', type=str, default='greedy', help='') @@ -203,6 +206,10 @@ def parse_opt(): # For diversity evaluation during training add_diversity_opts(parser) + add_loader_options(parser) + + add_pragmatics_opts(parser) + # config parser.add_argument('--cfg', type=str, default=None, @@ -261,6 +268,11 @@ def parse_opt(): return args +def add_loader_options(parser): + parser.add_argument('--num_workers', type=int, default=4, + help='number of pytorch workers (will have k + 1 total processes)') + parser.add_argument('--max_images_per_split', type=int, help='limit to this many images in each split: train / val / test') + def add_eval_options(parser): # Basic options parser.add_argument('--batch_size', type=int, default=0, @@ -302,10 +314,14 @@ def add_eval_options(parser): # misc parser.add_argument('--id', type=str, default='', help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files') - parser.add_argument('--verbose_beam', type=int, default=1, + parser.add_argument('--verbose', type=int, default=1, + help='if we need to print out all beam search beams.') + parser.add_argument('--verbose_beam', type=int, default=1, help='if we need to print out all beam search beams.') parser.add_argument('--verbose_loss', type=int, default=0, help='If calculate loss using ground truth during evaluation') + parser.add_argument('--verbose_captions', type=int, default=0, + help='print candidate captions in evaluation') def add_diversity_opts(parser): parser.add_argument('--sample_n', type=int, default=1, @@ -342,6 +358,33 @@ def add_eval_sample_opts(parser): help='Not predicting UNK') +def add_pragmatics_opts(parser: argparse.ArgumentParser): + parser.add_argument('--index_serialization_root_path', default='data/cocobu_indices') + parser.add_argument('--pragmatic_inference', type=int, default=0, + help='') + parser.add_argument('--pragmatic_distractors', type=int, default=5, + help='number of distractor images to use (in addition to the target image)') + parser.add_argument('--pragmatic_distractor_split', choices=['train', 'val'], default='train') + parser.add_argument('--pragmatic_s0_weight', type=float, default=0.0, + help='lambda in lambda * log p_s0 + (1 - lambda) * log p_s1') + parser.add_argument('--pragmatic_serialize_all_scores', type=int, default=0, help='') + parser.add_argument('--pragmatic_distractor_type', choices=['closest', 'choose_within_closest'], default='closest') + parser.add_argument('--pragmatic_distractors_to_choose', type=int, default=1) + parser.add_argument('--pragmatic_incremental_alpha', type=float, default=1.0) + parser.add_argument('--pragmatic_incremental_l1_uses', choices=['s0', 's1'], default='s0') + + parser.add_argument('--pragmatic_distractor_candidate_type', + choices=['closest', 'batch', 'random'], + default='closest') + + parser.add_argument('--pragmatic_distractor_scoring', choices=['uniform', 'mlp'], default='uniform') + parser.add_argument('--pragmatic_distractor_scoring_hidden_size', type=int, default=200) + +def add_mbr_opts(parser: argparse.ArgumentParser): + parser.add_argument('--mbr_inference', type=int, default=0, help='') + parser.add_argument('--mbr_type', choices=['bert_cosine_sim'], default='bert_cosine_sim') + parser.add_argument('--mbr_s0_weight', type=float, default=0.0) + if __name__ == '__main__': import sys sys.argv = [sys.argv[0]] diff --git a/eval_literal.sh b/eval_literal.sh new file mode 100755 index 00000000..825f5713 --- /dev/null +++ b/eval_literal.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +model_dir="models/updown" + +split=$1 +beam_size=$2 + +id="literal_${split}_bs-${beam_size}_dl-0.0" + +python -u tools/eval.py \ + --id $id \ + --force 1 \ + --dump_images 0 \ + --verbose_captions 0 \ + --verbose_beam 0 \ + --save_verbose_predictions 1 \ + --split $split \ + --num_images 5000 \ + --model ${model_dir}/model-best.pth \ + --infos_path ${model_dir}/infos_tds-best.pkl \ + --language_eval 1 \ + --beam_size $beam_size \ + --diversity_lambda 0.0 \ + | tee expts/${id} diff --git a/eval_pragmatic.sh b/eval_pragmatic.sh new file mode 100755 index 00000000..e98cee5d --- /dev/null +++ b/eval_pragmatic.sh @@ -0,0 +1,47 @@ +model_dir="models/updown" + +# val or test +split=$1 + +# number of candidate captions to rescore +n_candidates=$2 + +# weight lambda in log p_s0 * lambda + log p_s1 * (1 - lambda) +s0_weight=$3 + +# number of additional images +n_distractors=$4 +if [ -z $n_distractors ] +then + n_distractors=5 +else + shift +fi + +shift 3; + +distractor_split="train" + +candidate_gen="bs" + +id="pragmatic_${split}_cand-${candidate_gen}-${n_candidates}_s0-weight-${s0_weight}" + +python -u tools/eval.py \ + --id $id \ + --force 1 \ + --dump_images 0 \ + --verbose_captions 0 \ + --verbose_beam 0 \ + --save_verbose_predictions 1 \ + --split $split \ + --num_images 5000 \ + --model ${model_dir}/model-best.pth \ + --infos_path ${model_dir}/infos_tds-best.pkl \ + --language_eval 1 \ + --pragmatic_inference 1 \ + --pragmatic_distractors $n_distractors \ + --pragmatic_distractor_split $distractor_split \ + --pragmatic_s0_weight $s0_weight \ + --sample_n_method $candidate_gen \ + --sample_n $n_candidates \ + | tee expts/${id} diff --git a/notebooks/.gitignore b/notebooks/.gitignore new file mode 100644 index 00000000..87620ac7 --- /dev/null +++ b/notebooks/.gitignore @@ -0,0 +1 @@ +.ipynb_checkpoints/ diff --git a/notebooks/nearest-neighbors.ipynb b/notebooks/nearest-neighbors.ipynb new file mode 100644 index 00000000..47622f8a --- /dev/null +++ b/notebooks/nearest-neighbors.ipynb @@ -0,0 +1,1103 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.core.display import display, HTML\n", + "display(HTML(\"\"))beam" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/dfried/projects/ImageCaptioning.pytorch\n" + ] + } + ], + "source": [ + "cd /home/dfried/projects/ImageCaptioning.pytorch" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import sys" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "sys.path.append(\"/home/dfried/projects/ImageCaptioning.pytorch\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "import einops" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import pickle" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "import argparse\n", + "import captioning.utils.opts as opts\n", + "import captioning.utils.misc as utils\n", + "import captioning.models as models\n", + "from captioning.utils import eval_utils" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from captioning.data.dataloader import DataLoader\n", + "from captioning.data.dataloaderraw import DataLoaderRaw" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "parser = argparse.ArgumentParser()\n", + "opts.add_eval_options(parser)\n", + "opts.add_diversity_opts(parser)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "opt = parser.parse_args([])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Namespace(batch_size=0, beam_size=1, block_trigrams=0, coco_json='', decoding_constraint=0, diversity_lambda=0.5, dump_images=1, dump_json=1, dump_path=0, eval_oracle=1, group_size=1, id='', image_folder='', image_root='', input_att_dir='', input_box_dir='', input_fc_dir='', input_json='', input_label_h5='', language_eval=0, length_penalty='', max_length=20, num_images=-1, remove_bad_endings=0, sample_method='greedy', sample_n=1, sample_n_method='sample', split='test', suppress_UNK=1, temperature=1.0, verbose_beam=1, verbose_loss=0)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "opt" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "model_fname = 'models/updown/model-best.pth'\n", + "infos_fname = 'models/updown/infos_tds-best.pkl'" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['vocab', 'opt', 'best_val_score', 'iter', 'iterators', 'epoch', 'split_ix'])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "with open(infos_fname, 'rb') as f:\n", + " infos = utils.pickle_load(f)\n", + "infos.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id']\n", + "ignore = ['start_from']\n", + "\n", + "for k in vars(infos['opt']).keys():\n", + " if k in replace:\n", + " setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, ''))\n", + " elif k not in ignore:\n", + " if not k in vars(opt):\n", + " vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "opt.vocab = infos['vocab']" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "model = models.setup(opt)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "del opt.vocab" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "CUDA = True" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "UpDownModel(\n", + " (embed): Sequential(\n", + " (0): Embedding(9488, 1000)\n", + " (1): ReLU()\n", + " (2): Dropout(p=0.5, inplace=False)\n", + " )\n", + " (fc_embed): Sequential(\n", + " (0): Linear(in_features=2048, out_features=1000, bias=True)\n", + " (1): ReLU()\n", + " (2): Dropout(p=0.5, inplace=False)\n", + " )\n", + " (att_embed): Sequential(\n", + " (0): Linear(in_features=2048, out_features=1000, bias=True)\n", + " (1): ReLU()\n", + " (2): Dropout(p=0.5, inplace=False)\n", + " )\n", + " (logit): Linear(in_features=1000, out_features=9488, bias=True)\n", + " (ctx2att): Linear(in_features=1000, out_features=512, bias=True)\n", + " (core): UpDownCore(\n", + " (att_lstm): LSTMCell(3000, 1000)\n", + " (lang_lstm): LSTMCell(2000, 1000)\n", + " (attention): Attention(\n", + " (h2att): Linear(in_features=1000, out_features=512, bias=True)\n", + " (alpha_net): Linear(in_features=512, out_features=1, bias=True)\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.load_state_dict(torch.load(model_fname, map_location='cpu'))\n", + "if CUDA:\n", + " model.cuda()\n", + "model.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataLoader loading json file: data/cocotalk.json\n", + "vocab size is 9487\n", + "DataLoader loading h5 file: data/cocobu_fc data/cocobu_att data/cocotalk_box data/cocotalk_label.h5\n", + "max sequence length in data is 16\n", + "read 123287 image features\n", + "assigned 113287 images to split train\n", + "assigned 5000 images to split val\n", + "assigned 5000 images to split test\n" + ] + } + ], + "source": [ + "loader = DataLoader(opt, shuffle_override=False, wrap_override=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "vocab = loader.get_vocab()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "PAD_ID = 0" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "REGENERATE = False" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "split = 'train'" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "if REGENERATE:\n", + " ixs = []\n", + " ids = []\n", + " file_paths = []\n", + " features = []\n", + " captions = []\n", + "\n", + " loader.reset_iterator(split)\n", + "\n", + " for batch_ix in tqdm.trange(len(loader.loaders[split])):\n", + " data = loader.get_batch(split)\n", + " infos = data['infos']\n", + " ixs.extend([d['ix'] for d in infos])\n", + " ids.extend([d['id'] for d in infos])\n", + " file_paths.extend([d['file_path'] for d in infos])\n", + " features.append(data['fc_feats'])\n", + "\n", + " batch_captions = []\n", + " for batch_labels in data['labels']:\n", + " instance_captions = []\n", + " for img_labels in batch_labels:\n", + " caption = [vocab[str(ix.item())] for ix in img_labels if ix != 0]\n", + " instance_captions.append(caption)\n", + " batch_captions.append(instance_captions)\n", + " captions.extend(batch_captions)\n", + " features_array = torch.cat(features, 0).numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "def get_captions_from_batch(data):\n", + " batch_captions = []\n", + " for batch_labels in data['labels']:\n", + " instance_captions = []\n", + " for img_labels in batch_labels:\n", + " caption = [vocab[str(ix.item())] for ix in img_labels if ix != 0]\n", + " instance_captions.append(caption)\n", + " batch_captions.append(instance_captions)\n", + " return batch_captions" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "fname = 'data/cocobu_fc/all_{}.pkl'.format(split)\n", + "if REGENERATE:\n", + " with open(fname, 'wb') as f:\n", + " d = {\n", + " 'ixs': ixs,\n", + " 'ids': ids,\n", + " 'file_paths': file_paths,\n", + " 'features': features_array,\n", + " 'captions': captions,\n", + " }\n", + " pickle.dump(d, f)\n", + "else:\n", + " with open(fname, 'rb') as f:\n", + " d = pickle.load(f)\n", + " ixs, ids, file_paths, features_array, captions = d['ixs'], d['ids'], d['file_paths'], d['features'], d['captions']" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "import faiss" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "index = faiss.IndexFlatL2(2048)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "index.add(features_array)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "113287" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "index.ntotal" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "def wrap_tag(tag, inner):\n", + " return f'<{tag}>{inner}'\n", + "\n", + "def image_html(image_path, width=300, border=False):\n", + " if border:\n", + " style = ' style=\"border: 5px solid #0FF\" '\n", + " else:\n", + " style = ''\n", + " return f''\n", + "\n", + "def captions_html(captions):\n", + " #return wrap_tag('p', '
'.join(' '.join(cap) for cap in captions))\n", + " return wrap_tag('ol', ''.join(wrap_tag('li', cap) for cap in captions))\n", + "\n", + "def images_html(image_paths, width=300, num_per_row=5, target=None, captions=None):\n", + " rows = []\n", + " for ix in range(0, len(image_paths), num_per_row):\n", + " items = [wrap_tag('td', image_html(image_paths[image_ix], width=width, border=image_ix == target)) \n", + " for image_ix in range(ix, ix+num_per_row) if image_ix < len(image_paths)]\n", + " rows.append(wrap_tag('tr', ''.join(items)))\n", + " if captions is not None:\n", + " cap_html = [\n", + " wrap_tag('td', captions_html(captions[image_ix]))\n", + " for image_ix in range(ix, ix+num_per_row)\n", + " if image_ix < len(image_paths)\n", + " ]\n", + " rows.append(wrap_tag('tr', ''.join(cap_html)))\n", + " return wrap_tag('table', ''.join(rows))\n", + "\n", + "def display_images(image_paths, width=300, num_per_row=5, target=None, captions=None):\n", + " display(HTML(images_html(image_paths, width=width, num_per_row=num_per_row, target=target, captions=captions)))" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "def get_neighbor_batch(img_fc_feat, k, include_self=False, self_ix=None):\n", + " assert len(img_fc_feat.shape) == 1, \"should be the features for a single image\"\n", + " D, I = index.search(img_fc_feat[None], k)\n", + " n_images, k_ = D.shape\n", + " assert n_images == 1\n", + " indices = []\n", + " if include_self:\n", + " assert self_ix is not None\n", + " indices.append(self_ix)\n", + " indices.extend([ixs[i] for i in I.flatten()])\n", + " data = [loader.dataset[ix, 0, False] for ix in indices]\n", + " batch = loader.dataset.collate_func(data, 'train')\n", + " return batch" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "def display_neighbors(features, k=5, num_per_row=5):\n", + " neighbor_batch = get_neighbor_batch(features.flatten(), k)\n", + " paths_k = [d['file_path'] for d in neighbor_batch['infos']]\n", + " captions_k = [[' '.join(c) for c in cs] for cs in get_captions_from_batch(neighbor_batch)]\n", + " display_images(paths_k, captions=captions_k, num_per_row=num_per_row)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
  1. a young boy standing in front of a computer keyboard
  2. a little boy wearing headphones and looking at a computer monitor
  3. he is listening intently to the computer at school
  4. a young boy stares up at the computer monitor
  5. a young kid with head phones on using a computer
  1. a boy wearing headphones using one computer in a long row of computers
  2. a little boy with earphones on listening to something
  3. a group of people sitting at desk using computers
  4. children sitting at computer stations on a long table
  5. a small child wearing headphones plays on the computer
  1. some small children playing on laptop games excited
  2. two children work at a desk on laptop computers
  3. two young ladies work on laptops on a white counter top
  4. a little girl sitting in front of a laptop computer
  5. two girls who are sitting in front of laptops
  1. a man with glasses sitting at a desktop computer
  2. two men at a computer playing game with headphones on
  3. two men are wearing headphones and playing a computer game
  4. two men are in the dark by a laptop computer
  5. dark haired man playing a video game on computer
  1. a man with a hoodie and headphones on in front of a computer
  2. a person sitting in front of a laptop computer wearing glasses
  3. a man with headphones sitting at a desk looking at a computer
  4. a young man in a red sweatshirt is on the computer
  5. a man is sitting at the computer desk with a laptop on it
  1. people at a work bench table with laptops and other electronic equipment
  2. two people on computers amongst a table full of debris
  3. three people are working on two laptops
  4. hands are at work at a table repairing laptop computers
  5. a group of people sitting around a pair of laptops
  1. a man holding a smart phone while standing next to a credit card reader
  2. a man looking at something in his hands
  3. a young man is at a workstation with a phone
  4. an image of man that is looking at his cellphone
  5. a young man is using a cell phone near electronics
  1. two men sitting around a laptop looking at the screen
  2. a man at a laptop with another looking on at his screen
  3. two men stare intently at a computer screen while one works at the keyboard
  4. two men at a desk working with a laptop computer
  5. two people looking at a laptop on a desk
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_neighbors(features_array[1], k=8, num_per_row=4)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "# this_pred_captions = []\n", + "# for sent in seq:\n", + "# this_pred_captions.append([\n", + "# [model.vocab.get(str(ix.item()), 'IX_{}'.format(ix.item())) for ix in sent if ix.item() != 0]\n", + "# ])" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [], + "source": [ + "from itertools import groupby" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.core.debugger import set_trace" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [], + "source": [ + "def caption_scores(fc_feats, att_feats, att_masks, seq, add_bos=True):\n", + " if add_bos:\n", + " seq = torch.cat([torch.zeros(seq.size(0), 1).long().to(seq), seq], 1)\n", + " with torch.no_grad():\n", + " scores = model(fc_feats, att_feats, seq, att_masks)\n", + " mask = (seq[:,:-1] > 0) | (seq[:,1:] > 0)\n", + " # TODO: does this include the EOS score?\n", + " # seq_t: words input at each position, with 0 for pad\n", + " # seq_t: 0, w_0, w_1, w_2, ..., w_k, 0, ...\n", + " # mask : 1, 1 , 1 , 1 , 1 , 1 , 0, ...\n", + " # selected_scores: w_0 + ... + w_k\n", + "# return scores[:,:-1].gather(2, seq[:,1:].unsqueeze(2)).squeeze(2)\n", + " selected_scores = (scores[:,:-1].gather(2, seq[:,1:].unsqueeze(2)).squeeze(2) * mask)\n", + " return selected_scores" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": {}, + "outputs": [], + "source": [ + "def cross_product_scores(fc_feats, att_feats, att_masks, seq, add_bos=True):\n", + " n_captions = seq.size(0)\n", + " n_images = fc_feats.size(0)\n", + " assert att_feats.size(0) == n_images\n", + " assert att_masks.size(0) == n_images\n", + " fc_feats_t = fc_feats.unsqueeze(0).repeat_interleave(n_captions, dim=0)\n", + " att_feats_t = att_feats.unsqueeze(0).repeat_interleave(n_captions, dim=0)\n", + " att_masks_t = att_masks.unsqueeze(0).repeat_interleave(n_captions, dim=0)\n", + " \n", + " fc_feats_t = einops.rearrange(fc_feats_t, 'caps imgs d -> (caps imgs) d')\n", + " att_feats_t = einops.rearrange(att_feats_t, 'caps imgs obj d -> (caps imgs) obj d')\n", + " att_masks_t = einops.rearrange(att_masks_t, 'caps imgs obj -> (caps imgs) obj')\n", + " seq_t = seq.unsqueeze(1).repeat_interleave(n_images, dim=1)\n", + " seq_t = einops.rearrange(seq_t, 'caps imgs d -> (caps imgs) d')\n", + " \n", + " scores_per_timestep = caption_scores(fc_feats_t, att_feats_t, att_masks_t, seq_t, add_bos=add_bos)\n", + " scores = scores_per_timestep.sum(1)\n", + " scores = einops.rearrange(scores, '(caps imgs) -> caps imgs', caps=n_captions, imgs=n_images)\n", + " return scores" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Automatic pdb calling has been turned ON\n" + ] + } + ], + "source": [ + "pdb on" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "loader.reset_iterator('val')" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [], + "source": [ + "data = loader.get_batch('val')" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [], + "source": [ + "this_feats = data['fc_feats']\n", + "this_ixs = [d['ix'] for d in data['infos']]\n", + "this_ids = [d['id'] for d in data['infos']]\n", + "this_paths = [d['file_path'] for d in data['infos']]" + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": {}, + "outputs": [], + "source": [ + "# set_trace()\n", + "sample_n = 10\n", + "input_data = data['fc_feats'].cuda(), data['att_feats'].cuda(), data['att_masks'].cuda(), data\n", + "n_predictions = []\n", + "eval_kwargs = {\n", + " 'sample_n_method': 'bs',\n", + " 'sample_n': sample_n,\n", + " 'temperature': 0.25,\n", + " 'verbose': False,\n", + "}\n", + "eval_utils.eval_split_n(model, n_predictions, input_data=input_data, eval_kwargs=eval_kwargs)\n", + "captions_by_id = {}\n", + "log_prob_by_id = {}\n", + "seq_by_id = {}\n", + "for k, ds in groupby(n_predictions, lambda d: d['image_id']):\n", + " ds = list(ds)\n", + " captions_by_id[k] = [d['caption'] for d in ds]\n", + " log_prob_by_id[k] = [d['log_prob'] for d in ds]\n", + " seq_by_id[k] = [d['seq'] for d in ds]\n", + "this_captions = [\n", + " captions_by_id[id_] for id_ in this_ids\n", + "]\n", + "this_log_probs = [\n", + " log_prob_by_id[id_] for id_ in this_ids\n", + "]\n", + "this_seq = [\n", + " torch.stack(seq_by_id[id_], 0) for id_ in this_ids\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
  1. a red and white airplane flying in the sky -5.3846
  2. a red and white plane flying in the sky -5.6425
  3. a red and white airplane flying in the air -6.2955
  4. a red and white airplane flying through a blue sky -6.4906
  5. a red and white airplane flying in a blue sky -6.7558
  6. a red and white plane flying in a blue sky -6.8799
  7. the red and white airplane is flying in the sky -7.9610
  8. there is a red and white plane flying in the sky -8.8055
  9. an airplane flying in the sky with a red tail -12.0166
  10. an airplane is flying in the sky with a red tail -12.4322
  1. a red and white airplane flying in a blue sky -2.2066
  2. a red and white airplane flying in the sky -2.2307
  3. there is a red and white plane flying in the sky -2.2577
  4. the red and white airplane is flying in the sky -2.2766
  5. an airplane flying in the sky with a red tail -2.2893
  6. an airplane is flying in the sky with a red tail -2.3098
  7. a red and white plane flying in a blue sky -2.3132
  8. a red and white airplane flying in the air -2.3209
  9. a red and white airplane flying through a blue sky -2.4151
  10. a red and white plane flying in the sky -2.4287
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
  1. a white and black plane flying with sky in background
  2. a small passenger plane flying on a sunny day
  3. a propeller airplane flying under a cloudy blue sky
  4. looking under an airplane as it flies in the air
  5. the airplane is flying near the clouds in the sky
  1. a person is flying a biplane in the sky
  2. an old biplane that is flying low for the crowd
  3. a red and white plane flying through a cloudy sky
  4. a propeller plane that is flying in the sky
  5. white and red biplane flying through the air
  1. a bed and white propeller plane flying through a blue sky
  2. a plane in the middle of the air its a propeller plane
  3. the old plane has been recently painted in red and white
  4. a very small red and white air plane
  5. a plane flying in the air on a clear day
  1. a yellow and blue biplane flying through the ski
  2. a small single engine plane in flying in the sky
  3. a blue red and white airplane is flying
  4. a small biplane flies through the blue sky
  5. an old airplane flies through the sky on a nice day
  1. an airplane flying through the air on a clear day
  2. the airplane is flying high in the clear blue sky
  3. a white airplane with a single propeller and double stacked wings
  4. a red white and blue airplane flying in the sky
  5. a small aircraft flying low on a clear sky
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a city street with lots of tall buildings -7.8097
  2. a city street with a large building in the background -10.1655
  3. a city street with a tall building in the background -10.6663
  4. a city street with a large city in the background -10.7126
  5. a city street with tall buildings and a large city -11.6744
  6. a city street with cars and a tall building -12.1900
  7. a city street with tall buildings and a large building -12.2099
  8. a city street with tall buildings and a large building in the background -12.6306
  9. a city street with cars and people walking on the sidewalk -13.4368
  10. a city street with a tall building and a large city -15.2348
  1. a city street with tall buildings and a large building in the background -1.9513
  2. a city street with a tall building and a large city -1.9794
  3. a city street with tall buildings and a large city -2.0439
  4. a city street with tall buildings and a large building -2.1185
  5. a city street with a large city in the background -2.2861
  6. a city street with a tall building in the background -2.4150
  7. a city street with a large building in the background -2.5029
  8. a city street with lots of tall buildings -2.5392
  9. a city street with cars and a tall building -2.5657
  10. a city street with cars and people walking on the sidewalk -3.1856
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
  1. heavy traffic in a city with a UNK bank building
  2. a busy downtown street is filled with cars waiting to move
  3. a green light is shown on this busy multi lane street
  4. there are many tall buildings and cars in this city
  5. a large city with UNK buildings and a green light
  1. a red and yellow double decker bus on street next to trees
  2. a double deck tour bus riding down a street through a traffic light
  3. a large red bus driving down the street
  4. a double deck bus that is driving down the road
  5. a red double decker bus driving down a street
  1. a man riding down the street in a horse and carriage
  2. a group of cars parked in a lot
  3. a horse drawn carriage comes down the street on a clear day
  4. a horse drawn carriage on a city street
  5. the cars are sharing the busy road with the horse
  1. traffic and people are standing in the downtown area
  2. a busy 2 way downtown intersection in the city
  3. there is a very tall tower with a clock on it on this street
  4. a busy semi busy street with three yellow taxis going down it
  5. a tall clock tower on a city street near buildings
  1. a transit bus moves through a crowded street
  2. the buses are lined up on the busy street
  3. view of down town in a city and traffic driving on the opposite side of the
  4. two public transit buses on a city street
  5. a picture of an outdoor area that seems great
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a bus with a advertisement on the side of it -9.9457
  2. a truck with a mural on the side of it -10.0262
  3. a bus with a mural on the side of it -10.1005
  4. a blue truck with a advertisement on the side of it -10.2522
  5. a blue bus with a advertisement on the side of it -10.5892
  6. a blue truck with a sign on the side of it -11.0073
  7. a bus with a large advertisement on the side of it -11.2638
  8. a blue truck with a blue and white bus -14.9023
  9. a blue truck with a blue and white bus on the side -16.5830
  10. a blue truck with a blue and white bus on the side of it -17.4545
  1. a blue truck with a advertisement on the side of it -1.9789
  2. a blue truck with a blue and white bus on the side of it -2.0041
  3. a blue truck with a blue and white bus on the side -2.0087
  4. a blue truck with a blue and white bus -2.1525
  5. a truck with a mural on the side of it -2.2120
  6. a blue truck with a sign on the side of it -2.2145
  7. a blue bus with a advertisement on the side of it -2.3565
  8. a bus with a mural on the side of it -2.7477
  9. a bus with a large advertisement on the side of it -2.9925
  10. a bus with a advertisement on the side of it -3.0165
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
  1. a bus that is sitting in the street with its door open
  2. a bus traveling down the street in a city
  3. the political tour bus serves as home base during the UNK
  4. a bus decorated for a presidential political campaign
  5. a bus for a politician driving down the street
  1. a large truck driving down a busy street filled with traffic
  2. a refrigerated semi truck drives on a rode beside a smaller car
  3. the semi truck and car turn the corner close to each other
  4. small car rides behind a large semi truck with a large bed
  5. a car and truck are navigating a turn together
  1. a long yellow bus advertising a musical play
  2. a yellow bus with a lion king ad on it
  3. a large yellow bus with pictures of people in lion costumes and the words the lion
  4. a bus is covered in an advertisement for a broadway show
  5. a bus with advertisement painted on the side
  1. a black bus on street with flags and buildings in background
  2. a black bus driving down the road in the middle of the city
  3. the team bus is parked at the building
  4. the large tour bus is painted with a dog mascot
  5. a large black truck driving down a city street
  1. a vehicle pulls up next to a building
  2. the large blue truck is parked at the curb
  3. a truck that is driving on the road
  4. a large blue tow truck sitting on the side of a road
  5. a truck is on the city street
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a couple of giraffe standing next to each other -5.7969
  2. a couple of giraffe standing on top of a lush green field -6.3289
  3. a couple of giraffes are standing in a field -6.6013
  4. a couple of giraffe standing next to each other on a field -6.6032
  5. two giraffes are standing in a grassy field -6.6232
  6. two giraffes standing in the grass near trees -7.4514
  7. two giraffes standing in a field of grass -8.3326
  8. two giraffes standing in a grassy field next to trees -8.4241
  9. two giraffes standing in a field with trees -8.5482
  10. two giraffes standing in a grassy field with trees -8.6169
  1. two giraffes standing in a grassy field with trees -1.8652
  2. two giraffes standing in the grass near trees -1.9938
  3. two giraffes are standing in a grassy field -2.0416
  4. two giraffes standing in a grassy field next to trees -2.0757
  5. two giraffes standing in a field of grass -2.2234
  6. two giraffes standing in a field with trees -2.2874
  7. a couple of giraffe standing on top of a lush green field -2.4703
  8. a couple of giraffes are standing in a field -2.8429
  9. a couple of giraffe standing next to each other -2.9213
  10. a couple of giraffe standing next to each other on a field -3.0531
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
  1. a group of zebras on a grassy area next to trees
  2. a group of giraffes stand next in a field by a tree
  3. a group of giraffes in a field near trees
  4. a group of four giraffes standing next to each other
  5. a herd of giraffe walking across a field
  1. a herd of giraffe standing next to each other near a forested hillside
  2. a group of giraffes are standing together on the UNK
  3. an image of a herd of giraffes in a plain
  4. half a dozen healthy giraffes hanging out in a field
  5. a group of giraffes faces and stare in the same direction while one of the faces
  1. the giraffes stood together next to the bush
  2. two giraffes standing in open field with trees
  3. two giraffes are heading towards trees for leaves
  4. one giraffe is behind another giraffe on the grass
  5. two giraffes stand right next to each other
  1. two giraffes walking through a spacious grassy field
  2. two giraffes that are walking together in a field
  3. two giraffes next to one another near a rock
  4. two giraffes standing next to each other under a group of trees
  5. two giraffes standing on all fours next to one another with grass bushes and trees around
  1. several giraffes are standing on the short grass
  2. group of giraffes standing in grass lands
  3. adult giraffe surrounded by three younger giraffe in the wild
  4. the three giraffes are standing in the grassy field together
  5. some giraffes in a field with trees in the background
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a giraffe is laying down in the dirt -5.5802
  2. the giraffe is laying down in the dirt -6.8027
  3. a giraffe laying on the ground next to a building -7.1025
  4. a giraffe laying down in the dirt near a building -7.6570
  5. a giraffe laying down in the dirt next to a building -8.2647
  6. a giraffe laying on the ground in a zoo -8.3065
  7. a giraffe laying on the ground in a dirt -8.6407
  8. a giraffe laying down in the dirt in a zoo -8.9004
  9. a giraffe laying down on the ground in a zoo -9.0528
  10. a giraffe laying down on the ground in a dirt -9.4864
  1. a giraffe laying down in the dirt near a building -2.1925
  2. a giraffe laying down in the dirt next to a building -2.1933
  3. a giraffe laying on the ground next to a building -2.1935
  4. a giraffe laying down on the ground in a dirt -2.3129
  5. the giraffe is laying down in the dirt -2.3140
  6. a giraffe laying down on the ground in a zoo -2.3565
  7. a giraffe laying on the ground in a dirt -2.3627
  8. a giraffe laying down in the dirt in a zoo -2.3633
  9. a giraffe laying on the ground in a zoo -2.3737
  10. a giraffe is laying down in the dirt -2.3938
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
  1. four giraffes out in a field of some sort
  2. three large giraffes with one small standing next to a building
  3. a group of giraffes in a large enclosure next to a stone barn
  4. three adult and one baby giraffe standing outside
  5. a stone barn at a zoo with four giraffe standing around
  1. a giraffe kneeling down on the ground next to a tree
  2. a giraffe on its knees in the sand
  3. a giraffe whose knees a UNK beneath itself
  4. a giraffe kneeling down on its front legs on sand in a fenced in area
  5. a giraffe in a fenced area down on its knees
  1. several giraffes are on display in a zoo exhibit
  2. four giraffes standing and lounging in an enclosure
  3. the giraffes are standing in the sand beside a fence
  4. four giraffes standing and sitting in an enclosure
  5. a herd of giraffe standing around an enclosure at a zoo
  1. giraffes standing in a dirt lot near a pool of water
  2. giraffes in a zoo enclosure with a stone wall
  3. two giraffes and a rhino in an enclosure
  4. two giraffes look over a fence in a zoo
  5. giraffes in an enclosure stand together by the water
  1. a giraffe standing next to another animal in a field
  2. a giraffe and a deer standing near a ravine
  3. a giraffe doing an odd pose in a field in front of a forest
  4. a giraffe with its back legs spread while it leans forward
  5. a giraffe is posing for the camera
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a couple of giraffe standing next to each other -6.3707
  2. a couple of giraffe standing on top of a lush green field -6.4391
  3. a couple of giraffes are standing in a field -7.0572
  4. a couple of giraffe standing next to each other on a lush green field -7.4226
  5. two giraffes standing in the grass near trees -8.1003
  6. two giraffes standing in a grassy area next to trees -8.6406
  7. two giraffes standing in a field next to trees -9.1155
  8. two giraffes standing in a field with trees -9.2221
  9. two giraffes standing in a grassy field next to trees -9.4280
  10. two giraffes standing in a field of grass and trees -9.8691
  1. two giraffes standing in a grassy area next to trees -1.8909
  2. two giraffes standing in the grass near trees -1.9922
  3. two giraffes standing in a field of grass and trees -2.0097
  4. two giraffes standing in a field with trees -2.1959
  5. two giraffes standing in a grassy field next to trees -2.1976
  6. two giraffes standing in a field next to trees -2.2443
  7. a couple of giraffe standing on top of a lush green field -2.4730
  8. a couple of giraffe standing next to each other on a lush green field -2.6140
  9. a couple of giraffes are standing in a field -2.8668
  10. a couple of giraffe standing next to each other -3.3280
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
  1. a group of giraffe standing next to each other in an open field
  2. a herd of giraffe standing around a large tree stump
  3. a group of giraffes are nibbling on a large tree trunk
  4. group of giraffes outside standing around a stump
  5. a group of giraffes standing around a short UNK tree
  1. three giraffes are standing together surrounded by trees and shrubbery
  2. a group of three giraffe in the wilderness
  3. a group of giraffes foraging among the grass
  4. three giraffes gathered together in their own habitat
  5. giraffes standing next to each other near a forest
  1. three giraffes standing in grass with their heads in a tree
  2. a group of giraffes standing in front of a tree
  3. three giraffes who are eating from a large tree
  4. the three giraffes are standing by the tree
  5. a family of three giraffes is standing under a big tree
  1. a group of giraffe standing on top of a field
  2. one adult giraffe and two kid giraffes standing in the woods
  3. giraffes in the wild under trees on a sunny day
  4. the adult giraffe is in the field feeding with the two offspring
  5. three giraffes standing in the grass among trees and bushes
  1. two giraffe in a wooded area with an orange fence
  2. two giraffes standing on rocks in the middle of a field
  3. two giraffes in a wooded and grassy area
  4. two giraffes standing in a green shady field
  5. two giraffes standing next to each other in front of trees
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a fire hydrant with a hose attached to it -7.4004
  2. a fire hydrant is on the side of the road -8.1863
  3. a fire hydrant that is on the side of the road -8.7289
  4. a fire hydrant that is on the side of a road -9.0019
  5. the fire hydrant is on the side of the road -9.0924
  6. a fire hydrant is painted white and blue -10.1041
  7. a fire hydrant is painted white and black -10.1185
  8. a fire hydrant is on the sidewalk next to a car -10.4407
  9. a fire hydrant is on a sidewalk next to a car -10.7089
  10. a fire hydrant that has been painted white and blue -11.2274
  1. a fire hydrant is on the sidewalk next to a car -1.4475
  2. a fire hydrant is on a sidewalk next to a car -1.4559
  3. a fire hydrant with a hose attached to it -1.5761
  4. a fire hydrant is painted white and black -2.6243
  5. a fire hydrant that has been painted white and blue -2.8812
  6. a fire hydrant is on the side of the road -2.9437
  7. the fire hydrant is on the side of the road -3.1696
  8. a fire hydrant that is on the side of a road -3.2031
  9. a fire hydrant that is on the side of the road -3.2978
  10. a fire hydrant is painted white and blue -3.7296
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
  1. a motorcycle parked in a parking lot next to a car
  2. an antique indian motorcycle is parked next to the sidewalk
  3. motorcycle parked on the edge of a street
  4. an old indian motorcycle parked at the curb of a street
  5. a motorcycle parked on a sidewalk next to a street
  1. a motorcycle parked next to a sidewalk on the street
  2. the motorcycle is parked at the curb near the bicycles
  3. a street scene with the motorcycle and bicycles on the side of the road
  4. bicycles and a motorcycle parked on a city sidewalk
  5. a motorcycle and bicycles parked on a city street
  1. the yellow fire hydrant is on the curb as cars pass by
  2. a yellow fire hydrant sitting on the side of a road
  3. a yellow fire hydrant next to a street
  4. a yellow fire hydrant that is on a sidewalk
  5. a fire hydrant sits next to a city street
  1. a red white and blue fire hydrant covered in stars
  2. the fire hydrant is painted red white and blue
  3. a close up of a fire hydrant UNK red white and blue with stars
  4. a fire hydrant painted red white and blue are on the curb
  5. a fire hydrant painted red white and blue with white stars
  1. a hydrant that is sitting on the sidewalk
  2. a fire hydrant is next to a cone on a sidewalk
  3. a pipe sticking out of a paved surface next to a street grate
  4. there is a water hole on the street
  5. there is construction work being done on an urban street
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a red fire hydrant sitting in the middle of a sidewalk -7.0782
  2. a red fire hydrant in front of a tree -7.2148
  3. a red fire hydrant in the middle of a park -7.5111
  4. a fire hydrant in the middle of a park -7.6126
  5. a red fire hydrant in the middle of a sidewalk -7.9263
  6. a fire hydrant in the middle of a sidewalk -8.0315
  7. a fire hydrant is spraying water onto a street -8.3755
  8. a red fire hydrant is in the middle of a sidewalk -9.2053
  9. a red fire hydrant in a city street -9.6061
  10. a red fire hydrant in a park next to a tree -10.1319
  1. a red fire hydrant in a park next to a tree -2.0400
  2. a red fire hydrant in a city street -2.0684
  3. a red fire hydrant in front of a tree -2.0851
  4. a red fire hydrant is in the middle of a sidewalk -2.0884
  5. a red fire hydrant in the middle of a park -2.0914
  6. a red fire hydrant in the middle of a sidewalk -2.1153
  7. a red fire hydrant sitting in the middle of a sidewalk -2.1219
  8. a fire hydrant in the middle of a park -2.7492
  9. a fire hydrant in the middle of a sidewalk -3.1231
  10. a fire hydrant is spraying water onto a street -3.7515
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
  1. a blue and white fire hydrant on the sidewalk
  2. a blue and white fire hydrant on the side of the street
  3. a blue fire hydrant with a white top sits beside a road
  4. a blue fire hydrant sits in the middle of a sidewalk
  5. a blue and white fire hydrant sitting on top of a sidewalk
  1. a fire hydrant in a city with water pouring out of both sides
  2. a fire hydrant has water streaming out of two holes on its side
  3. a green fire hydrant pouring water from two of its spouts
  4. a fire hydrant that is open with water coming out of two holes
  5. a fire hydrant with water pouring out of it
  1. a fire hydrant next to a bush at a park
  2. a parking meter on the side of a wooded street
  3. a fire hydrant on a neighborhood street with trees and shrubs around it
  4. a street corner with a blue fire hydrant
  5. a scenic view of a wooded area with parking meter
  1. a blue and pink fire hydrant spewing out water onto a street
  2. a fire hydrant open spilling water onto the street
  3. a pink faded fire hydrant with dirty water coming out of it
  4. a fire hydrant is open with water coming out
  5. open fire hydrant with warning cone in urban city setting
  1. a fire hydrant on the corner of a neighborhood street
  2. a fire hydrant on the corner of a street
  3. a yellow and green fire hydrant sitting on the side of a road
  4. the fire hydrant is green and yellow
  5. a fire hydrant sitting near a sign beside the street
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a yellow and blue fire hydrant sitting on the side of a road -4.9590
  2. a yellow fire hydrant sitting on the side of a road -5.1579
  3. a blue and yellow fire hydrant sitting on the side of a road -5.4978
  4. a yellow and blue fire hydrant on a sidewalk -5.8938
  5. a blue and yellow fire hydrant on a sidewalk -6.5236
  6. the fire hydrant is on the side of the road -8.2083
  7. a fire hydrant on a sidewalk next to a street -8.2193
  8. a yellow fire hydrant on a sidewalk next to a street -8.2990
  9. a fire hydrant on a sidewalk near a street -8.3000
  10. a yellow fire hydrant on a sidewalk near a street -8.9389
  1. a blue and yellow fire hydrant sitting on the side of a road -1.7379
  2. a blue and yellow fire hydrant on a sidewalk -1.9939
  3. a yellow and blue fire hydrant sitting on the side of a road -2.1876
  4. a yellow fire hydrant sitting on the side of a road -2.2820
  5. a yellow fire hydrant on a sidewalk near a street -2.4456
  6. a yellow fire hydrant on a sidewalk next to a street -2.4725
  7. a yellow and blue fire hydrant on a sidewalk -2.5455
  8. a fire hydrant on a sidewalk next to a street -2.5564
  9. the fire hydrant is on the side of the road -2.5935
  10. a fire hydrant on a sidewalk near a street -2.6334
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
  1. a blue fire hydrant posed on a street corner in a city
  2. a blue water hydrant on a pavement near the road
  3. a blue and yellow fire hydrant sitting on the sidewalk next to a quiet street
  4. a blue and yellow fire hydrant on the side of a road
  5. a fire hydrant on a sidewalk of a city
  1. a green fire hydrant with three yellow concrete barriers around it
  2. pavement level view of green hydrant near a street corner
  3. a green fire hydrant surrounded by three yellow poles
  4. a green fire hydrant sitting between three yellow post
  5. a green fire hydrant and a bus on the road
  1. a fire hydrant that is sitting on the sidewalk
  2. an orange fire hydrant near the side of the street
  3. an orange fire hydrant sitting at the side of the street
  4. a fire hydrant on a sidewalk next to a street
  5. a UNK hydrant on a side walk near a city street
  1. a street intersection that has a traffic light and a direction sign on the corner along
  2. there is a telescope in the middle of a street
  3. a street sign near a traffic light pole
  4. that is a picture of an outside region
  5. a closeup of a telescope next to a street
  1. a fire hydrant is painted silver and blue
  2. two fire hydrants that are by the street
  3. two silver and blue fire hydrants side on either side of a road
  4. silver and blue fire hydrants are placed parallel to each other
  5. a fire hydrant is painted blue and grey
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a close up of a street sign with a sky background -5.2408
  2. a street sign that is on a pole -9.8394
  3. a street sign with a sticker on it -10.0734
  4. a street sign that says UNK and UNK -11.1867
  5. a street sign with a sign on it -12.2090
  6. a street sign with a sign that reads UNK -13.1063
  7. a sign that says UNK and UNK UNK -13.5002
  8. a street sign with a sign that says UNK -13.5340
  9. a street sign with a sticker of a man on it -14.0173
  10. a sign that says UNK UNK and a street -15.6800
  1. a sign that says UNK UNK and a street -1.8236
  2. a street sign with a sticker of a man on it -2.0566
  3. a sign that says UNK and UNK UNK -2.0826
  4. a close up of a street sign with a sky background -2.1726
  5. a street sign with a sticker on it -2.3357
  6. a street sign with a sign that says UNK -2.3788
  7. a street sign with a sign that reads UNK -2.3977
  8. a street sign that is on a pole -2.7060
  9. a street sign with a sign on it -2.7338
  10. a street sign that says UNK and UNK -2.8123
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
  1. a stop sign with graffiti on the UNK block of UNK street
  2. a close up of a vandalized stop sign on a pole
  3. a stop sign and street sign attached to a pole at an intersection
  4. a stop sign on UNK street has graffiti
  5. stop sign with intended UNK written in below it
  1. a sign prohibiting bicycle parking with UNK of towing
  2. a sign indicating that bicycle parking is not allowed
  3. a red and white street sign stating no bicycle parking
  4. a picture of a no bicycle parking sign
  5. a street sign that tells UNK not to park
  1. a black and white street sign that reads end bird
  2. looking up at a street sign that reads end bird
  3. sign on a street pole saying end bird
  4. a street sign stands under some power lines
  5. a sign on a post that reads end bird on it
  1. a red stop sign with the word them under it
  2. a one way sign is attached to a stop sign
  3. a stop sign with the word stop them on it below a one way
  4. a stop sign vandalized to read stop them
  5. a street stop sign with a one way sign attached on top
  1. a red sign warning people about pedestrians UNK hit by a crossing guard
  2. a round red danger railroad crossing sign with a red umbrella in the background
  3. a warning sign about danger at a railroad crossing
  4. a sign show that there is danger ahead
  5. a red danger sign with a person on it
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "K = 5\n", + "for i in range(10):\n", + " caps = this_captions[i]\n", + "# this_s0_scores = this_log_probs[i]\n", + " seq = this_seq[i]\n", + " neighbor_batch = get_neighbor_batch(this_feats[i].numpy(), K, include_self=True, self_ix=this_ixs[i])\n", + " \n", + " # num_sampled_captions x (1+K)\n", + " cp_scores = cross_product_scores(\n", + " neighbor_batch['fc_feats'].cuda(),\n", + " neighbor_batch['att_feats'].cuda(), \n", + " neighbor_batch['att_masks'].cuda(),\n", + " this_seq[i].cuda()\n", + " )\n", + " this_s0_scores = cp_scores[:,0].detach().cpu().tolist()\n", + " \n", + " l1_scores = cp_scores.log_softmax(1)\n", + " s1_scores = l1_scores.log_softmax(0)\n", + " \n", + " this_s1_scores = s1_scores[:,0].detach().cpu().tolist()\n", + " \n", + " def make_strings(lps, caps):\n", + " scored_caps = sorted(list(zip(lps, caps)), reverse=True)\n", + " deduped_caps = [next(g) for k, g in groupby(scored_caps, lambda t: t[1])]\n", + " cap_strings = [\"{} {:.4f}\".format(cap, lp) for lp, cap in deduped_caps]\n", + " return cap_strings\n", + " \n", + " display_images([this_paths[i], this_paths[i]], \n", + " captions=[\n", + " make_strings(this_s0_scores, caps)[:10],\n", + " make_strings(this_s1_scores, caps)[:10],\n", + " ])\n", + " display_neighbors(data['fc_feats'][i].numpy(), k=min(K, 12), num_per_row=4)\n", + " print()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/train2014 b/notebooks/train2014 new file mode 120000 index 00000000..8964def3 --- /dev/null +++ b/notebooks/train2014 @@ -0,0 +1 @@ +/home/dfried/data/coco/train2014 \ No newline at end of file diff --git a/notebooks/val2014 b/notebooks/val2014 new file mode 120000 index 00000000..ee55c0b7 --- /dev/null +++ b/notebooks/val2014 @@ -0,0 +1 @@ +/home/dfried/data/coco/val2014 \ No newline at end of file diff --git a/notebooks/visualize.ipynb b/notebooks/visualize.ipynb new file mode 100644 index 00000000..fb2e8612 --- /dev/null +++ b/notebooks/visualize.ipynb @@ -0,0 +1,1434 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.core.display import display, HTML\n", + "display(HTML(\"\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/dfried/projects/ImageCaptioning.pytorch\n" + ] + } + ], + "source": [ + "cd /home/dfried/projects/ImageCaptioning.pytorch" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "def wrap_tag(tag, inner):\n", + " return f'<{tag}>{inner}'\n", + "\n", + "def image_html(image_path, width=300, border=False):\n", + " if border:\n", + " style = ' style=\"border: 5px solid #0FF\" '\n", + " else:\n", + " style = ''\n", + " return f''\n", + "\n", + "def captions_html(captions):\n", + " #return wrap_tag('p', '
'.join(' '.join(cap) for cap in captions))\n", + " return wrap_tag('ol', ''.join(wrap_tag('li', cap) for cap in captions))\n", + "\n", + "def images_html(image_paths, width=300, num_per_row=5, target=None, captions=None):\n", + " rows = []\n", + " for ix in range(0, len(image_paths), num_per_row):\n", + " items = [wrap_tag('td', image_html(image_paths[image_ix], width=width, border=image_ix == target)) \n", + " for image_ix in range(ix, ix+num_per_row) if image_ix < len(image_paths)]\n", + " rows.append(wrap_tag('tr', ''.join(items)))\n", + " if captions is not None:\n", + " cap_html = [\n", + " wrap_tag('td', captions_html(captions[image_ix]))\n", + " for image_ix in range(ix, ix+num_per_row)\n", + " if image_ix < len(image_paths)\n", + " ]\n", + " rows.append(wrap_tag('tr', ''.join(cap_html)))\n", + " return wrap_tag('table', ''.join(rows))\n", + "\n", + "def display_images(image_paths, width=300, num_per_row=5, target=None, captions=None):\n", + " display(HTML(images_html(image_paths, width=width, num_per_row=num_per_row, target=target, captions=captions)))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "def display_neighbors(features, k=5, num_per_row=5):\n", + " neighbor_batch = loader.indices['train'].get_neighbor_batch(loader, features.flatten(), k)\n", + "# neighbor_batch = get_neighbor_batch(features.flatten(), k)\n", + " paths_k = [d['file_path'] for d in neighbor_batch['infos']]\n", + " captions_k = [[' '.join(c) for c in cs] for cs in get_captions_from_batch(neighbor_batch)]\n", + " display_images(paths_k, captions=captions_k, num_per_row=num_per_row)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "results = torch.load('eval_results/pred_verbose_pragmatic_val_distr-10_cand-bs-10_s0-weight-0.0_val.pth')" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'image_id': 184613,\n", + " 'caption': 'a group of people standing in a field with animals',\n", + " 'perplexity': 0.0,\n", + " 'entropy': 0.0,\n", + " 'target_s0_scores': array([ -9.350447, -9.878253, -10.036623, -10.061303, -10.295622,\n", + " -10.518627, -10.526928, -10.881914, -11.104778, -12.62778 ],\n", + " dtype=float32),\n", + " 'target_s1_scores': array([-1.9426843, -2.6765547, -2.9688997, -2.054984 , -1.8441135,\n", + " -2.028884 , -2.0909386, -2.18645 , -3.2791595, -3.0959177],\n", + " dtype=float32),\n", + " 'target_s0s1_scores': array([-1.9426843, -2.6765547, -2.9688997, -2.054984 , -1.8441135,\n", + " -2.028884 , -2.0909386, -2.18645 , -3.2791595, -3.0959177],\n", + " dtype=float32),\n", + " 'chosen_target_s0s1_scores': -1.8441135,\n", + " 'candidates': ['a group of people in a field with animals',\n", + " 'a woman standing next to a herd of cattle',\n", + " 'a woman standing next to a herd of cows',\n", + " 'a group of people in a field with cows',\n", + " 'a group of people standing in a field with animals',\n", + " 'a woman standing next to a group of people',\n", + " 'a group of people in a field with some animals',\n", + " 'a group of people standing in a field with cows',\n", + " 'a group of people standing around a herd of animals',\n", + " 'a group of people standing in a field with a'],\n", + " 'context_paths': ['val2014/COCO_val2014_000000184613.jpg',\n", + " 'val2014/COCO_val2014_000000250804.jpg',\n", + " 'train2014/COCO_train2014_000000433662.jpg',\n", + " 'train2014/COCO_train2014_000000020966.jpg',\n", + " 'train2014/COCO_train2014_000000063043.jpg',\n", + " 'train2014/COCO_train2014_000000077693.jpg',\n", + " 'train2014/COCO_train2014_000000378214.jpg',\n", + " 'train2014/COCO_train2014_000000407590.jpg',\n", + " 'val2014/COCO_val2014_000000031471.jpg',\n", + " 'train2014/COCO_train2014_000000007953.jpg',\n", + " 'train2014/COCO_train2014_000000086329.jpg'],\n", + " 'context_ids': [184613,\n", + " 250804,\n", + " 433662,\n", + " 20966,\n", + " 63043,\n", + " 77693,\n", + " 378214,\n", + " 407590,\n", + " 31471,\n", + " 7953,\n", + " 86329]}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "def display_record(record):\n", + " paths = record['context_paths']\n", + " formatted_captions = [\n", + " \"{} | {:.2f} | {:.2f} | {:.2f}\".format(cap, s0, s1, s0s1)\n", + " for cap, s0, s1, s0s1 in zip(record['candidates'], record['target_s0_scores'], record['target_s1_scores'], record['target_s0s1_scores'])\n", + " ]\n", + " print(record['caption'])\n", + " display_images([paths[0]], captions=[formatted_captions], width=400)\n", + " display_images(paths[1:], num_per_row=3)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a group of people standing in a field with animals\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a group of people in a field with animals | -9.35 | -1.94 | -1.94
  2. a woman standing next to a herd of cattle | -9.88 | -2.68 | -2.68
  3. a woman standing next to a herd of cows | -10.04 | -2.97 | -2.97
  4. a group of people in a field with cows | -10.06 | -2.05 | -2.05
  5. a group of people standing in a field with animals | -10.30 | -1.84 | -1.84
  6. a woman standing next to a group of people | -10.52 | -2.03 | -2.03
  7. a group of people in a field with some animals | -10.53 | -2.09 | -2.09
  8. a group of people standing in a field with cows | -10.88 | -2.19 | -2.19
  9. a group of people standing around a herd of animals | -11.10 | -3.28 | -3.28
  10. a group of people standing in a field with a | -12.63 | -3.10 | -3.10
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a small kitchen with white appliances and a white stove\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a kitchen with a white stove top oven | -6.87 | -2.90 | -2.90
  2. a kitchen with white appliances and white appliances | -6.93 | -2.09 | -2.09
  3. a small kitchen with white appliances and white appliances | -7.03 | -1.83 | -1.83
  4. a kitchen with a stove and a sink | -7.43 | -3.09 | -3.09
  5. a small kitchen with a white stove top oven | -7.51 | -2.18 | -2.18
  6. a kitchen with a stove a sink and a stove | -7.70 | -2.95 | -2.95
  7. a kitchen with white appliances and a white stove | -7.90 | -1.87 | -1.87
  8. a kitchen with a stove a sink and a refrigerator | -8.03 | -3.53 | -3.53
  9. a kitchen with a white stove top oven next to a sink | -8.33 | -2.38 | -2.38
  10. a small kitchen with white appliances and a white stove | -8.37 | -1.75 | -1.75
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[1])" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "two elephants standing next to each other in a dirt field\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a couple of elephants standing next to each other | -5.85 | -2.65 | -2.65
  2. a couple of elephants that are standing in the dirt | -6.51 | -2.64 | -2.64
  3. a couple of elephants standing in a dirt field | -6.74 | -2.03 | -2.03
  4. a couple of elephants are standing in a pin | -6.95 | -2.45 | -2.45
  5. a group of elephants standing next to each other | -7.12 | -3.41 | -3.41
  6. a herd of elephants standing next to each other | -7.63 | -3.94 | -3.94
  7. a couple of elephants standing next to each other on a dirt field | -7.90 | -2.05 | -2.05
  8. two elephants standing next to each other in a dirt field | -8.18 | -1.74 | -1.74
  9. a couple of elephants standing next to a pile of logs | -8.57 | -2.04 | -2.04
  10. two elephants standing next to each other in a dirt | -8.60 | -1.88 | -1.88
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[300])" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a couple of giraffe standing in front of a building\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a couple of giraffe standing next to each other | -5.37 | -2.77 | -2.77
  2. two giraffes standing in front of a building | -5.97 | -1.80 | -1.80
  3. a couple of giraffe standing next to a building | -6.36 | -2.22 | -2.22
  4. a couple of giraffes that are standing in the dirt | -6.95 | -2.66 | -2.66
  5. a couple of giraffe standing on top of a dirt field | -7.11 | -3.03 | -3.03
  6. a couple of giraffe standing in front of a building | -7.12 | -1.73 | -1.73
  7. two giraffe standing next to each other in front of a building | -7.32 | -1.98 | -1.98
  8. a couple of giraffe standing next to each other on a field | -7.44 | -3.14 | -3.14
  9. two giraffes standing next to each other in front of a building | -7.55 | -2.01 | -2.01
  10. two giraffe standing next to each other on a dirt field | -7.61 | -2.94 | -2.94
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[400])" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a vase filled with lots of yellow flowers on top of a table\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a vase filled with yellow and yellow flowers | -8.07 | -2.28 | -2.28
  2. a vase filled with lots of yellow flowers | -8.25 | -2.00 | -2.00
  3. a vase filled with yellow flowers on top of a table | -8.32 | -2.16 | -2.16
  4. a vase filled with flowers on top of a table | -8.39 | -4.29 | -4.29
  5. a vase filled with flowers sitting on top of a table | -8.56 | -4.66 | -4.66
  6. a close up of a vase with flowers in it | -8.64 | -2.27 | -2.27
  7. a vase filled with yellow flowers sitting on top of a table | -9.37 | -2.50 | -2.50
  8. a vase filled with lots of yellow flowers on top of a table | -9.50 | -1.40 | -1.40
  9. a close up of a vase with many flowers in it | -9.75 | -1.94 | -1.94
  10. a vase filled with yellow flowers sitting on a table | -9.75 | -3.07 | -3.07
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[500])" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a little boy sitting on top of a surfboard\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a child sitting on a surfboard in the water | -6.02 | -2.32 | -2.32
  2. a little boy sitting on a surfboard in the water | -6.41 | -2.51 | -2.51
  3. a young boy sitting on a surfboard in the water | -6.47 | -2.82 | -2.82
  4. a little boy sitting on top of a surfboard | -6.74 | -1.86 | -1.86
  5. a young boy sitting on top of a surfboard | -6.87 | -2.11 | -2.11
  6. a child sitting on a surf board in the water | -6.93 | -2.28 | -2.28
  7. a young boy is sitting on a surfboard in the water | -7.39 | -2.78 | -2.78
  8. a young boy sitting on a surf board in the water | -7.51 | -2.67 | -2.67
  9. a little boy sitting on top of a surfboard in the water | -7.53 | -1.88 | -1.88
  10. a little boy sitting on top of a surfboard in the ocean | -8.40 | -2.32 | -2.32
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[600])" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a woman riding on the back of a skateboard\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a black and white photo of a woman on a skateboard | -7.71 | -2.47 | -2.47
  2. a woman riding a skateboard down a sidewalk | -7.78 | -2.83 | -2.83
  3. a black and white photo of a person on a skateboard | -7.89 | -4.17 | -4.17
  4. a black and white photo of a child on a skateboard | -8.00 | -1.86 | -1.86
  5. a woman riding on the back of a skateboard | -8.07 | -1.69 | -1.69
  6. a black and white photo of a boy on a skateboard | -8.12 | -3.24 | -3.24
  7. a black and white photo of a girl on a skateboard | -8.29 | -2.35 | -2.35
  8. a black and white photo of a young girl on a skateboard | -9.00 | -1.84 | -1.84
  9. a black and white photo of a young boy on a skateboard | -9.10 | -2.64 | -2.64
  10. a black and white photo of a young woman on a skateboard | -9.37 | -2.00 | -2.00
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[700])" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a sign that is on a pole near a tree\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a couple of bikes parked next to a river | -10.07 | -3.21 | -3.21
  2. a couple of bikes parked next to a lake | -10.27 | -2.31 | -2.31
  3. a couple of bikes parked next to a tree | -10.29 | -2.15 | -2.15
  4. a couple of bikes parked next to a park | -10.35 | -2.18 | -2.18
  5. a sign that is on a pole in the grass | -10.44 | -2.14 | -2.14
  6. a couple of bikes parked next to a body of water | -10.88 | -3.27 | -3.27
  7. a sign that is on a pole near a tree | -11.92 | -2.13 | -2.13
  8. a couple of bikes parked next to a park bench | -12.09 | -2.13 | -2.13
  9. a sign on the side of a road next to a river | -12.95 | -2.13 | -2.13
  10. a sign on the side of a road near a park | -13.08 | -2.14 | -2.14
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[800])" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a woman standing in a room while holding a cell phone\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a woman standing in a room holding a cell phone | -8.31 | -2.26 | -2.26
  2. a woman standing in a room with a cell phone | -8.74 | -2.32 | -2.32
  3. a woman standing in a room holding a camera | -8.99 | -2.29 | -2.29
  4. a woman standing in a room with a camera | -9.04 | -2.41 | -2.41
  5. a woman standing in a room holding a phone | -9.09 | -2.28 | -2.28
  6. a woman standing in a room while holding a cell phone | -9.97 | -2.25 | -2.25
  7. a woman standing in a room while holding a camera | -10.36 | -2.26 | -2.26
  8. a woman is standing in a room with a cell phone | -10.58 | -2.31 | -2.31
  9. a woman standing in front of a doorway holding a cell phone | -10.96 | -2.35 | -2.35
  10. a woman standing in front of a door holding a cell phone | -11.07 | -2.30 | -2.30
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[900])" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a kitchen with a microwave and a sink in it\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a kitchen with a sink and a refrigerator | -8.00 | -2.80 | -2.80
  2. a kitchen with a sink and a microwave | -8.30 | -2.44 | -2.44
  3. a kitchen with a microwave and a sink | -8.37 | -1.98 | -1.98
  4. a kitchen with a sink microwave and refrigerator | -8.69 | -2.01 | -2.01
  5. a kitchen with a sink a refrigerator and a sink | -9.44 | -2.22 | -2.22
  6. a kitchen with a sink a refrigerator and a microwave | -9.54 | -2.37 | -2.37
  7. a small kitchen with a sink and a refrigerator | -10.06 | -2.34 | -2.34
  8. a kitchen with a microwave and a sink in it | -10.12 | -1.96 | -1.96
  9. a kitchen with white cabinets and a white refrigerator | -10.24 | -3.72 | -3.72
  10. a kitchen with white cabinets and a white microwave | -10.35 | -2.16 | -2.16
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[1000])" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a couple of glazed donuts sitting on top of a table\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a couple of doughnuts that are on a plate | -8.16 | -3.04 | -3.04
  2. a couple of doughnuts that are on a table | -8.47 | -2.44 | -2.44
  3. a bunch of doughnuts that are on a plate | -8.51 | -3.24 | -3.24
  4. a couple of doughnuts that are sitting on a table | -8.65 | -2.05 | -2.05
  5. a couple of doughnuts that are on a tray | -8.77 | -2.58 | -2.58
  6. a close up of a doughnuts on a plate | -9.02 | -1.95 | -1.95
  7. a bunch of doughnuts that are on a table | -9.02 | -2.65 | -2.65
  8. a bunch of doughnuts that are sitting on a table | -9.12 | -2.25 | -2.25
  9. a close up of a doughnut on a plate | -9.16 | -2.27 | -2.27
  10. a couple of glazed donuts sitting on top of a table | -9.37 | -1.61 | -1.61
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[1100])" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a large black dog laying on top of a blanket\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a large black dog laying on top of a couch | -5.49 | -2.29 | -2.29
  2. a large black dog laying on top of a bed | -5.59 | -2.12 | -2.12
  3. a dog laying on a couch in a room | -7.04 | -3.18 | -3.18
  4. a black and brown dog laying on a couch | -7.07 | -1.97 | -1.97
  5. a black dog laying on top of a couch | -7.08 | -2.71 | -2.71
  6. a black and brown dog laying on top of a couch | -7.27 | -1.93 | -1.93
  7. a large black dog laying on top of a blanket | -7.39 | -1.64 | -1.64
  8. a black dog laying on top of a bed | -7.47 | -3.29 | -3.29
  9. a large dog laying on top of a couch | -7.53 | -2.87 | -2.87
  10. a black and brown dog laying on top of a bed | -7.62 | -2.33 | -2.33
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[1200])" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a group of motorcycles parked outside of a store\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a group of motorcycles parked in front of a store | -6.37 | -1.84 | -1.84
  2. a row of motorcycles parked in front of a store | -6.71 | -2.11 | -2.11
  3. a group of motorcycles parked in front of a building | -7.01 | -3.33 | -3.33
  4. a row of motorcycles parked in front of a building | -7.43 | -3.74 | -3.74
  5. a group of motorcycles parked outside of a store | -7.67 | -1.82 | -1.82
  6. a bunch of motorcycles parked in front of a store | -7.72 | -1.83 | -1.83
  7. a group of motorcycles parked outside of a building | -8.09 | -2.93 | -2.93
  8. a group of motorcycles parked in front of a shop | -8.20 | -2.16 | -2.16
  9. a bunch of motorcycles parked in front of a building | -8.21 | -3.08 | -3.08
  10. a group of motorcycles are parked in front of a store | -8.67 | -2.10 | -2.10
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[1300])" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a woman swinging a tennis racket at a tennis ball\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a woman swinging a tennis racket at a tennis ball | -5.18 | -2.25 | -2.25
  2. a woman hitting a tennis ball with a racquet | -5.26 | -2.27 | -2.27
  3. a woman swinging a tennis racquet at a tennis ball | -5.27 | -2.26 | -2.26
  4. a woman hitting a tennis ball with a tennis racquet | -5.39 | -2.27 | -2.27
  5. a woman holding a tennis racquet on a tennis court | -5.52 | -2.46 | -2.46
  6. a woman holding a tennis racquet on top of a tennis court | -5.70 | -2.40 | -2.40
  7. a woman swinging a tennis racket at a ball | -5.72 | -2.26 | -2.26
  8. a woman is swinging a tennis racket at a ball | -5.89 | -2.26 | -2.26
  9. a woman swinging a tennis racquet at a ball | -6.00 | -2.27 | -2.27
  10. a woman playing tennis on a tennis court | -6.04 | -2.36 | -2.36
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[1400])" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a slice of pizza sitting on top of a white paper plate\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a slice of pizza on a paper plate | -4.10 | -2.33 | -2.33
  2. a slice of pizza sitting on top of a white plate | -4.80 | -2.73 | -2.73
  3. a slice of pizza sitting on top of a paper plate | -5.20 | -2.16 | -2.16
  4. a piece of pizza on a paper plate | -5.21 | -2.40 | -2.40
  5. a slice of cheese pizza on a paper plate | -5.33 | -2.04 | -2.04
  6. a slice of pizza on a white plate | -5.38 | -3.09 | -3.09
  7. a slice of pizza sits on a paper plate | -5.39 | -2.05 | -2.05
  8. a piece of pizza sitting on top of a white plate | -5.70 | -2.85 | -2.85
  9. a slice of pizza sitting on top of a white paper plate | -5.79 | -1.79 | -1.79
  10. a slice of pizza sitting on a paper plate | -6.01 | -2.27 | -2.27
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[1500])" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a close up of a hot dog on a plate with ketchup\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a close up of a hot dog on a plate | -6.37 | -2.67 | -2.67
  2. a hot dog sitting on top of a white plate | -6.63 | -2.88 | -2.88
  3. a hot dog with ketchup and mustard on a plate | -7.70 | -2.19 | -2.19
  4. a hot dog with mustard and ketchup on a plate | -7.74 | -2.25 | -2.25
  5. a close up of a hot dog with ketchup and mustard | -8.07 | -2.34 | -2.34
  6. a hot dog sitting on top of a white plate next to a fork | -8.95 | -2.20 | -2.20
  7. a hot dog sitting on top of a bun covered in ketchup | -9.00 | -2.59 | -2.59
  8. a hot dog sitting on top of a bun covered in toppings | -9.04 | -2.63 | -2.63
  9. a close up of a hot dog on a plate with ketchup | -9.17 | -1.85 | -1.85
  10. a close up of a hot dog on a plate with mustard | -9.29 | -1.92 | -1.92
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[1600])" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a bunch of stuffed animals sitting on top of a window sill\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a group of stuffed animals sitting on top of a window sill | -9.90 | -1.73 | -1.73
  2. a bunch of stuffed animals on display in a window | -10.11 | -2.16 | -2.16
  3. a bunch of stuffed animals sitting in a window | -10.21 | -2.74 | -2.74
  4. a bunch of stuffed animals are on display | -10.36 | -4.84 | -4.84
  5. a group of stuffed animals sitting in a window | -10.45 | -2.29 | -2.29
  6. a group of stuffed animals sitting next to each other | -10.48 | -4.00 | -4.00
  7. a bunch of stuffed animals sitting on a window sill | -10.61 | -2.16 | -2.16
  8. a group of stuffed animals sitting on top of a window | -10.63 | -1.70 | -1.70
  9. a bunch of stuffed animals sitting on top of a window sill | -11.18 | -1.55 | -1.55
  10. a bunch of stuffed animals sitting on a shelf | -11.19 | -5.36 | -5.36
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[1700])" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a stop sign on a pole on a road\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a red stop sign sitting on the side of a road | -4.14 | -2.73 | -2.73
  2. a red stop sign sitting next to a road | -7.01 | -1.88 | -1.88
  3. a red stop sign sitting on top of a metal pole | -7.39 | -3.33 | -3.33
  4. a red stop sign sitting on top of a road | -7.52 | -2.10 | -2.10
  5. a red stop sign sitting on the side of a road road | -7.57 | -2.34 | -2.34
  6. a red stop sign sitting next to a street | -7.64 | -2.04 | -2.04
  7. a red stop sign sitting on the side of a street | -7.65 | -3.18 | -3.18
  8. a stop sign on a pole on a road | -7.77 | -1.62 | -1.62
  9. a close up of a stop sign with a sky background | -7.90 | -2.58 | -2.58
  10. a red stop sign sitting on the side of the road | -7.92 | -2.49 | -2.49
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[1800])" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a man riding on the back of a blue motorcycle\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a man riding on the back of a motorcycle | -6.67 | -2.30 | -2.30
  2. a man riding a motorcycle next to a bus | -7.29 | -2.31 | -2.31
  3. a man riding a motorcycle down a street | -7.54 | -2.31 | -2.31
  4. a man riding a motorcycle in front of a bus | -7.57 | -2.32 | -2.32
  5. a man riding on the back of a motorcycle down a street | -7.58 | -2.30 | -2.30
  6. a man riding on the back of a blue motorcycle | -7.76 | -2.29 | -2.29
  7. a man riding a motorcycle down a street next to a bus | -7.95 | -2.32 | -2.32
  8. a man riding on the back of a motorcycle next to a bus | -8.18 | -2.30 | -2.30
  9. a man riding a motorcycle next to a blue bus | -8.31 | -2.29 | -2.29
  10. a man riding on the back of a motorcycle down a road | -9.04 | -2.29 | -2.29
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[1900])" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a kitchen with wood cabinets and a white stove\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a kitchen with wooden cabinets and white appliances | -6.01 | -2.73 | -2.73
  2. a kitchen with wood cabinets and white appliances | -6.63 | -2.53 | -2.53
  3. a kitchen with wooden cabinets and a white stove | -6.98 | -1.96 | -1.96
  4. a kitchen with a stove and a refrigerator | -7.66 | -2.37 | -2.37
  5. a kitchen with wood cabinets and a white stove | -7.95 | -1.77 | -1.77
  6. a kitchen with wooden cabinets and a stove | -8.13 | -2.18 | -2.18
  7. a kitchen with wooden cabinets and a stove top oven | -8.33 | -2.49 | -2.49
  8. a kitchen with a stove a sink and a refrigerator | -8.56 | -2.52 | -2.52
  9. a kitchen with a stove a refrigerator and a microwave | -8.71 | -2.30 | -2.30
  10. a kitchen with a stove a sink and a microwave | -9.00 | -2.62 | -2.62
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[2000])" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a woman is sitting at a table in a kitchen\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a woman standing in a kitchen next to a table | -7.83 | -2.80 | -2.80
  2. a woman sitting at a table in a kitchen | -7.98 | -1.68 | -1.68
  3. a woman standing in a kitchen next to a kitchen | -9.02 | -3.08 | -3.08
  4. a woman standing in a kitchen next to a counter | -9.11 | -3.18 | -3.18
  5. a woman sitting at a table in a room | -9.19 | -1.96 | -1.96
  6. a woman is sitting at a table in a kitchen | -9.35 | -1.67 | -1.67
  7. a woman standing in a kitchen next to a table and chairs | -9.43 | -2.51 | -2.51
  8. a woman standing in a kitchen next to a dining room table | -9.59 | -2.40 | -2.40
  9. a woman standing in a kitchen next to a dining table | -9.69 | -2.61 | -2.61
  10. a woman standing in a kitchen next to a table with chairs | -9.85 | -2.41 | -2.41
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[2100])" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor(-2.3979)" + ] + }, + "execution_count": 58, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.log(torch.tensor(1./11))" + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a display in a grocery store with lots of vegetables\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a display in a grocery store filled with lots of vegetables | -7.25 | -2.30 | -2.30
  2. a display in a grocery store filled with lots of fresh produce | -7.58 | -2.61 | -2.61
  3. a display in a grocery store filled with lots of produce | -8.77 | -2.48 | -2.48
  4. a display in a grocery store filled with lots of fresh vegetables | -8.86 | -2.63 | -2.63
  5. a variety of vegetables are displayed in a market | -8.89 | -2.18 | -2.18
  6. a bunch of vegetables are displayed in a market | -8.93 | -2.08 | -2.08
  7. a display in a grocery store with lots of vegetables | -9.02 | -1.94 | -1.94
  8. a display in a grocery store filled with lots of different vegetables | -9.26 | -2.43 | -2.43
  9. a bunch of vegetables are in a market | -9.42 | -2.33 | -2.33
  10. a display in a grocery store filled with vegetables | -9.44 | -2.26 | -2.26
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[2200])" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a dog that is standing in the grass with a dog\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a dog that is standing in the grass | -5.96 | -2.31 | -2.31
  2. a dog in a field with mountains in the background | -6.95 | -2.30 | -2.30
  3. a dog in a field with a mountain in the background | -7.55 | -2.30 | -2.30
  4. a dog standing in a field with mountains in the background | -7.67 | -2.30 | -2.30
  5. a dog standing in a field with a mountain in the background | -8.46 | -2.30 | -2.30
  6. a dog that is standing in the grass near a mountain | -8.91 | -2.30 | -2.30
  7. a black and white dog in a field with mountains in the background | -8.95 | -2.30 | -2.30
  8. a dog is standing in a field with a mountain in the background | -9.57 | -2.30 | -2.30
  9. a dog that is standing in the grass with a dog | -10.16 | -2.30 | -2.30
  10. a dog that is standing in the grass with a mountain in the background | -10.16 | -2.30 | -2.30
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[2300])" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a man riding a motorcycle down a street next to a tent\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a man riding on the back of a motorcycle | -5.31 | -2.27 | -2.27
  2. a man riding a motorcycle down a street | -5.65 | -2.45 | -2.45
  3. a man riding on the back of a motorcycle down a street | -5.70 | -2.38 | -2.38
  4. a man is riding a motorcycle down the street | -6.71 | -2.43 | -2.43
  5. a man is riding a motorcycle down a street | -6.90 | -2.29 | -2.29
  6. a man riding a motorcycle down the street | -7.07 | -2.79 | -2.79
  7. a man riding a motorcycle on a street | -7.08 | -2.21 | -2.21
  8. a man riding on the back of a motorcycle down a road | -7.63 | -2.25 | -2.25
  9. a man riding a motorcycle down a street next to a tent | -7.84 | -2.01 | -2.01
  10. a man on a motorcycle in the middle of a street | -7.90 | -2.13 | -2.13
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[2400])" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a chair and a chair on a sidewalk next to a bench\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a chair and a chair on a sidewalk | -7.79 | -2.30 | -2.30
  2. a chair and a chair on a city street | -8.00 | -2.30 | -2.30
  3. a chair and a chair sitting on a sidewalk | -8.18 | -2.30 | -2.30
  4. a chair and a chair on a street | -8.41 | -2.30 | -2.30
  5. a chair and a chair sit on a sidewalk | -9.05 | -2.30 | -2.30
  6. a chair and a chair next to a bench | -9.28 | -2.30 | -2.30
  7. a chair and a chair sitting next to a bench | -9.80 | -2.30 | -2.30
  8. a chair and a chair sitting next to a table | -9.89 | -2.30 | -2.30
  9. a chair and a chair on a sidewalk next to a bench | -10.00 | -2.30 | -2.30
  10. a chair and a chair on a sidewalk near a bench | -10.62 | -2.30 | -2.30
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[2500])" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a black and white photo of a person doing a trick on a skateboard\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a man riding a skateboard up the side of a ramp | -5.42 | -3.51 | -3.51
  2. a person riding a skate board on a skate park | -6.29 | -2.35 | -2.35
  3. a man riding a skateboard on top of a ramp | -6.34 | -2.61 | -2.61
  4. a man doing a trick on a skateboard on a ramp | -6.84 | -2.60 | -2.60
  5. a person riding a skate board on a ramp | -7.04 | -2.38 | -2.38
  6. a black and white photo of a skateboarder doing a trick | -7.42 | -2.11 | -2.11
  7. a black and white photo of a man on a skateboard | -7.50 | -2.31 | -2.31
  8. a black and white photo of a man doing a trick on a skateboard | -8.06 | -2.15 | -2.15
  9. a black and white photo of a person doing a trick on a skateboard | -8.20 | -1.72 | -1.72
  10. a black and white photo of a man doing a skateboard trick | -8.83 | -2.14 | -2.14
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[2600])" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a man diving for a frisbee in a field\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a man in a field catching a frisbee | -8.04 | -3.13 | -3.13
  2. a man diving to catch a frisbee in a field | -8.08 | -1.82 | -1.82
  3. a man diving to catch a frisbee in the grass | -8.24 | -1.92 | -1.92
  4. a man catching a frisbee in a field | -8.25 | -3.18 | -3.18
  5. a man diving for a frisbee in a field | -8.31 | -1.77 | -1.77
  6. a man catching a frisbee in a grassy field | -8.65 | -3.29 | -3.29
  7. a man diving to catch a frisbee in the park | -8.71 | -2.28 | -2.28
  8. a man diving for a frisbee in the grass | -8.73 | -1.95 | -1.95
  9. a man catching a frisbee on a field | -8.79 | -2.95 | -2.95
  10. a man diving to catch a frisbee in a park | -9.03 | -2.29 | -2.29
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[2700])" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a woman standing on a train platform looking at a train\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a woman standing on a platform next to a train | -8.56 | -2.33 | -2.33
  2. a woman standing on a train platform next to a train | -9.01 | -2.32 | -2.32
  3. a woman is standing on a train platform | -9.06 | -2.31 | -2.31
  4. a woman standing on the side of a train | -9.44 | -2.28 | -2.28
  5. a woman standing on a train platform near a train | -9.76 | -2.31 | -2.31
  6. a woman standing on the side of a train track | -9.77 | -2.31 | -2.31
  7. a woman standing on the side of a train train | -9.84 | -2.28 | -2.28
  8. a woman is looking at a train on the tracks | -9.96 | -2.30 | -2.30
  9. a woman standing on a train platform looking at a train | -10.08 | -2.28 | -2.28
  10. a woman standing on the side of a train station | -10.20 | -2.32 | -2.32
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[2800])" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "a bath room with a toilet a bath tub and a mirror\n" + ] + }, + { + "data": { + "text/html": [ + "
  1. a white toilet sitting next to a bath tub | -5.15 | -2.17 | -2.17
  2. a bath room with a toilet and a bath tub | -6.29 | -2.17 | -2.17
  3. a bath room with a toilet a sink and a bath tub | -6.37 | -2.16 | -2.16
  4. a white toilet sitting next to a bath tub in a bathroom | -6.68 | -2.38 | -2.38
  5. a bath room with a toilet and a sink | -6.76 | -2.48 | -2.48
  6. a bathroom with a toilet sink and bathtub | -6.79 | -2.31 | -2.31
  7. a bathroom with a toilet and a sink | -6.94 | -2.60 | -2.60
  8. a bathroom with a toilet sink and shower | -6.95 | -2.51 | -2.51
  9. a bath room with a toilet a bath tub and a sink | -7.11 | -2.36 | -2.36
  10. a bath room with a toilet a bath tub and a mirror | -7.33 | -2.03 | -2.03
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display_record(results[2900])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/plot.py b/plot.py new file mode 100644 index 00000000..160b6935 --- /dev/null +++ b/plot.py @@ -0,0 +1,69 @@ +import sys + +import argparse +import matplotlib.pyplot as plt +import pandas + +STATS = [ + 'Bleu_4', + 'METEOR', + 'ROUGE_L', + 'CIDEr', + 'WMD', +] + +DEFAULT_STATS = STATS + +def parse(log_file): + last_iter = None + stats_by_iter = {} + try: + with open(log_file, 'r') as f: + for line in f: + if line.startswith('iter'): + toks = line.split() + last_iter = int(toks[1]) + else: + for stat in STATS: + if line.startswith('{}: '.format(stat)): + if last_iter not in stats_by_iter: + stats_by_iter[last_iter] = {} + stat_val = float(line.split()[1]) + stats_by_iter[last_iter][stat] = stat_val + except Exception as e: + print(e) + return None + return stats_by_iter + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("log_files", nargs="+") + parser.add_argument("--stats", nargs="+", default=DEFAULT_STATS) + args = parser.parse_args() + dfs_by_name = {} + for log_file in args.log_files: + stats_by_iter = parse(log_file) + if stats_by_iter is None: + print("error parsing {}".format(log_file)) + df = pandas.DataFrame(stats_by_iter).transpose() + dfs_by_name[log_file] = df + last_stats = df[DEFAULT_STATS].iloc[-1] + # print(last_stats) + print("{}\t{}\t{}".format( + log_file, + last_stats.name, + ','.join('{:.4f}'.format(x) for x in last_stats) + )) + for stat in args.stats: + data = {} + for name, df in dfs_by_name.items(): + if stat in df.columns: + data[name] = df[stat] + else: + print('df {} does not have stat {}'.format(name, stat)) + collected_df = pandas.DataFrame(data) + if collected_df.empty: + print("stat {} is empty".format(stat)) + else: + collected_df.plot(title=stat) + plt.show() diff --git a/scripts/make_bu_data.py b/scripts/make_bu_data.py index 211f3e93..c77e94b4 100644 --- a/scripts/make_bu_data.py +++ b/scripts/make_bu_data.py @@ -11,6 +11,7 @@ import time import mmap import argparse +import tqdm parser = argparse.ArgumentParser() @@ -37,7 +38,7 @@ print('Reading ' + infile) with open(os.path.join(args.downloaded_feats, infile), "r") as tsv_in_file: reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames = FIELDNAMES) - for item in reader: + for item in tqdm.tqdm(list(reader), ncols=80): item['image_id'] = int(item['image_id']) item['num_boxes'] = int(item['num_boxes']) for field in ['boxes', 'features']: diff --git a/tools/eval.py b/tools/eval.py index 97c3498f..ba66b321 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -7,8 +7,11 @@ import time import os +import sys from six.moves import cPickle +import pprint + import captioning.utils.opts as opts import captioning.models as models from captioning.data.dataloader import * @@ -19,107 +22,143 @@ import captioning.modules.losses as losses import torch -# Input arguments and options -parser = argparse.ArgumentParser() -# Input paths -parser.add_argument('--model', type=str, default='', - help='path to model to evaluate') -parser.add_argument('--cnn_model', type=str, default='resnet101', - help='resnet101, resnet152') -parser.add_argument('--infos_path', type=str, default='', - help='path to infos to evaluate') -parser.add_argument('--only_lang_eval', type=int, default=0, - help='lang eval on saved results') -parser.add_argument('--force', type=int, default=0, - help='force to evaluate no matter if there are results available') -parser.add_argument('--device', type=str, default='cuda', - help='cpu or cuda') -opts.add_eval_options(parser) -opts.add_diversity_opts(parser) -opt = parser.parse_args() - -# Load infos -with open(opt.infos_path, 'rb') as f: - infos = utils.pickle_load(f) - -# override and collect parameters -replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id'] -ignore = ['start_from'] - -for k in vars(infos['opt']).keys(): - if k in replace: - setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, '')) - elif k not in ignore: - if not k in vars(opt): - vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model - -vocab = infos['vocab'] # ix -> word mapping - -pred_fn = os.path.join('eval_results/', '.saved_pred_'+ opt.id + '_' + opt.split + '.pth') -result_fn = os.path.join('eval_results/', opt.id + '_' + opt.split + '.json') - -if opt.only_lang_eval == 1 or (not opt.force and os.path.isfile(pred_fn)): - # if results existed, then skip, unless force is on +if __name__ == "__main__": + # Input arguments and options + parser = argparse.ArgumentParser() + # Input paths + parser.add_argument('--model', type=str, default='', + help='path to model to evaluate') + parser.add_argument('--cnn_model', type=str, default='resnet101', + help='resnet101, resnet152') + parser.add_argument('--infos_path', type=str, default='', + help='path to infos to evaluate') + parser.add_argument('--only_lang_eval', type=int, default=0, + help='lang eval on saved results') + parser.add_argument('--force', type=int, default=0, + help='force to evaluate no matter if there are results available') + parser.add_argument('--device', type=str, default='cuda', + help='cpu or cuda') + parser.add_argument('--from_serialized_candidates') + parser.add_argument('--save_verbose_predictions', type=int, default=0, help='write predictions to eval_results/pred_verbose_{id}_{split}.pth') + opts.add_loader_options(parser) + opts.add_eval_options(parser) + opts.add_diversity_opts(parser) + opts.add_pragmatics_opts(parser) + opts.add_mbr_opts(parser) + opt = parser.parse_args() + + print(' '.join(sys.argv)) + utils.dump_git_status() + pprint.pprint(vars(opt)) + + # Load infos + with open(opt.infos_path, 'rb') as f: + infos = utils.pickle_load(f) + + # override and collect parameters + replace = ['input_fc_dir', 'input_att_dir', 'input_box_dir', 'input_label_h5', 'input_json', 'batch_size', 'id'] + ignore = ['start_from'] + + for k in vars(infos['opt']).keys(): + if k in replace: + setattr(opt, k, getattr(opt, k) or getattr(infos['opt'], k, '')) + elif k not in ignore: + if not k in vars(opt): + vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model + + vocab = infos['vocab'] # ix -> word mapping + + pred_fn = os.path.join('eval_results/', '.saved_pred_'+ opt.id + '_' + opt.split + '.pth') + result_fn = os.path.join('eval_results/', opt.id + '_' + opt.split + '.json') + + def print_lang_stats(lang_stats): + stat_keys = ['Bleu_1', 'Bleu_4', 'METEOR', 'ROUGE_L', 'CIDEr', 'SPICE'] + if lang_stats: + pprint.pprint(lang_stats) + print(','.join(stat_keys)) + print(','.join('{:.4f}'.format(lang_stats[key]) if key in lang_stats else '----' + for key in stat_keys )) + + def vis_predictions(split_predictions): + if opt.dump_json == 1: + # dump the json + json.dump(split_predictions, open('vis/vis.json', 'w')) + + if opt.from_serialized_candidates or opt.only_lang_eval == 1 or (not opt.force and os.path.isfile(pred_fn)): + # if results existed, then skip, unless force is on + if opt.from_serialized_candidates: + predictions = eval_utils.eval_split_from_serialized( + opt.from_serialized_candidates, vars(opt) + ) + n_predictions = [] + print(f'saving to {pred_fn}') + torch.save((predictions, n_predictions), pred_fn) + else: + if not opt.force: + try: + if os.path.isfile(result_fn): + print(result_fn) + json.load(open(result_fn, 'r')) + print('already evaluated') + os._exit(0) + except: + pass + + predictions, n_predictions = torch.load(pred_fn) + lang_stats = eval_utils.language_eval(opt.input_json, predictions, n_predictions, vars(opt), opt.split) + print_lang_stats(lang_stats) + vis_predictions(predictions) + os._exit(0) + + # At this point only_lang_eval if 0 if not opt.force: + # Check out if try: - if os.path.isfile(result_fn): - print(result_fn) + # if no pred exists, then continue + tmp = torch.load(pred_fn) + # if language_eval == 1, and no pred exists, then continue + if opt.language_eval == 1: json.load(open(result_fn, 'r')) - print('already evaluated') - os._exit(0) + print('Result is already there') + os._exit(0) except: pass - predictions, n_predictions = torch.load(pred_fn) - lang_stats = eval_utils.language_eval(opt.input_json, predictions, n_predictions, vars(opt), opt.split) - print(lang_stats) - os._exit(0) - -# At this point only_lang_eval if 0 -if not opt.force: - # Check out if - try: - # if no pred exists, then continue - tmp = torch.load(pred_fn) - # if language_eval == 1, and no pred exists, then continue - if opt.language_eval == 1: - json.load(open(result_fn, 'r')) - print('Result is already there') - os._exit(0) - except: - pass - -# Setup the model -opt.vocab = vocab -model = models.setup(opt) -del opt.vocab -model.load_state_dict(torch.load(opt.model, map_location='cpu')) -model.to(opt.device) -model.eval() -crit = losses.LanguageModelCriterion() - -# Create the Data Loader instance -if len(opt.image_folder) == 0: - loader = DataLoader(opt) -else: - loader = DataLoaderRaw({'folder_path': opt.image_folder, - 'coco_json': opt.coco_json, - 'batch_size': opt.batch_size, - 'cnn_model': opt.cnn_model}) -# When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json -# So make sure to use the vocab in infos file. -loader.dataset.ix_to_word = infos['vocab'] - - -# Set sample options -opt.dataset = opt.input_json -loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, - vars(opt)) - -print('loss: ', loss) -if lang_stats: - print(lang_stats) - -if opt.dump_json == 1: - # dump the json - json.dump(split_predictions, open('vis/vis.json', 'w')) + # Setup the model + opt.vocab = vocab + model = models.setup(opt) + del opt.vocab + model.load_state_dict(torch.load(opt.model, map_location='cpu')) + model.to(opt.device) + model.eval() + crit = losses.LanguageModelCriterion() + + # Create the Data Loader instance + if len(opt.image_folder) == 0: + # sample_method and sample_n_method take different specifications of contrastive_beam_search + if opt.pragmatic_inference or opt.sample_method == 'contrastive_beam_search' or opt.sample_n_method == 'contrastive_bs': + loader = DataLoader(opt, + build_nearest_neighbor_indices_for_splits=['train'], + index_serialization_root_path=opt.index_serialization_root_path) + else: + loader = DataLoader(opt) + else: + assert not opt.pragmatic_inference + loader = DataLoaderRaw({'folder_path': opt.image_folder, + 'coco_json': opt.coco_json, + 'batch_size': opt.batch_size, + 'cnn_model': opt.cnn_model}) + # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json + # So make sure to use the vocab in infos file. + loader.dataset.ix_to_word = infos['vocab'] + + + # Set sample options + opt.dataset = opt.input_json + loss, split_predictions, lang_stats = eval_utils.eval_split(model, crit, loader, + vars(opt)) + + print('loss: ', loss) + print_lang_stats(lang_stats) + vis_predictions(split_predictions) + diff --git a/tools/train.py b/tools/train.py index a9919482..7800b517 100644 --- a/tools/train.py +++ b/tools/train.py @@ -9,6 +9,9 @@ import numpy as np +import pprint +import sys + import time import os from six.moves import cPickle @@ -34,7 +37,13 @@ def train(opt): ################################ # Build dataloader ################################ - loader = DataLoader(opt) + if 0 <= opt.contrastive_after <= opt.max_epochs: + build_nearest_neighbor_indices_for_splits = ['train'] + else: + build_nearest_neighbor_indices_for_splits = None + loader = DataLoader(opt, + build_nearest_neighbor_indices_for_splits=build_nearest_neighbor_indices_for_splits, + index_serialization_root_path=opt.index_serialization_root_path) opt.vocab_size = loader.vocab_size opt.seq_length = loader.seq_length @@ -160,6 +169,11 @@ def train(opt): else: struc_flag = False + if opt.contrastive_after != -1 and epoch >= opt.contrastive_after: + contrastive_flag = True + else: + contrastive_flag = False + epoch_done = False start = time.time() @@ -173,12 +187,26 @@ def train(opt): torch.cuda.synchronize() start = time.time() - tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']] + if contrastive_flag: + ixs = [d['ix'] for d in data['infos']] + neighbor_data = loader.indices['train'].get_neighbor_batch( + loader, data['fc_feats'].cpu().numpy(), opt.pragmatic_distractors, + include_self=True, self_indices=ixs, + neighbor_type=opt.pragmatic_distractor_candidate_type + ) + data_to_use = neighbor_data + else: + data_to_use = data + + tmp = [data_to_use['fc_feats'], data_to_use['att_feats'], data_to_use['labels'], + data_to_use['masks'], data_to_use['att_masks']] tmp = [_ if _ is None else _.cuda() for _ in tmp] fc_feats, att_feats, labels, masks, att_masks = tmp - + optimizer.zero_grad() - model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag) + model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, + data_to_use['gts'], torch.arange(0, len(data_to_use['gts'])), + sc_flag, struc_flag, contrastive_flag) loss = model_out['loss'].mean() @@ -284,5 +312,10 @@ def train(opt): print(stack_trace) -opt = opts.parse_opt() -train(opt) +if __name__: + print(' '.join(sys.argv)) + opt = opts.parse_opt() + utils.dump_git_status() + pprint.pprint(vars(opt)) + + train(opt)