1616from dataclasses import dataclass
1717from typing import Optional , Tuple , Union
1818
19- import numpy as np
2019import torch
2120
2221from ..configuration_utils import ConfigMixin , register_to_config
@@ -210,13 +209,13 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
210209 """
211210 self .num_inference_steps = num_inference_steps
212211
213- ramp = np .linspace (0 , 1 , self .num_inference_steps )
212+ ramp = torch .linspace (0 , 1 , self .num_inference_steps )
214213 if self .config .sigma_schedule == "karras" :
215214 sigmas = self ._compute_karras_sigmas (ramp )
216215 elif self .config .sigma_schedule == "exponential" :
217216 sigmas = self ._compute_exponential_sigmas (ramp )
218217
219- sigmas = torch . from_numpy ( sigmas ) .to (dtype = torch .float32 , device = device )
218+ sigmas = sigmas .to (dtype = torch .float32 , device = device )
220219 self .timesteps = self .precondition_noise (sigmas )
221220
222221 self .sigmas = torch .cat ([sigmas , torch .zeros (1 , device = sigmas .device )])
@@ -234,7 +233,6 @@ def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.
234233 min_inv_rho = sigma_min ** (1 / rho )
235234 max_inv_rho = sigma_max ** (1 / rho )
236235 sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho )) ** rho
237-
238236 return sigmas
239237
240238 def _compute_exponential_sigmas (self , ramp , sigma_min = None , sigma_max = None ) -> torch .Tensor :
0 commit comments