Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion imagenet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@
'fastest way to use PyTorch for either single node or '
'multi node data parallel training')
parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")
parser.add_argument(
'--sync-bn',
action='store_true',
help='Convert BatchNorm layers to SyncBatchNorm for multi-GPU training'
)


best_acc1 = 0

Expand Down Expand Up @@ -160,10 +166,15 @@ def main_worker(gpu, ngpus_per_node, args):
if args.pretrained:
print("=> using pre-trained model '{}'".format(args.arch))
model = models.__dict__[args.arch](pretrained=True)

else:
print("=> creating model '{}'".format(args.arch))
model = models.__dict__[args.arch]()


# Convert BN → SyncBatchNorm if requested AND distributed training is enabled
if args.distributed and args.sync_bn:
print("=> Converting BatchNorm layers to SyncBatchNorm")
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
if not use_accel:
print('using CPU, this will be slow')
elif args.distributed:
Expand Down