Skip to content

Commit b58e41d

Browse files
committed
Use torch.utils.data to load file.
1 parent 0b8c057 commit b58e41d

File tree

1 file changed

+45
-41
lines changed

1 file changed

+45
-41
lines changed

dataloader.py

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,22 @@
77
import os
88
import numpy as np
99
import random
10-
from multiprocessing.dummy import Pool
10+
11+
import torch
12+
import torch.utils.data as data
13+
14+
import multiprocessing
1115

1216
def get_npy_data(ix, fc_file, att_file, use_att):
1317
if use_att == True:
1418
return (np.load(fc_file), np.load(att_file)['feat'], ix)
1519
else:
1620
return (np.load(fc_file), np.zeros((1,1,1)), ix)
1721

18-
class DataLoader():
22+
class DataLoader(data.Dataset):
1923

2024
def reset_iterator(self, split):
21-
self._prefetch_process[split].terminate()
22-
self._prefetch_process[split].join()
25+
del self._prefetch_process[split]
2326
self._prefetch_process[split] = BlobFetcher(split, self, split=='train')
2427
self.iterators[split] = 0
2528

@@ -89,8 +92,7 @@ def __init__(self, opt):
8992
def cleanup():
9093
print('Terminating BlobFetcher')
9194
for split in self.iterators.keys():
92-
self._prefetch_process[split].terminate()
93-
self._prefetch_process[split].join()
95+
del self._prefetch_process[split]
9496
import atexit
9597
atexit.register(cleanup)
9698

@@ -167,6 +169,22 @@ def get_batch(self, split, batch_size=None, seq_per_img=None):
167169

168170
return data
169171

172+
# It's not coherent to make DataLoader a subclass of Dataset, but essentially, we only need to implement the following to functions,
173+
# so that the torch.utils.data.DataLoader can load the data according the index.
174+
# However, it's minimum change to switch to pytorch data loading.
175+
def __getitem__(self, index):
176+
"""This function returns a tuple that is further passed to collate_fn
177+
"""
178+
ix = index #self.split_ix[index]
179+
return get_npy_data(ix, \
180+
os.path.join(self.input_fc_dir, str(self.info['images'][ix]['id']) + '.npy'),
181+
os.path.join(self.input_att_dir, str(self.info['images'][ix]['id']) + '.npz'),
182+
self.use_att
183+
)
184+
185+
def __len__(self):
186+
return len(self.info['images'])
187+
170188
class BlobFetcher():
171189
"""Experimental class for prefetching blobs in a separate process."""
172190
def __init__(self, split, dataloader, if_shuffle=False):
@@ -177,41 +195,24 @@ def __init__(self, split, dataloader, if_shuffle=False):
177195
self.dataloader = dataloader
178196
self.if_shuffle = if_shuffle
179197

180-
self.pool = Pool()
181-
self.fifo = []
182-
183198
# Add more in the queue
184199
def reset(self):
185-
if len(self.fifo) == 0:
186-
self.cur_idx = self.dataloader.iterators[self.split]
187-
self.cur_split_ix = self.dataloader.split_ix[self.split][:] # copy
188-
for i in range(512 - len(self.fifo)):
189-
ix = self.cur_split_ix[self.cur_idx]
190-
if self.cur_idx + 1 >= len(self.cur_split_ix):
191-
self.cur_idx = 0
192-
if self.if_shuffle:
193-
random.shuffle(self.cur_split_ix)
194-
else:
195-
self.cur_idx += 1
196-
self.fifo.append(self.pool.apply_async(get_npy_data, \
197-
(ix, \
198-
os.path.join(self.dataloader.input_fc_dir, str(self.dataloader.info['images'][ix]['id']) + '.npy'),
199-
os.path.join(self.dataloader.input_att_dir, str(self.dataloader.info['images'][ix]['id']) + '.npz'),
200-
self.dataloader.use_att
201-
)))
202-
203-
def terminate(self):
204-
while len(self.fifo) > 0:
205-
self.fifo.pop(0).get()
206-
self.pool.terminate()
207-
print(self.split, 'terminated')
208-
209-
def join(self):
210-
self.pool.join()
211-
print(self.split, 'joined')
200+
"""
201+
Two cases:
202+
1. not hasattr(self, 'split_loader'): Resume from previous training. Create the dataset given the saved split_ix and iterator
203+
2. wrapped: a new epoch, the split_ix and iterator have been updated in the get_minibatch_inds already.
204+
"""
205+
# batch_size is 0, the merge is done in DataLoader class
206+
self.split_loader = iter(data.DataLoader(dataset=self.dataloader,
207+
batch_size=1,
208+
sampler=self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:],
209+
shuffle=False,
210+
pin_memory=True,
211+
num_workers=multiprocessing.cpu_count(),
212+
collate_fn=lambda x: x[0]))
212213

213214
def _get_next_minibatch_inds(self):
214-
max_index = len(self.cur_split_ix)
215+
max_index = len(self.dataloader.split_ix[self.split])
215216
wrapped = False
216217

217218
ri = self.dataloader.iterators[self.split]
@@ -220,19 +221,22 @@ def _get_next_minibatch_inds(self):
220221
ri_next = ri + 1
221222
if ri_next >= max_index:
222223
ri_next = 0
223-
self.dataloader.split_ix[self.split] = self.cur_split_ix[:] # copy
224+
if self.if_shuffle:
225+
random.shuffle(self.dataloader.split_ix[self.split])
224226
wrapped = True
225227
self.dataloader.iterators[self.split] = ri_next
226228

227229
return ix, wrapped
228230

229231
def get(self):
230-
if len(self.fifo) < 400:
232+
if not hasattr(self, 'split_loader'):
231233
self.reset()
232234

233235
ix, wrapped = self._get_next_minibatch_inds()
234-
tmp = self.fifo.pop(0).get()
236+
tmp = self.split_loader.next()
237+
if wrapped:
238+
self.reset()
235239

236240
assert tmp[2] == ix, "ix not equal"
237241

238-
return tmp + (wrapped,)
242+
return tmp + [wrapped]

0 commit comments

Comments
 (0)