If you wish to use the `enhanced_pyspark_processor`, be sure that `from sagemaker.spark.processing import PySparkProcessor` is commented out and that you're using `from enhanced_pyspark_processor import PySparkProcessor` instead.

In [None]:
import sagemaker
from sagemaker.local import LocalSession
#from sagemaker.spark.processing import PySparkProcessor
from enhanced_pyspark_processor import PySparkProcessor

In [None]:
sagemaker_session = LocalSession()
sagemaker_session.config = {"local": {"local_code": True}}

# Update with your SM execution role
role_arn = ""

In [None]:
%%writefile processing.py

import argparse
import logging
import os
import sys

from pyspark.sql import SparkSession
from pyspark.sql.functions import (udf, col)
from pyspark.sql.types import StringType, StructField, StructType, FloatType

# Define custom handler
logger = logging.getLogger(__name__)
handler = logging.StreamHandler(sys.stdout)
handler.setFormatter(logging.Formatter("%(asctime)s %(message)s"))
logger.addHandler(handler)
logger.setLevel(logging.INFO)

def main(data_path):

 spark = SparkSession.builder.appName("PySparkJob").getOrCreate()
 spark.sparkContext.setLogLevel("ERROR")

 schema = StructType(
 [
 StructField("sex", StringType(), True),
 StructField("length", FloatType(), True),
 StructField("diameter", FloatType(), True),
 StructField("height", FloatType(), True),
 StructField("whole_weight", FloatType(), True),
 StructField("shucked_weight", FloatType(), True),
 StructField("viscera_weight", FloatType(), True),
 StructField("rings", FloatType(), True),
 ]
 )

 df = spark.read.csv(data_path, header=False, schema=schema)
 return df.select("sex", "length", "diameter", "rings")

if __name__ == "__main__":

 parser = argparse.ArgumentParser(description="app inputs")
 parser.add_argument("--data_path", type=str, help="path to the channel data")
 parser.add_argument("--output_path", type=str, help="path to the output data")
 args = parser.parse_args()
 
 df = main(args.data_path)

 logger.info("Writing transformed data")
 df.write.csv(os.path.join(args.output_path, "transformed.csv"), header=True, mode="overwrite")

Local Mode only supports an `instance_count` value of 1.

In [None]:
spark_processor = PySparkProcessor( 
 role= role_arn,
 instance_type="local",
 instance_count=1,
 framework_version="2.4"
)

For the `enhanced_pyspark_processor`, you need to make sure you use `s3a` rather than `s3` for your S3 paths.

In [None]:
spark_processor.run(
 "processing.py",
 arguments=[
 "--data_path",
 f"s3a://sagemaker-servicecatalog-seedcode-{sagemaker_session.boto_region_name}/dataset/abalone-dataset.csv",
 "--output_path",
 f"s3a://{sagemaker_session.default_bucket()}/enhanced_pyspark_processor/output/"
 ]
)