1use std::sync::{Arc, Mutex};
2use std::collections::HashSet;
3use uuid::Uuid;
4use chrono::Utc;
5use tracing::info;
6use blake3::Hasher;
7use ndarray::prelude::*;
8use anyhow::{Result, bail, anyhow, ensure};
9use serde::{Deserialize, Serialize};
10use super::geometry::Geometry;
11use super::connection::Connection;
12
13#[derive(Serialize, Deserialize, Clone, Eq, Hash, PartialEq, Default)]
17pub struct Hash(pub [u8; 32]);
18
19impl std::fmt::Display for Hash {
20 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
21 write!(f, "[{:02X} {:02X} {:02X} {:02X}]", self.0[0], self.0[1], self.0[2], self.0[3])
22 }
23}
24
25impl std::fmt::Debug for Hash {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 write!(f, "{:?}", self.0)
28 }
29}
30
31impl From<Uuid> for Hash {
32 fn from(uuid: Uuid) -> Self {
33 let mut ret = [0; 32];
34 ret[..16].copy_from_slice(&uuid.into_bytes());
35 Hash(ret)
36 }
37}
38
39impl From<(&Uuid, &Uuid)> for Hash {
40 fn from(uuid: (&Uuid, &Uuid)) -> Self {
41 let mut ret = [0; 32];
42 ret[..16].copy_from_slice(&uuid.0.into_bytes());
43 ret[16..].copy_from_slice(&uuid.1.into_bytes());
44 Hash(ret)
45 }
46}
47
48impl From<u64> for Hash {
49 fn from(hash: u64) -> Self {
50 let mut ret = [0; 32];
51 ret[..8].copy_from_slice(&hash.to_le_bytes());
52 Hash(ret)
53 }
54}
55
56pub type DataCache = Arc<Mutex<hashbrown::HashMap<Hash, (i64, Data)>>>;
60
61#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
63pub enum Data {
64 TriMeshData(Array3<f32>),
68 RayData(Array2<f32>),
70 MaterialData(Array3<f32>, Array4<f32>),
72 TriMeshUVXYZ(Array3<f32>),
74 Intermediate(Intermediate),
76}
77
78#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
80pub struct Intermediate {
81 pub client_uuid: Uuid,
83 pub forward_requests: hashbrown::HashMap<Uuid, (RequestRayTracingForward, Vec<u8>)>,
86 pub d_texture: Array3<f32>,
88 pub d_envmap: Array4<f32>,
90}
91
92impl std::fmt::Display for Data {
93 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94 match self {
95 Data::TriMeshData(data) => write!(f, "Data::TriMeshData size: {}", data.shape()[2]),
96 Data::RayData(data) => write!(f, "Data::RayData size: {}", data.shape()[1]),
97 Data::MaterialData(texture, envmap) =>
98 write!(f, "Data::MaterialData size: {:?} {:?}", texture.shape(), envmap.shape()),
99 Data::TriMeshUVXYZ(data) => write!(f, "Data::TriMeshUVXYZ size: {}", data.shape()[1]),
100 Data::Intermediate(intermediate) => write!(f, "Data::Intermediate client_uuid: {}", intermediate.client_uuid),
101 }
102 }
103}
104
105impl Data {
106 pub fn hash(&self) -> Hash {
110 let mut hasher = Hasher::new();
111 let msg = postcard::to_stdvec(self).expect("Internal Data Struct");
112 hasher.update(&msg);
113 Hash(*hasher.finalize().as_bytes())
114 }
115}
116
117pub async fn ensure_data(
122 connection: &mut Connection,
123 data_cache: &DataCache,
124 mut required_data: HashSet<Hash>) -> Result<()> {
125 {
126 let mut data_cache = data_cache.lock().expect("No task should panic");
127 required_data.retain(|data_hash| {
128 data_cache.get_mut(data_hash).map_or_else(|| { true }, |entry| {
129 entry.0 = Utc::now().timestamp();
130 false
131 })
132 });
133 }
134 for data_hash in required_data.iter() {
135 let request_task = Message::RequestData(data_hash.to_owned());
136 info!("Message::ensure_data: {} Request {request_task}", connection.description());
137 connection.send_msg(&request_task).await?
138 }
139 while !required_data.is_empty() {
140 let msg_response = connection.recv_msg().await?;
141 info!("Message::ensure_data: {} Response {msg_response}", connection.description());
142 let Message::ResponseData(data) = msg_response else {
143 bail!("ensure_data: Unexpected command from `{}`", connection.description());
144 };
145 let data = data.map_err(|err| {
146 anyhow!("ensure_data: Remote reports `ResponseData` error: {err}")
147 })?;
148 let data_hash = data.hash();
149 if required_data.remove(&data_hash) {
150 let mut data_cache = data_cache.lock().expect("No task should panic");
151 data_cache.insert(data_hash, (Utc::now().timestamp(), data));
152 } else {
153 bail!("ensure_data: Unexpected `ResponseData` hash {data_hash} from `{}`", connection.description());
154 }
155 }
156 Ok(())
157}
158
159pub fn map_cache_data<R, F>(hash: &Hash, data_cache: &DataCache, f: F) -> Result<R>
163 where F: FnOnce(&Data) -> Result<R> {
164 let mut data_cache = data_cache.lock().expect("No task should panic");
165 let data = data_cache.get_mut(hash).ok_or_else(||
166 anyhow!("Message: map_cache_data: Hash {hash} not found"))?;
167 data.0 = Utc::now().timestamp();
168 f(&data.1)
169}
170
171pub fn insert_map_cache_data<F>(hash: &Hash, data_cache: &DataCache, f: F) -> Result<Hash>
173 where F: FnOnce(&Data) -> Result<Data> {
174 let mut data_cache = data_cache.lock().expect("No task should panic");
175 let data = data_cache.get_mut(hash).ok_or_else(|| anyhow!("Message: map_cache_data: Hash {hash} not found"))?;
176 data.0 = Utc::now().timestamp();
177 let new_data = f(&data.1)?;
178 let new_data_hash = new_data.hash();
179 data_cache.insert(new_data_hash.to_owned(), (Utc::now().timestamp(), new_data));
180 Ok(new_data_hash)
181}
182
183#[allow(missing_docs)]
187#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
188pub struct RequestRayTracingForward {
189 pub geometry: Geometry,
191 pub ray: Hash,
193 pub material: Hash,
195 pub sample_per_pixel: (usize, usize),
196 pub max_bounce: (usize, usize, usize, usize),
197 pub switches: (u8, u8, u8, u8),
198 pub clip_near: (f32, f32, f32),
199 pub camera_space: bool,
200 pub requires_grad: bool,
201 pub srand: i32,
202 pub low_discrepancy: u32,
203}
204
205impl RequestRayTracingForward {
206 pub fn size(&self, data_cache: &DataCache) -> Result<usize> {
208 map_cache_data(&self.ray, data_cache, |ray| {
209 if let Data::RayData(ray) = ray {
210 Ok(ray.shape()[1])
211 } else {
212 bail!("RequestRayTracingForward::size: Wrong data for argument `ray`: {ray}");
213 }
214 })
215 }
216
217 pub fn material_resolution(&self, data_cache: &DataCache) -> Result<(usize, usize)> {
219 map_cache_data(&self.material, data_cache, |material| {
220 if let Data::MaterialData(texture, envmap) = material {
221 Ok((texture.shape()[2], envmap.shape()[3]))
222 } else {
223 bail!("RequestRayTracingForward::material_resolution: Wrong data for argument `material`: {material}");
224 }
225 })
226 }
227}
228
229#[allow(missing_docs)]
231#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
232pub struct RequestRayTracingBackward {
233 pub d_ray: Array2<f32>,
235}
236
237impl RequestRayTracingBackward {
238 pub fn size(&self, _data_cache: &DataCache) -> Result<usize> {
240 Ok(self.d_ray.shape()[1])
241 }
242}
243
244#[allow(missing_docs)]
245#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
246pub enum RequestTask {
247 RequestRayTracingForward(Box<RequestRayTracingForward>),
249 RequestRayTracingBackward(Box<RequestRayTracingBackward>),
250}
251
252impl std::fmt::Display for RequestTask {
253 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
254 match self {
255 RequestTask::RequestRayTracingForward(request) =>
256 write!(f, "RequestTask::RequestRayTracingForward requires_grad: {}", request.requires_grad),
257 RequestTask::RequestRayTracingBackward(_) => write!(f, "RequestTask::RequestRayTracingBackward"),
258 }
259 }
260}
261
262impl RequestTask {
264 pub fn size(&self, data_cache: &DataCache) -> Result<usize> {
266 match self {
267 RequestTask::RequestRayTracingForward(request) => request.size(data_cache),
268 RequestTask::RequestRayTracingBackward(request) => request.size(data_cache),
269 }
270 }
271
272 pub fn split(&self, progress: &mut usize, tile_size: usize, data_cache: &DataCache) -> Result<Option<Self>> {
277 let new_progress = std::cmp::min(*progress + tile_size, self.size(data_cache)?);
278 if *progress >= new_progress {
279 Ok(None)
280 } else {
281 let sub_task = Some(match self {
282 RequestTask::RequestRayTracingForward(request) =>
283 RequestTask::RequestRayTracingForward(Box::new({
284 let sub_ray_hash = insert_map_cache_data(&request.ray, data_cache, |ray| {
285 if let Data::RayData(ray) = ray {
286 Ok(Data::RayData(ray.slice(s![.., *progress..new_progress]).to_owned()))
287 } else {
288 bail!("Message::RequestTask::size: Wrong data for argument `ray`: {ray}");
289 }
290 })?;
291 RequestRayTracingForward {
292 geometry: request.geometry.to_owned(),
293 ray: sub_ray_hash,
294 material: request.material.to_owned(),
295 ..**request
296 }
297 })),
298 RequestTask::RequestRayTracingBackward(request) =>
299 RequestTask::RequestRayTracingBackward(Box::new({
300 let sub_d_ray = request.d_ray.slice(s![.., *progress..new_progress]).to_owned();
301 RequestRayTracingBackward {
302 d_ray: sub_d_ray,
303 }
304 })),
305 });
306 *progress = new_progress;
307 Ok(sub_task)
308 }
309 }
310
311 pub fn required_data(&self) -> HashSet<Hash> {
313 match self {
314 RequestTask::RequestRayTracingForward(request) => {
315 let mut ret = HashSet::new();
316 ret.extend(request.geometry.required_data());
317 ret.insert(request.ray.to_owned());
318 ret.insert(request.material.to_owned());
319 ret
320 },
321 RequestTask::RequestRayTracingBackward(_request) => HashSet::new(),
322 }
323 }
324
325 pub fn new_response(&self) -> ResponseTask {
327 match self {
328 RequestTask::RequestRayTracingForward(request) => ResponseTask::ResponseRayTracingForward(
329 Box::new(ResponseRayTracingForward {
330 render: Array2::default((if request.camera_space { 3 } else { 9 }, 0)),
335 })
336 ),
337 RequestTask::RequestRayTracingBackward(_) => ResponseTask::ResponseRayTracingBackward(
338 Box::new(ResponseRayTracingBackward {
339 d_texture: None,
340 d_envmap: None,
341 d_ray_texture: None,
342 })
343 ),
344 }
345 }
346}
347
348#[allow(missing_docs)]
349#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
350pub struct ResponseRayTracingForward {
351 pub render: Array2<f32>,
352}
353
354impl ResponseRayTracingForward {
355 fn merge(&mut self, other: &ResponseRayTracingForward) -> Result<()> {
356 self.render.append(Axis(1), other.render.view())?;
357 Ok(())
358 }
359}
360
361#[allow(missing_docs)]
362#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
363pub struct ResponseRayTracingBackward {
364 pub d_texture: Option<Array3<f32>>,
365 pub d_envmap: Option<Array4<f32>>,
366 pub d_ray_texture: Option<Array2<f32>>,
367}
368
369impl ResponseRayTracingBackward {
370 fn merge(&mut self, other: &ResponseRayTracingBackward) -> Result<()> {
371 ensure!(other.d_texture.is_some() == other.d_texture.is_some());
372 if let Some(d_texture) = &mut self.d_texture {
373 if let Some(other_d_texture) = &other.d_texture {
374 ensure!(d_texture.shape() == other_d_texture.shape(), "ResponseRayTracingBackward::merge:
375 d_texture shape mismatch {:?} and {:?}", d_texture.shape(), other_d_texture.shape());
376 *d_texture += other_d_texture;
377 }
378 } else {
379 self.d_texture = other.d_texture.to_owned();
380 }
381 if let Some(d_envmap) = &mut self.d_envmap {
382 if let Some(other_d_envmap) = &other.d_envmap {
383 ensure!(d_envmap.shape() == other_d_envmap.shape(), "ResponseRayTracingBackward::merge:
384 d_envmap shape mismatch {:?} and {:?}", d_envmap.shape(), other_d_envmap.shape());
385 *d_envmap += other_d_envmap;
386 }
387 } else {
388 self.d_envmap = other.d_envmap.to_owned();
389 }
390 if let Some(other_d_ray_texture) = &other.d_ray_texture {
391 if let Some(d_ray_texture) = &mut self.d_ray_texture {
392 d_ray_texture.append(Axis(1), other_d_ray_texture.view())?;
393 } else {
394 self.d_ray_texture = Some(other_d_ray_texture.to_owned());
395 }
396 }
397 Ok(())
398 }
399}
400
401#[allow(missing_docs)]
402#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
403pub enum ResponseTask {
404 ResponseRayTracingForward(Box<ResponseRayTracingForward>),
405 ResponseRayTracingBackward(Box<ResponseRayTracingBackward>),
406}
407
408impl std::fmt::Display for ResponseTask {
409 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
410 match self {
411 ResponseTask::ResponseRayTracingForward(_) => write!(f, "ResponseTask::ResponseRayTracingForward"),
412 ResponseTask::ResponseRayTracingBackward(_) => write!(f, "ResponseTask::ResponseRayTracingBackward"),
413 }
414 }
415}
416
417#[allow(missing_docs)]
418impl ResponseTask {
419 pub fn merge(&mut self, other: &ResponseTask) -> Result<()> {
420 match self {
421 ResponseTask::ResponseRayTracingForward(response_task) => {
422 let ResponseTask::ResponseRayTracingForward(other) = other else {
423 bail!("ResponseTask::merge ResponseRayTracingForward and {other}");
424 };
425 response_task.merge(other)
426 },
427 ResponseTask::ResponseRayTracingBackward(response_task) => {
428 let ResponseTask::ResponseRayTracingBackward(other) = other else {
429 bail!("ResponseTask::merge ResponseRayTracingBackward and {other}");
430 };
431 response_task.merge(other)
432 },
433 }
434 }
435}
436
437#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
439pub struct GradUUID {
440 #[allow(missing_docs)]
441 pub client_uuid: Uuid,
442 #[allow(missing_docs)]
443 pub request_uuid: Uuid,
444}
445
446impl GradUUID {
447 pub fn nil() -> Self {
449 GradUUID { client_uuid: Uuid::nil(), request_uuid: Uuid::nil() }
450 }
451
452 pub fn is_nil(&self) -> bool {
454 self.client_uuid.is_nil() && self.request_uuid.is_nil()
455 }
456}
457
458#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
460pub enum Message {
461 Version(String, String, String),
463 RequestData(Hash),
465 ResponseData(Box<Result<Data, String>>),
469 RequestGradUUID(GradUUID),
471 RequestTask(RequestTask),
473 ResponseTask(Result<ResponseTask, String>),
475 Close(),
477}
478
479impl std::fmt::Display for Message {
480 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
481 match self {
482 Message::Version(ver, git, build) => write!(f, "Message::Version {ver} - {git} - {build}"),
483 Message::RequestData(msg) => write!(f, "Message::RequestData {msg}"),
484 Message::ResponseData(msg) => match msg.as_ref() {
485 Ok(msg) => write!(f, "Message::ResponseData {msg}"),
486 Err(err) => write!(f, "Message::ResponseData Error {err}"),
487 },
488 Message::RequestGradUUID(_) => write!(f, "Message::RequestGradUUID"),
489 Message::RequestTask(msg) => write!(f, "Message::RequestTask {msg}"),
490 Message::ResponseTask(msg) => match msg {
491 Ok(msg) => write!(f, "Message::ResponseTask {msg}"),
492 Err(err) => write!(f, "Message::ResponseTask Error {err}"),
493 },
494 Message::Close() => write!(f, "Message::Close"),
495 }
496 }
497}