libredr_common/
message.rs

1use std::sync::{Arc, Mutex};
2use std::collections::HashSet;
3use uuid::Uuid;
4use chrono::Utc;
5use tracing::info;
6use blake3::Hasher;
7use ndarray::prelude::*;
8use anyhow::{Result, bail, anyhow, ensure};
9use serde::{Deserialize, Serialize};
10use super::geometry::Geometry;
11use super::connection::Connection;
12
13/// Hash type for lazy-loading
14///
15/// Currently use 32-byte long hash from BLAKE3 algorithm.
16#[derive(Serialize, Deserialize, Clone, Eq, Hash, PartialEq, Default)]
17pub struct Hash(pub [u8; 32]);
18
19impl std::fmt::Display for Hash {
20  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21    write!(f, "[{:02X} {:02X} {:02X} {:02X}]", self.0[0], self.0[1], self.0[2], self.0[3])
22  }
23}
24
25impl std::fmt::Debug for Hash {
26  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27    write!(f, "{:?}", self.0)
28  }
29}
30
31impl From<Uuid> for Hash {
32  fn from(uuid: Uuid) -> Self {
33    let mut ret = [0; 32];
34    ret[..16].copy_from_slice(&uuid.into_bytes());
35    Hash(ret)
36  }
37}
38
39impl From<(&Uuid, &Uuid)> for Hash {
40  fn from(uuid: (&Uuid, &Uuid)) -> Self {
41    let mut ret = [0; 32];
42    ret[..16].copy_from_slice(&uuid.0.into_bytes());
43    ret[16..].copy_from_slice(&uuid.1.into_bytes());
44    Hash(ret)
45  }
46}
47
48impl From<u64> for Hash {
49  fn from(hash: u64) -> Self {
50    let mut ret = [0; 32];
51    ret[..8].copy_from_slice(&hash.to_le_bytes());
52    Hash(ret)
53  }
54}
55
56/// Type for lazy-loading data cache
57///
58/// hash -> (access time, data)
59pub type DataCache = Arc<Mutex<hashbrown::HashMap<Hash, (i64, Data)>>>;
60
61/// All kinds of lazy-loading data types
62#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
63pub enum Data {
64  /// Vertex coordinate, vertex normal, and vertex texture coordinate of a TriMesh
65  ///
66  /// (3 + 3 + 2) * 3 * Number of faces
67  TriMeshData(Array3<f32>),
68  /// Ray tracing `ray`, see `ray` argument in `py_ray_tracing_forward`
69  RayData(Array2<f32>),
70  /// `texture` and `envmap`
71  MaterialData(Array3<f32>, Array4<f32>),
72  /// Cache for uv_xyz calculation, only on worker
73  TriMeshUVXYZ(Array3<f32>),
74  /// Cache intermediate data for back propagation and cumulate gradient to reduce communication, only on worker
75  Intermediate(Intermediate),
76}
77
78/// Cache intermediate data for back propagation and cumulate gradient to reduce communication, only on worker
79#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
80pub struct Intermediate {
81  /// Uuid of the client task
82  pub client_uuid: Uuid,
83  /// Uuids of all the request tasks that are assigned to this worker \
84  /// Save forward task and intermediate data for backward task
85  pub forward_requests: hashbrown::HashMap<Uuid, (RequestRayTracingForward, Vec<u8>)>,
86  /// Cumulate `d_texture` and only return to server on the last tile
87  pub d_texture: Array3<f32>,
88  /// Cumulate `d_envmap` and only return to server on the last tile
89  pub d_envmap: Array4<f32>,
90}
91
92impl std::fmt::Display for Data {
93  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94    match self {
95      Data::TriMeshData(data) => write!(f, "Data::TriMeshData size: {}", data.shape()[2]),
96      Data::RayData(data) => write!(f, "Data::RayData size: {}", data.shape()[1]),
97      Data::MaterialData(texture, envmap) =>
98        write!(f, "Data::MaterialData size: {:?} {:?}", texture.shape(), envmap.shape()),
99      Data::TriMeshUVXYZ(data) => write!(f, "Data::TriMeshUVXYZ size: {}", data.shape()[1]),
100      Data::Intermediate(intermediate) => write!(f, "Data::Intermediate client_uuid: {}", intermediate.client_uuid),
101    }
102  }
103}
104
105impl Data {
106  /// Calculate hash for lazy-loading
107  ///
108  /// Currently use 32-byte long hash from BLAKE3 algorithm.
109  pub fn hash(&self) -> Hash {
110    let mut hasher = Hasher::new();
111    let msg = postcard::to_stdvec(self).expect("Internal Data Struct");
112    hasher.update(&msg);
113    Hash(*hasher.finalize().as_bytes())
114  }
115}
116
117/// Load all lazy-loading data hash from `required_data` into `data_cache`
118///
119/// The `Data` is fetched from `connection` \
120/// All `Data` timestamps in `data_cache` are updated
121pub async fn ensure_data(
122    connection: &mut Connection,
123    data_cache: &DataCache,
124    mut required_data: HashSet<Hash>) -> Result<()> {
125  {
126    let mut data_cache = data_cache.lock().expect("No task should panic");
127    required_data.retain(|data_hash| {
128      data_cache.get_mut(data_hash).map_or_else(|| { true }, |entry| {
129        entry.0 = Utc::now().timestamp();
130        false
131      })
132    });
133  }
134  for data_hash in required_data.iter() {
135    let request_task = Message::RequestData(data_hash.to_owned());
136    info!("Message::ensure_data: {} Request {request_task}", connection.description());
137    connection.send_msg(&request_task).await?
138  }
139  while !required_data.is_empty() {
140    let msg_response = connection.recv_msg().await?;
141    info!("Message::ensure_data: {} Response {msg_response}", connection.description());
142    let Message::ResponseData(data) = msg_response else {
143      bail!("ensure_data: Unexpected command from `{}`", connection.description());
144    };
145    let data = data.map_err(|err| {
146      anyhow!("ensure_data: Remote reports `ResponseData` error: {err}")
147    })?;
148    let data_hash = data.hash();
149    if required_data.remove(&data_hash) {
150      let mut data_cache = data_cache.lock().expect("No task should panic");
151      data_cache.insert(data_hash, (Utc::now().timestamp(), data));
152    } else {
153      bail!("ensure_data: Unexpected `ResponseData` hash {data_hash} from `{}`", connection.description());
154    }
155  }
156  Ok(())
157}
158
159/// Get `hash` data from `data_cache`, return result of f(data)
160///
161/// To prevent data copy, `data` is not returned
162pub fn map_cache_data<R, F>(hash: &Hash, data_cache: &DataCache, f: F) -> Result<R>
163    where F: FnOnce(&Data) -> Result<R> {
164  let mut data_cache = data_cache.lock().expect("No task should panic");
165  let data = data_cache.get_mut(hash).ok_or_else(||
166    anyhow!("Message: map_cache_data: Hash {hash} not found"))?;
167  data.0 = Utc::now().timestamp();
168  f(&data.1)
169}
170
171/// Get `hash` data from `data_cache`, insert result of f(data), return new hash
172pub fn insert_map_cache_data<F>(hash: &Hash, data_cache: &DataCache, f: F) -> Result<Hash>
173    where F: FnOnce(&Data) -> Result<Data> {
174  let mut data_cache = data_cache.lock().expect("No task should panic");
175  let data = data_cache.get_mut(hash).ok_or_else(|| anyhow!("Message: map_cache_data: Hash {hash} not found"))?;
176  data.0 = Utc::now().timestamp();
177  let new_data = f(&data.1)?;
178  let new_data_hash = new_data.hash();
179  data_cache.insert(new_data_hash.to_owned(), (Utc::now().timestamp(), new_data));
180  Ok(new_data_hash)
181}
182
183/// Arguments for ray-tracing forward task
184///
185/// See arguments in `ray_tracing_forward` for details
186#[allow(missing_docs)]
187#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
188pub struct RequestRayTracingForward {
189  /// Lazy-loading `geometry`
190  pub geometry: Geometry,
191  /// Lazy-loading `ray`
192  pub ray: Hash,
193  /// Lazy-loading `texture` and `envmap`
194  pub material: Hash,
195  pub sample_per_pixel: (usize, usize),
196  pub max_bounce: (usize, usize, usize, usize),
197  pub switches: (u8, u8, u8, u8),
198  pub clip_near: (f32, f32, f32),
199  pub camera_space: bool,
200  pub requires_grad: bool,
201  pub srand: i32,
202  pub low_discrepancy: u32,
203}
204
205impl RequestRayTracingForward {
206  /// Number of rays in [`RequestRayTracingForward`]
207  pub fn size(&self, data_cache: &DataCache) -> Result<usize> {
208    map_cache_data(&self.ray, data_cache, |ray| {
209      if let Data::RayData(ray) = ray {
210        Ok(ray.shape()[1])
211      } else {
212        bail!("RequestRayTracingForward::size: Wrong data for argument `ray`: {ray}");
213      }
214    })
215  }
216
217  /// Texture resolution and envmap resolution in [`RequestRayTracingForward`]
218  pub fn material_resolution(&self, data_cache: &DataCache) -> Result<(usize, usize)> {
219    map_cache_data(&self.material, data_cache, |material| {
220      if let Data::MaterialData(texture, envmap) = material {
221        Ok((texture.shape()[2], envmap.shape()[3]))
222      } else {
223        bail!("RequestRayTracingForward::material_resolution: Wrong data for argument `material`: {material}");
224      }
225    })
226  }
227}
228
229// Because we hardly see same parameter cross different backward tasks, we don't lazy-load backward data
230#[allow(missing_docs)]
231#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
232pub struct RequestRayTracingBackward {
233  /// `d_ray`, must have the same number of rays as [`RequestRayTracingForward`]
234  pub d_ray: Array2<f32>,
235}
236
237impl RequestRayTracingBackward {
238  /// Number of rays in `RequestRayTracingBackward`
239  pub fn size(&self, _data_cache: &DataCache) -> Result<usize> {
240    Ok(self.d_ray.shape()[1])
241  }
242}
243
244#[allow(missing_docs)]
245#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
246pub enum RequestTask {
247  // Boxed because of large size difference
248  RequestRayTracingForward(Box<RequestRayTracingForward>),
249  RequestRayTracingBackward(Box<RequestRayTracingBackward>),
250}
251
252impl std::fmt::Display for RequestTask {
253  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254    match self {
255      RequestTask::RequestRayTracingForward(request) =>
256        write!(f, "RequestTask::RequestRayTracingForward requires_grad: {}", request.requires_grad),
257      RequestTask::RequestRayTracingBackward(_) => write!(f, "RequestTask::RequestRayTracingBackward"),
258    }
259  }
260}
261
262/// Functions to split task into multiple sub-tasks for worker
263impl RequestTask {
264  /// Get task size for split
265  pub fn size(&self, data_cache: &DataCache) -> Result<usize> {
266    match self {
267      RequestTask::RequestRayTracingForward(request) => request.size(data_cache),
268      RequestTask::RequestRayTracingBackward(request) => request.size(data_cache),
269    }
270  }
271
272  /// Generate a sub-task by slicing current task from `progress` to `progress + tile_size`
273  ///
274  /// The last piece can be smaller than `tile_size` \
275  /// `progress` is updated automatically
276  pub fn split(&self, progress: &mut usize, tile_size: usize, data_cache: &DataCache) -> Result<Option<Self>> {
277    let new_progress = std::cmp::min(*progress + tile_size, self.size(data_cache)?);
278    if *progress >= new_progress {
279      Ok(None)
280    } else {
281      let sub_task = Some(match self {
282        RequestTask::RequestRayTracingForward(request) =>
283          RequestTask::RequestRayTracingForward(Box::new({
284            let sub_ray_hash = insert_map_cache_data(&request.ray, data_cache, |ray| {
285              if let Data::RayData(ray) = ray {
286                Ok(Data::RayData(ray.slice(s![.., *progress..new_progress]).to_owned()))
287              } else {
288                bail!("Message::RequestTask::size: Wrong data for argument `ray`: {ray}");
289              }
290            })?;
291            RequestRayTracingForward {
292              geometry: request.geometry.to_owned(),
293              ray: sub_ray_hash,
294              material: request.material.to_owned(),
295              ..**request
296            }
297          })),
298        RequestTask::RequestRayTracingBackward(request) =>
299          RequestTask::RequestRayTracingBackward(Box::new({
300            let sub_d_ray = request.d_ray.slice(s![.., *progress..new_progress]).to_owned();
301            RequestRayTracingBackward {
302              d_ray: sub_d_ray,
303            }
304          })),
305      });
306      *progress = new_progress;
307      Ok(sub_task)
308    }
309  }
310
311  /// Get all hashes of lazy-loading `Data`
312  pub fn required_data(&self) -> HashSet<Hash> {
313    match self {
314      RequestTask::RequestRayTracingForward(request) => {
315        let mut ret = HashSet::new();
316        ret.extend(request.geometry.required_data());
317        ret.insert(request.ray.to_owned());
318        ret.insert(request.material.to_owned());
319        ret
320      },
321      RequestTask::RequestRayTracingBackward(_request) => HashSet::new(),
322    }
323  }
324
325  /// Create an empty `ResponseTask` as the same type of `RequestTask`
326  pub fn new_response(&self) -> ResponseTask {
327    match self {
328      RequestTask::RequestRayTracingForward(request) => ResponseTask::ResponseRayTracingForward(
329        Box::new(ResponseRayTracingForward {
330          // * if `camera_space` is `true`, (3)
331          //   * ray_render
332          // * if `camera_space` is `false`, (3 + 2 + 1 + 3)
333          //   * ray_render + ray_texture + ray_depth + ray_normal
334          render: Array2::default((if request.camera_space { 3 } else { 9 }, 0)),
335        })
336      ),
337      RequestTask::RequestRayTracingBackward(_) => ResponseTask::ResponseRayTracingBackward(
338        Box::new(ResponseRayTracingBackward {
339          d_texture: None,
340          d_envmap: None,
341          d_ray_texture: None,
342        })
343      ),
344    }
345  }
346}
347
348#[allow(missing_docs)]
349#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
350pub struct ResponseRayTracingForward {
351  pub render: Array2<f32>,
352}
353
354impl ResponseRayTracingForward {
355  fn merge(&mut self, other: &ResponseRayTracingForward) -> Result<()> {
356    self.render.append(Axis(1), other.render.view())?;
357    Ok(())
358  }
359}
360
361#[allow(missing_docs)]
362#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
363pub struct ResponseRayTracingBackward {
364  pub d_texture: Option<Array3<f32>>,
365  pub d_envmap: Option<Array4<f32>>,
366  pub d_ray_texture: Option<Array2<f32>>,
367}
368
369impl ResponseRayTracingBackward {
370  fn merge(&mut self, other: &ResponseRayTracingBackward) -> Result<()> {
371    ensure!(other.d_texture.is_some() == other.d_texture.is_some());
372    if let Some(d_texture) = &mut self.d_texture {
373      if let Some(other_d_texture) = &other.d_texture {
374        ensure!(d_texture.shape() == other_d_texture.shape(), "ResponseRayTracingBackward::merge:
375          d_texture shape mismatch {:?} and {:?}", d_texture.shape(), other_d_texture.shape());
376        *d_texture += other_d_texture;
377      }
378    } else {
379      self.d_texture = other.d_texture.to_owned();
380    }
381    if let Some(d_envmap) = &mut self.d_envmap {
382      if let Some(other_d_envmap) = &other.d_envmap {
383        ensure!(d_envmap.shape() == other_d_envmap.shape(), "ResponseRayTracingBackward::merge:
384        d_envmap shape mismatch {:?} and {:?}", d_envmap.shape(), other_d_envmap.shape());
385        *d_envmap += other_d_envmap;
386      }
387    } else {
388      self.d_envmap = other.d_envmap.to_owned();
389    }
390    if let Some(other_d_ray_texture) = &other.d_ray_texture {
391      if let Some(d_ray_texture) = &mut self.d_ray_texture {
392        d_ray_texture.append(Axis(1), other_d_ray_texture.view())?;
393      } else {
394        self.d_ray_texture = Some(other_d_ray_texture.to_owned());
395      }
396    }
397    Ok(())
398  }
399}
400
401#[allow(missing_docs)]
402#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
403pub enum ResponseTask {
404  ResponseRayTracingForward(Box<ResponseRayTracingForward>),
405  ResponseRayTracingBackward(Box<ResponseRayTracingBackward>),
406}
407
408impl std::fmt::Display for ResponseTask {
409  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410    match self {
411      ResponseTask::ResponseRayTracingForward(_) => write!(f, "ResponseTask::ResponseRayTracingForward"),
412      ResponseTask::ResponseRayTracingBackward(_) => write!(f, "ResponseTask::ResponseRayTracingBackward"),
413    }
414  }
415}
416
417#[allow(missing_docs)]
418impl ResponseTask {
419  pub fn merge(&mut self, other: &ResponseTask) -> Result<()> {
420    match self {
421      ResponseTask::ResponseRayTracingForward(response_task) => {
422        let ResponseTask::ResponseRayTracingForward(other) = other else {
423          bail!("ResponseTask::merge ResponseRayTracingForward and {other}");
424        };
425        response_task.merge(other)
426      },
427      ResponseTask::ResponseRayTracingBackward(response_task) => {
428        let ResponseTask::ResponseRayTracingBackward(other) = other else {
429          bail!("ResponseTask::merge ResponseRayTracingBackward and {other}");
430        };
431        response_task.merge(other)
432      },
433    }
434  }
435}
436
437/// Pair of `client_uuid` and `request_uuid` to help worker find and merge intermediate data
438#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
439pub struct GradUUID {
440  #[allow(missing_docs)]
441  pub client_uuid: Uuid,
442  #[allow(missing_docs)]
443  pub request_uuid: Uuid,
444}
445
446impl GradUUID {
447  /// Create `GradUUID` with both `client_uuid` and `request_uuid` as nil
448  pub fn nil() -> Self {
449    GradUUID { client_uuid: Uuid::nil(), request_uuid: Uuid::nil() }
450  }
451
452  /// Test if `GradUUID` is nil
453  pub fn is_nil(&self) -> bool {
454    self.client_uuid.is_nil() && self.request_uuid.is_nil()
455  }
456}
457
458/// `Message` type shared by Client, Server, and Worker
459#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
460pub enum Message {
461  /// Check remote version
462  Version(String, String, String),
463  /// Request lazy-loading data by hash
464  RequestData(Hash),
465  /// Response lazy-loading data
466  ///
467  /// Error if the data-cache is cleaned up
468  ResponseData(Box<Result<Data, String>>),
469  /// Notify the worker `GradUUID` to save and merge intermediate data
470  RequestGradUUID(GradUUID),
471  /// All types of `RequestTask`
472  RequestTask(RequestTask),
473  /// All types of `ResponseTask`
474  ResponseTask(Result<ResponseTask, String>),
475  /// Close cleanly
476  Close(),
477}
478
479impl std::fmt::Display for Message {
480  fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
481    match self {
482      Message::Version(ver, git, build) => write!(f, "Message::Version {ver} - {git} - {build}"),
483      Message::RequestData(msg) => write!(f, "Message::RequestData {msg}"),
484      Message::ResponseData(msg) => match msg.as_ref() {
485        Ok(msg) => write!(f, "Message::ResponseData {msg}"),
486        Err(err) => write!(f, "Message::ResponseData Error {err}"),
487      },
488      Message::RequestGradUUID(_) => write!(f, "Message::RequestGradUUID"),
489      Message::RequestTask(msg) => write!(f, "Message::RequestTask {msg}"),
490      Message::ResponseTask(msg) => match msg {
491        Ok(msg) => write!(f, "Message::ResponseTask {msg}"),
492        Err(err) => write!(f, "Message::ResponseTask Error {err}"),
493      },
494      Message::Close() => write!(f, "Message::Close"),
495    }
496  }
497}