In [None]:
import sys
import boto3
import numpy as np
import tensorflow as tf # use tensorflow for local training
import json
import warnings
import local_data_processing
import client_fedlearn
import time
from datetime import timedelta
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 
start_time = 0
end_time = 0
hasStarted = True

# load configuration file
config_client = 'code_repo/client2_config.json'
# output the name of FL trained model
global_model = None 

##############################################################################################
# client info from the configuration file
with open(config_client, 'r') as config_file:
 client_config = json.load(config_file)
s3_fl_model_registry = client_config["s3_fl_model_registry"]
dynamodb_table_model_info = client_config["dynamodb_table_model_info"]
member_ID = int(client_config["member_ID"])
sqs_region = client_config["sqs_region"]
client_queue_name = client_config["client_sqs_name"]

# local data processing at the client
num_train_per_client = 20000
num_test_per_client = 3000
x_train_client, y_train_client, x_test_client, y_test_client = local_data_processing.input_fn(member_ID, num_train_per_client, num_test_per_client)

cross_account = False
# if cross account, all access should be cross account
if (
 "cross_account_sqs_role" in client_config.keys() and 
 "cross_account_s3_role" in client_config.keys() and
 "cross_account_dynamodb_role" in client_config.keys() 
 ):
 cross_account = True
else:
 assert("cross_account_sqs_role" not in client_config.keys() and 
 "cross_account_s3_role" not in client_config.keys() and
 "cross_account_dynamodb_role" not in client_config.keys())

# create a FL client
client = client_fedlearn.FLClient(member_ID, x_train_client, y_train_client, x_test_client, y_test_client)

signalTerminate = False
while not signalTerminate:
 # receive a message from associate queue
 if cross_account == True:
 messages = client.receiveNotificationsFromServer(sqs_region, client_queue_name, client_config["cross_account_sqs_role"])
 else:
 messages = client.receiveNotificationsFromServer(sqs_region, client_queue_name)

 # collect all received notifications
 transactions = []
 curr_round = -1
 for msg in messages:
 msg_body = msg.body
 msg_body_json = json.loads(msg_body)
 msg_rec = msg_body_json["Message"]
 transaction_dict = json.loads(msg_rec) 

 # check if termination is signaled from the server
 if type(transaction_dict) == str:
 transaction_dict = json.loads(transaction_dict)
 if transaction_dict['roundId'] == "NA":
 if cross_account == True:
 global_model = client.downloadFLGlobalModel(s3_fl_model_registry, client_config["cross_account_s3_role"])
 else: 
 global_model = client.downloadFLGlobalModel(s3_fl_model_registry)
 end_time = time.monotonic()
 print("FL training is finished")
 signalTerminate = True
 break 
 
 transaction = transaction_dict["Input"]
 transaction["TaskToken"] = transaction_dict["TaskToken"] # make a single dict 
 transactions.append(transaction)
 
 if int(transaction['roundId']) > curr_round:
 curr_round = transaction['roundId']
 
 if len(transactions) > 0: 
 infoServer = client.processGlobalModelInfoFromServer(transactions) 
 
 if (infoServer != None): 
 if hasStarted:
 start_time = time.monotonic()
 hasStarted = False
 print("FL training round: " + str(curr_round) + "\n")
 print("1: Receive notification") 
 printDiction = {
 "taskName": transaction["Task_Name"],
 "taskId": transaction["Task_ID"],
 "roundId": transaction["roundId"],
 "memberId": transaction["member_ID"],
 "numClientEpochs": transaction["numClientEpochs"],
 "trainAcc": float(transaction["trainAcc"]) if transaction["trainAcc"] != "NA" else "NA", 
 "testAcc" : float(transaction["testAcc"]) if transaction["testAcc"] != "NA" else "NA", 
 "trainLoss": float(transaction["trainLoss"]) if transaction["trainLoss"] != "NA" else "NA", 
 "testLoss" : float(transaction["testLoss"]) if transaction["testLoss"] != "NA" else "NA", 
 "weightsFile": transaction["weightsFile"],
 "numClientsRequired": transaction["numClientsRequired"],
 "source": transaction["source"], 
 "taskToken": transaction["TaskToken"],
 } 
 print("Received notification: {} \n".format(str(printDiction))) 
 
 # obtain the model file 
 global_model_name = infoServer["weightsFile"]
 
 # local training 
 if cross_account == True:
 local_model_name, local_model_info = client.localTraining(global_model_name, s3_fl_model_registry, client_config["cross_account_s3_role"])
 else:
 local_model_name, local_model_info = client.localTraining(global_model_name, s3_fl_model_registry)
 
 # write local model and transactions to server 
 taskToken = infoServer["TaskToken"] 
 local_model_info["TaskToken"] = taskToken

 print("4: Upload local model & its information ...")
 if cross_account == True:
 client.uploadToFLServer(s3_fl_model_registry, 
 local_model_name, 
 dynamodb_table_model_info, 
 local_model_info, 
 client_config["cross_account_s3_role"], 
 client_config["cross_account_dynamodb_role"]
 )
 else:
 client.uploadToFLServer(s3_fl_model_registry, 
 local_model_name, 
 dynamodb_table_model_info,
 local_model_info)

 # update the roundId to make sure no duplicated local training
 client.roundId = client.roundId + 1

In [None]:
from code_repo import MODEL
# print the FL trained global model name
print(global_model)

# evaluate the FL trained global model with the same set of central testing dataset
mlmodel = MODEL.MLMODEL()
model = mlmodel.getModel()
weights = np.load("models/" + global_model, allow_pickle=True) 
model.set_weights(weights)

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

print("Accuracy for the test dataset")
model.evaluate(x_test, y_test, verbose=2)
print("Training time spent: " + str(timedelta(seconds=end_time - start_time)))