1+ use std:: arch:: x86_64:: {
2+ __m256i, _mm256_cmpeq_epi8, _mm256_cmpgt_epi8, _mm256_loadu_si256, _mm256_or_si256,
3+ _mm256_set1_epi8, _mm256_storeu_si256, _mm256_testz_si256,
4+ } ;
5+
6+ use crate :: { encode_str_fallback, ESCAPE , HEX_BYTES , UU } ;
7+
8+ /// Four contiguous 32-byte AVX2 registers (128 B) per loop.
9+ const CHUNK : usize = 128 ;
10+ /// Distance (in bytes) to prefetch ahead.
11+ /// Keeping ~4 iterations (4 × CHUNK = 512 B) ahead strikes a good balance
12+ /// between hiding memory latency and not evicting useful cache lines.
13+ const PREFETCH_DISTANCE : usize = CHUNK * 4 ;
14+
15+ pub fn encode_str < S : AsRef < str > > ( input : S ) -> String {
16+ let s = input. as_ref ( ) ;
17+ let mut out = Vec :: with_capacity ( s. len ( ) + 2 ) ;
18+ let bytes = s. as_bytes ( ) ;
19+ let n = bytes. len ( ) ;
20+ out. push ( b'"' ) ;
21+
22+ unsafe {
23+ let slash = _mm256_set1_epi8 ( b'\\' as i8 ) ;
24+ let quote = _mm256_set1_epi8 ( b'"' as i8 ) ;
25+ let tab = _mm256_set1_epi8 ( b'\t' as i8 ) ;
26+ let newline = _mm256_set1_epi8 ( b'\n' as i8 ) ;
27+ let carriage = _mm256_set1_epi8 ( b'\r' as i8 ) ;
28+ let backspace = _mm256_set1_epi8 ( 0x08i8 ) ;
29+ let formfeed = _mm256_set1_epi8 ( 0x0ci8 ) ;
30+ let ctrl_upper_bound = _mm256_set1_epi8 ( 0x20i8 ) ;
31+
32+ let mut i = 0 ;
33+
34+ // Re-usable scratch – *uninitialised*, so no memset in the loop.
35+ #[ allow( invalid_value) ]
36+ let mut placeholder: [ u8 ; 32 ] = core:: mem:: MaybeUninit :: uninit ( ) . assume_init ( ) ;
37+
38+ while i + CHUNK <= n {
39+ let ptr = bytes. as_ptr ( ) . add ( i) ;
40+
41+ // Prefetch data ahead
42+ #[ cfg( any( target_arch = "x86" , target_arch = "x86_64" ) ) ]
43+ {
44+ core:: arch:: x86_64:: _mm_prefetch (
45+ ptr. add ( PREFETCH_DISTANCE ) as * const i8 ,
46+ core:: arch:: x86_64:: _MM_HINT_T0,
47+ ) ;
48+ }
49+
50+ // Load 128 bytes (four 32-byte chunks)
51+ let a = _mm256_loadu_si256 ( ptr as * const __m256i ) ;
52+ let b = _mm256_loadu_si256 ( ptr. add ( 32 ) as * const __m256i ) ;
53+ let c = _mm256_loadu_si256 ( ptr. add ( 64 ) as * const __m256i ) ;
54+ let d = _mm256_loadu_si256 ( ptr. add ( 96 ) as * const __m256i ) ;
55+
56+ // For each chunk, check if it needs escaping
57+ let mask_1 = process_chunk (
58+ a, slash, quote, tab, newline, carriage, backspace, formfeed, ctrl_upper_bound,
59+ ) ;
60+ let mask_2 = process_chunk (
61+ b, slash, quote, tab, newline, carriage, backspace, formfeed, ctrl_upper_bound,
62+ ) ;
63+ let mask_3 = process_chunk (
64+ c, slash, quote, tab, newline, carriage, backspace, formfeed, ctrl_upper_bound,
65+ ) ;
66+ let mask_4 = process_chunk (
67+ d, slash, quote, tab, newline, carriage, backspace, formfeed, ctrl_upper_bound,
68+ ) ;
69+
70+ // Check if any chunk needs escaping
71+ let any_escape = _mm256_or_si256 (
72+ _mm256_or_si256 ( mask_1, mask_2) ,
73+ _mm256_or_si256 ( mask_3, mask_4) ,
74+ ) ;
75+
76+ // Fast path: nothing needs escaping
77+ if _mm256_testz_si256 ( any_escape, any_escape) != 0 {
78+ out. extend_from_slice ( std:: slice:: from_raw_parts ( ptr, CHUNK ) ) ;
79+ i += CHUNK ;
80+ continue ;
81+ }
82+
83+ // Slow path: handle each 32-byte chunk
84+ macro_rules! handle {
85+ ( $mask: expr, $off: expr) => {
86+ if _mm256_testz_si256( $mask, $mask) != 0 {
87+ // No escapes in this chunk
88+ out. extend_from_slice( std:: slice:: from_raw_parts( ptr. add( $off) , 32 ) ) ;
89+ } else {
90+ // Store mask and process byte by byte
91+ _mm256_storeu_si256( placeholder. as_mut_ptr( ) as * mut __m256i, $mask) ;
92+ handle_block( & bytes[ i + $off..i + $off + 32 ] , & placeholder, & mut out) ;
93+ }
94+ } ;
95+ }
96+
97+ handle ! ( mask_1, 0 ) ;
98+ handle ! ( mask_2, 32 ) ;
99+ handle ! ( mask_3, 64 ) ;
100+ handle ! ( mask_4, 96 ) ;
101+
102+ i += CHUNK ;
103+ }
104+
105+ // Handle remaining bytes using the fallback
106+ if i < n {
107+ let remaining_str = std:: str:: from_utf8 ( & bytes[ i..] ) . unwrap ( ) ;
108+ let escaped = encode_str_fallback ( remaining_str) ;
109+ // Remove the quotes that encode_str_fallback adds
110+ let escaped_bytes = escaped. as_bytes ( ) ;
111+ out. extend_from_slice ( & escaped_bytes[ 1 ..escaped_bytes. len ( ) - 1 ] ) ;
112+ }
113+ }
114+ out. push ( b'"' ) ;
115+ // SAFETY: we only emit valid UTF-8
116+ unsafe { String :: from_utf8_unchecked ( out) }
117+ }
118+
119+ #[ inline( always) ]
120+ unsafe fn process_chunk (
121+ data : __m256i ,
122+ slash : __m256i ,
123+ quote : __m256i ,
124+ tab : __m256i ,
125+ newline : __m256i ,
126+ carriage : __m256i ,
127+ backspace : __m256i ,
128+ formfeed : __m256i ,
129+ ctrl_upper_bound : __m256i ,
130+ ) -> __m256i {
131+ // Check for each special character
132+ let slash_mask = _mm256_cmpeq_epi8 ( data, slash) ;
133+ let quote_mask = _mm256_cmpeq_epi8 ( data, quote) ;
134+ let tab_mask = _mm256_cmpeq_epi8 ( data, tab) ;
135+ let newline_mask = _mm256_cmpeq_epi8 ( data, newline) ;
136+ let carriage_mask = _mm256_cmpeq_epi8 ( data, carriage) ;
137+ let backspace_mask = _mm256_cmpeq_epi8 ( data, backspace) ;
138+ let formfeed_mask = _mm256_cmpeq_epi8 ( data, formfeed) ;
139+
140+ // Check for control characters (< 0x20)
141+ // Note: AVX2 doesn't have unsigned comparison, so we use signed comparison
142+ // This works because ASCII control characters are all < 0x20 (positive signed values)
143+ let ctrl_mask = _mm256_cmpgt_epi8 ( ctrl_upper_bound, data) ;
144+
145+ // Combine all masks
146+ let combined = _mm256_or_si256 (
147+ _mm256_or_si256 (
148+ _mm256_or_si256 ( slash_mask, quote_mask) ,
149+ _mm256_or_si256 ( tab_mask, newline_mask) ,
150+ ) ,
151+ _mm256_or_si256 (
152+ _mm256_or_si256 ( carriage_mask, backspace_mask) ,
153+ _mm256_or_si256 ( formfeed_mask, ctrl_mask) ,
154+ ) ,
155+ ) ;
156+
157+ combined
158+ }
159+
160+ #[ inline( always) ]
161+ unsafe fn handle_block ( src : & [ u8 ] , mask : & [ u8 ; 32 ] , dst : & mut Vec < u8 > ) {
162+ for ( j, & m) in mask. iter ( ) . enumerate ( ) {
163+ let c = src[ j] ;
164+ if m == 0 {
165+ dst. push ( c) ;
166+ } else {
167+ let escape_byte = ESCAPE [ c as usize ] ;
168+ if escape_byte != 0 {
169+ // Handle the escape
170+ dst. push ( b'\\' ) ;
171+ if escape_byte == UU {
172+ // Unicode escape for control characters
173+ dst. extend_from_slice ( b"u00" ) ;
174+ let hex_digits = & HEX_BYTES [ c as usize ] ;
175+ dst. push ( hex_digits. 0 ) ;
176+ dst. push ( hex_digits. 1 ) ;
177+ } else {
178+ // Simple escape
179+ dst. push ( escape_byte) ;
180+ }
181+ } else if c == b'\\' {
182+ // Backslash needs escaping
183+ dst. extend_from_slice ( b"\\ \\ " ) ;
184+ } else {
185+ // Should not happen if mask is correct
186+ dst. push ( c) ;
187+ }
188+ }
189+ }
190+ }
0 commit comments