{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "import boto3\n", "import numpy as np\n", "import tensorflow as tf # use tensorflow for local training\n", "import json\n", "import warnings\n", "import local_data_processing\n", "import client_fedlearn\n", "import time\n", "from datetime import timedelta\n", "warnings.filterwarnings(\"ignore\", category=np.VisibleDeprecationWarning) \n", "start_time = 0\n", "end_time = 0\n", "hasStarted = True\n", "\n", "# load configuration file\n", "config_client = 'code_repo/client2_config.json'\n", "# output the name of FL trained model\n", "global_model = None \n", "\n", "##############################################################################################\n", "# client info from the configuration file\n", "with open(config_client, 'r') as config_file:\n", " client_config = json.load(config_file)\n", "s3_fl_model_registry = client_config[\"s3_fl_model_registry\"]\n", "dynamodb_table_model_info = client_config[\"dynamodb_table_model_info\"]\n", "member_ID = int(client_config[\"member_ID\"])\n", "sqs_region = client_config[\"sqs_region\"]\n", "client_queue_name = client_config[\"client_sqs_name\"]\n", "\n", "# local data processing at the client\n", "num_train_per_client = 20000\n", "num_test_per_client = 3000\n", "x_train_client, y_train_client, x_test_client, y_test_client = local_data_processing.input_fn(member_ID, num_train_per_client, num_test_per_client)\n", "\n", "cross_account = False\n", "# if cross account, all access should be cross account\n", "if (\n", " \"cross_account_sqs_role\" in client_config.keys() and \n", " \"cross_account_s3_role\" in client_config.keys() and\n", " \"cross_account_dynamodb_role\" in client_config.keys() \n", " ):\n", " cross_account = True\n", "else:\n", " assert(\"cross_account_sqs_role\" not in client_config.keys() and \n", " \"cross_account_s3_role\" not in client_config.keys() and\n", " \"cross_account_dynamodb_role\" not in client_config.keys())\n", "\n", "# create a FL client\n", "client = client_fedlearn.FLClient(member_ID, x_train_client, y_train_client, x_test_client, y_test_client)\n", "\n", "signalTerminate = False\n", "while not signalTerminate:\n", " # receive a message from associate queue\n", " if cross_account == True:\n", " messages = client.receiveNotificationsFromServer(sqs_region, client_queue_name, client_config[\"cross_account_sqs_role\"])\n", " else:\n", " messages = client.receiveNotificationsFromServer(sqs_region, client_queue_name)\n", "\n", " # collect all received notifications\n", " transactions = []\n", " curr_round = -1\n", " for msg in messages:\n", " msg_body = msg.body\n", " msg_body_json = json.loads(msg_body)\n", " msg_rec = msg_body_json[\"Message\"]\n", " transaction_dict = json.loads(msg_rec) \n", "\n", " # check if termination is signaled from the server\n", " if type(transaction_dict) == str:\n", " transaction_dict = json.loads(transaction_dict)\n", " if transaction_dict['roundId'] == \"NA\":\n", " if cross_account == True:\n", " global_model = client.downloadFLGlobalModel(s3_fl_model_registry, client_config[\"cross_account_s3_role\"])\n", " else: \n", " global_model = client.downloadFLGlobalModel(s3_fl_model_registry)\n", " end_time = time.monotonic()\n", " print(\"FL training is finished\")\n", " signalTerminate = True\n", " break \n", " \n", " transaction = transaction_dict[\"Input\"]\n", " transaction[\"TaskToken\"] = transaction_dict[\"TaskToken\"] # make a single dict \n", " transactions.append(transaction)\n", " \n", " if int(transaction['roundId']) > curr_round:\n", " curr_round = transaction['roundId']\n", " \n", " if len(transactions) > 0: \n", " infoServer = client.processGlobalModelInfoFromServer(transactions) \n", " \n", " if (infoServer != None): \n", " if hasStarted:\n", " start_time = time.monotonic()\n", " hasStarted = False\n", " print(\"FL training round: \" + str(curr_round) + \"\\n\")\n", " print(\"1: Receive notification\") \n", " printDiction = {\n", " \"taskName\": transaction[\"Task_Name\"],\n", " \"taskId\": transaction[\"Task_ID\"],\n", " \"roundId\": transaction[\"roundId\"],\n", " \"memberId\": transaction[\"member_ID\"],\n", " \"numClientEpochs\": transaction[\"numClientEpochs\"],\n", " \"trainAcc\": float(transaction[\"trainAcc\"]) if transaction[\"trainAcc\"] != \"NA\" else \"NA\", \n", " \"testAcc\" : float(transaction[\"testAcc\"]) if transaction[\"testAcc\"] != \"NA\" else \"NA\", \n", " \"trainLoss\": float(transaction[\"trainLoss\"]) if transaction[\"trainLoss\"] != \"NA\" else \"NA\", \n", " \"testLoss\" : float(transaction[\"testLoss\"]) if transaction[\"testLoss\"] != \"NA\" else \"NA\", \n", " \"weightsFile\": transaction[\"weightsFile\"],\n", " \"numClientsRequired\": transaction[\"numClientsRequired\"],\n", " \"source\": transaction[\"source\"], \n", " \"taskToken\": transaction[\"TaskToken\"],\n", " } \n", " print(\"Received notification: {} \\n\".format(str(printDiction))) \n", " \n", " # obtain the model file \n", " global_model_name = infoServer[\"weightsFile\"]\n", " \n", " # local training \n", " if cross_account == True:\n", " local_model_name, local_model_info = client.localTraining(global_model_name, s3_fl_model_registry, client_config[\"cross_account_s3_role\"])\n", " else:\n", " local_model_name, local_model_info = client.localTraining(global_model_name, s3_fl_model_registry)\n", " \n", " # write local model and transactions to server \n", " taskToken = infoServer[\"TaskToken\"] \n", " local_model_info[\"TaskToken\"] = taskToken\n", "\n", " print(\"4: Upload local model & its information ...\")\n", " if cross_account == True:\n", " client.uploadToFLServer(s3_fl_model_registry, \n", " local_model_name, \n", " dynamodb_table_model_info, \n", " local_model_info, \n", " client_config[\"cross_account_s3_role\"], \n", " client_config[\"cross_account_dynamodb_role\"]\n", " )\n", " else:\n", " client.uploadToFLServer(s3_fl_model_registry, \n", " local_model_name, \n", " dynamodb_table_model_info,\n", " local_model_info)\n", "\n", " # update the roundId to make sure no duplicated local training\n", " client.roundId = client.roundId + 1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from code_repo import MODEL\n", "# print the FL trained global model name\n", "print(global_model)\n", "\n", "# evaluate the FL trained global model with the same set of central testing dataset\n", "mlmodel = MODEL.MLMODEL()\n", "model = mlmodel.getModel()\n", "weights = np.load(\"models/\" + global_model, allow_pickle=True) \n", "model.set_weights(weights)\n", "\n", "mnist = tf.keras.datasets.mnist\n", "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", "x_train, x_test = x_train / 255.0, x_test / 255.0\n", "\n", "print(\"Accuracy for the test dataset\")\n", "model.evaluate(x_test, y_test, verbose=2)\n", "print(\"Training time spent: \" + str(timedelta(seconds=end_time - start_time)))" ] } ], "metadata": { "instance_type": "ml.m5.large", "kernelspec": { "display_name": "Python 3 (TensorFlow 2.1 Python 3.6 CPU Optimized)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/tensorflow-2.1-cpu-py36" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.13" } }, "nbformat": 4, "nbformat_minor": 4 }