package main import ( "context" "encoding/json" "fmt" "os" "strings" "github.com/aws/aws-lambda-go/lambda" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/ecr" ) // ScanSpec represents configuration for the target repository type ScanSpec struct { // ID is a unique identifier for the scan spec ID string `json:"id"` // CreationTime is the UTC timestamp of when the scan spec was created CreationTime string `json:"created"` // Region specifies the region the repository is in Region string `json:"region"` // RegistryID specifies the registry ID RegistryID string `json:"registry"` // Repository specifies the repository name Repository string `json:"repository"` // Tags to take into consideration, if empty, all tags will be scanned Tags []string `json:"tags"` } func startScan(scanspec ScanSpec) error { s := session.Must(session.NewSession(&aws.Config{ Region: aws.String(scanspec.Region), })) svc := ecr.New(s) scaninput := &ecr.StartImageScanInput{ RepositoryName: &scanspec.Repository, RegistryId: &scanspec.RegistryID, } switch len(scanspec.Tags) { case 0: // empty list of tags, scan all tags: fmt.Printf("DEBUG:: scanning all tags for repo %v\n", scanspec.Repository) lio, err := svc.ListImages(&ecr.ListImagesInput{ RepositoryName: &scanspec.Repository, RegistryId: &scanspec.RegistryID, Filter: &ecr.ListImagesFilter{ TagStatus: aws.String("TAGGED"), }, }) if err != nil { fmt.Println(err) return err } for _, iid := range lio.ImageIds { scaninput.ImageId = iid result, err := svc.StartImageScan(scaninput) if err != nil { fmt.Println(err) return err } fmt.Printf("DEBUG:: result for tag %v: %v\n", *iid.ImageTag, result) } default: // iterate over the tags specified in the config: fmt.Printf("DEBUG:: scanning tags %v for repo %v\n", scanspec.Tags, scanspec.Repository) for _, tag := range scanspec.Tags { scaninput.ImageId = &ecr.ImageIdentifier{ ImageTag: aws.String(tag), } result, err := svc.StartImageScan(scaninput) if err != nil { fmt.Println(err) return err } fmt.Printf("DEBUG:: result for tag %v: %v\n", tag, result) } } return nil } // fetchScanSpec returns the scan spec // in a given bucket, with a given scan ID func fetchScanSpec(configbucket, scanid string) (ScanSpec, error) { ss := ScanSpec{} cfg, err := config.LoadDefaultConfig(context.TODO()) if err != nil { return ss, err } // Create an S3 Client with the config client := s3.NewFromConfig(cfg) // Create an uploader passing it the client downloader := manager.NewDownloader(client) buf := aws.NewWriteAtBuffer([]byte{}) _, err = downloader.Download(context.TODO(), buf, &s3.GetObjectInput{ Bucket: aws.String(configbucket), Key: aws.String(scanid + ".json"), }) if err != nil { return ss, err } err = json.Unmarshal(buf.Bytes(), &ss) if err != nil { return ss, err } return ss, nil } func handler() error { configbucket := os.Getenv("ECR_SCAN_CONFIG_BUCKET") fmt.Printf("DEBUG:: scan start\n") cfg, err := config.LoadDefaultConfig(context.TODO()) if err != nil { fmt.Println(err) return err } svc := s3.NewFromConfig(cfg) fmt.Printf("Scanning bucket %v for scan specs\n", configbucket) resp, err := svc.ListObjectsV2(context.TODO(), &s3.ListObjectsV2Input{ Bucket: &configbucket, }, ) // resp, err := req.Send(context.TODO()) if err != nil { fmt.Println(err) return err } for _, obj := range resp.Contents { fn := *obj.Key scanID := strings.TrimSuffix(fn, ".json") scanspec, err := fetchScanSpec(configbucket, scanID) if err != nil { fmt.Println(err) return err } err = startScan(scanspec) if err != nil { fmt.Println(err) return err } } fmt.Printf("DEBUG:: scan done\n") return nil } func main() { lambda.Start(handler) }