diff --git a/tests/test_vmas.py b/tests/test_vmas.py index 3a782b36..ca033782 100644 --- a/tests/test_vmas.py +++ b/tests/test_vmas.py @@ -302,3 +302,21 @@ def test_vmas_differentiable(scenario, n_steps=10, n_envs=10): loss = obs[-1].mean() + rews[-1].mean() grad = torch.autograd.grad(loss, first_action) + + +def test_seeding(): + env = make_env(scenario="balance", num_envs=2, seed=0) + env.seed(0) + random_obs = env.reset()[0][0, 0] + env.seed(0) + assert random_obs == env.reset()[0][0, 0] + env.seed(0) + torch.manual_seed(1) + assert random_obs == env.reset()[0][0, 0] + + torch.manual_seed(0) + random_obs = torch.randn(1) + torch.manual_seed(0) + env.seed(1) + env.reset() + assert random_obs == torch.randn(1) diff --git a/vmas/simulator/environment/environment.py b/vmas/simulator/environment/environment.py index 39d2d720..a061c914 100644 --- a/vmas/simulator/environment/environment.py +++ b/vmas/simulator/environment/environment.py @@ -1,6 +1,7 @@ -# Copyright (c) 2022-2024. +# Copyright (c) 2022-2025. # ProrokLab (https://www.proroklab.org/) # All rights reserved. +import contextlib import math import random from ctypes import byref @@ -26,14 +27,41 @@ ) -# environment for all agents in the multiagent world -# currently code assumes that no agents will be created/destroyed at runtime! +@contextlib.contextmanager +def local_seed(vmas_random_state): + torch_state = torch.random.get_rng_state() + np_state = np.random.get_state() + py_state = random.getstate() + + torch.random.set_rng_state(vmas_random_state[0]) + np.random.set_state(vmas_random_state[1]) + random.setstate(vmas_random_state[2]) + yield + vmas_random_state[0] = torch.random.get_rng_state() + vmas_random_state[1] = np.random.get_state() + vmas_random_state[2] = random.getstate() + + torch.random.set_rng_state(torch_state) + np.random.set_state(np_state) + random.setstate(py_state) + + class Environment(TorchVectorizedObject): + """ + The VMAS environment + """ + metadata = { "render.modes": ["human", "rgb_array"], "runtime.vectorized": True, } + vmas_random_state = [ + torch.random.get_rng_state(), + np.random.get_state(), + random.getstate(), + ] + @local_seed(vmas_random_state) def __init__( self, scenario: BaseScenario, @@ -68,7 +96,7 @@ def __init__( self.grad_enabled = grad_enabled self.terminated_truncated = terminated_truncated - observations = self.reset(seed=seed) + observations = self._reset(seed=seed) # configure spaces self.multidiscrete_actions = multidiscrete_actions @@ -81,6 +109,7 @@ def __init__( self.visible_display = None self.text_lines = None + @local_seed(vmas_random_state) def reset( self, seed: Optional[int] = None, @@ -92,13 +121,104 @@ def reset( Resets the environment in a vectorized way Returns observations for all envs and agents """ + return self._reset( + seed=seed, + return_observations=return_observations, + return_info=return_info, + return_dones=return_dones, + ) + + @local_seed(vmas_random_state) + def reset_at( + self, + index: int, + return_observations: bool = True, + return_info: bool = False, + return_dones: bool = False, + ): + """ + Resets the environment at index + Returns observations for all agents in that environment + """ + return self._reset_at( + index=index, + return_observations=return_observations, + return_info=return_info, + return_dones=return_dones, + ) + + @local_seed(vmas_random_state) + def get_from_scenario( + self, + get_observations: bool, + get_rewards: bool, + get_infos: bool, + get_dones: bool, + dict_agent_names: Optional[bool] = None, + ): + """ + Get the environment data from the scenario + + Args: + get_observations (bool): whether to return the observations + get_rewards (bool): whether to return the rewards + get_infos (bool): whether to return the infos + get_dones (bool): whether to return the dones + dict_agent_names (bool, optional): whether to return the information in a dictionary with agent names as keys + or in a list + + Returns: + The agents' data + + """ + return self._get_from_scenario( + get_observations=get_observations, + get_rewards=get_rewards, + get_infos=get_infos, + get_dones=get_dones, + dict_agent_names=dict_agent_names, + ) + + @local_seed(vmas_random_state) + def seed(self, seed=None): + """ + Sets the seed for the environment + Args: + seed (int, optional): Seed for the environment. Defaults to None. + + """ + return self._seed(seed=seed) + + @local_seed(vmas_random_state) + def done(self): + """ + Get the done flags for the scenario. + + Returns: + Either terminated, truncated (if self.terminated_truncated==True) or terminated + truncated (if self.terminated_truncated==False) + + """ + return self._done() + + def _reset( + self, + seed: Optional[int] = None, + return_observations: bool = True, + return_info: bool = False, + return_dones: bool = False, + ): + """ + Resets the environment in a vectorized way + Returns observations for all envs and agents + """ + if seed is not None: - self.seed(seed) + self._seed(seed) # reset world self.scenario.env_reset_world_at(env_index=None) self.steps = torch.zeros(self.num_envs, device=self.device) - result = self.get_from_scenario( + result = self._get_from_scenario( get_observations=return_observations, get_infos=return_info, get_rewards=False, @@ -106,7 +226,7 @@ def reset( ) return result[0] if result and len(result) == 1 else result - def reset_at( + def _reset_at( self, index: int, return_observations: bool = True, @@ -121,7 +241,7 @@ def reset_at( self.scenario.env_reset_world_at(index) self.steps[index] = 0 - result = self.get_from_scenario( + result = self._get_from_scenario( get_observations=return_observations, get_infos=return_info, get_rewards=False, @@ -130,7 +250,7 @@ def reset_at( return result[0] if result and len(result) == 1 else result - def get_from_scenario( + def _get_from_scenario( self, get_observations: bool, get_rewards: bool, @@ -178,16 +298,22 @@ def get_from_scenario( if self.terminated_truncated: if get_dones: - terminated, truncated = self.done() + terminated, truncated = self._done() result = [obs, rewards, terminated, truncated, infos] else: if get_dones: - dones = self.done() + dones = self._done() result = [obs, rewards, dones, infos] return [data for data in result if data is not None] - def seed(self, seed=None): + def _seed(self, seed=None): + """ + Sets the seed for the environment + Args: + seed (int, optional): Seed for the environment. Defaults to None. + + """ if seed is None: seed = 0 torch.manual_seed(seed) @@ -195,18 +321,18 @@ def seed(self, seed=None): random.seed(seed) return [seed] + @local_seed(vmas_random_state) def step(self, actions: Union[List, Dict]): """Performs a vectorized step on all sub environments using `actions`. + Args: - actions: Is a list on len 'self.n_agents' of which each element is a torch.Tensor of shape - '(self.num_envs, action_size_of_agent)'. + actions: Is a list on len 'self.n_agents' of which each element is a torch.Tensor of shape '(self.num_envs, action_size_of_agent)'. + Returns: - obs: List on len 'self.n_agents' of which each element is a torch.Tensor - of shape '(self.num_envs, obs_size_of_agent)' + obs: List on len 'self.n_agents' of which each element is a torch.Tensor of shape '(self.num_envs, obs_size_of_agent)' rewards: List on len 'self.n_agents' of which each element is a torch.Tensor of shape '(self.num_envs)' dones: Tensor of len 'self.num_envs' of which each element is a bool - infos : List on len 'self.n_agents' of which each element is a dictionary for which each key is a metric - and the value is a tensor of shape '(self.num_envs, metric_size_per_agent)' + infos: List on len 'self.n_agents' of which each element is a dictionary for which each key is a metric and the value is a tensor of shape '(self.num_envs, metric_size_per_agent)' Examples: >>> import vmas @@ -222,6 +348,7 @@ def step(self, actions: Union[List, Dict]): >>> obs = env.reset() >>> for _ in range(10): ... obs, rews, dones, info = env.step(env.get_random_actions()) + """ if isinstance(actions, Dict): actions_dict = actions @@ -269,14 +396,21 @@ def step(self, actions: Union[List, Dict]): self.steps += 1 - return self.get_from_scenario( + return self._get_from_scenario( get_observations=True, get_infos=True, get_rewards=True, get_dones=True, ) - def done(self): + def _done(self): + """ + Get the done flags for the scenario. + + Returns: + Either terminated, truncated (if self.terminated_truncated==True) or terminated + truncated (if self.terminated_truncated==False) + + """ terminated = self.scenario.done().clone() if self.max_steps is not None: @@ -387,6 +521,7 @@ def get_agent_observation_space(self, agent: Agent, obs: AGENT_OBS_TYPE): f"Invalid type of observation {obs} for agent {agent.name}" ) + @local_seed(vmas_random_state) def get_random_action(self, agent: Agent) -> torch.Tensor: """Returns a random action for the given agent. @@ -447,7 +582,7 @@ def get_random_action(self, agent: Agent) -> torch.Tensor: return action def get_random_actions(self) -> Sequence[torch.Tensor]: - """Returns random actions for all agents that you can feed to :class:`step` + """Returns random actions for all agents that you can feed to :meth:`step` Returns: Sequence[torch.tensor]: the random actions for the agents @@ -612,6 +747,7 @@ def _set_action(self, action, agent): ) agent.action.c += noise + @local_seed(vmas_random_state) def render( self, mode="human", @@ -635,6 +771,7 @@ def render( Render function for environment using pyglet On servers use mode="rgb_array" and set + ``` export DISPLAY=':99.0' Xvfb :99 -screen 0 1400x900x24 > /dev/null 2>&1 & @@ -642,8 +779,7 @@ def render( :param mode: One of human or rgb_array :param env_index: Index of the environment to render - :param agent_index_focus: If specified the camera will stay on the agent with this index. - If None, the camera will stay in the center and zoom out to contain all agents + :param agent_index_focus: If specified the camera will stay on the agent with this index. If None, the camera will stay in the center and zoom out to contain all agents :param visualize_when_rgb: Also run human visualization when mode=="rgb_array" :param plot_position_function: A function to plot under the rendering. The function takes a numpy array with shape (n_points, 2), which represents a set of x,y values to evaluate f over and plot it @@ -657,6 +793,7 @@ def render( :param plot_position_function_cmap_range: The range of the cmap in case plot_position_function outputs a single value :param plot_position_function_cmap_alpha: The alpha of the cmap in case plot_position_function outputs a single value :return: Rgb array or None, depending on the mode + """ self._check_batch_index(env_index) assert (