Skip to content

Commit 672ee24

Browse files
committed
Modify the data loading mechanism and separate validation class
1 parent 01b4d61 commit 672ee24

File tree

8 files changed

+138
-158
lines changed

8 files changed

+138
-158
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ learning:
102102
momentum: 0.5
103103
batch-size: 256
104104
control-count: 3 # control count on client
105-
validation: False # run validate on client side
106105
```
107106
108107
This configuration is use for server.
@@ -152,7 +151,7 @@ If the `*.pth` file exists, the server will read the file and send the parameter
152151

153152
---
154153

155-
Version 1.8.0
154+
Version 2.0.0
156155

157156
The application is under development...
158157

client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
parser = argparse.ArgumentParser(description="Split learning framework")
1414
parser.add_argument('--layer_id', type=int, required=True, help='ID of layer, start from 1')
1515
parser.add_argument('--device', type=str, required=False, help='Device of client')
16+
parser.add_argument('--event_time', type=bool, default=False, required=False, help='Log event time for debug mode')
1617

1718
args = parser.parse_args()
1819

@@ -46,7 +47,7 @@
4647
if __name__ == "__main__":
4748
src.Log.print_with_color("[>>>] Client sending registration message to server...", "red")
4849
data = {"action": "REGISTER", "client_id": client_id, "layer_id": args.layer_id, "message": "Hello from Client!"}
49-
scheduler = Scheduler(client_id, args.layer_id, channel, device)
50+
scheduler = Scheduler(client_id, args.layer_id, channel, device, args.event_time)
5051
client = RpcClient(client_id, args.layer_id, address, username, password, scheduler.train_on_device, device)
5152
client.send_to_server(data)
5253
client.wait_response()

config.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,3 @@ learning:
2626
momentum: 0.5
2727
batch-size: 128
2828
control-count: 3
29-
validation: False

src/Model.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,5 @@
11
import torch
22
import torch.nn as nn
3-
import numpy as np
4-
import math
5-
from tqdm import tqdm
6-
7-
import torchvision
8-
import torchvision.transforms as transforms
9-
import torch.nn.functional as F
103

114
if torch.cuda.is_available():
125
device = "cuda"
@@ -464,44 +457,3 @@ def forward(self, x):
464457
out83 = self.layer83(out82)
465458
out84 = self.layer84(out83)
466459
return out84
467-
468-
469-
def test(model_name, state_dict_full, logger):
470-
transform_test = transforms.Compose([
471-
transforms.ToTensor(),
472-
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
473-
])
474-
475-
testset = torchvision.datasets.CIFAR10(
476-
root='./data', train=False, download=True, transform=transform_test)
477-
test_loader = torch.utils.data.DataLoader(
478-
testset, batch_size=100, shuffle=False, num_workers=2)
479-
480-
klass = globals().get(model_name)
481-
if klass is None:
482-
raise ValueError(f"Class '{model_name}' does not exist.")
483-
model = klass()
484-
model = nn.Sequential(*nn.ModuleList(model.children()))
485-
model.load_state_dict(state_dict_full)
486-
# evaluation mode
487-
model.eval()
488-
test_loss = 0
489-
correct = 0
490-
for data, target in tqdm(test_loader):
491-
output = model(data)
492-
test_loss += F.nll_loss(output, target, reduction='sum').item()
493-
pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
494-
correct += pred.eq(target.data.view_as(pred)).long().cpu().sum()
495-
496-
test_loss /= len(test_loader.dataset)
497-
accuracy = 100.0 * correct / len(test_loader.dataset)
498-
print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
499-
test_loss, correct, len(test_loader.dataset), accuracy))
500-
501-
if np.isnan(test_loss) or math.isnan(test_loss) or abs(test_loss) > 10e5:
502-
return False
503-
else:
504-
logger.log_info('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
505-
test_loss, correct, len(test_loader.dataset), accuracy))
506-
507-
return True

src/RpcClient.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
import time
22
import pickle
33
import pika
4+
import random
5+
import torch
6+
import torchvision
7+
import torchvision.transforms as transforms
8+
49
from torch import nn
10+
from collections import defaultdict
511

612
import src.Log
713
import src.Model
@@ -23,6 +29,24 @@ def __init__(self, client_id, layer_id, address, username, password, train_func,
2329
self.model = None
2430
self.connect()
2531

32+
self.train_set = None
33+
self.label_to_indices = None
34+
if self.layer_id == 1:
35+
# Read and load dataset
36+
transform_train = transforms.Compose([
37+
transforms.RandomCrop(32, padding=4),
38+
transforms.RandomHorizontalFlip(),
39+
transforms.ToTensor(),
40+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
41+
])
42+
43+
self.train_set = torchvision.datasets.CIFAR10(
44+
root='./data', train=True, download=True, transform=transform_train)
45+
46+
self.label_to_indices = defaultdict(list)
47+
for idx, (_, label) in enumerate(self.train_set):
48+
self.label_to_indices[label].append(idx)
49+
2650
def wait_response(self):
2751
status = True
2852
reply_queue_name = f'reply_{self.client_id}'
@@ -66,12 +90,20 @@ def response_message(self, body):
6690
batch_size = self.response["batch_size"]
6791
lr = self.response["lr"]
6892
momentum = self.response["momentum"]
69-
validation = self.response["validation"]
7093
control_count = self.response["control_count"]
7194

7295
# Start training
73-
result, size = self.train_func(self.model, control_count, batch_size, lr, momentum,
74-
validation, label_count, num_layers)
96+
if self.layer_id == 1:
97+
selected_indices = []
98+
for label, count in enumerate(label_count):
99+
selected_indices.extend(random.sample(self.label_to_indices[label], count))
100+
101+
subset = torch.utils.data.Subset(self.train_set, selected_indices)
102+
train_loader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True)
103+
104+
result, size = self.train_func(self.model, lr, momentum, num_layers, control_count, train_loader)
105+
else:
106+
result, size = self.train_func(self.model, lr, momentum, num_layers, control_count)
75107

76108
# Stop training, then send parameters to server
77109
model_state_dict = self.model.state_dict()

0 commit comments

Comments
 (0)