diff --git a/vmas/scenarios/road_traffic.py b/vmas/scenarios/road_traffic.py index a19aedab..ac944d80 100644 --- a/vmas/scenarios/road_traffic.py +++ b/vmas/scenarios/road_traffic.py @@ -745,7 +745,11 @@ def init_params(self, batch_dim, device, **kwargs): self.distances = Distances( agents=torch.zeros( - batch_dim, self.n_agents, self.n_agents, dtype=torch.float32 + batch_dim, + self.n_agents, + self.n_agents, + device=device, + dtype=torch.float32, ), left_boundaries=torch.zeros( (batch_dim, self.n_agents, 1 + 4), device=device, dtype=torch.float32