diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 015a32b..2ed459f 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -4,6 +4,7 @@ import time import warnings from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple, Type, Union +import time import gym import numpy as np @@ -105,6 +106,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): sde_support: bool = True, remove_time_limit_termination: bool = False, supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, + metric_path = None ): super(OffPolicyAlgorithm, self).__init__( @@ -152,6 +154,18 @@ class OffPolicyAlgorithm(BaseAlgorithm): self.policy_kwargs["use_sde"] = self.use_sde # For gSDE only self.use_sde_at_warmup = use_sde_at_warmup + + self._metric_path = metric_path + self._seed = seed + self._iteration_num = 0 + + def increment_iteration_number(self): + self._iteration_num += 1 + + def write_metrics(self, path, data): + with open(path, "a+") as fp: + fp.write("\n".join(data)) + fp.write("\n") def _convert_train_freq(self) -> None: """ @@ -575,6 +589,8 @@ class OffPolicyAlgorithm(BaseAlgorithm): callback.on_rollout_start() continue_training = True + step_time_list = [] + n_steps = 0 while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes): if self.use_sde and self.sde_sample_freq > 0 and num_collected_steps % self.sde_sample_freq == 0: # Sample a new noise matrix @@ -584,7 +600,11 @@ class OffPolicyAlgorithm(BaseAlgorithm): actions, buffer_actions = self._sample_action(learning_starts, action_noise, env.num_envs) # Rescale and perform action + start_step_time = time.time() + n_steps += 1 new_obs, rewards, dones, infos = env.step(actions) + step_time_list.append("{}|{}|{}|{}|{}".format(self.seed, self._iteration_num, n_steps, + start_step_time, time.time() - start_step_time)) self.num_timesteps += env.num_envs num_collected_steps += 1 @@ -622,7 +642,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): # Log training infos if log_interval is not None and self._episode_num % log_interval == 0: self._dump_logs() - + self.write_metrics(self._metric_path, step_time_list) callback.on_rollout_end() return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training) diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 48cb365..996602b 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union import gym import numpy as np import torch as th +import time from stable_baselines3.common.base_class import BaseAlgorithm from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer @@ -72,6 +73,7 @@ class OnPolicyAlgorithm(BaseAlgorithm): device: Union[th.device, str] = "auto", _init_setup_model: bool = True, supported_action_spaces: Optional[Tuple[gym.spaces.Space, ...]] = None, + metric_path = None ): super(OnPolicyAlgorithm, self).__init__( @@ -101,7 +103,14 @@ class OnPolicyAlgorithm(BaseAlgorithm): if _init_setup_model: self._setup_model() - + + self._metric_path = metric_path + self._seed = seed + self._iteration_num = 0 + + def increment_iteration_number(self): + self._iteration_num += 1 + def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) @@ -126,6 +135,11 @@ class OnPolicyAlgorithm(BaseAlgorithm): ) self.policy = self.policy.to(self.device) + def write_metrics(self, path, data): + with open(path, "a+") as fp: + fp.write("\n".join(data)) + fp.write("\n") + def collect_rollouts( self, env: VecEnv, @@ -157,7 +171,8 @@ class OnPolicyAlgorithm(BaseAlgorithm): self.policy.reset_noise(env.num_envs) callback.on_rollout_start() - + + step_time_list = [] while n_steps < n_rollout_steps: if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0: # Sample a new noise matrix @@ -174,8 +189,12 @@ class OnPolicyAlgorithm(BaseAlgorithm): # Clip the actions to avoid out of bound error if isinstance(self.action_space, gym.spaces.Box): clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) - + + start_step_time = time.time() new_obs, rewards, dones, infos = env.step(clipped_actions) + step_time_list.append("{}|{}|{}|{}|{}".format(self.seed, self._iteration_num, n_steps, + start_step_time, time.time() - start_step_time)) + self.num_timesteps += env.num_envs @@ -207,6 +226,8 @@ class OnPolicyAlgorithm(BaseAlgorithm): rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs) self._last_obs = new_obs self._last_episode_starts = dones + + self.write_metrics(self._metric_path, step_time_list) with th.no_grad(): # Compute value for the last timestep diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index 14293ca..78e514c 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -76,6 +76,7 @@ class DDPG(TD3): seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, + metric_path = None ): super(DDPG, self).__init__( @@ -105,6 +106,7 @@ class DDPG(TD3): target_noise_clip=0.0, target_policy_noise=0.1, _init_setup_model=False, + metric_path = metric_path ) # Use only one critic diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 9e16e04..ff9e53e 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -89,8 +89,8 @@ class PPO(OnPolicyAlgorithm): seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, + metric_path = None ): - super(PPO, self).__init__( policy, env, @@ -116,6 +116,7 @@ class PPO(OnPolicyAlgorithm): spaces.MultiDiscrete, spaces.MultiBinary, ), + metric_path = metric_path ) # Sanity check, otherwise it will lead to noisy gradient and NaN diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 5f3a833..598c4d7 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -101,6 +101,7 @@ class SAC(OffPolicyAlgorithm): seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, + metric_path = None ): super(SAC, self).__init__( @@ -130,6 +131,7 @@ class SAC(OffPolicyAlgorithm): optimize_memory_usage=optimize_memory_usage, supported_action_spaces=(gym.spaces.Box), support_multi_env=True, + metric_path = metric_path ) self.target_entropy = target_entropy diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index eb257a6..4208f72 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -86,6 +86,7 @@ class TD3(OffPolicyAlgorithm): seed: Optional[int] = None, device: Union[th.device, str] = "auto", _init_setup_model: bool = True, + metric_path = None ): super(TD3, self).__init__( @@ -113,6 +114,7 @@ class TD3(OffPolicyAlgorithm): optimize_memory_usage=optimize_memory_usage, supported_action_spaces=(gym.spaces.Box), support_multi_env=True, + metric_path = metric_path ) self.policy_delay = policy_delay