-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathDistributedWeightedSampler.py
More file actions
43 lines (33 loc) · 1.65 KB
/
DistributedWeightedSampler.py
File metadata and controls
43 lines (33 loc) · 1.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch.distributed as dist
import torch
from torch.utils.data.distributed import DistributedSampler
class DistributedProxySampler(DistributedSampler):
"""Sampler that restricts data loading to a subset of input sampler indices.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
process can pass a DistributedSampler instance as a DataLoader sampler,
and load a subset of the original dataset that is exclusive to it.
.. note::
Input sampler is assumed to be of constant size.
Arguments:
sampler: Input data sampler.
num_replicas (optional): Number of processes participating in
distributed training.
rank (optional): Rank of the current process within num_replicas.
"""
def __init__(self, sampler, num_replicas=None, rank=None):
super(DistributedProxySampler, self).__init__(sampler, num_replicas=num_replicas, rank=rank, shuffle=False)
self.sampler = sampler
def __iter__(self):
# deterministically shuffle based on epoch
torch.manual_seed(self.epoch)
indices = list(self.sampler)
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
if len(indices) != self.total_size:
raise RuntimeError("{} vs {}".format(len(indices), self.total_size))
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
if len(indices) != self.num_samples:
raise RuntimeError("{} vs {}".format(len(indices), self.num_samples))
return iter(indices)