diff --git a/src/stream/tcb.rs b/src/stream/tcb.rs index 93a484d..f5ee464 100644 --- a/src/stream/tcb.rs +++ b/src/stream/tcb.rs @@ -43,6 +43,7 @@ pub(super) enum PacketType { /// - `unordered_packets` is the bytes stream received from the lower device, /// which can be acknowledged and extracted by `consume_unordered_packets` method /// then can be read by upstream application via `Tcp::poll_read` method. +/// - `ordered_packets` is the list of contiguous packets ready to be read by the application. #[derive(Debug, Clone)] pub(crate) struct Tcb { seq: SeqNum, @@ -53,6 +54,7 @@ pub(crate) struct Tcb { state: TcpState, inflight_packets: BTreeMap, unordered_packets: BTreeMap>, + ordered_packets: Vec>, duplicate_ack_count: usize, duplicate_ack_count_helper: SeqNum, } @@ -72,6 +74,7 @@ impl Tcb { state: TcpState::Listen, inflight_packets: BTreeMap::new(), unordered_packets: BTreeMap::new(), + ordered_packets: Vec::new(), duplicate_ack_count: 0, duplicate_ack_count_helper: seq.into(), } @@ -106,46 +109,40 @@ impl Tcb { self.unordered_packets.insert(seq, buf); } pub(super) fn get_available_read_buffer_size(&self) -> usize { - READ_BUFFER_SIZE.saturating_sub(self.get_unordered_packets_total_len()) + let total_buffered = self.get_unordered_packets_total_len() + self.get_ordered_packets_total_len(); + READ_BUFFER_SIZE.saturating_sub(total_buffered) } #[inline] pub(crate) fn get_unordered_packets_total_len(&self) -> usize { self.unordered_packets.values().map(|p| p.len()).sum() } + #[inline] + pub(crate) fn get_ordered_packets_total_len(&self) -> usize { + self.ordered_packets.iter().map(|p| p.len()).sum() + } - pub(super) fn consume_unordered_packets(&mut self, max_bytes: usize) -> Option> { - let mut data = Vec::new(); - let mut remaining_bytes = max_bytes; - - while remaining_bytes > 0 { - if let Some(seq) = self.unordered_packets.keys().next().copied() { - if seq != self.ack { - break; // sequence number is not continuous, stop extracting - } - - // remove and get the first packet - let mut payload = self.unordered_packets.remove(&seq).unwrap(); - let payload_len = payload.len(); - - if payload_len <= remaining_bytes { - // current packet can be fully extracted - data.extend(payload); - self.ack += payload_len as u32; - remaining_bytes -= payload_len; - } else { - // current packet can only be partially extracted - let remaining_payload = payload.split_off(remaining_bytes); - data.extend_from_slice(&payload); - self.ack += remaining_bytes as u32; - self.unordered_packets.insert(self.ack, remaining_payload); - break; - } - } else { - break; // no more packets to extract + pub(super) fn consume_unordered_packets(&mut self) { + while let Some(seq) = self.unordered_packets.keys().next().copied() { + if seq != self.ack { + break; // sequence number is not continuous, stop extracting } + + // remove and get the first packet + let payload = self.unordered_packets.remove(&seq).unwrap(); + let payload_len = payload.len(); + + // Move the packet to ordered_packets + self.ordered_packets.push(payload); + self.ack += payload_len as u32; } + } - if data.is_empty() { None } else { Some(data) } + pub(super) fn pop_ordered_packet(&mut self) -> Option> { + if self.ordered_packets.is_empty() { + None + } else { + Some(self.ordered_packets.remove(0)) + } } pub(super) fn increase_seq(&mut self) { @@ -359,26 +356,32 @@ mod tests { tcb.add_unordered_packet(SeqNum(1500), vec![2; 500]); // seq=1500, len=500 tcb.add_unordered_packet(SeqNum(2000), vec![3; 500]); // seq=2000, len=500 - // test 1: extract up to 700 bytes - let data = tcb.consume_unordered_packets(700).unwrap(); - assert_eq!(data.len(), 700); // extract 500 + 200 - assert_eq!(data[..500], vec![1; 500]); // the first packet - assert_eq!(data[500..700], vec![2; 200]); // the first 200 bytes of the second packet - assert_eq!(tcb.ack, SeqNum(1700)); // ack increased by 700 - assert_eq!(tcb.unordered_packets.len(), 2); // remaining two packets - assert_eq!(tcb.unordered_packets.get(&SeqNum(1700)).unwrap().len(), 300); // the second packet remaining 300 bytes - assert_eq!(tcb.unordered_packets.get(&SeqNum(2000)).unwrap().len(), 500); // the third packet unchanged - - // test 2: extract up to 800 bytes - let data = tcb.consume_unordered_packets(800).unwrap(); - assert_eq!(data.len(), 800); // extract 300 bytes of the second packet and the third packet - assert_eq!(data[..300], vec![2; 300]); // the remaining 300 bytes of the second packet - assert_eq!(data[300..800], vec![3; 500]); // the third packet - assert_eq!(tcb.ack, SeqNum(2500)); // ack increased by 800 - assert_eq!(tcb.unordered_packets.len(), 0); // no remaining packets - - // test 3: no data to extract - let data = tcb.consume_unordered_packets(1000); + // test 1: consume contiguous packets + tcb.consume_unordered_packets(); + assert_eq!(tcb.ack, SeqNum(2500)); // ack increased by 1500 + assert_eq!(tcb.unordered_packets.len(), 0); // all packets consumed + assert_eq!(tcb.ordered_packets.len(), 3); // three packets in ordered list + + // test 2: pop first ordered packet + let data = tcb.pop_ordered_packet().unwrap(); + assert_eq!(data.len(), 500); + assert_eq!(data, vec![1; 500]); + assert_eq!(tcb.ordered_packets.len(), 2); + + // test 3: pop second ordered packet + let data = tcb.pop_ordered_packet().unwrap(); + assert_eq!(data.len(), 500); + assert_eq!(data, vec![2; 500]); + assert_eq!(tcb.ordered_packets.len(), 1); + + // test 4: pop third ordered packet + let data = tcb.pop_ordered_packet().unwrap(); + assert_eq!(data.len(), 500); + assert_eq!(data, vec![3; 500]); + assert_eq!(tcb.ordered_packets.len(), 0); + + // test 5: no more data to extract + let data = tcb.pop_ordered_packet(); assert!(data.is_none()); } diff --git a/src/stream/tcp.rs b/src/stream/tcp.rs index c1ad80e..aea1988 100644 --- a/src/stream/tcp.rs +++ b/src/stream/tcp.rs @@ -623,7 +623,7 @@ async fn tcp_main_logic_loop( if flags & ACK == ACK { if len > 0 { tcb.add_unordered_packet(incoming_seq, payload); - extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &data_tx, &read_notify)?; + extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &read_notify)?; } tcb.change_state(TcpState::Established); } @@ -650,7 +650,7 @@ async fn tcp_main_logic_loop( PacketType::NewPacket => { tcb.add_unordered_packet(incoming_seq, payload); let nt = network_tuple; - extract_data_n_write_upstream(&up_packet_sender, &mut tcb, nt, &data_tx, &read_notify)?; + extract_data_n_write_upstream(&up_packet_sender, &mut tcb, nt, &read_notify)?; write_notify.lock().unwrap().take().map(|w| w.wake_by_ref()).unwrap_or(()); } PacketType::Ack => { @@ -711,7 +711,7 @@ async fn tcp_main_logic_loop( } else if flags == (ACK | PSH) && pkt_type == PacketType::NewPacket { if !payload.is_empty() && tcb.get_ack() == incoming_seq { tcb.add_unordered_packet(incoming_seq, payload); - extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &data_tx, &read_notify)?; + extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &read_notify)?; } } else { // unnormal case, we do nothing here @@ -769,7 +769,7 @@ async fn tcp_main_logic_loop( if len > 0 { // if the other side is still sending data, we need to deal with it like PacketStatus::NewPacket tcb.add_unordered_packet(incoming_seq, payload); - extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &data_tx, &read_notify)?; + extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &read_notify)?; write_notify.lock().unwrap().take().map(|w| w.wake_by_ref()).unwrap_or(()); } let new_state = tcb.get_state(); @@ -799,7 +799,7 @@ async fn tcp_main_logic_loop( } else { // if the other side is still sending data, we need to deal with it like PacketStatus::NewPacket tcb.add_unordered_packet(incoming_seq, payload); - extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &data_tx, &read_notify)?; + extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &read_notify)?; write_notify.lock().unwrap().take().map(|w| w.wake_by_ref()).unwrap_or(()); } if flags & FIN == FIN { @@ -833,7 +833,6 @@ fn extract_data_n_write_upstream( up_packet_sender: &PacketSender, tcb: &mut Tcb, network_tuple: NetworkTuple, - data_tx: &tokio::sync::mpsc::UnboundedSender>, read_notify: &std::sync::Arc>>, ) -> std::io::Result<()> { let (state, seq, ack) = (tcb.get_state(), tcb.get_seq(), tcb.get_ack()); @@ -843,10 +842,13 @@ fn extract_data_n_write_upstream( return Ok(()); } - if let Some(data) = tcb.consume_unordered_packets(8192) { + let before_len = tcb.get_ordered_packets_total_len(); + tcb.consume_unordered_packets(); + let after_len = tcb.get_ordered_packets_total_len(); + + if after_len > before_len { let hint = if state == TcpState::Established { "normally" } else { "still" }; - log::trace!("{network_tuple} {state:?}: {l_info} {hint} receiving data, len = {}", data.len()); - data_tx.send(data).map_err(|e| std::io::Error::new(BrokenPipe, e))?; + log::trace!("{network_tuple} {state:?}: {l_info} {hint} receiving data, new ordered bytes = {}", after_len - before_len); read_notify.lock().unwrap().take().map(|w| w.wake_by_ref()).unwrap_or(()); write_packet_to_device(up_packet_sender, network_tuple, tcb, None, ACK, None, None)?; }