import time import os import os import awswrangler as wr import random from malware_detection_utils.utils import logMessage ## Add new input max number of images to train ## IMAGES_TO_TRAIN, IMAGES_TO_TEST - Remove .. replace with static values of 80%-20% ## list_of_trainging_objects, list_of_test_objects convert to a set so that duplicates are no present LOGTYPE_ERROR = 'ERROR' LOGTYPE_INFO = 'INFO' LOGTYPE_DEBUG = 'DEBUG' def main(): start_time = time.time() array_index = int(os.environ.get('AWS_BATCH_JOB_ARRAY_INDEX', '0')) aws_region=os.environ.get('AWS_REGION') image_bucket=os.environ.get('IMAGE_BUCKET') training_bucket=os.environ.get('TRAINING_BUCKET') prefix_list=list(os.environ.get('PREFIX_LIST').split(",")) #Example of prefix_list #prefix_list = ["adware", "flooder", "ransomware", "dropper", "spyware", "packed", "crypto_miner", "file_infector", "installer", "worm", "downloader"] #prefix_list = ["benign"] images_to_train=int(os.environ.get('IMAGES_TO_TRAIN')) logMessage(f"AWS_REGION {aws_region}", LOGTYPE_INFO) logMessage(f"IMAGE_BUCKET {image_bucket}", LOGTYPE_INFO) logMessage(f"TRAINING_BUCKET {training_bucket}", LOGTYPE_INFO) logMessage(f"PREFIX_LIST {prefix_list}", LOGTYPE_INFO) # Random Image Copy prefix = prefix_list[array_index] logMessage("Reading Images", LOGTYPE_INFO) list_of_objects=[] list_of_trainging_objects=set() enterwhileloop=True logMessage(f"Reading Images in {prefix}", LOGTYPE_INFO) list_of_objects = wr.s3.list_objects(f's3://{image_bucket}/' + prefix) if len(list_of_objects) < images_to_train: images_to_train=len(list_of_objects) list_of_trainging_objects.update(list_of_objects) enterwhileloop=False while len(list_of_trainging_objects) <= images_to_train and enterwhileloop: random_key = random.choice(list_of_objects) list_of_trainging_objects.add(random_key) logMessage("Length of list_of_trainging_objects" + str(len(list_of_trainging_objects)), LOGTYPE_INFO) logMessage("Copying Images", LOGTYPE_INFO) if 'benign' in prefix_list: wr.s3.copy_objects( paths=list(list_of_trainging_objects), source_path=f"s3://{image_bucket}", target_path=f"s3://{training_bucket}/", use_threads=(True) ) else: wr.s3.copy_objects( paths=list(list_of_trainging_objects), source_path=f"s3://{image_bucket}", target_path=f"s3://{training_bucket}/malware/", use_threads=(True) ) end_time = time.time() time_spent = (end_time - start_time) * 1000 logMessage(f"Time Spent Processing {time_spent} ms", LOGTYPE_INFO) print("Python file invoked") if __name__ == '__main__': main()