libredr_common/
connection.rs1use std::collections::HashMap;
2use uuid::Uuid;
3use shadow_rs::shadow;
4use tokio::net::TcpStream;
5use tracing::{debug, warn};
6#[cfg(unix)]
7use tokio::net::UnixStream;
8use anyhow::{Result, bail};
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10use super::message::Message;
11
12shadow!(build);
13
14#[allow(missing_docs)]
15#[derive(Debug)]
16pub enum Stream {
17 TcpStream(TcpStream),
18#[cfg(unix)]
19 UnixStream(UnixStream),
20}
21
22#[allow(missing_docs)]
23impl Stream {
24 pub async fn read_exact(&mut self, buf: &mut [u8]) -> tokio::io::Result<usize> {
25 match self {
26 Stream::TcpStream(stream) => stream.read_exact(buf).await,
27#[cfg(unix)]
28 Stream::UnixStream(stream) => stream.read_exact(buf).await,
29 }
30 }
31}
32
33#[derive(Debug)]
35pub struct Connection {
36 uuid: Uuid,
37 stream: Stream,
38 description: String,
39}
40
41impl Connection {
42 pub fn uuid(&self) -> Uuid {
44 self.uuid
45 }
46
47 pub fn description(&self) -> &str {
49 &self.description
50 }
51
52 async fn check_version(&mut self) -> Result<()> {
53 let local_version = Message::Version(
54 build::PKG_VERSION.to_owned(),
55 build::SHORT_COMMIT.to_owned(),
56 build::COMMIT_DATE.to_owned());
57 debug!("Connection::check_version: local: {}", local_version);
58 self.send_msg(&local_version).await?;
59 let remote_version = self.recv_msg().await?;
60 if let Message::Version(ver, git, _) = &remote_version {
61 debug!("Connection::check_version: remote: {}", remote_version);
62 if ver == build::PKG_VERSION {
63 if git != build::SHORT_COMMIT {
64 warn!("Connection::check_version: Version match, but git commit mismatch. \
65 local: {}, remote: {}", local_version, remote_version);
66 }
67 return Ok(());
68 }
69 bail!("Connection::check_version: Version mismatch: local. {}, remote: {}", local_version, remote_version);
70 }
71 bail!("Connection::check_version: unexpected message {}", remote_version);
72 }
73
74 pub async fn from_stream(stream: Stream, description: String) -> Result<Self> {
76 let uuid = Uuid::new_v4();
77 let mut connection = Connection {
78 uuid,
79 stream,
80 description: format!("{description} - {uuid}"),
81 };
82 connection.check_version().await?;
83 Ok(connection)
84 }
85
86 pub async fn from_config(config: &HashMap<String, String>) -> Result<Self> {
100 let (stream, description) = match config["unix"].as_str() {
101 "false" => {
102 let stream = TcpStream::connect(config["connect"].to_owned()).await?;
103 let description = format!("tcp://{} - Server", config["connect"]);
104 (Stream::TcpStream(stream), description)
105 }
106#[cfg(unix)]
107 "true" => {
108 let stream = UnixStream::connect(config["connect"].to_owned()).await?;
109 let description = format!("unix://{} - Server", config["connect"]);
110 (Stream::UnixStream(stream), description)
111 },
112#[cfg(not(unix))]
113 "true" => {
114 bail!("Connection::new: Error: Unix socket is not supported on current platform");
115 },
116 other => {
117 bail!("Connection::new: Error: unknown config unix: {}", other);
118 }
119 };
120 let uuid = Uuid::new_v4();
121 let mut connection = Connection {
122 uuid,
123 stream,
124 description,
125 };
126 connection.check_version().await?;
127 Ok(connection)
128 }
129
130 pub async fn send_msg(&mut self, msg: &Message) -> Result<()> {
132 debug!("Connection::send_msg: timer 0 - {} - Serializing", self.uuid);
133 let mut msg = postcard::to_stdvec(msg)?;
135 let msg_len: u64 = msg.len().try_into()?;
136 let mut raw_msg = Vec::from(msg_len.to_le_bytes());
137 raw_msg.append(&mut msg);
138 debug!("Connection::send_msg: timer 1 - {} - Sending {msg_len} bytes", self.uuid);
139 match &mut self.stream {
140 Stream::TcpStream(stream) => stream.write_all(&raw_msg).await?,
141#[cfg(unix)]
142 Stream::UnixStream(stream) => stream.write_all(&raw_msg).await?,
143 }
144 debug!("Connection::send_msg: timer 2 - {} - Finished", self.uuid);
145 Ok(())
146 }
147
148 pub async fn recv_msg(&mut self) -> Result<Message> {
150 debug!("Connection::recv_msg: timer 0 - {} - Receiving", self.uuid);
151 let mut size_buffer = [0u8; 8];
152 self.stream.read_exact(&mut size_buffer).await?;
153 let msg_len = u64::from_le_bytes(size_buffer);
154 let mut read_buffer = vec![0; msg_len as usize];
155 self.stream.read_exact(&mut read_buffer).await?;
156 debug!("Connection::recv_msg: timer 1 - {} - Deserializing {msg_len} bytes", self.uuid);
157 let msg: Message = postcard::from_bytes(&read_buffer)?;
159 debug!("Connection::recv_msg: timer 2 - {} - Finished", self.uuid);
160 Ok(msg)
161 }
162}