44use std:: collections:: VecDeque ;
55use std:: fs:: File ;
66use std:: io:: { Read , Write } ;
7+ use std:: os:: unix:: io:: FromRawFd ;
78
89use crate :: common:: ascii:: { CR , CRLF_LEN , LF } ;
910use crate :: common:: Body ;
@@ -15,6 +16,7 @@ use crate::server::MAX_PAYLOAD_SIZE;
1516use vmm_sys_util:: sock_ctrl_msg:: ScmSocket ;
1617
1718const BUFFER_SIZE : usize = 1024 ;
19+ const SCM_MAX_FD : usize = 253 ;
1820
1921/// Describes the state machine of an HTTP connection.
2022enum ConnectionState {
@@ -52,9 +54,9 @@ pub struct HttpConnection<T> {
5254 /// A buffer containing the bytes of a response that is currently
5355 /// being sent.
5456 response_buffer : Option < Vec < u8 > > ,
55- /// The latest file that has been received and which must be associated
57+ /// The list of files that has been received and which must be associated
5658 /// with the pending request.
57- file : Option < File > ,
59+ files : Vec < File > ,
5860 /// Optional payload max size.
5961 payload_max_size : usize ,
6062}
@@ -73,7 +75,7 @@ impl<T: Read + Write + ScmSocket> HttpConnection<T> {
7375 parsed_requests : VecDeque :: new ( ) ,
7476 response_queue : VecDeque :: new ( ) ,
7577 response_buffer : None ,
76- file : None ,
78+ files : Vec :: new ( ) ,
7779 payload_max_size : MAX_PAYLOAD_SIZE ,
7880 }
7981 }
@@ -123,7 +125,7 @@ impl<T: Read + Write + ScmSocket> HttpConnection<T> {
123125 self . state = ConnectionState :: WaitingForRequestLine ;
124126 self . body_bytes_to_be_read = 0 ;
125127 let mut pending_request = self . pending_request . take ( ) . unwrap ( ) ;
126- pending_request. file = self . file . take ( ) ;
128+ pending_request. files = self . files . drain ( .. ) . collect ( ) ;
127129 self . parsed_requests . push_back ( pending_request) ;
128130 }
129131 } ;
@@ -143,15 +145,11 @@ impl<T: Read + Write + ScmSocket> HttpConnection<T> {
143145 }
144146 // Append new bytes to what we already have in the buffer.
145147 // The slice access is safe, the index is checked above.
146- let ( bytes_read, file) = self
147- . stream
148- . recv_with_fd ( & mut self . buffer [ self . read_cursor ..] )
149- . map_err ( ConnectionError :: StreamReadError ) ?;
150-
151- // Update the internal file that must be associated with the request.
152- if file. is_some ( ) {
153- self . file = file;
154- }
148+ let ( bytes_read, new_files) = self . recv_with_fds ( ) ?;
149+
150+ // Update the internal list of files that must be associated with the
151+ // request.
152+ self . files . extend ( new_files) ;
155153
156154 // If the read returned 0 then the client has closed the connection.
157155 if bytes_read == 0 {
@@ -162,6 +160,43 @@ impl<T: Read + Write + ScmSocket> HttpConnection<T> {
162160 . ok_or ( ConnectionError :: ParseError ( RequestError :: Overflow ) )
163161 }
164162
163+ /// Receive data along with optional files descriptors.
164+ /// It is a wrapper around the same function from vmm-sys-util.
165+ ///
166+ /// # Errors
167+ /// `StreamError` is returned if any error occurred while reading the stream.
168+ fn recv_with_fds ( & mut self ) -> Result < ( usize , Vec < File > ) , ConnectionError > {
169+ let buf = & mut self . buffer [ self . read_cursor ..] ;
170+ // We must allocate the maximum number of receivable file descriptors
171+ // if don't want to miss any of them. Allocating a too small number
172+ // would lead to the incapacity of receiving the file descriptors.
173+ let mut fds = [ 0 ; SCM_MAX_FD ] ;
174+ let mut iovecs = [ libc:: iovec {
175+ iov_base : buf. as_mut_ptr ( ) as * mut libc:: c_void ,
176+ iov_len : buf. len ( ) ,
177+ } ] ;
178+
179+ // Safe because we have mutably borrowed buf and it's safe to write
180+ // arbitrary data to a slice.
181+ let ( read_count, fd_count) = unsafe {
182+ self . stream
183+ . recv_with_fds ( & mut iovecs, & mut fds)
184+ . map_err ( ConnectionError :: StreamReadError ) ?
185+ } ;
186+
187+ Ok ( (
188+ read_count,
189+ fds. iter ( )
190+ . take ( fd_count)
191+ . map ( |fd| {
192+ // Safe because all fds are owned by us after they have been
193+ // received through the socket.
194+ unsafe { File :: from_raw_fd ( * fd) }
195+ } )
196+ . collect ( ) ,
197+ ) )
198+ }
199+
165200 /// Parses bytes in `buffer` for a valid request line.
166201 /// Returns `false` if there are no more bytes to be parsed in the buffer.
167202 ///
@@ -197,7 +232,7 @@ impl<T: Read + Write + ScmSocket> HttpConnection<T> {
197232 . map_err ( ConnectionError :: ParseError ) ?,
198233 headers : Headers :: default ( ) ,
199234 body : None ,
200- file : None ,
235+ files : Vec :: new ( ) ,
201236 } ) ;
202237 self . state = ConnectionState :: WaitingForHeaders ;
203238 Ok ( true )
@@ -517,13 +552,17 @@ impl<T: Read + Write + ScmSocket> HttpConnection<T> {
517552
518553#[ cfg( test) ]
519554mod tests {
555+ use std:: io:: { Seek , SeekFrom } ;
520556 use std:: net:: Shutdown ;
557+ use std:: os:: unix:: io:: IntoRawFd ;
521558 use std:: os:: unix:: net:: UnixStream ;
522559
523560 use super :: * ;
524561 use crate :: common:: { Method , Version } ;
525562 use crate :: server:: MAX_PAYLOAD_SIZE ;
526563
564+ use vmm_sys_util:: tempfile:: TempFile ;
565+
527566 #[ test]
528567 fn test_try_read_expect ( ) {
529568 // Test request with `Expect` header.
@@ -548,7 +587,7 @@ mod tests {
548587 request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
549588 headers : Headers :: new ( 26 , true , true ) ,
550589 body : Some ( Body :: new ( b"this is not\n \r \n a json \n body" . to_vec ( ) ) ) ,
551- file : None ,
590+ files : Vec :: new ( ) ,
552591 } ;
553592
554593 assert_eq ! ( request, expected_request) ;
@@ -585,7 +624,7 @@ mod tests {
585624 request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
586625 headers : Headers :: new ( 26 , true , true ) ,
587626 body : Some ( Body :: new ( b"this is not\n \r \n a json \n body" . to_vec ( ) ) ) ,
588- file : None ,
627+ files : Vec :: new ( ) ,
589628 } ;
590629 assert_eq ! ( request, expected_request) ;
591630 }
@@ -619,7 +658,7 @@ mod tests {
619658 request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
620659 headers : Headers :: new ( 26 , true , true ) ,
621660 body : Some ( Body :: new ( b"this is not\n \r \n a json \n body" . to_vec ( ) ) ) ,
622- file : None ,
661+ files : Vec :: new ( ) ,
623662 } ;
624663 assert_eq ! ( request, expected_request) ;
625664 }
@@ -684,7 +723,7 @@ mod tests {
684723 request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
685724 headers : Headers :: new ( 1400 , true , true ) ,
686725 body : Some ( Body :: new ( request_body) ) ,
687- file : None ,
726+ files : Vec :: new ( ) ,
688727 } ;
689728
690729 assert_eq ! ( request, expected_request) ;
@@ -755,7 +794,7 @@ mod tests {
755794 request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
756795 headers : Headers :: new ( 0 , true , true ) ,
757796 body : None ,
758- file : None ,
797+ files : Vec :: new ( ) ,
759798 } ;
760799 assert_eq ! ( request, expected_request) ;
761800 }
@@ -777,7 +816,7 @@ mod tests {
777816 request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
778817 headers : Headers :: new ( 0 , false , false ) ,
779818 body : None ,
780- file : None ,
819+ files : Vec :: new ( ) ,
781820 } ;
782821 assert_eq ! ( request, expected_request) ;
783822 }
@@ -806,7 +845,7 @@ mod tests {
806845 request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
807846 headers : Headers :: new ( 0 , false , false ) ,
808847 body : None ,
809- file : None ,
848+ files : Vec :: new ( ) ,
810849 } ;
811850 assert_eq ! ( request, expected_request) ;
812851
@@ -825,7 +864,7 @@ mod tests {
825864 ) ,
826865 headers : Headers :: new ( 0 , false , false ) ,
827866 body : None ,
828- file : None ,
867+ files : Vec :: new ( ) ,
829868 } ;
830869 assert_eq ! ( request, expected_request) ;
831870 }
@@ -853,7 +892,7 @@ mod tests {
853892 request_line : RequestLine :: new ( Method :: Patch , "http://localhost/home" , Version :: Http11 ) ,
854893 headers : Headers :: new ( 26 , false , true ) ,
855894 body : Some ( Body :: new ( b"this is not\n \r \n a json \n body" . to_vec ( ) ) ) ,
856- file : None ,
895+ files : Vec :: new ( ) ,
857896 } ;
858897
859898 conn. try_read ( ) . unwrap ( ) ;
@@ -864,7 +903,7 @@ mod tests {
864903 request_line : RequestLine :: new ( Method :: Put , "http://farhost/away" , Version :: Http11 ) ,
865904 headers : Headers :: new ( 23 , false , false ) ,
866905 body : Some ( Body :: new ( b"this is another request" . to_vec ( ) ) ) ,
867- file : None ,
906+ files : Vec :: new ( ) ,
868907 } ;
869908 assert_eq ! ( request_first, expected_request_first) ;
870909 assert_eq ! ( request_second, expected_request_second) ;
@@ -999,6 +1038,77 @@ mod tests {
9991038 ) ;
10001039 }
10011040
1041+ #[ test]
1042+ fn test_read_bytes_with_files ( ) {
1043+ let ( sender, receiver) = UnixStream :: pair ( ) . unwrap ( ) ;
1044+ receiver. set_nonblocking ( true ) . expect ( "Can't modify socket" ) ;
1045+ let mut conn = HttpConnection :: new ( receiver) ;
1046+
1047+ // Create 3 files, edit the content and rewind back to the start.
1048+ let mut file1 = TempFile :: new ( ) . unwrap ( ) . into_file ( ) ;
1049+ let mut file2 = TempFile :: new ( ) . unwrap ( ) . into_file ( ) ;
1050+ let mut file3 = TempFile :: new ( ) . unwrap ( ) . into_file ( ) ;
1051+ file1. write ( b"foo" ) . unwrap ( ) ;
1052+ file1. seek ( SeekFrom :: Start ( 0 ) ) . unwrap ( ) ;
1053+ file2. write ( b"bar" ) . unwrap ( ) ;
1054+ file2. seek ( SeekFrom :: Start ( 0 ) ) . unwrap ( ) ;
1055+ file3. write ( b"foobar" ) . unwrap ( ) ;
1056+ file3. seek ( SeekFrom :: Start ( 0 ) ) . unwrap ( ) ;
1057+
1058+ // Send 2 file descriptors along with 3 bytes of data.
1059+ assert_eq ! (
1060+ sender. send_with_fds(
1061+ & [ [ 1 , 2 , 3 ] . as_ref( ) ] ,
1062+ & [ file1. into_raw_fd( ) , file2. into_raw_fd( ) ]
1063+ ) ,
1064+ Ok ( 3 )
1065+ ) ;
1066+
1067+ // Check we receive the right amount of data along with the right
1068+ // amount of file descriptors.
1069+ assert_eq ! ( conn. read_bytes( ) , Ok ( 3 ) ) ;
1070+ assert_eq ! ( conn. files. len( ) , 2 ) ;
1071+
1072+ // Check the content of the data received
1073+ assert_eq ! ( conn. buffer[ 0 ] , 1 ) ;
1074+ assert_eq ! ( conn. buffer[ 1 ] , 2 ) ;
1075+ assert_eq ! ( conn. buffer[ 2 ] , 3 ) ;
1076+
1077+ // Check the file descriptors are usable by checking the content that
1078+ // can be read.
1079+ let mut buf = [ 0 ; 10 ] ;
1080+ assert_eq ! ( conn. files[ 0 ] . read( & mut buf) . unwrap( ) , 3 ) ;
1081+ assert_eq ! ( & buf[ ..3 ] , b"foo" ) ;
1082+ assert_eq ! ( conn. files[ 1 ] . read( & mut buf) . unwrap( ) , 3 ) ;
1083+ assert_eq ! ( & buf[ ..3 ] , b"bar" ) ;
1084+
1085+ // Send the 3rd file descriptor along with 1 byte of data.
1086+ assert_eq ! (
1087+ sender. send_with_fds( & [ [ 10 ] . as_ref( ) ] , & [ file3. into_raw_fd( ) ] ) ,
1088+ Ok ( 1 )
1089+ ) ;
1090+
1091+ // Check the amount of data along with the amount of file descriptors
1092+ // are updated.
1093+ assert_eq ! ( conn. read_bytes( ) , Ok ( 1 ) ) ;
1094+ assert_eq ! ( conn. files. len( ) , 3 ) ;
1095+
1096+ // Check the content of the new data received
1097+ assert_eq ! ( conn. buffer[ 0 ] , 10 ) ;
1098+
1099+ // Check the latest file descriptor is usable by checking the content
1100+ // that can be read.
1101+ let mut buf = [ 0 ; 10 ] ;
1102+ assert_eq ! ( conn. files[ 2 ] . read( & mut buf) . unwrap( ) , 6 ) ;
1103+ assert_eq ! ( & buf[ ..6 ] , b"foobar" ) ;
1104+
1105+ sender. shutdown ( Shutdown :: Write ) . unwrap ( ) ;
1106+ assert_eq ! (
1107+ conn. read_bytes( ) . unwrap_err( ) ,
1108+ ConnectionError :: ConnectionClosed
1109+ ) ;
1110+ }
1111+
10021112 #[ test]
10031113 fn test_shift_buffer_left ( ) {
10041114 let ( _, receiver) = UnixStream :: pair ( ) . unwrap ( ) ;
@@ -1095,7 +1205,7 @@ mod tests {
10951205 request_line : RequestLine :: new ( Method :: Get , "http://foo/bar" , Version :: Http11 ) ,
10961206 headers : Headers :: new ( 0 , true , true ) ,
10971207 body : None ,
1098- file : None ,
1208+ files : Vec :: new ( ) ,
10991209 } ) ;
11001210 assert_eq ! (
11011211 conn. parse_headers( & mut 0 , BUFFER_SIZE ) . unwrap_err( ) ,
@@ -1153,7 +1263,7 @@ mod tests {
11531263 request_line : RequestLine :: new ( Method :: Get , "http://foo/bar" , Version :: Http11 ) ,
11541264 headers : Headers :: new ( 0 , true , true ) ,
11551265 body : None ,
1156- file : None ,
1266+ files : Vec :: new ( ) ,
11571267 } ) ;
11581268 conn. body_vec = vec ! [ 0xde , 0xad , 0xbe , 0xef ] ;
11591269 assert_eq ! (
0 commit comments