@@ -164,10 +164,10 @@ def _dtypes(self, kind):
164164 int16 = torch .int16
165165 int32 = torch .int32
166166 int64 = torch .int64
167- uint8 = getattr ( torch , " uint8" , None )
168- uint16 = getattr ( torch , "uint16" , None )
169- uint32 = getattr ( torch , "uint32" , None )
170- uint64 = getattr ( torch , "uint64" , None )
167+ uint8 = torch . uint8
168+ # uint16, uint32, and uint64 are present in newer versions of pytorch,
169+ # but they aren't generally supported by the array API functions, so
170+ # we omit them from this function.
171171 float32 = torch .float32
172172 float64 = torch .float64
173173 complex64 = torch .complex64
@@ -181,9 +181,6 @@ def _dtypes(self, kind):
181181 "int32" : int32 ,
182182 "int64" : int64 ,
183183 "uint8" : uint8 ,
184- "uint16" : uint16 ,
185- "uint32" : uint32 ,
186- "uint64" : uint64 ,
187184 "float32" : float32 ,
188185 "float64" : float64 ,
189186 "complex64" : complex64 ,
@@ -201,9 +198,6 @@ def _dtypes(self, kind):
201198 if kind == "unsigned integer" :
202199 return {
203200 "uint8" : uint8 ,
204- "uint16" : uint16 ,
205- "uint32" : uint32 ,
206- "uint64" : uint64 ,
207201 }
208202 if kind == "integral" :
209203 return {
@@ -212,9 +206,6 @@ def _dtypes(self, kind):
212206 "int32" : int32 ,
213207 "int64" : int64 ,
214208 "uint8" : uint8 ,
215- "uint16" : uint16 ,
216- "uint32" : uint32 ,
217- "uint64" : uint64 ,
218209 }
219210 if kind == "real floating" :
220211 return {
@@ -233,9 +224,6 @@ def _dtypes(self, kind):
233224 "int32" : int32 ,
234225 "int64" : int64 ,
235226 "uint8" : uint8 ,
236- "uint16" : uint16 ,
237- "uint32" : uint32 ,
238- "uint64" : uint64 ,
239227 "float32" : float32 ,
240228 "float64" : float64 ,
241229 "complex64" : complex64 ,
@@ -305,9 +293,6 @@ def dtypes(self, *, device=None, kind=None):
305293 """
306294 res = self ._dtypes (kind )
307295 for k , v in res .copy ().items ():
308- if v is None :
309- del res [k ]
310- continue
311296 try :
312297 torch .empty ((0 ,), dtype = v , device = device )
313298 except :
0 commit comments