11import time
22import pickle
33import pika
4+ import random
5+ import torch
6+ import torchvision
7+ import torchvision .transforms as transforms
8+
49from torch import nn
10+ from collections import defaultdict
511
612import src .Log
713import src .Model
@@ -23,6 +29,24 @@ def __init__(self, client_id, layer_id, address, username, password, train_func,
2329 self .model = None
2430 self .connect ()
2531
32+ self .train_set = None
33+ self .label_to_indices = None
34+ if self .layer_id == 1 :
35+ # Read and load dataset
36+ transform_train = transforms .Compose ([
37+ transforms .RandomCrop (32 , padding = 4 ),
38+ transforms .RandomHorizontalFlip (),
39+ transforms .ToTensor (),
40+ transforms .Normalize ((0.4914 , 0.4822 , 0.4465 ), (0.2023 , 0.1994 , 0.2010 )),
41+ ])
42+
43+ self .train_set = torchvision .datasets .CIFAR10 (
44+ root = './data' , train = True , download = True , transform = transform_train )
45+
46+ self .label_to_indices = defaultdict (list )
47+ for idx , (_ , label ) in enumerate (self .train_set ):
48+ self .label_to_indices [label ].append (idx )
49+
2650 def wait_response (self ):
2751 status = True
2852 reply_queue_name = f'reply_{ self .client_id } '
@@ -66,12 +90,20 @@ def response_message(self, body):
6690 batch_size = self .response ["batch_size" ]
6791 lr = self .response ["lr" ]
6892 momentum = self .response ["momentum" ]
69- validation = self .response ["validation" ]
7093 control_count = self .response ["control_count" ]
7194
7295 # Start training
73- result , size = self .train_func (self .model , control_count , batch_size , lr , momentum ,
74- validation , label_count , num_layers )
96+ if self .layer_id == 1 :
97+ selected_indices = []
98+ for label , count in enumerate (label_count ):
99+ selected_indices .extend (random .sample (self .label_to_indices [label ], count ))
100+
101+ subset = torch .utils .data .Subset (self .train_set , selected_indices )
102+ train_loader = torch .utils .data .DataLoader (subset , batch_size = batch_size , shuffle = True )
103+
104+ result , size = self .train_func (self .model , lr , momentum , num_layers , control_count , train_loader )
105+ else :
106+ result , size = self .train_func (self .model , lr , momentum , num_layers , control_count )
75107
76108 # Stop training, then send parameters to server
77109 model_state_dict = self .model .state_dict ()
0 commit comments