package software.amazon.ssm.patchbaseline; import software.amazon.awssdk.services.ssm.model.UpdatePatchBaselineRequest; import software.amazon.awssdk.services.ssm.model.UpdatePatchBaselineResponse; import software.amazon.awssdk.services.ssm.model.GetPatchBaselineRequest; import software.amazon.awssdk.services.ssm.model.GetPatchBaselineResponse; import software.amazon.awssdk.services.ssm.model.DeregisterPatchBaselineForPatchGroupRequest; import software.amazon.awssdk.services.ssm.model.DeregisterPatchBaselineForPatchGroupResponse; import software.amazon.awssdk.services.ssm.model.RegisterPatchBaselineForPatchGroupRequest; import software.amazon.awssdk.services.ssm.model.RegisterPatchBaselineForPatchGroupResponse; import software.amazon.awssdk.services.ssm.model.RegisterDefaultPatchBaselineRequest; import software.amazon.awssdk.services.ssm.model.RegisterDefaultPatchBaselineResponse; import software.amazon.awssdk.utils.CollectionUtils; import software.amazon.cloudformation.proxy.AmazonWebServicesClientProxy; import software.amazon.cloudformation.proxy.Logger; import software.amazon.cloudformation.proxy.ProgressEvent; import software.amazon.cloudformation.proxy.OperationStatus; import software.amazon.cloudformation.proxy.ResourceHandlerRequest; import software.amazon.awssdk.services.ssm.SsmClient; import software.amazon.ssm.patchbaseline.translator.request.UpdatePatchBaselineRequestTranslator; import software.amazon.ssm.patchbaseline.utils.SsmClientBuilder; import static software.amazon.ssm.patchbaseline.ResourceModel.TYPE_NAME; import java.util.ArrayList; import java.util.List; import org.apache.commons.lang3.BooleanUtils; public class UpdateHandler extends BaseHandler<CallbackContext> { private static final SsmClient ssmClient = SsmClientBuilder.getClient(); protected static final String PATCH_BASELINE_RESOURCE_NAME = "PatchBaseline"; private final TagHelper tagHelper; public UpdateHandler() { this(new TagHelper()); } public UpdateHandler(TagHelper tagHelper) { this.tagHelper = tagHelper; } @Override public ProgressEvent<ResourceModel, CallbackContext> handleRequest( final AmazonWebServicesClientProxy proxy, final ResourceHandlerRequest<ResourceModel> request, final CallbackContext callbackContext, final Logger logger) { final ResourceModel model = request.getDesiredResourceState(); // if failed, return previous resource state final ResourceModel previousModel = request.getPreviousResourceState(); String baselineId = model.getId(); logger.log(String.format("INFO Activity %s request with clientRequestToken: %s %n", TYPE_NAME, request.getClientRequestToken())); try { //Build Update request and send SSM UpdatePatchBaselineRequest updatePatchBaselineRequest = UpdatePatchBaselineRequestTranslator.updatePatchBaseline(model); final UpdatePatchBaselineResponse updatePatchBaselineResponse = proxy.injectCredentialsAndInvokeV2(updatePatchBaselineRequest, ssmClient::updatePatchBaseline); logger.log(String.format("INFO Updated patch baseline %s successfully %n", baselineId)); //Get List of current groups GetPatchBaselineRequest getPatchBaselineRequest = GetPatchBaselineRequest.builder() .baselineId(baselineId) .build(); GetPatchBaselineResponse getPatchBaselineResponse = proxy.injectCredentialsAndInvokeV2(getPatchBaselineRequest, ssmClient::getPatchBaseline); List<String> originalGroups = new ArrayList<>(getPatchBaselineResponse.patchGroups()); //Get the new/desired patch groups List<String> newGroups = CollectionUtils.isNullOrEmpty(model.getPatchGroups()) ? new ArrayList<>() : model.getPatchGroups(); //Compute the intersection of the two lists (the groups that don't need to be changed) List<String> intersectingGroups = new ArrayList<>(originalGroups); intersectingGroups.retainAll(newGroups); //The groups we need to remove are ORIGINAL - INTERSECT //The groups we need to add are DESIRED - INTERSECT newGroups.removeAll(intersectingGroups); originalGroups.removeAll(intersectingGroups); //Remove the old groups first for (String group : originalGroups) { DeregisterPatchBaselineForPatchGroupRequest deregisterRequest = DeregisterPatchBaselineForPatchGroupRequest.builder() .baselineId(baselineId) .patchGroup(group) .build(); DeregisterPatchBaselineForPatchGroupResponse deregisterResponse = proxy.injectCredentialsAndInvokeV2(deregisterRequest, ssmClient::deregisterPatchBaselineForPatchGroup); } logger.log(String.format("INFO Deregistered old group(s) from patch baseline %s %n", getPatchBaselineResponse.baselineId())); //Add the new groups after for (String group : newGroups) { RegisterPatchBaselineForPatchGroupRequest groupRequest = RegisterPatchBaselineForPatchGroupRequest.builder() .baselineId(baselineId) .patchGroup(group) .build(); RegisterPatchBaselineForPatchGroupResponse groupResponse = proxy.injectCredentialsAndInvokeV2(groupRequest, ssmClient::registerPatchBaselineForPatchGroup); } logger.log(String.format("INFO Registered new group(s) from patch baseline %s %n", baselineId)); //Remove old tags (except those that are overwritten) then add new tags tagHelper.updateTagsForResource(request, PATCH_BASELINE_RESOURCE_NAME, ssmClient, proxy); logger.log(String.format("INFO Updated tags for patch baseline %s %n", baselineId)); // Set to default patch baseline if (BooleanUtils.isTrue(model.getDefaultBaseline()) && !BooleanUtils.isTrue(previousModel.getDefaultBaseline())){ RegisterDefaultPatchBaselineRequest registerDefaultPatchBaselineRequest = RegisterDefaultPatchBaselineRequest.builder() .baselineId(baselineId) .build(); RegisterDefaultPatchBaselineResponse registerDefaultPatchBaselineResponse = proxy.injectCredentialsAndInvokeV2(registerDefaultPatchBaselineRequest, ssmClient::registerDefaultPatchBaseline); logger.log(String.format("INFO Registered patch baseline %s to default patch baseline successfully %n", baselineId)); } //If we made it here, we're done logger.log(String.format("INFO Successfully updated patch baseline %s %n", baselineId)); return ProgressEvent.<ResourceModel, CallbackContext>builder() .resourceModel(model) .status(OperationStatus.SUCCESS) .build(); } catch (Exception e) { return Resource.handleException(e, previousModel, baselineId, logger); } } }