1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
use std::collections::HashMap;
use uuid::Uuid;
use shadow_rs::shadow;
use tokio::net::TcpStream;
use tracing::{debug, warn};
#[cfg(unix)]
use tokio::net::UnixStream;
use anyhow::{Result, bail};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::message::Message;

shadow!(build);

#[allow(missing_docs)]
#[derive(Debug)]
pub enum Stream {
  TcpStream(TcpStream),
#[cfg(unix)]
  UnixStream(UnixStream),
}

#[allow(missing_docs)]
impl Stream {
  pub async fn read_exact(&mut self, buf: &mut [u8]) -> tokio::io::Result<usize> {
    match self {
      Stream::TcpStream(stream) => stream.read_exact(buf).await,
#[cfg(unix)]
      Stream::UnixStream(stream) => stream.read_exact(buf).await,
    }
  }
}

/// `Connection` with a random `uuid` and human-readable `description`
#[derive(Debug)]
pub struct Connection {
  uuid: Uuid,
  stream: Stream,
  description: String,
}

impl Connection {
  /// `uuid` is read-only
  pub fn uuid(&self) -> Uuid {
    self.uuid
  }

  /// `description` is read-only
  pub fn description(&self) -> &str {
    &self.description
  }

  async fn check_version(&mut self) -> Result<()> {
    let local_version = Message::Version(
      build::PKG_VERSION.to_owned(),
      build::SHORT_COMMIT.to_owned(),
      build::COMMIT_DATE.to_owned());
    debug!("Connection::check_version: local: {}", local_version);
    self.send_msg(&local_version).await?;
    let remote_version = self.recv_msg().await?;
    if let Message::Version(ver, git, _) = &remote_version {
      debug!("Connection::check_version: remote: {}", remote_version);
      if ver == build::PKG_VERSION {
        if git != build::SHORT_COMMIT {
          warn!("Connection::check_version: Version match, but git commit mismatch. \
            local: {}, remote: {}", local_version, remote_version);
        }
        return Ok(());
      }
      bail!("Connection::check_version: Version mismatch: local. {}, remote: {}", local_version, remote_version);
    }
    bail!("Connection::check_version: unexpected message {}", remote_version);
  }

  /// Construct `Connection` with random UUID
  pub async fn from_stream(stream: Stream, description: String) -> Result<Self> {
    let uuid = Uuid::new_v4();
    let mut connection = Connection {
      uuid,
      stream,
      description: format!("{description} - {uuid}"),
    };
    connection.check_version().await?;
    Ok(connection)
  }

  /// Construct `Connection` by connecting to LibreDR server.
  ///
  /// Return `Error` if connection failed
  /// # Examples
  /// ```
  /// async {
  ///   let mut config = HashMap::new();
  ///   config.insert(String::from("connect"), string::from("127.0.0.1:9000"));
  ///   config.insert(String::from("unix"), string::from("false"));
  ///   config.insert(String::from("tls"), string::from("false"));
  ///   let connection = Connection::from_config(&config).await?;
  /// }
  /// ```
  pub async fn from_config(config: &HashMap<String, String>) -> Result<Self> {
    let (stream, description) = match config["unix"].as_str() {
      "false" => {
        let stream = TcpStream::connect(config["connect"].to_owned()).await?;
        let description = format!("tcp://{} - Server", config["connect"]);
        (Stream::TcpStream(stream), description)
      }
#[cfg(unix)]
      "true" => {
        let stream = UnixStream::connect(config["connect"].to_owned()).await?;
        let description = format!("unix://{} - Server", config["connect"]);
        (Stream::UnixStream(stream), description)
      },
#[cfg(not(unix))]
      "true" => {
        bail!("Connection::new: Error: Unix socket is not supported on current platform");
      },
      other => {
        bail!("Connection::new: Error: unknown config unix: {}", other);
      }
    };
    let uuid = Uuid::new_v4();
    let mut connection = Connection {
      uuid,
      stream,
      description,
    };
    connection.check_version().await?;
    Ok(connection)
  }

  /// Send a `Message`
  pub async fn send_msg(&mut self, msg: &Message) -> Result<()> {
    debug!("Connection::send_msg: timer 0 - {} - Serializing", self.uuid);
    // let mut msg = rmp_serde::to_vec(msg)?;
    let mut msg = postcard::to_stdvec(msg)?;
    let msg_len: u64 = msg.len().try_into()?;
    let mut raw_msg = Vec::from(msg_len.to_le_bytes());
    raw_msg.append(&mut msg);
    debug!("Connection::send_msg: timer 1 - {} - Sending {msg_len} bytes", self.uuid);
    match &mut self.stream {
      Stream::TcpStream(stream) => stream.write_all(&raw_msg).await?,
#[cfg(unix)]
      Stream::UnixStream(stream) => stream.write_all(&raw_msg).await?,
    }
    debug!("Connection::send_msg: timer 2 - {} - Finished", self.uuid);
    Ok(())
  }

  /// Receive a `Message`
  pub async fn recv_msg(&mut self) -> Result<Message> {
    debug!("Connection::recv_msg: timer 0 - {} - Receiving", self.uuid);
    let mut size_buffer = [0u8; 8];
    self.stream.read_exact(&mut size_buffer).await?;
    let msg_len = u64::from_le_bytes(size_buffer);
    let mut read_buffer = vec![0; msg_len as usize];
    self.stream.read_exact(&mut read_buffer).await?;
    debug!("Connection::recv_msg: timer 1 - {} - Deserializing {msg_len} bytes", self.uuid);
    // let msg: Message = rmp_serde::from_slice(&read_buffer)?;
    let msg: Message = postcard::from_bytes(&read_buffer)?;
    debug!("Connection::recv_msg: timer 2 - {} - Finished", self.uuid);
    Ok(msg)
  }
}