Skip to content

Commit 31e2fd9

Browse files
author
Bai Yifan
authored
some fix about CE (#1242)
* ce fix
1 parent d51ebfe commit 31e2fd9

File tree

3 files changed

+18
-15
lines changed

3 files changed

+18
-15
lines changed

fluid/image_classification/.run_ce.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
# This file is only used for continuous evaluation.
44
export FLAGS_cudnn_deterministic=True
5+
BATCH_SIZE=56
56
cudaid=${object_detection_cudaid:=0}
67
export CUDA_VISIBLE_DEVICES=$cudaid
7-
python train.py --batch_size=64 --num_epochs=5 --enable_ce=True --lr_strategy=cosine_decay | python _ce.py
8+
python train.py --batch_size=${BATCH_SIZE} --num_epochs=5 --enable_ce=True --lr_strategy=cosine_decay | python _ce.py
89

910
cudaid=${object_detection_cudaid_m:=0, 1, 2, 3}
1011
export CUDA_VISIBLE_DEVICES=$cudaid
11-
python train.py --batch_size=64 --num_epochs=5 --enable_ce=True --lr_strategy=cosine_decay | python _ce.py
12+
python train.py --batch_size=${BATCH_SIZE} --num_epochs=5 --enable_ce=True --lr_strategy=cosine_decay | python _ce.py

fluid/object_detection/reader.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -292,24 +292,24 @@ def train(settings,
292292
shuffle=True,
293293
use_multiprocessing=True,
294294
num_workers=8,
295-
max_queue=24):
295+
max_queue=24,
296+
enable_ce=False):
296297
file_list = os.path.join(settings.data_dir, file_list)
297298

298-
def infinite_reader(gen):
299-
while True:
300-
for data in gen():
301-
yield data
302-
303299
if 'coco' in settings.dataset:
304300
generator = coco(settings, file_list, "train", batch_size, shuffle)
305301
else:
306302
generator = pascalvoc(settings, file_list, "train", batch_size, shuffle)
307303

304+
def infinite_reader():
305+
while True:
306+
for data in generator():
307+
yield data
308+
308309
def reader():
309310
try:
310311
enqueuer = GeneratorEnqueuer(
311-
infinite_reader(generator),
312-
use_multiprocessing=use_multiprocessing)
312+
infinite_reader(), use_multiprocessing=use_multiprocessing)
313313
enqueuer.start(max_queue_size=max_queue, workers=num_workers)
314314
generator_output = None
315315
while True:
@@ -325,7 +325,10 @@ def reader():
325325
if enqueuer is not None:
326326
enqueuer.stop()
327327

328-
return reader
328+
if enable_ce:
329+
return infinite_reader
330+
else:
331+
return reader
329332

330333

331334
def test(settings, file_list, batch_size):

fluid/object_detection/train.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
train_parameters = {
3434
"pascalvoc": {
35-
"train_images": 19200,
35+
"train_images": 16551,
3636
"image_shape": [3, 300, 300],
3737
"class_num": 21,
3838
"batch_size": 64,
@@ -143,7 +143,6 @@ def train(args,
143143
startup_prog.random_seed = 111
144144
train_prog.random_seed = 111
145145
test_prog.random_seed = 111
146-
num_workers = 1
147146

148147
train_py_reader, loss = build_program(
149148
main_prog=train_prog,
@@ -170,14 +169,14 @@ def if_exist(var):
170169
if parallel:
171170
train_exe = fluid.ParallelExecutor(main_program=train_prog,
172171
use_cuda=use_gpu, loss_name=loss.name)
173-
174172
train_reader = reader.train(data_args,
175173
train_file_list,
176174
batch_size_per_device,
177175
shuffle=is_shuffle,
178176
use_multiprocessing=True,
179177
num_workers=num_workers,
180-
max_queue=24)
178+
max_queue=24,
179+
enable_ce=enable_ce)
181180
test_reader = reader.test(data_args, val_file_list, batch_size)
182181
train_py_reader.decorate_paddle_reader(train_reader)
183182
test_py_reader.decorate_paddle_reader(test_reader)

0 commit comments

Comments
 (0)