From 02c21c08487c55fa7200a6a2cc00e6c1df0e22bd Mon Sep 17 00:00:00 2001 From: taooo Date: Wed, 23 Apr 2025 15:52:46 +0800 Subject: [PATCH] fix bug of road_traffic senario when using cuda --- vmas/scenarios/road_traffic.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/vmas/scenarios/road_traffic.py b/vmas/scenarios/road_traffic.py index a19aedab..ac944d80 100644 --- a/vmas/scenarios/road_traffic.py +++ b/vmas/scenarios/road_traffic.py @@ -745,7 +745,11 @@ def init_params(self, batch_dim, device, **kwargs): self.distances = Distances( agents=torch.zeros( - batch_dim, self.n_agents, self.n_agents, dtype=torch.float32 + batch_dim, + self.n_agents, + self.n_agents, + device=device, + dtype=torch.float32, ), left_boundaries=torch.zeros( (batch_dim, self.n_agents, 1 + 4), device=device, dtype=torch.float32