Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 54 additions & 51 deletions src/stream/tcb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ pub(super) enum PacketType {
/// - `unordered_packets` is the bytes stream received from the lower device,
/// which can be acknowledged and extracted by `consume_unordered_packets` method
/// then can be read by upstream application via `Tcp::poll_read` method.
/// - `ordered_packets` is the list of contiguous packets ready to be read by the application.
#[derive(Debug, Clone)]
pub(crate) struct Tcb {
seq: SeqNum,
Expand All @@ -53,6 +54,7 @@ pub(crate) struct Tcb {
state: TcpState,
inflight_packets: BTreeMap<SeqNum, InflightPacket>,
unordered_packets: BTreeMap<SeqNum, Vec<u8>>,
ordered_packets: Vec<Vec<u8>>,
duplicate_ack_count: usize,
duplicate_ack_count_helper: SeqNum,
}
Expand All @@ -72,6 +74,7 @@ impl Tcb {
state: TcpState::Listen,
inflight_packets: BTreeMap::new(),
unordered_packets: BTreeMap::new(),
ordered_packets: Vec::new(),
duplicate_ack_count: 0,
duplicate_ack_count_helper: seq.into(),
}
Expand Down Expand Up @@ -106,46 +109,40 @@ impl Tcb {
self.unordered_packets.insert(seq, buf);
}
pub(super) fn get_available_read_buffer_size(&self) -> usize {
READ_BUFFER_SIZE.saturating_sub(self.get_unordered_packets_total_len())
let total_buffered = self.get_unordered_packets_total_len() + self.get_ordered_packets_total_len();
READ_BUFFER_SIZE.saturating_sub(total_buffered)
}
#[inline]
pub(crate) fn get_unordered_packets_total_len(&self) -> usize {
self.unordered_packets.values().map(|p| p.len()).sum()
}
#[inline]
pub(crate) fn get_ordered_packets_total_len(&self) -> usize {
self.ordered_packets.iter().map(|p| p.len()).sum()
}

pub(super) fn consume_unordered_packets(&mut self, max_bytes: usize) -> Option<Vec<u8>> {
let mut data = Vec::new();
let mut remaining_bytes = max_bytes;

while remaining_bytes > 0 {
if let Some(seq) = self.unordered_packets.keys().next().copied() {
if seq != self.ack {
break; // sequence number is not continuous, stop extracting
}

// remove and get the first packet
let mut payload = self.unordered_packets.remove(&seq).unwrap();
let payload_len = payload.len();

if payload_len <= remaining_bytes {
// current packet can be fully extracted
data.extend(payload);
self.ack += payload_len as u32;
remaining_bytes -= payload_len;
} else {
// current packet can only be partially extracted
let remaining_payload = payload.split_off(remaining_bytes);
data.extend_from_slice(&payload);
self.ack += remaining_bytes as u32;
self.unordered_packets.insert(self.ack, remaining_payload);
break;
}
} else {
break; // no more packets to extract
pub(super) fn consume_unordered_packets(&mut self) {
while let Some(seq) = self.unordered_packets.keys().next().copied() {
if seq != self.ack {
break; // sequence number is not continuous, stop extracting
}

// remove and get the first packet
let payload = self.unordered_packets.remove(&seq).unwrap();
let payload_len = payload.len();

// Move the packet to ordered_packets
self.ordered_packets.push(payload);
self.ack += payload_len as u32;
}
}

if data.is_empty() { None } else { Some(data) }
pub(super) fn pop_ordered_packet(&mut self) -> Option<Vec<u8>> {
if self.ordered_packets.is_empty() {
None
} else {
Some(self.ordered_packets.remove(0))
}
}

pub(super) fn increase_seq(&mut self) {
Expand Down Expand Up @@ -359,26 +356,32 @@ mod tests {
tcb.add_unordered_packet(SeqNum(1500), vec![2; 500]); // seq=1500, len=500
tcb.add_unordered_packet(SeqNum(2000), vec![3; 500]); // seq=2000, len=500

// test 1: extract up to 700 bytes
let data = tcb.consume_unordered_packets(700).unwrap();
assert_eq!(data.len(), 700); // extract 500 + 200
assert_eq!(data[..500], vec![1; 500]); // the first packet
assert_eq!(data[500..700], vec![2; 200]); // the first 200 bytes of the second packet
assert_eq!(tcb.ack, SeqNum(1700)); // ack increased by 700
assert_eq!(tcb.unordered_packets.len(), 2); // remaining two packets
assert_eq!(tcb.unordered_packets.get(&SeqNum(1700)).unwrap().len(), 300); // the second packet remaining 300 bytes
assert_eq!(tcb.unordered_packets.get(&SeqNum(2000)).unwrap().len(), 500); // the third packet unchanged

// test 2: extract up to 800 bytes
let data = tcb.consume_unordered_packets(800).unwrap();
assert_eq!(data.len(), 800); // extract 300 bytes of the second packet and the third packet
assert_eq!(data[..300], vec![2; 300]); // the remaining 300 bytes of the second packet
assert_eq!(data[300..800], vec![3; 500]); // the third packet
assert_eq!(tcb.ack, SeqNum(2500)); // ack increased by 800
assert_eq!(tcb.unordered_packets.len(), 0); // no remaining packets

// test 3: no data to extract
let data = tcb.consume_unordered_packets(1000);
// test 1: consume contiguous packets
tcb.consume_unordered_packets();
assert_eq!(tcb.ack, SeqNum(2500)); // ack increased by 1500
assert_eq!(tcb.unordered_packets.len(), 0); // all packets consumed
assert_eq!(tcb.ordered_packets.len(), 3); // three packets in ordered list

// test 2: pop first ordered packet
let data = tcb.pop_ordered_packet().unwrap();
assert_eq!(data.len(), 500);
assert_eq!(data, vec![1; 500]);
assert_eq!(tcb.ordered_packets.len(), 2);

// test 3: pop second ordered packet
let data = tcb.pop_ordered_packet().unwrap();
assert_eq!(data.len(), 500);
assert_eq!(data, vec![2; 500]);
assert_eq!(tcb.ordered_packets.len(), 1);

// test 4: pop third ordered packet
let data = tcb.pop_ordered_packet().unwrap();
assert_eq!(data.len(), 500);
assert_eq!(data, vec![3; 500]);
assert_eq!(tcb.ordered_packets.len(), 0);

// test 5: no more data to extract
let data = tcb.pop_ordered_packet();
assert!(data.is_none());
}

Expand Down
20 changes: 11 additions & 9 deletions src/stream/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ async fn tcp_main_logic_loop(
if flags & ACK == ACK {
if len > 0 {
tcb.add_unordered_packet(incoming_seq, payload);
extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &data_tx, &read_notify)?;
extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &read_notify)?;
}
tcb.change_state(TcpState::Established);
}
Expand All @@ -650,7 +650,7 @@ async fn tcp_main_logic_loop(
PacketType::NewPacket => {
tcb.add_unordered_packet(incoming_seq, payload);
let nt = network_tuple;
extract_data_n_write_upstream(&up_packet_sender, &mut tcb, nt, &data_tx, &read_notify)?;
extract_data_n_write_upstream(&up_packet_sender, &mut tcb, nt, &read_notify)?;
write_notify.lock().unwrap().take().map(|w| w.wake_by_ref()).unwrap_or(());
}
PacketType::Ack => {
Expand Down Expand Up @@ -711,7 +711,7 @@ async fn tcp_main_logic_loop(
} else if flags == (ACK | PSH) && pkt_type == PacketType::NewPacket {
if !payload.is_empty() && tcb.get_ack() == incoming_seq {
tcb.add_unordered_packet(incoming_seq, payload);
extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &data_tx, &read_notify)?;
extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &read_notify)?;
}
} else {
// unnormal case, we do nothing here
Expand Down Expand Up @@ -769,7 +769,7 @@ async fn tcp_main_logic_loop(
if len > 0 {
// if the other side is still sending data, we need to deal with it like PacketStatus::NewPacket
tcb.add_unordered_packet(incoming_seq, payload);
extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &data_tx, &read_notify)?;
extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &read_notify)?;
write_notify.lock().unwrap().take().map(|w| w.wake_by_ref()).unwrap_or(());
}
let new_state = tcb.get_state();
Expand Down Expand Up @@ -799,7 +799,7 @@ async fn tcp_main_logic_loop(
} else {
// if the other side is still sending data, we need to deal with it like PacketStatus::NewPacket
tcb.add_unordered_packet(incoming_seq, payload);
extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &data_tx, &read_notify)?;
extract_data_n_write_upstream(&up_packet_sender, &mut tcb, network_tuple, &read_notify)?;
write_notify.lock().unwrap().take().map(|w| w.wake_by_ref()).unwrap_or(());
}
if flags & FIN == FIN {
Expand Down Expand Up @@ -833,7 +833,6 @@ fn extract_data_n_write_upstream(
up_packet_sender: &PacketSender,
tcb: &mut Tcb,
network_tuple: NetworkTuple,
data_tx: &tokio::sync::mpsc::UnboundedSender<Vec<u8>>,
read_notify: &std::sync::Arc<std::sync::Mutex<Option<Waker>>>,
) -> std::io::Result<()> {
let (state, seq, ack) = (tcb.get_state(), tcb.get_seq(), tcb.get_ack());
Expand All @@ -843,10 +842,13 @@ fn extract_data_n_write_upstream(
return Ok(());
}

if let Some(data) = tcb.consume_unordered_packets(8192) {
let before_len = tcb.get_ordered_packets_total_len();
tcb.consume_unordered_packets();
let after_len = tcb.get_ordered_packets_total_len();

if after_len > before_len {
let hint = if state == TcpState::Established { "normally" } else { "still" };
log::trace!("{network_tuple} {state:?}: {l_info} {hint} receiving data, len = {}", data.len());
data_tx.send(data).map_err(|e| std::io::Error::new(BrokenPipe, e))?;
log::trace!("{network_tuple} {state:?}: {l_info} {hint} receiving data, new ordered bytes = {}", after_len - before_len);
read_notify.lock().unwrap().take().map(|w| w.wake_by_ref()).unwrap_or(());
write_packet_to_device(up_packet_sender, network_tuple, tcb, None, ACK, None, None)?;
}
Expand Down
Loading