From ce4d42937e6cc4fa018d5a50c561da9bc9b9c439 Mon Sep 17 00:00:00 2001 From: truongtruong373 Date: Tue, 22 Apr 2025 16:19:54 +0700 Subject: [PATCH 1/3] add calculate manual grad and update manual grad of parameter each layer --- src/Schedule_zb.py | 152 +++++++++++++++++++++++++++++++++++++++++++++ src/Scheduler.py | 6 ++ src/Utils.py | 112 +++++++++++++++++++++++++++++++++ 3 files changed, 270 insertions(+) create mode 100644 src/Schedule_zb.py diff --git a/src/Schedule_zb.py b/src/Schedule_zb.py new file mode 100644 index 0000000..78cf717 --- /dev/null +++ b/src/Schedule_zb.py @@ -0,0 +1,152 @@ +import time +import uuid +import pickle +from tqdm import tqdm + +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as f + +from src.Utils import manual_W, hook_model +import src.Log + + +class Scheduler: + def __init__(self, client_id, layer_id, channel, device): + self.client_id = client_id + self.layer_id = layer_id + self.channel = channel + self.device = device + self.data_count = 0 + + def balanced_softmax_loss(self, logits, labels, class_counts, epsilon=1e-6): + class_counts = torch.tensor(class_counts, dtype=torch.int64).to(self.device) + log_probs = f.log_softmax(logits, dim=1) + class_probs = class_counts / (class_counts.sum() + epsilon) + weights = 1.0 / (class_probs + epsilon) + weights = weights / weights.sum() + loss = (-weights[labels] * log_probs[range(labels.shape[0]), labels]).mean() + return loss + + def send_intermediate_output(self, data_id, label_count, output, labels, trace, cluster=None, special=False): + if special: + forward_queue_name = f"intermediate_queue_{self.layer_id}" + else: + forward_queue_name = f"intermediate_queue_{self.layer_id}_{cluster}" + self.channel.queue_declare(forward_queue_name, durable=False) + + if trace: + trace.append(self.client_id) + message = pickle.dumps( + {"data_id": data_id, "label_count": label_count, "data": output.detach().cpu().numpy(), + "labels": labels, "trace": trace} + ) + else: + message = pickle.dumps( + {"data_id": data_id, "label_count": label_count, "data": output.detach().cpu().numpy(), + "labels": labels, "trace": [self.client_id]} + ) + + self.channel.basic_publish( + exchange='', + routing_key=forward_queue_name, + body=message + ) + + def send_gradient(self, data_id, gradient, trace): + to_client_id = trace[-1] + trace.pop(-1) + backward_queue_name = f'gradient_queue_{to_client_id}' + self.channel.queue_declare(backward_queue_name, durable=False) + + message = pickle.dumps( + {"data_id": data_id, "data": gradient.detach().cpu().numpy(), "trace": trace} + ) + + self.channel.basic_publish( + exchange='', + routing_key=backward_queue_name, + body=message + ) + + def send_to_server(self, message): + self.channel.queue_declare('rpc_queue', durable=False) + self.channel.basic_publish( + exchange='', + routing_key='rpc_queue', + body=pickle.dumps(message) + ) + + def train_on_first_layer(self, model, global_model, label_count, lr, momentum, clip_grad_norm, computer_loss, + control_count=5, train_loader=None, cluster=None, special=False): + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) + data_iter = iter(train_loader) + num_forward = 0 + num_backward = 0 + end_data = False + micro_batch = 8 + data_store = {} + + backward_queue_name = f"gradient_queue_{self.client_id}" + self.channel.queue_declare(queue=backward_queue_name, durable=False) + self.channel.basic_qos(perfetch_count=1) + + model = model.to(self.device) + neural_layers, inputs_per_layer, outputs_per_layer = hook_model(model) + storage_to_calculate_W = [] + chunk = 0 + count_w = 0 + + with tqdm(total=len(train_loader), desc="Processing", unit="step") as pbar: + while True: + # Train model + model.train() + optimizer.zero_grad() + + method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True) + if method_frame and body: + # B process + num_backward += 1 + received_data = pickle.loads(body) + grad_output = received_data["data"] + grad_output = torch.tensor(grad_output).to(self.device) + data_id = received_data["data_id"] + + load_data = data_store.pop(data_id) + torch.autograd.grad(load_data[1], load_data[0], grad_outputs=grad_output, retain_graph=True) + grad_of_outputs_per_layer = [_.grad for _ in outputs_per_layer] + outputs_per_layer = [] + storage_to_calculate_W.append([inputs_per_layer, grad_of_outputs_per_layer]) + else: + if (len(data_store) <= control_count) and (chunk < micro_batch): + # F process + try: + data, label = next(data_iter) + data = data.to(self.device) + data_id = uuid.uuid4() + output = model(data) + load_data.append([data, output]) + intermediate_output = output.detach().requires_grad_(True) + + num_forward += 1 + chunk += 1 + self.data_count += 1 + + self.send_intermediate_output(data_id, label_count, intermediate_output, label, trace=None, cluster=cluster, special=special) + except StopIteration: + end_data = True + else: + # W process + if len(storage_to_calculate_W) > 0: + load_w = storage_to_calculate_W.pop(0) + # load_w : [inputs_per_layer, grad_of_outputs_per_layer] + manual_W(load_w[0], load_w[1], neural_layers) # calculate grad of parameter per layer and sum update it + count_w += 1 + pbar.update(1) + else: + if count_w == micro_batch: + count_w = 0 + chunk = 0 + optimizer.step() + diff --git a/src/Scheduler.py b/src/Scheduler.py index c0be1a5..0d521b1 100644 --- a/src/Scheduler.py +++ b/src/Scheduler.py @@ -111,6 +111,10 @@ def train_on_first_layer(self, model, global_model, label_count, lr, momentum, c data_input = data_store.pop(data_id) output = model(data_input) output.backward(gradient=gradient) + + if clip_grad_norm and clip_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) + optimizer.step() if self.event_time: self.time_event_backward.append(time.time()) @@ -269,6 +273,8 @@ def train_on_middle_layer(self, model, global_model, label_count, lr, momentum, output = model(data_input) data_input.retain_grad() output.backward(gradient=gradient, retain_graph=True) + if clip_grad_norm and clip_grad_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm) optimizer.step() gradient = data_input.grad diff --git a/src/Utils.py b/src/Utils.py index d944be8..f8f9d87 100644 --- a/src/Utils.py +++ b/src/Utils.py @@ -1,3 +1,7 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + import numpy as np import random import pika @@ -38,6 +42,7 @@ def change_name(name): number = int(parts[0]) + i name = f"{number}" + "." + parts[1] return name + new_state_dict = {} for key, value in state_dicts.items(): new_key = change_name(key) @@ -62,3 +67,110 @@ def num_client_in_cluster(client_cluster_label): count_list[num] += 1 count_list = [[x] for x in count_list] return count_list + + +def check_layer(layer): + return isinstance(layer, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)) + + +def hook_model(model): + input_layers = [] + output_layers = [] + neural_layers = [] + + def save_output_hook(module, input, output): + if check_layer(module): + input_layers.append(input[0].detach()) + output.retain_grad() + output_layers.append(output) + + for layer in model.modules(): + if check_layer(layer): + neural_layers.append(layer) + layer.register_forward_hook(save_output_hook) + return neural_layers, input_layers, output_layers + +def manual_linear_grad_weight(x, grad_out, linear_layer): + """ + Calculate manual ∂L/∂W for nn.Linear + Args: + x: input of layer, shape [B, in_features] + grad_out: ∂L/∂z, shape [B, out_features] + linear_layer: nn.Linear + Returns: + grad_w: [out_features, in_features] + """ + grad_w = grad_out.T @ x # [out_features, in_features] + return grad_w + + +def manual_BatchNorm2d_grad_weight(x_in, grad_out, bn_layer, eps=1e-5): + """ + Calculate manual ∂L/∂gamma, ∂L/∂beta for BatchNorm2d + x_in: [B, C, H, W] + grad_out: ∂L/∂y, [B, C, H, W] + bn_layer: nn.BatchNorm2d + """ + mu = x_in.mean(dim=(0, 2, 3), keepdim=True) + var = x_in.var(dim=(0, 2, 3), unbiased=False, keepdim=True) + x_hat = (x_in - mu) / torch.sqrt(var + eps) + + grad_gamma = (grad_out * x_hat).sum(dim=(0, 2, 3)) + grad_beta = grad_out.sum(dim=(0, 2, 3)) + return grad_gamma, grad_beta + + +def manual_conv_grad_weight(x_in, grad_out, conv_layer): + """ + Calculate manual ∂L/∂W for Conv_2d + Args: + x_in: đầu vào của layer này, shape [N, Cin, H, W] + grad_out: ∂L/∂z (gradient output), shape [N, Cout, H_out, W_out] + conv_layer: đối tượng nn.Conv2d + Returns: + grad_w: ∂L/∂W, shape [Cout, Cin, Kh, Kw] + """ + Kh, Kw = conv_layer.kernel_size + stride = conv_layer.stride + padding = conv_layer.padding + dilation = conv_layer.dilation + + x_unfold = F.unfold(x_in, kernel_size=(Kh, Kw), stride=stride, padding=padding, dilation=dilation) + + N = grad_out.shape[0] + grad_out = grad_out.reshape(N, grad_out.shape[1], -1) # [N, Cout, L] + grad_w_batch = torch.bmm(grad_out, x_unfold.transpose(1, 2)) # [N, Cout, Cin*Kh*Kw] + grad_w = grad_w_batch.sum(dim=0).view(conv_layer.out_channels, conv_layer.in_channels, *conv_layer.kernel_size) + return grad_w + + +def manual_W(inputs, grads_z_per_layer, neural_layers): + grads_w = [] + for i in range(len(neural_layers)): + x_in = inputs[i] + grad_out = grads_z_per_layer[i] # ∂L/∂z tại layer[i] + + Layer = neural_layers[i] + if isinstance(Layer, nn.Conv2d): + grad_w = manual_conv_grad_weight(x_in, grad_out, Layer) + # grad_w = torch.tensor(grad_w) + Layer.weight.grad = grad_w.clone().detach() + if Layer.bias is not None: + # ∂L/∂b = sum over batch, height, width + grad_b = grad_out.sum(dim=(0, 2, 3)) + Layer.bias.grad = grad_b + # print(f'Complete layer {i}') + + elif isinstance(Layer, nn.Linear): + grad_w = manual_linear_grad_weight(x_in, grad_out, Layer) + # grad_w = torch.tensor(grad_w) + Layer.weight.grad = grad_w.clone().detach() + if Layer.bias is not None: + # ∂L/∂b = sum over batch + grad_b = grad_out.sum(dim=0) + Layer.bias.grad = grad_b + + else: + raise print(f"Layer not define backward functions for layer {Layer}") + # grads_w.append(grad_w) + # return grads_w From 885167a994522a628c2130275c6c79d81b5eb61e Mon Sep 17 00:00:00 2001 From: truongtruong373 Date: Thu, 29 May 2025 11:48:10 +0700 Subject: [PATCH 2/3] FBW with architecture 1-1, manual chunk in last layer --- client.py | 10 +- config.yaml | 3 + src/RpcClient.py | 8 +- src/Schedule_zb.py | 286 ++++++++++++++++++++++++++++++++++++--------- src/Scheduler.py | 2 +- src/Server.py | 19 ++- src/Utils.py | 48 +++----- src/Validation.py | 14 +-- 8 files changed, 285 insertions(+), 105 deletions(-) diff --git a/client.py b/client.py index 529fc87..7d1b380 100644 --- a/client.py +++ b/client.py @@ -7,14 +7,14 @@ import src.Log from src.RpcClient import RpcClient -from src.Scheduler import Scheduler +from src.Schedule_zb import Scheduler parser = argparse.ArgumentParser(description="Split learning framework") parser.add_argument('--layer_id', type=int, required=True, help='ID of layer, start from 1') parser.add_argument('--device', type=str, required=False, help='Device of client') parser.add_argument('--event_time', type=bool, default=False, required=False, help='Log event time for debug mode') -parser.add_argument('--performance', type=int, required=False, help='Cluster by device') +parser.add_argument('--p', type=int, required=False, help='Cluster by device') args = parser.parse_args() @@ -45,15 +45,15 @@ connection = pika.BlockingConnection(pika.ConnectionParameters(address, 5672, f'{virtual_host}', credentials)) channel = connection.channel() -if args.performance is None: +if args.p is None: performance = -1 else: - performance = args.performance + performance = args.p if __name__ == "__main__": src.Log.print_with_color("[>>>] Client sending registration message to server...", "red") data = {"action": "REGISTER", "client_id": client_id, "layer_id": args.layer_id, "performance": performance, "message": "Hello from Client!"} - scheduler = Scheduler(client_id, args.layer_id, channel, device, args.event_time) + scheduler = Scheduler(client_id, args.layer_id, channel, device) client = RpcClient(client_id, args.layer_id, address, username, password, scheduler.train_on_device, device) client.send_to_server(data) client.wait_response() diff --git a/config.yaml b/config.yaml index 2d4d26c..e388c1a 100644 --- a/config.yaml +++ b/config.yaml @@ -35,6 +35,9 @@ server: AffinityPropagation: damping: 0.9 max_iter: 1000 + fbw: + enable: True + chunk: 1 rabbit: address: 127.0.0.1 diff --git a/src/RpcClient.py b/src/RpcClient.py index b12a9b1..f035fb5 100644 --- a/src/RpcClient.py +++ b/src/RpcClient.py @@ -74,6 +74,7 @@ def response_message(self, body): label_count = self.response['label_count'] num_layers = self.response['num_layers'] clip_grad_norm = self.response['clip_grad_norm'] + chunk = self.response['chunk'] if self.label_count is None: self.label_count = label_count if self.response['cluster'] is not None: @@ -116,11 +117,11 @@ def response_message(self, body): subset = torch.utils.data.Subset(self.train_set, selected_indices) train_loader = torch.utils.data.DataLoader(subset, batch_size=batch_size, shuffle=True) if cut_layers[1] != 0: - result, size = self.train_func(self.model, self.global_model, self.label_count, lr, momentum, clip_grad_norm, compute_loss, num_layers, control_count, train_loader, self.cluster, special, alone_train=False) + result, size = self.train_func(self.model, self.global_model, self.label_count, lr, momentum, clip_grad_norm, compute_loss, num_layers, control_count, train_loader, self.cluster, special, alone_train=False, chunk=chunk) else: - result, size = self.train_func(self.model, self.global_model, self.label_count, lr, momentum, clip_grad_norm, compute_loss, num_layers, control_count, train_loader, self.cluster, special, alone_train=True) + result, size = self.train_func(self.model, self.global_model, self.label_count, lr, momentum, clip_grad_norm, compute_loss, num_layers, control_count, train_loader, self.cluster, special, alone_train=True, chunk=chunk) else: - result, size = self.train_func(self.model, self.global_model, self.label_count, lr, momentum, clip_grad_norm, compute_loss, num_layers, control_count, None, self.cluster, special) + result, size = self.train_func(self.model, self.global_model, self.label_count, lr, momentum, clip_grad_norm, compute_loss, num_layers, control_count, None, self.cluster, special, chunk=chunk) # Stop training, then send parameters to server model_state_dict = self.model.state_dict() @@ -132,6 +133,7 @@ def response_message(self, body): "message": "Sent parameters to Server", "parameters": model_state_dict} src.Log.print_with_color("[>>>] Client sent parameters to server", "red") self.send_to_server(data) + self.model = None return True elif action == "STOP": return False diff --git a/src/Schedule_zb.py b/src/Schedule_zb.py index 78cf717..5a176f7 100644 --- a/src/Schedule_zb.py +++ b/src/Schedule_zb.py @@ -8,7 +8,7 @@ import torch.optim as optim import torch.nn.functional as f -from src.Utils import manual_W, hook_model +from src.Utils import manual_W import src.Log @@ -79,74 +79,256 @@ def send_to_server(self, message): ) def train_on_first_layer(self, model, global_model, label_count, lr, momentum, clip_grad_norm, computer_loss, - control_count=5, train_loader=None, cluster=None, special=False): + control_count=5, train_loader=None, cluster=None, special=False, chunks=1): optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) data_iter = iter(train_loader) + model = model.to(self.device) num_forward = 0 num_backward = 0 + num_weight = 0 end_data = False - micro_batch = 8 data_store = {} + dict_outputs_per_layer = {} + inputs_per_layer = [] + outputs_per_layer = [] + neural_layers = [] + storage_to_calculate_W = [] + def check_layer(the_layer): + return isinstance(the_layer, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)) + def save_output_hook(module, the_input, the_output): + if check_layer(module): + inputs_per_layer.append(the_input[0].detach()) + the_output.retain_grad() + outputs_per_layer.append(the_output) + + for layer in model.modules(): + if check_layer(layer): + neural_layers.append(layer) + layer.register_forward_hook(save_output_hook) backward_queue_name = f"gradient_queue_{self.client_id}" self.channel.queue_declare(queue=backward_queue_name, durable=False) - self.channel.basic_qos(perfetch_count=1) - - model = model.to(self.device) - neural_layers, inputs_per_layer, outputs_per_layer = hook_model(model) - storage_to_calculate_W = [] - chunk = 0 - count_w = 0 + self.channel.basic_qos(prefetch_count=1) with tqdm(total=len(train_loader), desc="Processing", unit="step") as pbar: while True: - # Train model - model.train() - optimizer.zero_grad() + try: + images , labels = next(data_iter) + image_chunks = torch.chunk(images, chunks, dim=0) + label_chunks = torch.chunk(labels, chunks, dim=0) + micro_batch = iter(zip(image_chunks, label_chunks)) + end_batch = False + model.train() + optimizer.zero_grad() + # Train model + while True: + # Check_backward + method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True) + if method_frame and body: + # B process + num_backward += 1 + # Read message + received_data = pickle.loads(body) + grad_output = received_data["data"] + grad_output = torch.tensor(grad_output).to(self.device) + data_id = received_data["data_id"] + + load_data = data_store.pop(data_id) + # Calculate gradient loss from x + grad_x = torch.autograd.grad(load_data[1], load_data[0], grad_outputs=grad_output, retain_graph=True)[0] + # take out grad of each layer in model + grad_of_outputs_per_layer = [_.grad for _ in dict_outputs_per_layer[data_id]] + # clear + dict_outputs_per_layer[data_id].clear() + del dict_outputs_per_layer[data_id] + storage_to_calculate_W.append([load_data[2], grad_of_outputs_per_layer]) + else: + if not end_batch: + # F process + try: + data, label = next(micro_batch) + data = data.to(self.device).requires_grad_() + + data_id = uuid.uuid4() + output = model(data) + output.retain_grad() + data_store[data_id] = [data, output, inputs_per_layer] + inputs_per_layer = [] + dict_outputs_per_layer[data_id] = outputs_per_layer + outputs_per_layer = [] + intermediate_output = output.detach().requires_grad_(True) + + num_forward += 1 + self.data_count += 1 + + self.send_intermediate_output(data_id, label_count, intermediate_output, label, trace=None, cluster=cluster, special=special) + except StopIteration: + end_batch = True + else: + # W process + if len(storage_to_calculate_W) > 0: + load_w = storage_to_calculate_W.pop(0) + # load_w : [inputs_per_layer, grad_of_outputs_per_layer] + manual_W(load_w[0], load_w[1], neural_layers) # calculate grad of parameter per layer and sum update it + num_weight += 1 + + if (num_forward == num_backward == num_weight) and end_batch: + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + num_forward = 0 + num_backward = 0 + num_weight = 0 + break + except StopIteration: + end_data = True + break + notify_data = {"action": "NOTIFY", "client_id": self.client_id, "layer_id": self.layer_id, + "message": "Finish training!", "cluster": cluster} + + # Finish epoch training, send notify to server + src.Log.print_with_color("[>>>] Finish training!", "red") + self.send_to_server(notify_data) + + broadcast_queue_name = f'reply_{self.client_id}' + while True: # Wait for broadcast + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f"[<<<] Received message from server {received_data}", "blue") + if received_data["action"] == "PAUSE": + return True + time.sleep(0.5) + + + def train_on_last_layer(self, model, global_model, label_count, lr, momentum, clip_grad_norm, compute_loss, cluster, + special=False, chunks=1): + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) + optimizer.zero_grad() + result = True + + criterion = nn.CrossEntropyLoss() + if special: + forward_queue_name = f'intermediate_queue_{self.layer_id - 1}' + else: + forward_queue_name = f'intermediate_queue_{self.layer_id - 1}_{cluster}' + self.channel.queue_declare(queue=forward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + print('Waiting for intermediate output. To exit press CTRL+C') + model.to(self.device) + inputs_per_layer = [] + outputs_per_layer = [] + neural_layers = [] + + def check_layer(the_layer): + return isinstance(the_layer, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)) + + def save_output_hook(module, the_input, the_output): + if check_layer(module): + inputs_per_layer.append(the_input[0].detach()) + the_output.retain_grad() + outputs_per_layer.append(the_output) - method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True) + for layer in model.modules(): + if check_layer(layer): + neural_layers.append(layer) + layer.register_forward_hook(save_output_hook) + + num_forward = 0 + num_backward = 0 + num_weight = 0 + storage_to_calculate_W = [] + while True: + if num_forward < chunks: + method_frame, header_frame, body = self.channel.basic_get(queue=forward_queue_name, auto_ack=True) if method_frame and body: - # B process - num_backward += 1 received_data = pickle.loads(body) - grad_output = received_data["data"] - grad_output = torch.tensor(grad_output).to(self.device) - data_id = received_data["data_id"] + intermediate_output_numpy = received_data['data'] + trace = received_data['trace'] + data_id = received_data['data_id'] + labels = received_data['labels'].to(self.device) + label_count = received_data['label_count'] - load_data = data_store.pop(data_id) - torch.autograd.grad(load_data[1], load_data[0], grad_outputs=grad_output, retain_graph=True) - grad_of_outputs_per_layer = [_.grad for _ in outputs_per_layer] + intermediate_output = torch.tensor(intermediate_output_numpy, requires_grad=True).to(self.device) + intermediate_output.retain_grad() + + # F process + output = model(intermediate_output) + num_forward += 1 + self.data_count += 1 + loss = criterion(output, labels) + print(f'Loss: {loss.item()}') + if torch.isnan(loss).any(): + src.Log.print_with_color("NaN detected in loss", 'yellow') + result = False + + # B process + grad_out = torch.autograd.grad(loss, output, retain_graph=True)[0] + gradient_intermediate = torch.autograd.grad(output, intermediate_output, grad_outputs=grad_out, retain_graph=True)[0] + self.send_gradient(data_id, gradient_intermediate, trace) + grads_of_output_per_layer = [_.grad for _ in outputs_per_layer] outputs_per_layer = [] - storage_to_calculate_W.append([inputs_per_layer, grad_of_outputs_per_layer]) + storage_to_calculate_W.append([inputs_per_layer, grads_of_output_per_layer]) + num_backward += 1 + inputs_per_layer = [] + else: - if (len(data_store) <= control_count) and (chunk < micro_batch): - # F process - try: - data, label = next(data_iter) - data = data.to(self.device) - data_id = uuid.uuid4() - output = model(data) - load_data.append([data, output]) - intermediate_output = output.detach().requires_grad_(True) - - num_forward += 1 - chunk += 1 - self.data_count += 1 - - self.send_intermediate_output(data_id, label_count, intermediate_output, label, trace=None, cluster=cluster, special=special) - except StopIteration: - end_data = True + # W process + if len(storage_to_calculate_W) > 0: + load_w = storage_to_calculate_W.pop(0) + # load_w : [inputs_per_layer, grad_of_outputs_per_layer] + manual_W(load_w[0], load_w[1], neural_layers) # calculate grad of parameter per layer and sum update it + num_weight += 1 + else: - # W process - if len(storage_to_calculate_W) > 0: - load_w = storage_to_calculate_W.pop(0) - # load_w : [inputs_per_layer, grad_of_outputs_per_layer] - manual_W(load_w[0], load_w[1], neural_layers) # calculate grad of parameter per layer and sum update it - count_w += 1 - pbar.update(1) - else: - if count_w == micro_batch: - count_w = 0 - chunk = 0 - optimizer.step() + # Check PAUSE from server + if num_forward == num_backward == num_weight: + broadcast_queue_name = f'reply_{self.client_id}' + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f'[<<<] Received message from server {received_data}', 'blue') + if received_data['action'] == 'PAUSE': + return result + + # perform remaining W and otp step + else: + if len(storage_to_calculate_W) > 0: + load_w = storage_to_calculate_W.pop(0) + manual_W(load_w[0], load_w[1], neural_layers) + num_weight += 1 + else: + if num_forward == num_backward == num_weight: + optimizer.step() + optimizer.zero_grad() + num_forward = 0 + num_backward = 0 + num_weight = 0 + + def train_on_middle_layer(self, model, global_model, label_count, lr, momentum, clip_grad_norm, compute_loss, control_count=5, cluster=None, special=False, chunks=1): + return True + + def alone_training(self, model, global_model, label_count, lr, momentum, clip_grad_norm, compute_loss, train_loader=None, cluster=None, chunks=1): + return True + + def train_on_device(self, model, global_model, label_count, lr, momentum, clip_grad_norm, compute_loss, num_layers, + control_count, train_loader=None, cluster=None, special=False, alone_train=False, chunk=1): + self.data_count = 0 + if self.layer_id == 1: + if alone_train is False: + result = self.train_on_first_layer(model, global_model, label_count, lr, momentum, clip_grad_norm, + compute_loss, control_count, train_loader, cluster, special, chunks=chunk) + else: + result = self.alone_training(model, global_model, label_count, lr, momentum, clip_grad_norm, + compute_loss, train_loader=train_loader, cluster=cluster, chunks=chunk) + elif self.layer_id == num_layers: + result = self.train_on_last_layer(model, global_model, label_count, lr, momentum, clip_grad_norm, + compute_loss, cluster=cluster, special=special, chunks=chunk) + else: + result = self.train_on_middle_layer(model, global_model, label_count, lr, momentum, clip_grad_norm, + compute_loss, control_count, cluster=cluster, special=special, chunks=chunk) + + return result, self.data_count +# Scheduler_zb: chunks diff --git a/src/Scheduler.py b/src/Scheduler.py index 0d521b1..dede979 100644 --- a/src/Scheduler.py +++ b/src/Scheduler.py @@ -368,7 +368,7 @@ def alone_training(self, model, global_model, label_count, lr, momentum, clip_gr return True time.sleep(0.5) - def train_on_device(self, model, global_model, label_count, lr, momentum, clip_grad_norm, compute_loss, num_layers, control_count, train_loader=None, cluster=None, special=False, alone_train=False): + def train_on_device(self, model, global_model, label_count, lr, momentum, clip_grad_norm, compute_loss, num_layers, control_count, train_loader=None, cluster=None, special=False, alone_train=False, chunk=1): self.data_count = 0 if self.layer_id == 1: if alone_train is False: diff --git a/src/Server.py b/src/Server.py index 5e193ba..0b0a318 100644 --- a/src/Server.py +++ b/src/Server.py @@ -58,6 +58,10 @@ def __init__(self, config): self.random_seed = config["server"]["random-seed"] self.label_counts = None + # FBW + self.fbw = config["server"]["fbw"] + self.chunk = self.fbw["chunk"] + if self.random_seed: random.seed(self.random_seed) @@ -273,7 +277,8 @@ def notify_clients(self, start=True, register=True, cluster=None, special=False) "clip_grad_norm": self.clip_grad_norm, "label_count": None, "cluster": None, - "special": False} + "special": False, + "chunk": self.chunk} else: response = {"action": "START", "message": "Server accept the connection!", @@ -289,7 +294,8 @@ def notify_clients(self, start=True, register=True, cluster=None, special=False) "clip_grad_norm": self.clip_grad_norm, "label_count": None, "cluster": None, - "special": False} + "special": False, + "chunk": self.chunk} self.send_to_response(client_id, pickle.dumps(response)) if cluster is None: # Send message to clients when consumed all clients @@ -345,7 +351,8 @@ def notify_clients(self, start=True, register=True, cluster=None, special=False) "compute_loss": self.compute_loss, "label_count": label_counts.pop(), "cluster": clustering, - "special": self.special} + "special": self.special, + "chunk": self.chunk} else: response = {"action": "START", "message": "Server accept the connection!", @@ -361,7 +368,8 @@ def notify_clients(self, start=True, register=True, cluster=None, special=False) "compute_loss": self.compute_loss, "label_count": None, "cluster": clustering, - "special": self.special} + "special": self.special, + "chunk": self.chunk} else: src.Log.print_with_color(f"[>>>] Sent stop training request to client {client_id}", "red") @@ -395,7 +403,8 @@ def notify_clients(self, start=True, register=True, cluster=None, special=False) "clip_grad_norm": self.clip_grad_norm, "label_count": None, "cluster": None, - "special": True} + "special": True, + "chunk": self.chunk} self.send_to_response(client_id, pickle.dumps(response)) def cluster_client(self): diff --git a/src/Utils.py b/src/Utils.py index f8f9d87..a2d8796 100644 --- a/src/Utils.py +++ b/src/Utils.py @@ -69,27 +69,6 @@ def num_client_in_cluster(client_cluster_label): return count_list -def check_layer(layer): - return isinstance(layer, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)) - - -def hook_model(model): - input_layers = [] - output_layers = [] - neural_layers = [] - - def save_output_hook(module, input, output): - if check_layer(module): - input_layers.append(input[0].detach()) - output.retain_grad() - output_layers.append(output) - - for layer in model.modules(): - if check_layer(layer): - neural_layers.append(layer) - layer.register_forward_hook(save_output_hook) - return neural_layers, input_layers, output_layers - def manual_linear_grad_weight(x, grad_out, linear_layer): """ Calculate manual ∂L/∂W for nn.Linear @@ -145,7 +124,6 @@ def manual_conv_grad_weight(x_in, grad_out, conv_layer): def manual_W(inputs, grads_z_per_layer, neural_layers): - grads_w = [] for i in range(len(neural_layers)): x_in = inputs[i] grad_out = grads_z_per_layer[i] # ∂L/∂z tại layer[i] @@ -154,23 +132,29 @@ def manual_W(inputs, grads_z_per_layer, neural_layers): if isinstance(Layer, nn.Conv2d): grad_w = manual_conv_grad_weight(x_in, grad_out, Layer) # grad_w = torch.tensor(grad_w) - Layer.weight.grad = grad_w.clone().detach() - if Layer.bias is not None: + if Layer.weight.grad is None: + Layer.weight.grad = grad_w.clone().detach() + else: + Layer.weight.grad += grad_w.clone().detach() + if Layer.bias.grad is not None: # ∂L/∂b = sum over batch, height, width + grad_b = grad_out.sum(dim=(0, 2, 3)) + Layer.bias.grad += grad_b + else: grad_b = grad_out.sum(dim=(0, 2, 3)) Layer.bias.grad = grad_b - # print(f'Complete layer {i}') elif isinstance(Layer, nn.Linear): grad_w = manual_linear_grad_weight(x_in, grad_out, Layer) # grad_w = torch.tensor(grad_w) - Layer.weight.grad = grad_w.clone().detach() - if Layer.bias is not None: + if Layer.weight.grad is None: + Layer.weight.grad = grad_w.clone().detach() + else: + Layer.weight.grad += grad_w.clone().detach() + if Layer.bias.grad is not None: # ∂L/∂b = sum over batch + grad_b = grad_out.sum(dim=0) + Layer.bias.grad += grad_b + else: grad_b = grad_out.sum(dim=0) Layer.bias.grad = grad_b - - else: - raise print(f"Layer not define backward functions for layer {Layer}") - # grads_w.append(grad_w) - # return grads_w diff --git a/src/Validation.py b/src/Validation.py index 27c00c4..6602ec0 100644 --- a/src/Validation.py +++ b/src/Validation.py @@ -10,8 +10,7 @@ import src.Model - -def test(model_name, state_dict_full, logger): +def test(model_name, state_dict_full, logger=None): transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), @@ -32,11 +31,12 @@ def test(model_name, state_dict_full, logger): model.eval() test_loss = 0 correct = 0 - for data, target in tqdm(test_loader): - output = model(data) - test_loss += F.nll_loss(output, target, reduction='sum').item() - pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability - correct += pred.eq(target.data.view_as(pred)).long().cpu().sum() + with torch.no_grad(): + for data, target in tqdm(test_loader): + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability + correct += pred.eq(target.data.view_as(pred)).long().cpu().sum() test_loss /= len(test_loader.dataset) accuracy = 100.0 * correct / len(test_loader.dataset) From bc7c1d1ea11e2f3c6707dd5e1187524ba42a45b0 Mon Sep 17 00:00:00 2001 From: truongtruong373 Date: Wed, 16 Jul 2025 21:49:42 +0700 Subject: [PATCH 3/3] update train middle layer --- src/RpcClient.py | 1 + src/Schedule_zb.py | 122 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 122 insertions(+), 1 deletion(-) diff --git a/src/RpcClient.py b/src/RpcClient.py index f035fb5..331ff4c 100644 --- a/src/RpcClient.py +++ b/src/RpcClient.py @@ -81,6 +81,7 @@ def response_message(self, body): self.cluster = self.response['cluster'] if self.label_count is not None: src.Log.print_with_color(f"Label distribution of client: {self.label_count}", "yellow") + if self.model is None: klass = getattr(src.Model, model_name) full_model = klass() diff --git a/src/Schedule_zb.py b/src/Schedule_zb.py index 5a176f7..1ef5ce8 100644 --- a/src/Schedule_zb.py +++ b/src/Schedule_zb.py @@ -290,6 +290,8 @@ def save_output_hook(module, the_input, the_output): received_data = pickle.loads(body) src.Log.print_with_color(f'[<<<] Received message from server {received_data}', 'blue') if received_data['action'] == 'PAUSE': + optimizer.step() + optimizer.zero_grad() return result # perform remaining W and otp step @@ -307,7 +309,125 @@ def save_output_hook(module, the_input, the_output): num_weight = 0 def train_on_middle_layer(self, model, global_model, label_count, lr, momentum, clip_grad_norm, compute_loss, control_count=5, cluster=None, special=False, chunks=1): - return True + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum) + optimizer.zero_grad() + criterion = nn.CrossEntropyLoss() + result = True + if special: + forward_queue_name = f'intermediate_queue_{self.layer_id - 1}' + else: + forward_queue_name = f'intermediate_queue_{self.layer_id - 1}_{cluster}' + backward_queue_name = f'gradient_queue_{self.client_id}' + self.channel.queue_declare(queue=forward_queue_name, durable=False) + self.channel.queue_declare(queue=backward_queue_name, durable=False) + self.channel.basic_qos(prefetch_count=1) + data_store = {} + dict_outputs_per_layer = {} + print('Waiting for intermediate output. To exit press CTRL+C') + + model.to(self.device) + inputs_per_layer = [] + outputs_per_layer = [] + neural_layers = [] + def check_layer(the_layer): + return isinstance(the_layer, (nn.Conv2d, nn.Linear, nn.BatchNorm2d)) + + def save_output_hook(module, the_input, the_output): + if check_layer(module): + inputs_per_layer.append(the_input[0].detach()) + the_output.retain_grad() + outputs_per_layer.append(the_output) + + for layer in model.modules(): + if check_layer(layer): + neural_layers.append(layer) + layer.register_forward_hook(save_output_hook) + + num_forward = 0 + num_backward = 0 + num_weight = 0 + storage_to_calculate_W = [] + + while True: + # B process + method_frame, header_frame, body = self.channel.basic_get(queue=backward_queue_name, auto_ack=True) + if method_frame and body: + received_data = pickle.loads(body) + grad_output = received_data['data'] + grad_output = torch.tensor(grad_output).to(self.device) + trace = received_data["trace"] + data_id = received_data["data_id"] + + num_backward += 1 + load_data = data_store.pop(data_id) + # Calculate gradient loss from x + grad_x = torch.autograd.grad(load_data[1], load_data[0], grad_outputs=grad_output, retain_graph=True)[0] + self.send_gradient(data_id, grad_x, trace) + # take out grad of each layer in model + grad_of_outputs_per_layer = [_.grad for _ in dict_outputs_per_layer[data_id]] + dict_outputs_per_layer[data_id].clear() + del dict_outputs_per_layer[data_id] + storage_to_calculate_W.append([load_data[2], grad_of_outputs_per_layer]) + continue + + if num_forward < chunks: + method_frame, header_frame, body = self.channel.basic_get(queue=forward_queue_name, auto_ack=True) + if method_frame and body: + received_data = pickle.loads(body) + intermediate_output_numpy = received_data["data"] + intermediate_output = torch.tensor(intermediate_output_numpy, requires_grad=True).to(self.device) + intermediate_output.retain_grad() + trace = received_data["trace"] + data_id = received_data["data_id"] + labels = received_data["labels"].to(self.device) + label_count = received_data["label_count"] + + # F process + output = model(intermediate_output) + output.retain_grad() + data_store[data_id] = [intermediate_output, output, inputs_per_layer] + inputs_per_layer = [] + dict_outputs_per_layer[data_id] = outputs_per_layer + outputs_per_layer = [] + intermediate_output = output.detach().requires_grad_(True) + + num_forward += 1 + self.data_count += 1 + + self.send_intermediate_output(data_id, label_count, intermediate_output, labels, trace=trace, cluster=cluster, special=special) + continue + + if len(storage_to_calculate_W) > 0: + # W process + load_w = storage_to_calculate_W.pop(0) + manual_W(load_w[0], load_w[1], neural_layers) + num_weight += 1 + continue + else: + if num_forward == num_backward == num_weight == chunks: + optimizer.step() + optimizer.zero_grad() + num_forward = 0 + num_backward = 0 + num_weight = 0 + continue + + if num_forward == num_backward == num_weight: + broadcast_queue_name = f'reply_{self.client_id}' + method_frame, header_frame, body = self.channel.basic_get(queue=broadcast_queue_name, auto_ack=True) + + if body: + received_data = pickle.loads(body) + src.Log.print_with_color(f'[<<<] Received message from server {received_data}', 'blue') + if received_data['action'] == 'PAUSE': + for Layer in neural_layers: + if Layer.weight.grad is not None: + Layer.weight.grad /= chunks + if Layer.bias.grad is not None: + Layer.bias.grad /= chunks + optimizer.step() + optimizer.zero_grad() + return result def alone_training(self, model, global_model, label_count, lr, momentum, clip_grad_norm, compute_loss, train_loader=None, cluster=None, chunks=1): return True