7
7
import os
8
8
import numpy as np
9
9
import random
10
- from multiprocessing .dummy import Pool
10
+
11
+ import torch
12
+ import torch .utils .data as data
13
+
14
+ import multiprocessing
11
15
12
16
def get_npy_data (ix , fc_file , att_file , use_att ):
13
17
if use_att == True :
14
18
return (np .load (fc_file ), np .load (att_file )['feat' ], ix )
15
19
else :
16
20
return (np .load (fc_file ), np .zeros ((1 ,1 ,1 )), ix )
17
21
18
- class DataLoader ():
22
+ class DataLoader (data . Dataset ):
19
23
20
24
def reset_iterator (self , split ):
21
- self ._prefetch_process [split ].terminate ()
22
- self ._prefetch_process [split ].join ()
25
+ del self ._prefetch_process [split ]
23
26
self ._prefetch_process [split ] = BlobFetcher (split , self , split == 'train' )
24
27
self .iterators [split ] = 0
25
28
@@ -89,8 +92,7 @@ def __init__(self, opt):
89
92
def cleanup ():
90
93
print ('Terminating BlobFetcher' )
91
94
for split in self .iterators .keys ():
92
- self ._prefetch_process [split ].terminate ()
93
- self ._prefetch_process [split ].join ()
95
+ del self ._prefetch_process [split ]
94
96
import atexit
95
97
atexit .register (cleanup )
96
98
@@ -167,6 +169,22 @@ def get_batch(self, split, batch_size=None, seq_per_img=None):
167
169
168
170
return data
169
171
172
+ # It's not coherent to make DataLoader a subclass of Dataset, but essentially, we only need to implement the following to functions,
173
+ # so that the torch.utils.data.DataLoader can load the data according the index.
174
+ # However, it's minimum change to switch to pytorch data loading.
175
+ def __getitem__ (self , index ):
176
+ """This function returns a tuple that is further passed to collate_fn
177
+ """
178
+ ix = index #self.split_ix[index]
179
+ return get_npy_data (ix , \
180
+ os .path .join (self .input_fc_dir , str (self .info ['images' ][ix ]['id' ]) + '.npy' ),
181
+ os .path .join (self .input_att_dir , str (self .info ['images' ][ix ]['id' ]) + '.npz' ),
182
+ self .use_att
183
+ )
184
+
185
+ def __len__ (self ):
186
+ return len (self .info ['images' ])
187
+
170
188
class BlobFetcher ():
171
189
"""Experimental class for prefetching blobs in a separate process."""
172
190
def __init__ (self , split , dataloader , if_shuffle = False ):
@@ -177,41 +195,24 @@ def __init__(self, split, dataloader, if_shuffle=False):
177
195
self .dataloader = dataloader
178
196
self .if_shuffle = if_shuffle
179
197
180
- self .pool = Pool ()
181
- self .fifo = []
182
-
183
198
# Add more in the queue
184
199
def reset (self ):
185
- if len (self .fifo ) == 0 :
186
- self .cur_idx = self .dataloader .iterators [self .split ]
187
- self .cur_split_ix = self .dataloader .split_ix [self .split ][:] # copy
188
- for i in range (512 - len (self .fifo )):
189
- ix = self .cur_split_ix [self .cur_idx ]
190
- if self .cur_idx + 1 >= len (self .cur_split_ix ):
191
- self .cur_idx = 0
192
- if self .if_shuffle :
193
- random .shuffle (self .cur_split_ix )
194
- else :
195
- self .cur_idx += 1
196
- self .fifo .append (self .pool .apply_async (get_npy_data , \
197
- (ix , \
198
- os .path .join (self .dataloader .input_fc_dir , str (self .dataloader .info ['images' ][ix ]['id' ]) + '.npy' ),
199
- os .path .join (self .dataloader .input_att_dir , str (self .dataloader .info ['images' ][ix ]['id' ]) + '.npz' ),
200
- self .dataloader .use_att
201
- )))
202
-
203
- def terminate (self ):
204
- while len (self .fifo ) > 0 :
205
- self .fifo .pop (0 ).get ()
206
- self .pool .terminate ()
207
- print (self .split , 'terminated' )
208
-
209
- def join (self ):
210
- self .pool .join ()
211
- print (self .split , 'joined' )
200
+ """
201
+ Two cases:
202
+ 1. not hasattr(self, 'split_loader'): Resume from previous training. Create the dataset given the saved split_ix and iterator
203
+ 2. wrapped: a new epoch, the split_ix and iterator have been updated in the get_minibatch_inds already.
204
+ """
205
+ # batch_size is 0, the merge is done in DataLoader class
206
+ self .split_loader = iter (data .DataLoader (dataset = self .dataloader ,
207
+ batch_size = 1 ,
208
+ sampler = self .dataloader .split_ix [self .split ][self .dataloader .iterators [self .split ]:],
209
+ shuffle = False ,
210
+ pin_memory = True ,
211
+ num_workers = multiprocessing .cpu_count (),
212
+ collate_fn = lambda x : x [0 ]))
212
213
213
214
def _get_next_minibatch_inds (self ):
214
- max_index = len (self .cur_split_ix )
215
+ max_index = len (self .dataloader . split_ix [ self . split ] )
215
216
wrapped = False
216
217
217
218
ri = self .dataloader .iterators [self .split ]
@@ -220,19 +221,22 @@ def _get_next_minibatch_inds(self):
220
221
ri_next = ri + 1
221
222
if ri_next >= max_index :
222
223
ri_next = 0
223
- self .dataloader .split_ix [self .split ] = self .cur_split_ix [:] # copy
224
+ if self .if_shuffle :
225
+ random .shuffle (self .dataloader .split_ix [self .split ])
224
226
wrapped = True
225
227
self .dataloader .iterators [self .split ] = ri_next
226
228
227
229
return ix , wrapped
228
230
229
231
def get (self ):
230
- if len (self . fifo ) < 400 :
232
+ if not hasattr (self , 'split_loader' ) :
231
233
self .reset ()
232
234
233
235
ix , wrapped = self ._get_next_minibatch_inds ()
234
- tmp = self .fifo .pop (0 ).get ()
236
+ tmp = self .split_loader .next ()
237
+ if wrapped :
238
+ self .reset ()
235
239
236
240
assert tmp [2 ] == ix , "ix not equal"
237
241
238
- return tmp + ( wrapped ,)
242
+ return tmp + [ wrapped ]
0 commit comments