{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "e87658c9", "metadata": {}, "outputs": [], "source": [ "!pip install --upgrade pip" ] }, { "cell_type": "code", "execution_count": null, "id": "c9731b65", "metadata": {}, "outputs": [], "source": [ "!pip install --upgrade sagemaker" ] }, { "cell_type": "code", "execution_count": null, "id": "6cc62919", "metadata": {}, "outputs": [], "source": [ "!pip install flwr==1.3.0" ] }, { "cell_type": "markdown", "id": "c6e6a1bc", "metadata": {}, "source": [ "## Assume role to kick off training job in client account" ] }, { "cell_type": "code", "execution_count": null, "id": "5a59a09f", "metadata": {}, "outputs": [], "source": [ "import boto3\n", "\n", "sts_client = boto3.client('sts')\n", "assumed_role_object = sts_client.assume_role(\n", " RoleArn = \"arn:aws:iam:::role/FL-kickoff-client-job\",\n", " RoleSessionName = \"AssumeRoleSession1\"\n", ")\n", "\n", "credentials = assumed_role_object['Credentials']" ] }, { "cell_type": "code", "execution_count": null, "id": "7375e19e", "metadata": {}, "outputs": [], "source": [ "sagemaker_client = boto3.client(\n", " 'sagemaker',\n", " aws_access_key_id = credentials['AccessKeyId'],\n", " aws_secret_access_key = credentials['SecretAccessKey'],\n", " aws_session_token = credentials['SessionToken'],\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "0e6b085e", "metadata": {}, "outputs": [], "source": [ "from sagemaker import image_uris\n", "\n", "framework_version = \"1.0-1\"\n", "region = \"us-east-1\"\n", "\n", "training_image = image_uris.retrieve(\n", " framework=\"sklearn\",\n", " region=region,\n", " version=framework_version,\n", " py_version=\"py3\",\n", " instance_type=\"ml.m5.xlarge\",\n", ")\n", "print(training_image)" ] }, { "cell_type": "code", "execution_count": null, "id": "de700458", "metadata": {}, "outputs": [], "source": [ "import datetime\n", "\n", "training_job_name = \"client-training-job-\" + datetime.datetime.now().strftime(\"%Y-%m-%d-%H-%M-%S\")\n", "\n", "sagemaker_client.create_training_job(\n", " TrainingJobName = training_job_name,\n", " HyperParameters = {\n", " \"penalty\": \"l2\",\n", " \"max-iter\": \"10\",\n", " \"server-address\":\":8080\", # server IP address\n", " \"sagemaker_program\": \"client.py\",\n", " \"sagemaker_submit_directory\": \"s3:///client_code/source.tar.gz\",\n", " },\n", " AlgorithmSpecification = {\n", " \"TrainingImage\": training_image,\n", " \"TrainingInputMode\": \"File\",\n", " },\n", " RoleArn = \"arn:aws:iam:::role/service-role/AmazonSageMaker-ExecutionRole-\",\n", " InputDataConfig=[\n", " {\n", " \"ChannelName\": \"train\",\n", " \"DataSource\": {\n", " \"S3DataSource\": {\n", " \"S3DataType\": \"S3Prefix\",\n", " \"S3Uri\": \"s3:///data_prep/\",\n", " \"S3DataDistributionType\": \"FullyReplicated\",\n", " }\n", " },\n", " },\n", " ],\n", " OutputDataConfig = {\n", " \"S3OutputPath\": \"s3:///client_artifact/\"\n", " },\n", " ResourceConfig = {\n", " \"InstanceType\": \"ml.m5.xlarge\", \n", " \"InstanceCount\": 1, \n", " \"VolumeSizeInGB\": 10,\n", " },\n", " VpcConfig={\n", " 'SecurityGroupIds': [\n", " \"sg-\",\n", " ],\n", " 'Subnets': [\n", " \"subnet-\",\n", " ]\n", " },\n", " StoppingCondition = {\n", " \"MaxRuntimeInSeconds\": 86400\n", " },\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "686d5135", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "60011a4c", "metadata": {}, "source": [ "## FL server code" ] }, { "cell_type": "code", "execution_count": null, "id": "9e84ffcd", "metadata": {}, "outputs": [], "source": [ "import flwr as fl\n", "import utils\n", "from sklearn.metrics import log_loss\n", "from sklearn.linear_model import LogisticRegression\n", "from typing import Dict\n", "import argparse\n", "import os\n", "import pandas as pd\n", "import numpy as np\n", "\n", "\n", "def fit_round(rnd: int) -> Dict:\n", " \"\"\"Send round number to client\"\"\"\n", " return {\"rnd\": rnd}\n", "\n", "\n", "def get_evaluate_fn(model: LogisticRegression, X_test, y_test):\n", " \"\"\"Return an evaluation function for server-side evaluation\"\"\"\n", " # The `evaluate` function will be called after every round\n", " \n", " # def evaluate(parameters: fl.common.NDArrays)\n", " # updating due to this error:\n", " # TypeError: evaluate() takes 1 positional argument but 3 were given\n", " def evaluate(server_round, parameters: fl.common.NDArrays, config):\n", " # Update model with the latest parameters\n", " utils.set_model_params(model, parameters)\n", " loss = log_loss(y_test, model.predict_proba(X_test))\n", " accuracy = model.score(X_test, y_test)\n", " return loss, {\"accuracy\": accuracy}\n", "\n", " return evaluate\n", "\n", "\n", "if __name__ == \"__main__\":\n", " \n", " parser = argparse.ArgumentParser()\n", "\n", " \"\"\"Set parameters (e.g., data/model directory, server ip address)\"\"\"\n", " parser.add_argument(\"--model-dir\", type=str, default=\"/home/ec2-user/SageMaker/SM_test/model\") # os.environ.get(\"SM_MODEL_DIR\")\n", " parser.add_argument(\"--train\", type=str, default=\"/home/ec2-user/SageMaker/SM_test/data\") # os.environ.get(\"SM_CHANNEL_TRAIN\"))\n", " parser.add_argument(\"--test\", type=str, default=\"/home/ec2-user/SageMaker/SM_test/data\") # os.environ.get(\"SM_CHANNEL_TEST\"))\n", " \n", " parser.add_argument(\"--train-file\", type=str, default=\"cms_payment_test.csv\")\n", " parser.add_argument(\"--test-file\", type=str, default=\"cms_payment_test.csv\")\n", " \n", " parser.add_argument(\"--server-address\", type=str, default=\":8080\") # server IP address, \"0.0.0.0:8080\" for running on same machine\n", " \n", " args, _ = parser.parse_known_args()\n", " \n", " # Load data (not the same dataset as on client)\n", " _, (X_test, y_test) = utils.load_data(args.train, args.train_file, args.test, args.test_file)\n", " \n", " \"\"\"Initialize the model and federation strategy, then start the server\"\"\"\n", " model = LogisticRegression()\n", " utils.set_initial_params(model)\n", " \n", " strategy = fl.server.strategy.FedAvg(\n", " min_available_clients = 1, # Minimum number of clients that need to be connected to the server before a training round can start\n", " min_fit_clients = 1, # Minimum number of clients to be sampled for the next round\n", " min_evaluate_clients = 1,\n", " evaluate_fn = get_evaluate_fn(model, X_test, y_test),\n", " on_fit_config_fn = fit_round,\n", " )\n", " \n", " fl.server.start_server(\n", " server_address = args.server_address, \n", " strategy = strategy, \n", " config = fl.server.ServerConfig(num_rounds=3) # run for 3 rounds\n", " )\n", " \n", " utils.save_model(args.model_dir, model)" ] }, { "cell_type": "code", "execution_count": null, "id": "06b9ef3d", "metadata": {}, "outputs": [], "source": [ "# !route -n" ] }, { "cell_type": "code", "execution_count": null, "id": "5925f294", "metadata": {}, "outputs": [], "source": [ "\"\"\"Test the final federated model\"\"\"\n", "\n", "import pandas as pd\n", "import joblib\n", "import os\n", "from sklearn.metrics import classification_report\n", "\n", "test_path = \"/home/ec2-user/SageMaker/SM_test/data\"\n", "test_data = pd.read_csv(os.path.join(test_path, \"cms_payment_test.csv\"), delimiter=\",\") # testing dataset is from data_prep\n", "\n", "test_y = test_data.iloc[:, 0].to_numpy()\n", "test_X = test_data.iloc[:, 1:].to_numpy()\n", "\n", "model_path = \"/home/ec2-user/SageMaker/SM_test/model\"\n", "model = joblib.load(os.path.join(model_path, \"model.joblib\"))\n", "\n", "test_preds = model.predict(test_X)\n", "print(classification_report(test_y, test_preds, target_names=['non-fraud', 'fraud']))" ] } ], "metadata": { "availableInstances": [ { "_defaultOrder": 0, "_isFastLaunch": true, "category": "General purpose", "gpuNum": 0, "memoryGiB": 4, "name": "ml.t3.medium", "vcpuNum": 2 }, { "_defaultOrder": 1, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 8, "name": "ml.t3.large", "vcpuNum": 2 }, { "_defaultOrder": 2, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 16, "name": "ml.t3.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 3, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 32, "name": "ml.t3.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 4, "_isFastLaunch": true, "category": "General purpose", "gpuNum": 0, "memoryGiB": 8, "name": "ml.m5.large", "vcpuNum": 2 }, { "_defaultOrder": 5, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 16, "name": "ml.m5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 6, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 32, "name": "ml.m5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 7, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 64, "name": "ml.m5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 8, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 128, "name": "ml.m5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 9, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 192, "name": "ml.m5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 10, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 256, "name": "ml.m5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 11, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 384, "name": "ml.m5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 12, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 8, "name": "ml.m5d.large", "vcpuNum": 2 }, { "_defaultOrder": 13, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 16, "name": "ml.m5d.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 14, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 32, "name": "ml.m5d.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 15, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 64, "name": "ml.m5d.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 16, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 128, "name": "ml.m5d.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 17, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 192, "name": "ml.m5d.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 18, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 256, "name": "ml.m5d.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 19, "_isFastLaunch": false, "category": "General purpose", "gpuNum": 0, "memoryGiB": 384, "name": "ml.m5d.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 20, "_isFastLaunch": true, "category": "Compute optimized", "gpuNum": 0, "memoryGiB": 4, "name": "ml.c5.large", "vcpuNum": 2 }, { "_defaultOrder": 21, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "memoryGiB": 8, "name": "ml.c5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 22, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "memoryGiB": 16, "name": "ml.c5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 23, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "memoryGiB": 32, "name": "ml.c5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 24, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "memoryGiB": 72, "name": "ml.c5.9xlarge", "vcpuNum": 36 }, { "_defaultOrder": 25, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "memoryGiB": 96, "name": "ml.c5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 26, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "memoryGiB": 144, "name": "ml.c5.18xlarge", "vcpuNum": 72 }, { "_defaultOrder": 27, "_isFastLaunch": false, "category": "Compute optimized", "gpuNum": 0, "memoryGiB": 192, "name": "ml.c5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 28, "_isFastLaunch": true, "category": "Accelerated computing", "gpuNum": 1, "memoryGiB": 16, "name": "ml.g4dn.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 29, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "memoryGiB": 32, "name": "ml.g4dn.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 30, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "memoryGiB": 64, "name": "ml.g4dn.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 31, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "memoryGiB": 128, "name": "ml.g4dn.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 32, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "memoryGiB": 192, "name": "ml.g4dn.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 33, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "memoryGiB": 256, "name": "ml.g4dn.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 34, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "memoryGiB": 61, "name": "ml.p3.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 35, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "memoryGiB": 244, "name": "ml.p3.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 36, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "memoryGiB": 488, "name": "ml.p3.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 37, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "memoryGiB": 768, "name": "ml.p3dn.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 38, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "memoryGiB": 16, "name": "ml.r5.large", "vcpuNum": 2 }, { "_defaultOrder": 39, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "memoryGiB": 32, "name": "ml.r5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 40, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "memoryGiB": 64, "name": "ml.r5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 41, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "memoryGiB": 128, "name": "ml.r5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 42, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "memoryGiB": 256, "name": "ml.r5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 43, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "memoryGiB": 384, "name": "ml.r5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 44, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "memoryGiB": 512, "name": "ml.r5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 45, "_isFastLaunch": false, "category": "Memory Optimized", "gpuNum": 0, "memoryGiB": 768, "name": "ml.r5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 46, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "memoryGiB": 16, "name": "ml.g5.xlarge", "vcpuNum": 4 }, { "_defaultOrder": 47, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "memoryGiB": 32, "name": "ml.g5.2xlarge", "vcpuNum": 8 }, { "_defaultOrder": 48, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "memoryGiB": 64, "name": "ml.g5.4xlarge", "vcpuNum": 16 }, { "_defaultOrder": 49, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "memoryGiB": 128, "name": "ml.g5.8xlarge", "vcpuNum": 32 }, { "_defaultOrder": 50, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 1, "memoryGiB": 256, "name": "ml.g5.16xlarge", "vcpuNum": 64 }, { "_defaultOrder": 51, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "memoryGiB": 192, "name": "ml.g5.12xlarge", "vcpuNum": 48 }, { "_defaultOrder": 52, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 4, "memoryGiB": 384, "name": "ml.g5.24xlarge", "vcpuNum": 96 }, { "_defaultOrder": 53, "_isFastLaunch": false, "category": "Accelerated computing", "gpuNum": 8, "memoryGiB": 768, "name": "ml.g5.48xlarge", "vcpuNum": 192 } ], "instance_type": "ml.t3.medium", "kernelspec": { "display_name": "Python 3 (Data Science)", "language": "python", "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.10" } }, "nbformat": 4, "nbformat_minor": 5 }