Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • stephen/minipac
1 result
Show changes
Commits on Source (3)
......@@ -3,13 +3,13 @@ image: "rust:latest"
before_script:
- rustup component add rustfmt
- rustup component add clippy
- cargo install --git https://gitlab.scd31.com/stephen/tcp-relay
- cargo install --git https://gitlab.scd31.com/stephen/kiss-relay
- trap 'kill $(jobs -p)' EXIT
test:
script:
- cargo fmt -- --check
- cargo clippy --all-targets --all-features -- -D warnings
- relay &
- kiss-relay &
- sleep 1
- cargo test --release
[package]
name = "minipac"
version = "0.2.0"
version = "0.3.0"
edition = "2018"
repository = "https://gitlab.scd31.com/stephen/minipac"
authors = ["Stephen D"]
......@@ -11,8 +11,12 @@ license = "MIT"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
kiss-tnc = "0.1.3"
kiss-tnc = "0.2.1"
thiserror = "1.0"
rand = "0.8"
zstd = "0.9"
crc = "2.1"
tokio = { version = "1.0", features = ["rt", "sync", "tracing"] }
[dev-dependencies]
tokio = { version = "1", features = ["macros", "rt", "sync", "time"] }
......@@ -5,11 +5,11 @@ use crate::packet::{Packet, PacketData};
use crate::state::ConnectionState;
use crate::stream::Stream;
use kiss_tnc::tnc::Tnc;
use std::net::ToSocketAddrs;
use std::sync::{mpsc, mpsc::TryRecvError};
use std::thread;
use kiss_tnc::Tnc;
use std::time::{Duration, Instant};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::ToSocketAddrs;
use tokio::sync::{mpsc, mpsc::error::TryRecvError};
pub struct Client {
us: Hostname,
......@@ -19,7 +19,7 @@ pub struct Client {
baud_rate: usize,
// sends packets to the Tnc (mpsc)
packet_sender: mpsc::Sender<Vec<u8>>,
packet_sender: mpsc::UnboundedSender<Vec<u8>>,
// sends channels which packets should be sent through
channel_sender: mpsc::Sender<mpsc::Sender<Packet>>,
}
......@@ -27,7 +27,7 @@ pub struct Client {
impl Client {
/// Create a new client
/// Used for initiating connections and sending pings
pub fn new<A: ToSocketAddrs>(
pub async fn new_tcp<A: ToSocketAddrs>(
addr: A,
us: Hostname,
timeout: Duration,
......@@ -35,19 +35,34 @@ impl Client {
max_retries: usize,
baud_rate: usize,
) -> Result<Client, ConnectError> {
let mut tnc = Tnc::connect(addr)?;
let tnc = Tnc::connect_tcp(addr).await?;
let (packet_sender, packet_receiver) = mpsc::channel::<Vec<_>>();
let (channel_sender, channel_receiver) = mpsc::channel();
Client::new(tnc, us, timeout, max_packet_size, max_retries, baud_rate).await
}
pub async fn new<
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
>(
tnc: Tnc<R, W>,
us: Hostname,
timeout: Duration,
max_packet_size: usize,
max_retries: usize,
baud_rate: usize,
) -> Result<Client, ConnectError> {
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel::<Vec<_>>();
let (channel_sender, mut channel_receiver) = mpsc::channel(1);
let (mut tnc_read, mut tnc_write) = tnc.split();
// take incoming frames and stuff them in a channel
{
let mut tnc = tnc.try_clone()?;
let packet_sender = packet_sender.clone();
thread::spawn(move || {
tokio::task::spawn(async move {
let mut channels: Vec<mpsc::Sender<Packet>> = vec![];
while let Ok((_, data)) = tnc.read_frame() {
while let Ok((_, data)) = tnc_read.read_frame().await {
loop {
match channel_receiver.try_recv() {
Ok(ch) => channels.push(ch),
......@@ -63,25 +78,27 @@ impl Client {
}
if let Some(packet) = Packet::from_bytes(&data) {
for i in (0..channels.len()).rev() {
if channels[i].send(packet.clone()).is_err() {
channels.remove(i);
if packet.to == us {
for i in (0..channels.len()).rev() {
if channels[i].send(packet.clone()).await.is_err() {
channels.remove(i);
}
}
}
// handle pings
if packet.to == us && packet.data == PacketData::Ping {
// send pong
let mut buf = vec![];
Packet::new(
us,
packet.from,
(packet.seq_num % 255) + 1,
PacketData::Pong,
)
.to_bytes(&mut buf);
if packet_sender.send(buf).is_err() {
break;
// handle pings
if packet.data == PacketData::Ping {
// send pong
let mut buf = vec![];
Packet::new(
us,
packet.from,
(packet.seq_num % 255) + 1,
PacketData::Pong,
)
.to_bytes(&mut buf);
if packet_sender.send(buf).is_err() {
break;
}
}
}
}
......@@ -91,11 +108,12 @@ impl Client {
// process outgoing frames
{
thread::spawn(move || {
while let Ok(data) = packet_receiver.recv() {
if tnc.send_frame(&data).is_err() {
tokio::task::spawn(async move {
while let Some(data) = packet_receiver.recv().await {
if tnc_write.send_frame(&data).await.is_err() {
break;
}
let _ = tnc_write.flush().await;
}
});
}
......@@ -113,10 +131,13 @@ impl Client {
/// Connects to the specified server.
/// Returns a stream which can be used to send/receive data.
pub fn connect(&self, to: Hostname) -> Result<Stream, ConnectError> {
pub async fn connect(&self, to: Hostname) -> Result<Stream, ConnectError> {
let (s1, s2) = Stream::new_entangled();
let (packet_sender, packet_receiver) = mpsc::channel();
self.channel_sender.send(packet_sender).map_broken_pipe()?;
let (packet_sender, packet_receiver) = mpsc::channel(1);
self.channel_sender
.send(packet_sender)
.await
.map_broken_pipe()?;
let mut client = ClientReader {
state: ConnectionState::new(
......@@ -132,12 +153,12 @@ impl Client {
packet_receiver,
};
client.actually_connect()?;
client.actually_connect().await?;
// process incoming frames
{
thread::spawn(move || {
let _ = client.process_forever();
tokio::task::spawn(async move {
let _ = client.process_forever().await;
});
}
......@@ -145,12 +166,15 @@ impl Client {
}
/// Pings a host
pub fn ping(&self, to: Hostname) -> Result<bool, BrokenPipeError> {
pub async fn ping(&self, to: Hostname) -> Result<bool, BrokenPipeError> {
let id = rand::random();
let next = (id % 255) + 1;
let (packet_sender, packet_receiver) = mpsc::channel();
self.channel_sender.send(packet_sender).map_broken_pipe()?;
let (packet_sender, mut packet_receiver) = mpsc::channel(1);
self.channel_sender
.send(packet_sender)
.await
.map_broken_pipe()?;
let mut buf = vec![];
Packet::new(self.us, to, id, PacketData::Ping).to_bytes(&mut buf);
......@@ -158,14 +182,14 @@ impl Client {
let max_time = Instant::now() + self.timeout;
while Instant::now() < max_time {
if let Some(packet) = try_read_packet(&packet_receiver)? {
if let Some(packet) = try_read_packet(&mut packet_receiver)? {
if packet == Packet::new(to, self.us, next, PacketData::Pong) {
// Ping success
return Ok(true);
}
}
thread::sleep(Duration::from_millis(10));
tokio::time::sleep(Duration::from_millis(10)).await;
}
// No response
......@@ -179,38 +203,42 @@ pub struct ClientReader {
}
impl ClientReader {
fn process_forever(&mut self) -> Result<(), Error> {
async fn process_forever(&mut self) -> Result<(), Error> {
loop {
match self.packet_receiver.try_recv() {
Ok(packet) => {
let r =
tokio::time::timeout(Duration::from_millis(50), self.packet_receiver.recv()).await;
match r {
Ok(Some(packet)) => {
if packet.to == self.state.us && packet.from == self.state.them {
self.process_packet(packet)?;
self.process_packet(packet).await?;
}
}
Err(TryRecvError::Empty) => {
thread::sleep(Duration::from_millis(10));
}
Err(TryRecvError::Disconnected) => {
Ok(None) => {
// channel died
// kill our thread
return Err(Error::BrokenPipe);
}
Err(_) => {
// no packet within timeout
// loop as usual
}
}
self.state.process()?;
self.state.process().await?;
}
}
fn process_packet(&mut self, packet: Packet) -> Result<(), Error> {
self.state.incoming_packet(packet)
async fn process_packet(&mut self, packet: Packet) -> Result<(), Error> {
self.state.incoming_packet(packet).await
}
fn actually_connect(&mut self) -> Result<(), ConnectError> {
async fn actually_connect(&mut self) -> Result<(), ConnectError> {
let mut connected = false;
for _ in 0..self.state.max_retries {
match self.try_establish_connection() {
match self.try_establish_connection().await {
Ok(()) => {
connected = true;
break;
......@@ -230,7 +258,7 @@ impl ClientReader {
Ok(())
}
fn try_establish_connection(&mut self) -> Result<(), ConnectError> {
async fn try_establish_connection(&mut self) -> Result<(), ConnectError> {
// Send connect packet
let mut buf = vec![];
Packet::new(
......@@ -240,11 +268,12 @@ impl ClientReader {
PacketData::Connect,
)
.to_bytes(&mut buf);
self.state.packet_sender.send(buf).map_broken_pipe()?;
let max_time = Instant::now() + self.state.timeout;
while Instant::now() < max_time {
if let Some(packet) = try_read_packet(&self.packet_receiver)? {
if let Some(packet) = try_read_packet(&mut self.packet_receiver)? {
if packet.from == self.state.them
&& packet.to == self.state.us
&& packet.data == PacketData::ConnectAck
......@@ -253,6 +282,8 @@ impl ClientReader {
return Ok(());
}
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
Err(ConnectError::ConnectionFailure)
......
......@@ -66,3 +66,19 @@ impl<T, Q> MapBrokenPipe<T> for Result<T, Q> {
}
}
}
impl<T> MapBrokenPipe<T> for Option<T> {
fn map_broken_pipe(self) -> Result<T, BrokenPipeError> {
match self {
Some(x) => Ok(x),
None => Err(BrokenPipeError),
}
}
fn map_disconnected(self) -> Result<T, Error> {
match self {
Some(x) => Ok(x),
None => Err(Error::Disconnected),
}
}
}
......@@ -55,7 +55,12 @@ impl Hostname {
impl Display for Hostname {
fn fmt(&self, ft: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
let callsign: String = self.callsign.iter().map(|x| *x as char).collect();
let callsign: String = self
.callsign
.iter()
.filter(|x| **x > 0)
.map(|x| *x as char)
.collect();
write!(ft, "{}-{}", callsign, self.index)
}
}
......@@ -84,6 +89,8 @@ mod tests {
}),
host
);
assert_eq!("N0CALL-0", host.unwrap().to_string());
}
#[test]
......@@ -97,6 +104,8 @@ mod tests {
}),
host
);
assert_eq!("N0CALL-7", host.unwrap().to_string());
}
#[test]
......@@ -131,6 +140,8 @@ mod tests {
}),
host
);
assert_eq!("N0CALL-255", host.unwrap().to_string());
}
#[test]
......@@ -144,6 +155,8 @@ mod tests {
}),
host
);
assert_eq!("N0CALL-0", host.unwrap().to_string());
}
#[test]
......@@ -164,6 +177,8 @@ mod tests {
}),
host
);
assert_eq!("N0CALL1-7", host.unwrap().to_string());
}
#[test]
......@@ -190,5 +205,7 @@ mod tests {
let host2 = Hostname::new("N0CALL");
assert_eq!(host, host2);
assert_eq!("N0CALL-0", host.unwrap().to_string());
}
}
......@@ -2,7 +2,7 @@ use crate::error::BrokenPipeError;
use crate::hostname::Hostname;
use crc::{Crc, CRC_16_IBM_SDLC};
use std::sync::{mpsc, mpsc::TryRecvError};
use tokio::sync::{mpsc, mpsc::error::TryRecvError};
// 8 bytes for hostname
// 8 bytes for hostname
......@@ -46,7 +46,7 @@ impl Packet {
}
pub fn from_bytes(buf: &[u8]) -> Option<Self> {
if buf.len() < 19 {
if buf.len() < 20 {
return None;
}
......@@ -260,7 +260,7 @@ fn u8s_to_u16(i: &[u8]) -> u16 {
}
pub(crate) fn try_read_packet(
packet_receiver: &mpsc::Receiver<Packet>,
packet_receiver: &mut mpsc::Receiver<Packet>,
) -> Result<Option<Packet>, BrokenPipeError> {
match packet_receiver.try_recv() {
Ok(packet) => Ok(Some(packet)),
......@@ -281,6 +281,34 @@ pub(crate) fn is_valid_seq_num(cur: u8, new: u8) -> bool {
mod tests {
use super::*;
#[test]
fn packet_decode_too_short_never_panics() {
let packet = Packet::new(
Hostname::new("N0CALL-0").unwrap(),
Hostname::new("ABCDEFG-243").unwrap(),
37,
PacketData::Ping,
);
let mut buf = vec![];
packet.to_bytes(&mut buf);
for i in 0..buf.len() {
for j in 0..i {
let mut truncated = buf[j..i].to_vec();
assert_eq!(None, Packet::from_bytes(&truncated));
// add a checksum on so it makes it past that part of the code
let checksum = X25.checksum(&truncated);
truncated.extend(u16_to_u8s(checksum));
// might be none, might be some
let _ = Packet::from_bytes(&truncated);
}
}
}
#[test]
fn encode_and_decode_ping() {
let packet = Packet::new(
......
......@@ -7,24 +7,24 @@ use crate::stream::Stream;
use kiss_tnc::tnc::Tnc;
use std::collections::HashMap;
use std::net::ToSocketAddrs;
use std::sync::{
use std::time::{Duration, Instant};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::ToSocketAddrs;
use tokio::sync::{
mpsc,
mpsc::{Receiver, Sender, TryRecvError},
mpsc::{Receiver, Sender, UnboundedSender},
};
use std::thread;
use std::time::{Duration, Instant};
pub struct Server {
stream_receiver: Receiver<(Hostname, Stream)>,
// sends channels which packets should be sent through
// only used for pings right now
channel_sender: mpsc::Sender<mpsc::Sender<Packet>>,
channel_sender: Sender<Sender<Packet>>,
// packets to be sent over the air
// also only for pings right now
packet_sender: mpsc::Sender<Vec<u8>>,
packet_sender: UnboundedSender<Vec<u8>>,
hostname: Hostname,
timeout: Duration,
......@@ -32,7 +32,8 @@ pub struct Server {
impl Server {
// TODO builder
pub fn new<A: ToSocketAddrs>(
pub async fn new_tcp<A: ToSocketAddrs>(
addr: A,
hostname: Hostname,
timeout: Duration,
......@@ -40,32 +41,54 @@ impl Server {
max_retries: usize,
baud_rate: usize,
) -> Result<Self, std::io::Error> {
let mut tnc = Tnc::connect(addr)?;
let tnc = Tnc::connect_tcp(addr).await?;
Self::new(
tnc,
hostname,
timeout,
max_packet_size,
max_retries,
baud_rate,
)
.await
}
pub async fn new<
R: AsyncRead + Unpin + Send + 'static,
W: AsyncWrite + Unpin + Send + 'static,
>(
tnc: Tnc<R, W>,
hostname: Hostname,
timeout: Duration,
max_packet_size: usize,
max_retries: usize,
baud_rate: usize,
) -> Result<Self, std::io::Error> {
let (stream_sender, stream_receiver) = mpsc::channel(1);
let (packet_sender, mut packet_receiver) = mpsc::unbounded_channel();
let (channel_sender, mut channel_receiver) = mpsc::channel::<mpsc::Sender<Packet>>(1);
let (received_packet_tx, mut received_packet_rx) = mpsc::channel(1);
let (stream_sender, stream_receiver) = mpsc::channel();
let (packet_sender, packet_receiver) = mpsc::channel();
let (channel_sender, channel_receiver) = mpsc::channel::<mpsc::Sender<Packet>>();
let (received_packet_tx, received_packet_rx) = mpsc::channel();
let (mut tnc_read, mut tnc_write) = tnc.split();
// take incoming frames and stuff them in a channel
{
let mut tnc = tnc.try_clone()?;
thread::spawn(move || {
tokio::task::spawn(async move {
let mut channel = None;
while let Ok((_, data)) = tnc.read_frame() {
while let Ok((_, data)) = tnc_read.read_frame().await {
if let Ok(ch) = channel_receiver.try_recv() {
channel = Some(ch);
}
if let Some(packet) = Packet::from_bytes(&data) {
if let Some(ch) = &channel {
if ch.send(packet.clone()).is_err() {
if ch.send(packet.clone()).await.is_err() {
channel = None;
}
}
if received_packet_tx.send(packet).is_err() {
if received_packet_tx.send(packet).await.is_err() {
// Channel closed. End this thread
break;
};
......@@ -77,7 +100,7 @@ impl Server {
// processes incoming frames
{
let packet_sender = packet_sender.clone();
thread::spawn(move || {
tokio::task::spawn(async move {
let mut reader = ServerReader::new(
hostname,
stream_sender,
......@@ -89,10 +112,14 @@ impl Server {
);
loop {
match received_packet_rx.try_recv() {
Ok(packet) => {
let r =
tokio::time::timeout(Duration::from_millis(50), received_packet_rx.recv())
.await;
match r {
Ok(Some(packet)) => {
if packet.to == hostname {
if let Err(BrokenPipeError) = reader.process_packet(packet) {
if let Err(BrokenPipeError) = reader.process_packet(packet).await {
// channel died
// kill our thread
break;
......@@ -100,24 +127,27 @@ impl Server {
}
}
Err(TryRecvError::Empty) => {
thread::sleep(Duration::from_millis(10));
}
Err(TryRecvError::Disconnected) => {
Ok(None) => {
// channel died
// kill our thread
break;
}
Err(_) => {
// no packet within timeout
// loop as usual
}
}
let mut disconnected_hosts = vec![];
for (host, state) in &mut reader.connections {
match state.process() {
match state.process().await {
Ok(()) => {}
Err(Error::Disconnected) => {
if let Err(BrokenPipeError) = state.disconnect(hostname, *host) {
if let Err(BrokenPipeError) =
state.disconnect(hostname, *host).await
{
// channel died
// kill our thread
break;
......@@ -142,10 +172,13 @@ impl Server {
}
// processes outgoing frames
thread::spawn(move || loop {
while let Ok(data) = packet_receiver.recv() {
if tnc.send_frame(&data).is_err() {
break;
tokio::task::spawn(async move {
loop {
while let Some(data) = packet_receiver.recv().await {
if tnc_write.send_frame(&data).await.is_err() {
break;
}
let _ = tnc_write.flush().await;
}
}
});
......@@ -160,18 +193,24 @@ impl Server {
}
/// Accepts a connection from a client.
/// Returns an error if the channel is disconnected (i.e. something went wrong with the TNC connection)
pub fn accept(&self) -> Result<(Hostname, Stream), BrokenPipeError> {
self.stream_receiver.recv().map_broken_pipe()
/// Returns an error if the channel is disconnected (i.e. something went wrong with the TNC connection).
///
/// NOTE: It is critical that this function is called regularly, even if the result is thrown away!
/// Otherwise, the streams will pile up and incoming packets will no longer be accepted properly (even pings!)
pub async fn accept(&mut self) -> Result<(Hostname, Stream), BrokenPipeError> {
self.stream_receiver.recv().await.map_broken_pipe()
}
/// Pings a host
pub fn ping(&self, to: Hostname) -> Result<bool, BrokenPipeError> {
pub async fn ping(&self, to: Hostname) -> Result<bool, BrokenPipeError> {
let id = rand::random();
let next = (id % 255) + 1;
let (packet_sender, packet_receiver) = mpsc::channel();
self.channel_sender.send(packet_sender).map_broken_pipe()?;
let (packet_sender, mut packet_receiver) = mpsc::channel(1);
self.channel_sender
.send(packet_sender)
.await
.map_broken_pipe()?;
let mut buf = vec![];
Packet::new(self.hostname, to, id, PacketData::Ping).to_bytes(&mut buf);
......@@ -179,14 +218,14 @@ impl Server {
let max_time = Instant::now() + self.timeout;
while Instant::now() < max_time {
if let Some(packet) = try_read_packet(&packet_receiver)? {
if let Some(packet) = try_read_packet(&mut packet_receiver)? {
if packet == Packet::new(to, self.hostname, next, PacketData::Pong) {
// Ping success
return Ok(true);
}
}
thread::sleep(Duration::from_millis(10));
tokio::time::sleep(Duration::from_millis(10)).await;
}
// No response
......@@ -197,7 +236,7 @@ impl Server {
struct ServerReader {
hostname: Hostname,
stream_sender: Sender<(Hostname, Stream)>,
packet_sender: Sender<Vec<u8>>,
packet_sender: UnboundedSender<Vec<u8>>,
connections: HashMap<Hostname, ConnectionState>,
// Needed to pass the settings to the ConnectionState
......@@ -211,7 +250,7 @@ impl ServerReader {
fn new(
hostname: Hostname,
stream_sender: Sender<(Hostname, Stream)>,
packet_sender: Sender<Vec<u8>>,
packet_sender: UnboundedSender<Vec<u8>>,
timeout: Duration,
max_packet_size: usize,
max_retries: usize,
......@@ -230,16 +269,16 @@ impl ServerReader {
}
}
fn process_packet(&mut self, packet: Packet) -> Result<(), BrokenPipeError> {
async fn process_packet(&mut self, packet: Packet) -> Result<(), BrokenPipeError> {
match &packet.data {
PacketData::Connect => self.connection_from(packet.from)?,
PacketData::Ping => self.pong(packet.from, packet.seq_num)?,
PacketData::Connect => self.connection_from(packet.from).await?,
PacketData::Ping => self.pong(packet.from, packet.seq_num).await?,
_ => {}
}
let from = packet.from;
if let Some(conn) = self.connections.get_mut(&from) {
match conn.incoming_packet(packet) {
match conn.incoming_packet(packet).await {
Ok(()) => {}
Err(Error::Disconnected) => {
self.connections.remove(&from);
......@@ -251,7 +290,7 @@ impl ServerReader {
Ok(())
}
fn pong(&mut self, to: Hostname, id: u8) -> Result<(), BrokenPipeError> {
async fn pong(&mut self, to: Hostname, id: u8) -> Result<(), BrokenPipeError> {
let mut buf = vec![];
Packet::new(self.hostname, to, (id % 255) + 1, PacketData::Pong).to_bytes(&mut buf);
self.packet_sender.send(buf).map_broken_pipe()?;
......@@ -259,7 +298,7 @@ impl ServerReader {
Ok(())
}
fn connection_from(&mut self, from: Hostname) -> Result<(), BrokenPipeError> {
async fn connection_from(&mut self, from: Hostname) -> Result<(), BrokenPipeError> {
let (s1, s2) = Stream::new_entangled();
let mut state = ConnectionState::new(
......@@ -282,7 +321,10 @@ impl ServerReader {
// If we're already connected, reset the connection
self.connections.insert(from, state);
self.stream_sender.send((from, s2)).map_broken_pipe()?;
self.stream_sender
.send((from, s2))
.await
.map_broken_pipe()?;
Ok(())
}
......
......@@ -4,8 +4,8 @@ use crate::hostname::Hostname;
use crate::packet::{Packet, PacketData};
use crate::stream::Stream;
use std::sync::mpsc::{Sender, TryRecvError};
use std::time::{Duration, Instant};
use tokio::sync::mpsc::{error::TryRecvError, UnboundedSender};
// increments our(tx) sequence
// needs to be a macro to appease the borrow checker
......@@ -34,7 +34,7 @@ pub(crate) struct ConnectionState {
pub timeout: Duration,
baud_rate: usize, // Used for calculating timeouts
pub stream: Stream,
pub packet_sender: Sender<Vec<u8>>,
pub packet_sender: UnboundedSender<Vec<u8>>,
seq_tx: u8,
seq_rx: u8,
pub us: Hostname,
......@@ -49,7 +49,7 @@ impl ConnectionState {
max_retries: usize,
baud_rate: usize,
stream: Stream,
packet_sender: Sender<Vec<u8>>,
packet_sender: UnboundedSender<Vec<u8>>,
us: Hostname,
them: Hostname,
) -> Self {
......@@ -71,7 +71,7 @@ impl ConnectionState {
}
}
fn send_data(&mut self, data: &[u8]) -> Result<(), Error> {
async fn send_data(&mut self, data: &[u8]) -> Result<(), Error> {
let mut data_chunks: Vec<Vec<u8>> = data
.chunks(self.max_data_size)
.map(|x| x.to_owned())
......@@ -108,7 +108,7 @@ impl ConnectionState {
Ok(())
}
pub fn incoming_packet(&mut self, packet: Packet) -> Result<(), Error> {
pub async fn incoming_packet(&mut self, packet: Packet) -> Result<(), Error> {
let them = packet.from;
let us = packet.to;
......@@ -123,41 +123,43 @@ impl ConnectionState {
}
let data = packet.data;
match (data, self.sending.clone()) {
(PacketData::DataStartAck, SendState::AwaitingDataStartAck(data)) => {
self.data_start_ack(data)?;
self.data_start_ack(data).await?;
}
(PacketData::DataResend(missing_frame_ids), SendState::AwaitingDataEndAck(data)) => {
self.data_resend(missing_frame_ids, data)?;
self.data_resend(missing_frame_ids, data).await?;
}
(PacketData::DataStart(num_frames, data), _) => {
self.data_start(num_frames, data)?;
self.data_start(num_frames, data).await?;
}
(PacketData::DataMid(frame_id, data), _) => {
if let Some(x) = self.data_mid(frame_id, data) {
if let Some(x) = self.data_mid(frame_id, data).await {
x?;
}
}
(PacketData::Disconnect, _) => {
self.disconnection()?;
self.disconnection().await?;
}
_ => {} // Do nothing
}
Ok(())
}
pub fn process(&mut self) -> Result<(), Error> {
pub async fn process(&mut self) -> Result<(), Error> {
if let SendState::Idle = self.sending {
// Process next action, if available
match self.stream.try_recv() {
Ok(data) => self.compress_and_send_data(&data)?,
Ok(data) => self.compress_and_send_data(&data).await?,
Err(TryRecvError::Disconnected) => {
self.disconnect(self.us, self.them)?;
self.disconnect(self.us, self.them).await?;
return Err(Error::Disconnected);
}
......@@ -237,7 +239,11 @@ impl ConnectionState {
Ok(())
}
pub fn disconnect(&mut self, from: Hostname, to: Hostname) -> Result<(), BrokenPipeError> {
pub async fn disconnect(
&mut self,
from: Hostname,
to: Hostname,
) -> Result<(), BrokenPipeError> {
let mut buf = vec![];
Packet::new(from, to, self.seq_inc(), PacketData::Disconnect).to_bytes(&mut buf);
self.packet_sender.send(buf).map_broken_pipe()
......@@ -259,7 +265,7 @@ impl ConnectionState {
}
}
fn data_start(&mut self, num_frames: u16, data: Vec<u8>) -> Result<(), Error> {
async fn data_start(&mut self, num_frames: u16, data: Vec<u8>) -> Result<(), Error> {
if num_frames == 0 {
// Reset so that we're ready for the next frame
self.receiving = None;
......@@ -281,9 +287,9 @@ impl ConnectionState {
.to_bytes(&mut buf);
self.packet_sender.send(buf).map_broken_pipe()?;
if self.stream.write(data).is_err() {
if self.stream.write(data).await.is_err() {
// disconnect and drop state
self.disconnect(self.us, self.them)?;
self.disconnect(self.us, self.them).await?;
return Err(Error::Disconnected);
}
......@@ -301,7 +307,7 @@ impl ConnectionState {
Ok(())
}
fn data_mid(&mut self, frame_id: u16, data: Vec<u8>) -> Option<Result<(), Error>> {
async fn data_mid(&mut self, frame_id: u16, data: Vec<u8>) -> Option<Result<(), Error>> {
let chunks = self.receiving.as_mut()?;
// Frame may be out of bounds(misconfigured or malicious sender)
// So we must use get_mut and handle the none case
......@@ -344,10 +350,10 @@ impl ConnectionState {
return Some(Err(e.into()));
}
if self.stream.write(data).is_err() {
if self.stream.write(data).await.is_err() {
// Stream is broken
// Disconnect and drop state
if let Err(BrokenPipeError) = self.disconnect(self.us, self.them) {
if let Err(BrokenPipeError) = self.disconnect(self.us, self.them).await {
return Some(Err(Error::BrokenPipe));
}
......@@ -357,14 +363,14 @@ impl ConnectionState {
// Reset so that we're ready for the next frame
self.receiving = None;
} else {
return Some(self.request_missing_frames(missing_frames));
return Some(self.request_missing_frames(missing_frames).await);
}
}
Some(Ok(()))
}
fn disconnection(&mut self) -> Result<(), Error> {
async fn disconnection(&mut self) -> Result<(), Error> {
let mut buf = vec![];
Packet::new(
self.us,
......@@ -378,42 +384,36 @@ impl ConnectionState {
Err(Error::Disconnected)
}
fn data_start_ack(&mut self, data: Vec<Vec<u8>>) -> Result<(), Error> {
async fn data_start_ack(&mut self, data: Vec<Vec<u8>>) -> Result<(), Error> {
let mut total_len = 0;
// send all of the packets
let data_len = data
.clone()
.into_iter()
.enumerate()
.skip(1) // first packet was sent in DataStart
.map(|(frame_id, frame)| {
let mut buf = vec![];
Packet::new(
self.us,
self.them,
self.seq_inc(),
PacketData::DataMid(frame_id as u16 - 1, frame),
)
.to_bytes(&mut buf);
let len = buf.len();
self.packet_sender.send(buf).map_broken_pipe()?;
for (frame_id, frame) in data.clone().into_iter().enumerate().skip(1)
// first packet was sent in DataStart
{
let mut buf = vec![];
Packet::new(
self.us,
self.them,
self.seq_inc(),
PacketData::DataMid(frame_id as u16 - 1, frame),
)
.to_bytes(&mut buf);
let len = buf.len();
self.packet_sender.send(buf).map_broken_pipe()?;
Ok(len)
})
.fold(Ok(0), |acc: Result<usize, Error>, x| match (acc, x) {
(Ok(acc), Ok(x)) => Ok(acc + x),
(Ok(_), Err(e)) => Err(e),
(Err(e), _) => Err(e),
})?;
total_len += len;
}
// Reset state
self.refresh_timeout_time(data_len);
self.refresh_timeout_time(total_len);
self.send_retries = 0;
self.sending = SendState::AwaitingDataEndAck(data);
Ok(())
}
fn data_resend(
async fn data_resend(
&mut self,
missing_frame_ids: Vec<u16>,
data: Vec<Vec<u8>>,
......@@ -438,27 +438,21 @@ impl ConnectionState {
}
// Send frames they requested
let data_len = missing_frame_ids
.iter()
.map(|frame_id| {
let mut buf = vec![];
Packet::new(
self.us,
self.them,
self.seq_inc(),
PacketData::DataMid(*frame_id, data[*frame_id as usize + 1].clone()),
)
.to_bytes(&mut buf);
let len = buf.len();
self.packet_sender.send(buf).map_broken_pipe()?;
let mut total_len = 0;
for frame_id in missing_frame_ids {
let mut buf = vec![];
Packet::new(
self.us,
self.them,
self.seq_inc(),
PacketData::DataMid(frame_id, data[frame_id as usize + 1].clone()),
)
.to_bytes(&mut buf);
let len = buf.len();
self.packet_sender.send(buf).map_broken_pipe()?;
Ok(len)
})
.fold(Ok(0), |acc: Result<usize, Error>, x| match (acc, x) {
(Ok(acc), Ok(x)) => Ok(acc + x),
(Ok(_), Err(e)) => Err(e),
(Err(e), _) => Err(e),
})?;
total_len += len;
}
// If they didn't request the last frame, resend it w/ empty data
// This is used so the other side can detect that we've finished
......@@ -475,7 +469,7 @@ impl ConnectionState {
}
// Reset state
self.refresh_timeout_time(data_len);
self.refresh_timeout_time(total_len);
self.send_retries = 0;
Ok(())
......@@ -487,14 +481,14 @@ impl ConnectionState {
+ Duration::from_millis((bytes * 8 * 1000 / self.baud_rate) as u64);
}
fn compress_and_send_data(&mut self, data: &[u8]) -> Result<(), Error> {
async fn compress_and_send_data(&mut self, data: &[u8]) -> Result<(), Error> {
// Safe to unwrap because read from a vector will never fail
let compressed = zstd::stream::encode_all(data, 9).unwrap();
self.send_data(&compressed)
self.send_data(&compressed).await
}
fn request_missing_frames(&mut self, missing_frame_ids: Vec<u16>) -> Result<(), Error> {
async fn request_missing_frames(&mut self, missing_frame_ids: Vec<u16>) -> Result<(), Error> {
let mut buf = vec![];
Packet::new(
self.us,
......
use crate::error;
use crate::error::MapBrokenPipe;
use std::sync::{
use tokio::sync::{
mpsc,
mpsc::{Receiver, Sender, TryRecvError},
mpsc::{error::TryRecvError, Receiver, Sender},
};
pub struct ReaderStream {
......@@ -10,8 +11,8 @@ pub struct ReaderStream {
}
impl ReaderStream {
pub fn read(&mut self) -> Result<Vec<u8>, error::Error> {
self.receiver.recv().map_err(|_| error::Error::Disconnected)
pub async fn read(&mut self) -> Result<Vec<u8>, error::Error> {
self.receiver.recv().await.map_disconnected()
}
/// Like read, but non-blocking
......@@ -25,10 +26,8 @@ pub struct WriterStream {
}
impl WriterStream {
pub fn write(&mut self, data: Vec<u8>) -> Result<(), error::Error> {
self.sender
.send(data)
.map_err(|_| error::Error::Disconnected)
pub async fn write(&mut self, data: Vec<u8>) -> Result<(), error::Error> {
self.sender.send(data).await.map_disconnected()
}
}
......@@ -39,8 +38,8 @@ pub struct Stream {
impl Stream {
pub fn new_entangled() -> (Self, Self) {
let (s1, r1) = mpsc::channel();
let (s2, r2) = mpsc::channel();
let (s1, r1) = mpsc::channel(1);
let (s2, r2) = mpsc::channel(1);
(
Self {
......@@ -58,8 +57,8 @@ impl Stream {
(self.reader, self.writer)
}
pub fn read(&mut self) -> Result<Vec<u8>, error::Error> {
self.reader.read()
pub async fn read(&mut self) -> Result<Vec<u8>, error::Error> {
self.reader.read().await
}
/// Like read, but non-blocking
......@@ -67,7 +66,7 @@ impl Stream {
self.reader.try_read()
}
pub fn write(&mut self, data: Vec<u8>) -> Result<(), error::Error> {
self.writer.write(data)
pub async fn write(&mut self, data: Vec<u8>) -> Result<(), error::Error> {
self.writer.write(data).await
}
}
use std::collections::hash_map::DefaultHasher;
use std::hash::Hash;
use std::hash::Hasher;
use std::thread;
use std::time::Duration;
use minipac::client::Client;
......@@ -15,9 +14,9 @@ fn calculate_hash<T: Hash>(t: &T) -> u64 {
s.finish()
}
#[test]
fn can_ping_server_when_disconnected() {
let _ = Server::new(
#[tokio::test]
async fn can_ping_server_when_disconnected() {
let _ = Server::new_tcp(
"localhost:8001",
Hostname::new("N0CALL-4").unwrap(),
Duration::from_secs(1),
......@@ -25,9 +24,10 @@ fn can_ping_server_when_disconnected() {
3,
1200,
)
.await
.unwrap();
let client = Client::new(
let client = Client::new_tcp(
"localhost:8001",
Hostname::new("DEFGH-213").unwrap(),
Duration::from_secs(3),
......@@ -35,14 +35,18 @@ fn can_ping_server_when_disconnected() {
1,
1200,
)
.await
.unwrap();
assert!(client.ping(Hostname::new("N0CALL-4").unwrap()).unwrap());
assert!(client
.ping(Hostname::new("N0CALL-4").unwrap())
.await
.unwrap());
}
#[test]
fn can_ping_server_when_connected() {
let _server = Server::new(
#[tokio::test]
async fn can_ping_server_when_connected() {
let _server = Server::new_tcp(
"localhost:8002",
Hostname::new("N0CALL-3").unwrap(),
Duration::from_secs(1),
......@@ -50,9 +54,10 @@ fn can_ping_server_when_connected() {
3,
1200,
)
.await
.unwrap();
let client = Client::new(
let client = Client::new_tcp(
"localhost:8002",
Hostname::new("DEFGH-212").unwrap(),
Duration::from_secs(3),
......@@ -60,16 +65,23 @@ fn can_ping_server_when_connected() {
1,
1200,
)
.await
.unwrap();
let _stream = client.connect(Hostname::new("N0CALL-3").unwrap()).unwrap();
let _stream = client
.connect(Hostname::new("N0CALL-3").unwrap())
.await
.unwrap();
assert!(client.ping(Hostname::new("N0CALL-3").unwrap()).unwrap());
assert!(client
.ping(Hostname::new("N0CALL-3").unwrap())
.await
.unwrap());
}
#[test]
fn can_ping_server_when_connected_then_disconnected() {
let _server = Server::new(
#[tokio::test]
async fn can_ping_server_when_connected_then_disconnected() {
let mut server = Server::new_tcp(
"localhost:8000",
Hostname::new("N0CALL-3").unwrap(),
Duration::from_secs(1),
......@@ -77,9 +89,16 @@ fn can_ping_server_when_connected_then_disconnected() {
3,
1200,
)
.await
.unwrap();
let client = Client::new(
tokio::task::spawn(async move {
loop {
let _ = server.accept().await;
}
});
let client = Client::new_tcp(
"localhost:8000",
Hostname::new("DEFGH-212").unwrap(),
Duration::from_secs(3),
......@@ -87,24 +106,42 @@ fn can_ping_server_when_connected_then_disconnected() {
1,
1200,
)
.await
.unwrap();
{
// connect, then immediately disconnect when the stream is dropped
let _stream = client.connect(Hostname::new("N0CALL-3").unwrap()).unwrap();
let _stream = client
.connect(Hostname::new("N0CALL-3").unwrap())
.await
.unwrap();
}
{
// do it again for good measure
let _stream = client.connect(Hostname::new("N0CALL-3").unwrap()).unwrap();
let _stream = client
.connect(Hostname::new("N0CALL-3").unwrap())
.await
.unwrap();
}
{
// do it again (again) for good measure
let _stream = client
.connect(Hostname::new("N0CALL-3").unwrap())
.await
.unwrap();
}
assert!(client.ping(Hostname::new("N0CALL-3").unwrap()).unwrap());
assert!(client
.ping(Hostname::new("N0CALL-3").unwrap())
.await
.unwrap());
}
#[test]
fn can_ping_client() {
let client = Client::new(
#[tokio::test]
async fn can_ping_client() {
let client = Client::new_tcp(
"localhost:8009",
Hostname::new("DEFGH-212").unwrap(),
Duration::from_secs(3),
......@@ -112,9 +149,10 @@ fn can_ping_client() {
1,
1200,
)
.await
.unwrap();
let server = Server::new(
let server = Server::new_tcp(
"localhost:8009",
Hostname::new("N0CALL-3").unwrap(),
Duration::from_secs(1),
......@@ -122,20 +160,30 @@ fn can_ping_client() {
3,
1200,
)
.await
.unwrap();
assert!(server.ping(Hostname::new("DEFGH-212").unwrap()).unwrap());
assert!(server
.ping(Hostname::new("DEFGH-212").unwrap())
.await
.unwrap());
// connect then ping again
let _stream = client.connect(Hostname::new("N0CALL-3").unwrap()).unwrap();
let _stream = client
.connect(Hostname::new("N0CALL-3").unwrap())
.await
.unwrap();
assert!(server.ping(Hostname::new("DEFGH-212").unwrap()).unwrap());
assert!(server
.ping(Hostname::new("DEFGH-212").unwrap())
.await
.unwrap());
}
#[test]
fn cant_ping_nonexistant_host() {
let _server = Server::new(
#[tokio::test]
async fn cant_ping_nonexistant_host() {
let _server = Server::new_tcp(
"localhost:8003",
Hostname::new("N0CALL-3").unwrap(),
Duration::from_secs(1),
......@@ -143,9 +191,10 @@ fn cant_ping_nonexistant_host() {
3,
1200,
)
.await
.unwrap();
let client = Client::new(
let client = Client::new_tcp(
"localhost:8003",
Hostname::new("DEFGH-212").unwrap(),
Duration::from_secs(1),
......@@ -153,11 +202,15 @@ fn cant_ping_nonexistant_host() {
1,
1200,
)
.await
.unwrap();
assert!(!client.ping(Hostname::new("N0CALL-4").unwrap(),).unwrap());
assert!(!client
.ping(Hostname::new("N0CALL-4").unwrap())
.await
.unwrap());
let _server = Server::new(
let _server = Server::new_tcp(
"localhost:8004",
Hostname::new("N0CALL-3").unwrap(),
Duration::from_secs(1),
......@@ -165,9 +218,10 @@ fn cant_ping_nonexistant_host() {
3,
1200,
)
.await
.unwrap();
let client = Client::new(
let client = Client::new_tcp(
"localhost:8004",
Hostname::new("DEFGH-212").unwrap(),
Duration::from_secs(1),
......@@ -175,14 +229,18 @@ fn cant_ping_nonexistant_host() {
1,
1200,
)
.await
.unwrap();
assert!(!client.ping(Hostname::new("N0CAL3-3").unwrap(),).unwrap());
assert!(!client
.ping(Hostname::new("N0CAL3-3").unwrap())
.await
.unwrap());
}
#[test]
fn can_establish_connection_and_transfer_data_single_packet() {
let server = Server::new(
#[tokio::test]
async fn can_establish_connection_and_transfer_data_single_packet() {
let mut server = Server::new_tcp(
"localhost:8005",
Hostname::new("N0CALL-3").unwrap(),
Duration::from_secs(1),
......@@ -190,19 +248,20 @@ fn can_establish_connection_and_transfer_data_single_packet() {
3,
1200,
)
.await
.unwrap();
// echo server
thread::spawn(move || {
let (host, mut stream) = server.accept().unwrap();
tokio::task::spawn(async move {
let (host, mut stream) = server.accept().await.unwrap();
assert_eq!(host, Hostname::new("DEFGH-212").unwrap());
let r = stream.read().unwrap();
let r = stream.read().await.unwrap();
assert_eq!(vec![5, 34, 54], r);
stream.write(r).unwrap();
stream.write(r).await.unwrap();
});
let client = Client::new(
let client = Client::new_tcp(
"localhost:8005",
Hostname::new("DEFGH-212").unwrap(),
Duration::from_secs(1),
......@@ -210,21 +269,25 @@ fn can_establish_connection_and_transfer_data_single_packet() {
1,
1200,
)
.await
.unwrap();
let mut stream = client.connect(Hostname::new("N0CALL-3").unwrap()).unwrap();
let mut stream = client
.connect(Hostname::new("N0CALL-3").unwrap())
.await
.unwrap();
stream.write(vec![5, 34, 54]).unwrap();
assert_eq!(vec![5, 34, 54], stream.read().unwrap());
stream.write(vec![5, 34, 54]).await.unwrap();
assert_eq!(vec![5, 34, 54], stream.read().await.unwrap());
// Try to read again
let res = stream.read();
let res = stream.read().await;
assert!(matches!(res, Err(Error::Disconnected)));
}
#[test]
fn can_establish_connection_and_transfer_data_many_packets() {
let server = Server::new(
#[tokio::test]
async fn can_establish_connection_and_transfer_data_many_packets() {
let mut server = Server::new_tcp(
"localhost:8006",
Hostname::new("N0CALL-3").unwrap(),
Duration::from_secs(1),
......@@ -232,6 +295,7 @@ fn can_establish_connection_and_transfer_data_many_packets() {
3,
1200,
)
.await
.unwrap();
let data: Vec<u8> = (0..100_000).map(|_| rand::random()).collect();
......@@ -239,17 +303,17 @@ fn can_establish_connection_and_transfer_data_many_packets() {
// echo server
{
let data = data.clone();
thread::spawn(move || {
let (host, mut stream) = server.accept().unwrap();
tokio::task::spawn(async move {
let (host, mut stream) = server.accept().await.unwrap();
assert_eq!(host, Hostname::new("DEFGH-212").unwrap());
let r = stream.read().unwrap();
let r = stream.read().await.unwrap();
assert_eq!(data, r);
stream.write(r).unwrap();
stream.write(r).await.unwrap();
});
}
let client = Client::new(
let client = Client::new_tcp(
"localhost:8006",
Hostname::new("DEFGH-212").unwrap(),
Duration::from_secs(1),
......@@ -257,22 +321,25 @@ fn can_establish_connection_and_transfer_data_many_packets() {
1,
1200,
)
.await
.unwrap();
let mut stream = client.connect(Hostname::new("N0CALL-3").unwrap()).unwrap();
let mut stream = client
.connect(Hostname::new("N0CALL-3").unwrap())
.await
.unwrap();
stream.write(data.clone()).unwrap();
assert_eq!(data, stream.read().unwrap());
stream.write(data.clone()).await.unwrap();
assert_eq!(data, stream.read().await.unwrap());
// Try to read again
let res = stream.read();
let res = stream.read().await;
assert!(matches!(res, Err(Error::Disconnected)));
}
// Warning - very slow(>60 seconds)
#[test]
fn can_establish_connection_and_transfer_data_max_packets() {
let server = Server::new(
#[tokio::test]
async fn can_establish_connection_and_transfer_data_max_packets() {
let mut server = Server::new_tcp(
"localhost:8007",
Hostname::new("N0CALL-3").unwrap(),
Duration::from_secs(1),
......@@ -280,6 +347,7 @@ fn can_establish_connection_and_transfer_data_max_packets() {
3,
1200,
)
.await
.unwrap();
// Can't be random because
......@@ -292,17 +360,17 @@ fn can_establish_connection_and_transfer_data_max_packets() {
// echo server
{
let data = data.clone();
thread::spawn(move || {
let (host, mut stream) = server.accept().unwrap();
tokio::task::spawn(async move {
let (host, mut stream) = server.accept().await.unwrap();
assert_eq!(host, Hostname::new("DEFGH-212").unwrap());
let r = stream.read().unwrap();
let r = stream.read().await.unwrap();
assert_eq!(data, r);
stream.write(r).unwrap();
stream.write(r).await.unwrap();
});
}
let client = Client::new(
let client = Client::new_tcp(
"localhost:8007",
Hostname::new("DEFGH-212").unwrap(),
Duration::from_secs(1),
......@@ -310,21 +378,25 @@ fn can_establish_connection_and_transfer_data_max_packets() {
1,
1200,
)
.await
.unwrap();
let mut stream = client.connect(Hostname::new("N0CALL-3").unwrap()).unwrap();
let mut stream = client
.connect(Hostname::new("N0CALL-3").unwrap())
.await
.unwrap();
stream.write(data.clone()).unwrap();
assert_eq!(data, stream.read().unwrap());
stream.write(data.clone()).await.unwrap();
assert_eq!(data, stream.read().await.unwrap());
// Try to read again
let res = stream.read();
let res = stream.read().await;
assert!(matches!(res, Err(Error::Disconnected)));
}
#[test]
fn multiple_actions_in_quick_succession() {
let server = Server::new(
#[tokio::test]
async fn multiple_actions_in_quick_succession() {
let mut server = Server::new_tcp(
"localhost:8008",
Hostname::new("N0CALL-3").unwrap(),
Duration::from_secs(1),
......@@ -332,6 +404,7 @@ fn multiple_actions_in_quick_succession() {
3,
1200,
)
.await
.unwrap();
let data: Vec<u8> = (0..100_000).map(|_| rand::random()).collect();
......@@ -339,21 +412,21 @@ fn multiple_actions_in_quick_succession() {
// echo server
{
let data = data.clone();
thread::spawn(move || {
let (host, mut stream) = server.accept().unwrap();
tokio::task::spawn(async move {
let (host, mut stream) = server.accept().await.unwrap();
assert_eq!(host, Hostname::new("DEFGH-212").unwrap());
let r = stream.read().unwrap();
let r = stream.read().await.unwrap();
assert_eq!(data, r);
stream.write(r).unwrap();
stream.write(r).await.unwrap();
let r = stream.read().unwrap();
let r = stream.read().await.unwrap();
assert_eq!(data, r);
stream.write(r).unwrap();
stream.write(r).await.unwrap();
});
}
let client = Client::new(
let client = Client::new_tcp(
"localhost:8008",
Hostname::new("DEFGH-212").unwrap(),
Duration::from_secs(1),
......@@ -361,26 +434,30 @@ fn multiple_actions_in_quick_succession() {
1,
1200,
)
.await
.unwrap();
let mut stream = client.connect(Hostname::new("N0CALL-3").unwrap()).unwrap();
let mut stream = client
.connect(Hostname::new("N0CALL-3").unwrap())
.await
.unwrap();
stream.write(data.clone()).unwrap();
stream.write(data.clone()).unwrap();
assert_eq!(data, stream.read().unwrap());
assert_eq!(data, stream.read().unwrap());
stream.write(data.clone()).await.unwrap();
stream.write(data.clone()).await.unwrap();
assert_eq!(data, stream.read().await.unwrap());
assert_eq!(data, stream.read().await.unwrap());
// Try to read again
let res = stream.read();
let res = stream.read().await;
assert!(matches!(res, Err(Error::Disconnected)));
}
#[test]
fn two_connections_one_client() {
#[tokio::test]
async fn two_connections_one_client() {
let data1: Vec<u8> = (0..1_000).map(|_| rand::random()).collect();
let data2: Vec<u8> = (0..1_000).map(|_| rand::random()).collect();
let server1 = Server::new(
let mut server1 = Server::new_tcp(
"localhost:8010",
Hostname::new("N0CALL-3").unwrap(),
Duration::from_secs(1),
......@@ -388,22 +465,23 @@ fn two_connections_one_client() {
3,
1200,
)
.await
.unwrap();
// echo server1
{
let data1 = data1.clone();
thread::spawn(move || {
let (host, mut stream) = server1.accept().unwrap();
tokio::task::spawn(async move {
let (host, mut stream) = server1.accept().await.unwrap();
assert_eq!(host, Hostname::new("DEFGH-212").unwrap());
let r = stream.read().unwrap();
let r = stream.read().await.unwrap();
assert_eq!(data1, r);
stream.write(r).unwrap();
stream.write(r).await.unwrap();
});
}
let server2 = Server::new(
let mut server2 = Server::new_tcp(
"localhost:8010",
Hostname::new("N0CALL-4").unwrap(),
Duration::from_secs(1),
......@@ -411,22 +489,23 @@ fn two_connections_one_client() {
3,
1200,
)
.await
.unwrap();
// echo server2
{
let data2 = data2.clone();
thread::spawn(move || {
let (host, mut stream) = server2.accept().unwrap();
tokio::task::spawn(async move {
let (host, mut stream) = server2.accept().await.unwrap();
assert_eq!(host, Hostname::new("DEFGH-212").unwrap());
let r = stream.read().unwrap();
let r = stream.read().await.unwrap();
assert_eq!(data2, r);
stream.write(r).unwrap();
stream.write(r).await.unwrap();
});
}
let client = Client::new(
let client = Client::new_tcp(
"localhost:8010",
Hostname::new("DEFGH-212").unwrap(),
Duration::from_secs(1),
......@@ -434,33 +513,40 @@ fn two_connections_one_client() {
1,
1200,
)
.await
.unwrap();
let mut stream1 = client.connect(Hostname::new("N0CALL-3").unwrap()).unwrap();
let mut stream2 = client.connect(Hostname::new("N0CALL-4").unwrap()).unwrap();
let mut stream1 = client
.connect(Hostname::new("N0CALL-3").unwrap())
.await
.unwrap();
let mut stream2 = client
.connect(Hostname::new("N0CALL-4").unwrap())
.await
.unwrap();
stream1.write(data1.clone()).unwrap();
stream1.write(data1.clone()).await.unwrap();
// sleep to prevent a race condition w/ the relay server
// kind of janky, but necessary since the relay server doesn't understand KISS
// and the servers will talk over each other otherwise
thread::sleep(Duration::from_millis(1000));
tokio::time::sleep(Duration::from_millis(1000)).await;
stream2.write(data2.clone()).unwrap();
assert_eq!(data1, stream1.read().unwrap());
assert_eq!(data2, stream2.read().unwrap());
stream2.write(data2.clone()).await.unwrap();
assert_eq!(data1, stream1.read().await.unwrap());
assert_eq!(data2, stream2.read().await.unwrap());
// Try to read again
let res = stream1.read();
let res = stream1.read().await;
assert!(matches!(res, Err(Error::Disconnected)));
let res = stream2.read();
let res = stream2.read().await;
assert!(matches!(res, Err(Error::Disconnected)));
}
#[test]
fn test_cant_connect_to_non_existant_server() {
let _server = Server::new(
#[tokio::test]
async fn test_cant_connect_to_non_existant_server() {
let _server = Server::new_tcp(
"localhost:8011",
Hostname::new("N0CALL-3").unwrap(),
Duration::from_secs(1),
......@@ -468,9 +554,10 @@ fn test_cant_connect_to_non_existant_server() {
3,
1200,
)
.await
.unwrap();
let client = Client::new(
let client = Client::new_tcp(
"localhost:8011",
Hostname::new("DEFGH-212").unwrap(),
Duration::from_secs(1),
......@@ -478,17 +565,18 @@ fn test_cant_connect_to_non_existant_server() {
1,
1200,
)
.await
.unwrap();
assert!(matches!(
client.connect(Hostname::new("N0CALL-4").unwrap()),
client.connect(Hostname::new("N0CALL-4").unwrap()).await,
Err(ConnectError::ConnectionFailure),
));
}
#[test]
fn test_disconnection_kills_client_stream() {
let server = Server::new(
#[tokio::test]
async fn test_disconnection_kills_client_stream() {
let mut server = Server::new_tcp(
"localhost:8012",
Hostname::new("N0CALL-3").unwrap(),
Duration::from_secs(1),
......@@ -496,9 +584,10 @@ fn test_disconnection_kills_client_stream() {
3,
1200,
)
.await
.unwrap();
let client = Client::new(
let client = Client::new_tcp(
"localhost:8012",
Hostname::new("DEFGH-212").unwrap(),
Duration::from_secs(1),
......@@ -506,23 +595,27 @@ fn test_disconnection_kills_client_stream() {
1,
1200,
)
.await
.unwrap();
let mut cli_stream = client.connect(Hostname::new("N0CALL-3").unwrap()).unwrap();
let (_, serv_stream) = server.accept().unwrap();
let mut cli_stream = client
.connect(Hostname::new("N0CALL-3").unwrap())
.await
.unwrap();
let (_, serv_stream) = server.accept().await.unwrap();
thread::spawn(move || {
tokio::task::spawn(async move {
// wait 1 second then disconnect
thread::sleep(Duration::from_secs(1));
tokio::time::sleep(Duration::from_secs(1)).await;
drop(serv_stream)
});
assert!(matches!(cli_stream.read(), Err(Error::Disconnected)));
assert!(matches!(cli_stream.read().await, Err(Error::Disconnected)));
}
#[test]
fn test_disconnection_kills_server_stream() {
let server = Server::new(
#[tokio::test]
async fn test_disconnection_kills_server_stream() {
let mut server = Server::new_tcp(
"localhost:8013",
Hostname::new("N0CALL-3").unwrap(),
Duration::from_secs(1),
......@@ -530,9 +623,10 @@ fn test_disconnection_kills_server_stream() {
3,
1200,
)
.await
.unwrap();
let client = Client::new(
let client = Client::new_tcp(
"localhost:8013",
Hostname::new("DEFGH-212").unwrap(),
Duration::from_secs(1),
......@@ -540,16 +634,20 @@ fn test_disconnection_kills_server_stream() {
1,
1200,
)
.await
.unwrap();
let cli_stream = client.connect(Hostname::new("N0CALL-3").unwrap()).unwrap();
let (_, mut serv_stream) = server.accept().unwrap();
let cli_stream = client
.connect(Hostname::new("N0CALL-3").unwrap())
.await
.unwrap();
let (_, mut serv_stream) = server.accept().await.unwrap();
thread::spawn(move || {
tokio::task::spawn(async move {
// wait 1 second then disconnect
thread::sleep(Duration::from_secs(1));
tokio::time::sleep(Duration::from_secs(1)).await;
drop(cli_stream)
});
assert!(matches!(serv_stream.read(), Err(Error::Disconnected)));
assert!(matches!(serv_stream.read().await, Err(Error::Disconnected)));
}