-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Add dummy dtypes #3195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dummy dtypes #3195
Conversation
|
signed dtypes are nice 👌 I've been having to pass u32s as i32s in cuda launch code and have been worried that would blow up in my face at some point |
ivarflakstad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is gonna be a good one! 🙌
| DType::F64 => convert_slice::<f64>(data, shape, device), | ||
| DType::F8E4M3 => convert_slice::<F8E4M3>(data, shape, device), | ||
| DType::F8E4M3 => convert_slice::<float8::F8E4M3>(data, shape, device), | ||
| DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't have to be in this PR, but I'd prefer to hoist this out into a helper fn.
Perhaps use convert_slice::<u8>(data, shape, device) and manually change the storage dtype? Might not even need a dedicated fn now that I think about it 🤔
| let shape = view.shape(); | ||
|
|
||
| // Create storage with the appropriate dummy type variant | ||
| let storage = match device { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Déjà vu helper fn 👀
| #[test] | ||
| fn load_i8() { | ||
| let bytes = b"8\0\0\0\0\0\0\0{\"x\":{\"dtype\":\"I8\",\"shape\":[2],\"data_offsets\":[0,2]}} \x01\x03"; | ||
| std::fs::write("test_i8.safetensors", bytes).unwrap(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not related to this PR, just noting down while I'm here: we should use temp files for these kinds of tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolutely 👍
|
Addressed most of the review comments; left some as unresolved for posterity. |
| let data = unary_map(storage, layout, |v| v as f64); | ||
| Ok(Self::F64(data)) | ||
| } | ||
| (Self::I32(storage), DType::F8E4M3) => { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have an idea for how to reduce the massive size of this match.
Adding it to the ever growing list of things to do :)
| S::F16(s) => self.f(s, d, l, S::F16)?, | ||
| S::F32(s) => self.f(s, d, l, S::F32)?, | ||
| S::F64(s) => self.f(s, d, l, S::F64)?, | ||
| S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You resolved this but looks the same to me?
| (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
candle-core/src/op.rs
Outdated
| #[inline(always)] | ||
| fn f32(v: f32) -> f32 { | ||
| (crate::cpu::erf::erf_f32(v * std::f32::consts::FRAC_1_SQRT_2) + 1.) * 0.5 * v | ||
| Self::f64(v as f64) as f32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still revert! ;)
| } | ||
|
|
||
| fn get_current_seed(&self) -> Result<u64> { | ||
| crate::bail!("cannot get the CPU rng seed with get_current_seed") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll have a look into this later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
| S::F16(s) => self.f(s, d, l, S::F16)?, | ||
| S::F32(s) => self.f(s, d, l, S::F32)?, | ||
| S::F64(s) => self.f(s, d, l, S::F64)?, | ||
| S::F8E4M3(s) => self.f(s, d, l, S::F8E4M3)?, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You resolved this but looks the same to me?
| (S::F16(s1), S::F16(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F32(s1), S::F32(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F64(s1), S::F64(s2)) => self.f(s1, l1, s2, l2, d)?, | ||
| (S::F8E4M3(s1), S::F8E4M3(s2)) => self.f(s1, l1, s2, l2, d)?, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
ivarflakstad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lgtm! 🎉
Let's just wait for CI before we merge and see if we haven't missed something (ignore pyo3 though).
735c517 to
75e688c
Compare
Adds support for:
i32
i16
f6e2m3
f6e3m2
f4
f8e8m0
These are "dummy" dtypes: this just means a typed bitbucket essentially.
CPU compiles
CUDA compiles
Metal compiles