@@ -1507,15 +1507,19 @@ class TorchBackend(Backend):
15071507
15081508 def __init__ (self ):
15091509
1510- self .rng_ = torch .Generator ()
1510+ self .rng_ = torch .Generator ("cpu" )
15111511 self .rng_ .seed ()
15121512
15131513 self .__type_list__ = [torch .tensor (1 , dtype = torch .float32 ),
15141514 torch .tensor (1 , dtype = torch .float64 )]
15151515
15161516 if torch .cuda .is_available ():
1517+ self .rng_cuda_ = torch .Generator ("cuda" )
1518+ self .rng_cuda_ .seed ()
15171519 self .__type_list__ .append (torch .tensor (1 , dtype = torch .float32 , device = 'cuda' ))
15181520 self .__type_list__ .append (torch .tensor (1 , dtype = torch .float64 , device = 'cuda' ))
1521+ else :
1522+ self .rng_cuda_ = torch .Generator ("cpu" )
15191523
15201524 from torch .autograd import Function
15211525
@@ -1761,20 +1765,26 @@ def reshape(self, a, shape):
17611765 def seed (self , seed = None ):
17621766 if isinstance (seed , int ):
17631767 self .rng_ .manual_seed (seed )
1768+ self .rng_cuda_ .manual_seed (seed )
17641769 elif isinstance (seed , torch .Generator ):
1765- self .rng_ = seed
1770+ if self .device_type (seed ) == "GPU" :
1771+ self .rng_cuda_ = seed
1772+ else :
1773+ self .rng_ = seed
17661774 else :
17671775 raise ValueError ("Non compatible seed : {}" .format (seed ))
17681776
17691777 def rand (self , * size , type_as = None ):
17701778 if type_as is not None :
1771- return torch .rand (size = size , generator = self .rng_ , dtype = type_as .dtype , device = type_as .device )
1779+ generator = self .rng_cuda_ if self .device_type (type_as ) == "GPU" else self .rng_
1780+ return torch .rand (size = size , generator = generator , dtype = type_as .dtype , device = type_as .device )
17721781 else :
17731782 return torch .rand (size = size , generator = self .rng_ )
17741783
17751784 def randn (self , * size , type_as = None ):
17761785 if type_as is not None :
1777- return torch .randn (size = size , dtype = type_as .dtype , generator = self .rng_ , device = type_as .device )
1786+ generator = self .rng_cuda_ if self .device_type (type_as ) == "GPU" else self .rng_
1787+ return torch .randn (size = size , dtype = type_as .dtype , generator = generator , device = type_as .device )
17781788 else :
17791789 return torch .randn (size = size , generator = self .rng_ )
17801790
0 commit comments