libredr_common/
connection.rs

1use std::collections::HashMap;
2use uuid::Uuid;
3use shadow_rs::shadow;
4use tokio::net::TcpStream;
5use tracing::{debug, warn};
6#[cfg(unix)]
7use tokio::net::UnixStream;
8use anyhow::{Result, bail};
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use super::message::Message;
11
12shadow!(build);
13
14#[allow(missing_docs)]
15#[derive(Debug)]
16pub enum Stream {
17  TcpStream(TcpStream),
18#[cfg(unix)]
19  UnixStream(UnixStream),
20}
21
22#[allow(missing_docs)]
23impl Stream {
24  pub async fn read_exact(&mut self, buf: &mut [u8]) -> tokio::io::Result<usize> {
25    match self {
26      Stream::TcpStream(stream) => stream.read_exact(buf).await,
27#[cfg(unix)]
28      Stream::UnixStream(stream) => stream.read_exact(buf).await,
29    }
30  }
31}
32
33/// `Connection` with a random `uuid` and human-readable `description`
34#[derive(Debug)]
35pub struct Connection {
36  uuid: Uuid,
37  stream: Stream,
38  description: String,
39}
40
41impl Connection {
42  /// `uuid` is read-only
43  pub fn uuid(&self) -> Uuid {
44    self.uuid
45  }
46
47  /// `description` is read-only
48  pub fn description(&self) -> &str {
49    &self.description
50  }
51
52  async fn check_version(&mut self) -> Result<()> {
53    let local_version = Message::Version(
54      build::PKG_VERSION.to_owned(),
55      build::SHORT_COMMIT.to_owned(),
56      build::COMMIT_DATE.to_owned());
57    debug!("Connection::check_version: local: {}", local_version);
58    self.send_msg(&local_version).await?;
59    let remote_version = self.recv_msg().await?;
60    if let Message::Version(ver, git, _) = &remote_version {
61      debug!("Connection::check_version: remote: {}", remote_version);
62      if ver == build::PKG_VERSION {
63        if git != build::SHORT_COMMIT {
64          warn!("Connection::check_version: Version match, but git commit mismatch. \
65            local: {}, remote: {}", local_version, remote_version);
66        }
67        return Ok(());
68      }
69      bail!("Connection::check_version: Version mismatch: local. {}, remote: {}", local_version, remote_version);
70    }
71    bail!("Connection::check_version: unexpected message {}", remote_version);
72  }
73
74  /// Construct `Connection` with random UUID
75  pub async fn from_stream(stream: Stream, description: String) -> Result<Self> {
76    let uuid = Uuid::new_v4();
77    let mut connection = Connection {
78      uuid,
79      stream,
80      description: format!("{description} - {uuid}"),
81    };
82    connection.check_version().await?;
83    Ok(connection)
84  }
85
86  /// Construct `Connection` by connecting to LibreDR server.
87  ///
88  /// Return `Error` if connection failed
89  /// # Examples
90  /// ```
91  /// async {
92  ///   let mut config = HashMap::new();
93  ///   config.insert(String::from("connect"), string::from("127.0.0.1:9000"));
94  ///   config.insert(String::from("unix"), string::from("false"));
95  ///   config.insert(String::from("tls"), string::from("false"));
96  ///   let connection = Connection::from_config(&config).await?;
97  /// }
98  /// ```
99  pub async fn from_config(config: &HashMap<String, String>) -> Result<Self> {
100    let (stream, description) = match config["unix"].as_str() {
101      "false" => {
102        let stream = TcpStream::connect(config["connect"].to_owned()).await?;
103        let description = format!("tcp://{} - Server", config["connect"]);
104        (Stream::TcpStream(stream), description)
105      }
106#[cfg(unix)]
107      "true" => {
108        let stream = UnixStream::connect(config["connect"].to_owned()).await?;
109        let description = format!("unix://{} - Server", config["connect"]);
110        (Stream::UnixStream(stream), description)
111      },
112#[cfg(not(unix))]
113      "true" => {
114        bail!("Connection::new: Error: Unix socket is not supported on current platform");
115      },
116      other => {
117        bail!("Connection::new: Error: unknown config unix: {}", other);
118      }
119    };
120    let uuid = Uuid::new_v4();
121    let mut connection = Connection {
122      uuid,
123      stream,
124      description,
125    };
126    connection.check_version().await?;
127    Ok(connection)
128  }
129
130  /// Send a `Message`
131  pub async fn send_msg(&mut self, msg: &Message) -> Result<()> {
132    debug!("Connection::send_msg: timer 0 - {} - Serializing", self.uuid);
133    // let mut msg = rmp_serde::to_vec(msg)?;
134    let mut msg = postcard::to_stdvec(msg)?;
135    let msg_len: u64 = msg.len().try_into()?;
136    let mut raw_msg = Vec::from(msg_len.to_le_bytes());
137    raw_msg.append(&mut msg);
138    debug!("Connection::send_msg: timer 1 - {} - Sending {msg_len} bytes", self.uuid);
139    match &mut self.stream {
140      Stream::TcpStream(stream) => stream.write_all(&raw_msg).await?,
141#[cfg(unix)]
142      Stream::UnixStream(stream) => stream.write_all(&raw_msg).await?,
143    }
144    debug!("Connection::send_msg: timer 2 - {} - Finished", self.uuid);
145    Ok(())
146  }
147
148  /// Receive a `Message`
149  pub async fn recv_msg(&mut self) -> Result<Message> {
150    debug!("Connection::recv_msg: timer 0 - {} - Receiving", self.uuid);
151    let mut size_buffer = [0u8; 8];
152    self.stream.read_exact(&mut size_buffer).await?;
153    let msg_len = u64::from_le_bytes(size_buffer);
154    let mut read_buffer = vec![0; msg_len as usize];
155    self.stream.read_exact(&mut read_buffer).await?;
156    debug!("Connection::recv_msg: timer 1 - {} - Deserializing {msg_len} bytes", self.uuid);
157    // let msg: Message = rmp_serde::from_slice(&read_buffer)?;
158    let msg: Message = postcard::from_bytes(&read_buffer)?;
159    debug!("Connection::recv_msg: timer 2 - {} - Finished", self.uuid);
160    Ok(msg)
161  }
162}