Skip to content

Commit 5cacde3

Browse files
committed
memoize: remove ignore + fixes
1 parent 1bc7b5a commit 5cacde3

File tree

2 files changed

+28
-59
lines changed

2 files changed

+28
-59
lines changed

macros/src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ mod memoize;
44
use crate::memoize::memoize_impl;
55

66
#[proc_macro_attribute]
7-
pub fn memoize(args: proc_macro::TokenStream, item: proc_macro::TokenStream) -> TokenStream {
8-
memoize_impl(args, item)
7+
pub fn memoize(attr: TokenStream, item: TokenStream) -> TokenStream {
8+
memoize_impl(attr, item)
99
}

macros/src/memoize.rs

Lines changed: 26 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,48 +3,32 @@ use quote::{quote, ToTokens};
33
use syn::{parse::*, punctuated::*, spanned::*, *};
44

55
mod kw {
6-
syn::custom_keyword!(ignore);
76
syn::custom_keyword!(key_function);
87
}
98

109
#[derive(Default)]
1110
struct CacheOptions {
12-
ignore: Vec<Ident>,
1311
key_function: Option<(Ident, Ident)>,
1412
}
1513

1614
enum CacheOption {
17-
Ignore(Vec<Ident>),
1815
KeyFunction(Ident, Ident),
1916
}
2017

2118
impl Parse for CacheOption {
2219
fn parse(input: ParseStream) -> syn::Result<Self> {
2320
let la = input.lookahead1();
2421

25-
if la.peek(kw::ignore) {
26-
input.parse::<kw::ignore>().unwrap();
27-
input.parse::<Token![=]>().unwrap();
28-
let bracketed_content;
29-
bracketed!(bracketed_content in input);
30-
31-
let result = Punctuated::<LitStr, Token![,]>::parse_terminated(&bracketed_content)
32-
.unwrap()
33-
.into_iter()
34-
.map(|lit_str| lit_str.parse::<Ident>().unwrap())
35-
.collect::<Vec<_>>();
36-
37-
return Ok(CacheOption::Ignore(result));
38-
}
39-
4022
if la.peek(kw::key_function) {
41-
input.parse::<kw::key_function>().unwrap();
42-
input.parse::<Token![=]>().unwrap();
43-
let input = input.parse::<LitStr>().unwrap();
23+
input.parse::<kw::key_function>()?;
24+
input.parse::<Token![=]>()?;
25+
let input = input.parse::<LitStr>()?;
4426
let input_value = input.value();
4527

4628
let (key_function_name_str, key_function_return_str) =
47-
input_value.split_once(" -> ").unwrap();
29+
input_value
30+
.split_once(" -> ")
31+
.ok_or(syn::Error::new(input.span(), "Can't split by ` -> `"))?;
4832
let key_function_name = Ident::new(key_function_name_str, input.span());
4933
let key_function_return = Ident::new(key_function_return_str, input.span());
5034

@@ -62,11 +46,8 @@ impl Parse for CacheOptions {
6246
fn parse(input: ParseStream) -> syn::Result<Self> {
6347
let mut opts = Self::default();
6448

65-
let attrs = Punctuated::<CacheOption, syn::Token![,]>::parse_terminated(input)?;
66-
67-
for opt in attrs {
49+
for opt in Punctuated::<CacheOption, syn::Token![,]>::parse_terminated(input)? {
6850
match opt {
69-
CacheOption::Ignore(ident) => opts.ignore.extend(ident),
7051
CacheOption::KeyFunction(name, return_type) => {
7152
opts.key_function = Some((name, return_type));
7253
}
@@ -77,8 +58,18 @@ impl Parse for CacheOptions {
7758
}
7859
}
7960

80-
pub fn memoize_impl(args: TokenStream, item: TokenStream) -> TokenStream {
81-
let options: CacheOptions = syn::parse(args).unwrap();
61+
fn parse_sig_inputs(sig: &Signature) -> (Vec<Pat>, Vec<Type>) {
62+
sig.inputs
63+
.iter()
64+
.filter_map(|arg| match arg {
65+
FnArg::Typed(PatType { pat, ty, .. }) => Some((*pat.clone(), *ty.clone())),
66+
FnArg::Receiver(_) => None,
67+
})
68+
.unzip()
69+
}
70+
71+
pub fn memoize_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
72+
let options: CacheOptions = syn::parse(attr).unwrap();
8273

8374
let ItemFn {
8475
sig,
@@ -87,43 +78,21 @@ pub fn memoize_impl(args: TokenStream, item: TokenStream) -> TokenStream {
8778
attrs,
8879
} = parse_macro_input!(item as ItemFn);
8980

90-
let fn_input_names = sig
91-
.inputs
92-
.iter()
93-
.filter_map(|arg| match arg {
94-
FnArg::Typed(PatType { pat, .. }) => Some(*pat.clone()),
95-
FnArg::Receiver(_) => None,
96-
})
97-
.collect::<Vec<_>>();
98-
99-
let cache_input_names = sig
100-
.inputs
101-
.iter()
102-
.filter_map(|arg| match arg {
103-
FnArg::Typed(PatType { pat, .. }) => Some(*pat.clone()),
104-
FnArg::Receiver(_) => None,
105-
})
106-
.filter(|pat| match pat {
107-
Pat::Ident(PatIdent { ident, .. }) => {
108-
!options.ignore.iter().any(|ignore| ignore == ident)
109-
}
110-
_ => true,
111-
})
112-
.collect::<Vec<_>>();
81+
let (fn_input_names, fn_input_types) = parse_sig_inputs(&sig);
11382

11483
let fn_return_type = match &sig.output {
11584
ReturnType::Default => quote! { () },
11685
ReturnType::Type(_, ty) => ty.to_token_stream(),
11786
};
11887

119-
let cache_key_name = match &options.key_function {
88+
let cache_key_value = match &options.key_function {
12089
Some((name, _)) => quote! { #name(#(#fn_input_names),*) },
121-
None => quote! { (#(#cache_input_names.clone()),*) },
90+
None => quote! { (#(#fn_input_names),*) },
12291
};
12392

124-
let cache_key_return_type = match &options.key_function {
93+
let cache_key_type = match &options.key_function {
12594
Some((_, return_type)) => quote! { #return_type },
126-
None => fn_return_type.clone(),
95+
None => quote! { (#(#fn_input_types),*) },
12796
};
12897

12998
let internal_fn_name = format!("__{}_internal", sig.ident);
@@ -139,15 +108,15 @@ pub fn memoize_impl(args: TokenStream, item: TokenStream) -> TokenStream {
139108

140109
quote!(
141110
thread_local! {
142-
static #cache_static_var_ident: std::cell::RefCell<advent_of_code::maneatingape::hash::FastMap<#cache_key_return_type, #fn_return_type>> = std::cell::RefCell::new(advent_of_code::maneatingape::hash::FastMapBuilder::new());
111+
static #cache_static_var_ident: std::cell::RefCell<advent_of_code::maneatingape::hash::FastMap<#cache_key_type, #fn_return_type>> = std::cell::RefCell::new(advent_of_code::maneatingape::hash::FastMapBuilder::new());
143112
}
144113

145114
#(#attrs)*
146115
#vis #internal_sig #block
147116

148117
#(#attrs)*
149118
#vis #sig {
150-
let cache_key = #cache_key_name;
119+
let cache_key = #cache_key_value;
151120

152121
let cached_result_option = #cache_static_var_ident.with(|cache| {
153122
cache.borrow().get(&cache_key).cloned()

0 commit comments

Comments
 (0)