# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # # http://www.apache.org/licenses/LICENSE-2.0 # # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. import os import json import gym import ray from ray.tune.registry import register_env from ray.rllib.models import ModelCatalog from procgen_ray_launcher import ProcgenSageMakerRayLauncher from ray_experiment_builder import RayExperimentBuilder from utils.loader import load_algorithms, load_preprocessors try: from custom.envs.procgen_env_wrapper import ProcgenEnvWrapper except ModuleNotFoundError: from envs.procgen_env_wrapper import ProcgenEnvWrapper class MyLauncher(ProcgenSageMakerRayLauncher): def register_env_creator(self): register_env( "stacked_procgen_env", # This should be different from procgen_env_wrapper lambda config: gym.wrappers.FrameStack(ProcgenEnvWrapper(config), 4) ) def _get_ray_config(self): return { "ray_num_cpus": 8, # adjust based on selected instance type "ray_num_gpus": 1, "eager": False, "v": True, # requried for CW to catch the progress } def _get_rllib_config(self): return { "experiment_name": "training", "run": "PPO", "env": "procgen_env_wrapper", "stop": { # 'time_total_s': 60, 'training_iteration': 500, }, "checkpoint_freq": 20, "checkpoint_at_end": True, "keep_checkpoints_num": 5, "queue_trials": False, "config": { # === Environment Settings === "gamma": 0.999, "lambda": 0.95, "lr": 5.0e-4, "num_sgd_iter": 3, "kl_coeff": 0.0, "kl_target": 0.01, "vf_loss_coeff": 0.5, "entropy_coeff": 0.01, "clip_param": 0.2, "vf_clip_param": 0.2, "grad_clip": 0.5, "observation_filter": "NoFilter", "vf_share_layers": True, "soft_horizon": False, "no_done_at_end": False, "normalize_actions": False, "clip_actions": True, "ignore_worker_failures": True, "use_pytorch": False, "sgd_minibatch_size": 2048, # 8 minibatches per epoch "train_batch_size": 16384, # 2048 * 8 # === Settings for Model === "model": { "custom_model": "impala_cnn_tf", }, # === Settings for Rollout Worker processes === "num_workers": 6, # adjust based on total number of CPUs available in the cluster, e.g., p3.2xlarge has 8 CPUs "rollout_fragment_length": 140, "batch_mode": "truncate_episodes", # === Advanced Resource Settings === "num_envs_per_worker": 12, "num_cpus_per_worker": 1, "num_cpus_for_driver": 1, "num_gpus_per_worker": 0.1, # === Settings for the Trainer process === "num_gpus": 0.3, # adjust based on number of GPUs available in a single node, e.g., p3.2xlarge has 1 GPU # === Exploration Settings === "explore": True, "exploration_config": { "type": "StochasticSampling", }, # === Settings for the Procgen Environment === "env_config": { # See https://github.com/AIcrowd/neurips2020-procgen-starter-kit/blob/master/experiments/procgen-starter-example.yaml#L34 for an explaination. "env_name": "coinrun", "num_levels": 0, "start_level": 0, "paint_vel_info": False, "use_generated_assets": False, "center_agent": True, "use_sequential_levels": False, "distribution_mode": "easy" } } } def register_algorithms_and_preprocessors(self): try: from custom.algorithms import CUSTOM_ALGORITHMS from custom.preprocessors import CUSTOM_PREPROCESSORS from custom.models.impala_cnn_tf import ImpalaCNN except ModuleNotFoundError: from algorithms import CUSTOM_ALGORITHMS from preprocessors import CUSTOM_PREPROCESSORS from models.impala_cnn_tf import ImpalaCNN load_algorithms(CUSTOM_ALGORITHMS) load_preprocessors(CUSTOM_PREPROCESSORS) ModelCatalog.register_custom_model("impala_cnn_tf", ImpalaCNN) def get_experiment_config(self): params = dict(self._get_ray_config()) params.update(self._get_rllib_config()) reb = RayExperimentBuilder(**params) return reb.get_experiment_definition() if __name__ == "__main__": MyLauncher().train_main()