libredr/
lib.rs

1#![doc = include_str!("../README.md")]
2#![warn(missing_docs)]
3#![warn(missing_debug_implementations)]
4
5use std::env;
6use std::path::Path;
7use std::sync::{Arc, Mutex};
8use pyo3::prelude::*;
9use anyhow::{Result, bail};
10use tokio::runtime::Runtime;
11use tracing::{error, debug};
12use once_cell::sync::OnceCell;
13use tracing_subscriber::{fmt, EnvFilter, prelude::*};
14use numpy::{PyReadonlyArray2, PyReadonlyArray3, PyReadonlyArray4, PyReadonlyArrayDyn, PyArray3, PyArray4, PyArrayDyn,
15  IntoPyArray, PyArrayMethods};
16use common::render;
17use common::message::*;
18pub use common::geometry::Geometry;
19pub use self::client::LibreDR;
20use self::camera::py_camera;
21use self::light_source::py_light_source;
22
23mod client;
24/// All camera models (Rust and Python)
25pub mod camera;
26/// All light source models (Rust and Python)
27pub mod light_source;
28
29static RUNTIME: OnceCell<Runtime> = OnceCell::new();
30
31/// Python interface for [`Geometry`]
32#[derive(Debug)]
33#[pyclass(name = "Geometry", subclass)]
34pub struct PyGeometry {
35  geometry: Geometry,
36  data_cache: DataCache,
37}
38
39/// Python interface for [`LibreDR`] client
40#[derive(Debug)]
41#[pyclass(name = "LibreDR", subclass)]
42pub struct PyLibreDR(LibreDR, String, bool, bool);
43
44#[pymethods]
45impl PyGeometry {
46  #[new]
47  /// Create an empty [`PyGeometry`]
48  pub fn py_new() -> Self {
49    PyGeometry {
50      geometry: Geometry::new(),
51      data_cache: Arc::new(Mutex::new(hashbrown::HashMap::new())),
52    }
53  }
54
55  // TODO: add_trimesh (using vertex and face arrays from python)
56  /// See [`Geometry::add_obj`]
57  #[pyo3(name = "add_obj")]
58  pub fn py_add_obj(
59      &mut self,
60      py: Python,
61      filename: &str,
62      transform_v: PyReadonlyArray2<f32>,
63      transform_vt: PyReadonlyArray2<f32>) -> Result<()> {
64    let transform_v = transform_v.to_owned_array();
65    let transform_vt = transform_vt.to_owned_array();
66    py.allow_threads(|| {
67      self.geometry.add_obj(Path::new(filename), transform_v, transform_vt, &self.data_cache)
68    })?;
69    Ok(())
70  }
71}
72
73#[pymethods]
74impl PyLibreDR {
75  /// See [`render::MISS_NONE`] for details.
76  #[classattr]
77  pub const MISS_NONE: u8 = render::MISS_NONE;
78  /// See [`render::MISS_ENVMAP`] for details.
79  #[classattr]
80  pub const MISS_ENVMAP: u8 = render::MISS_ENVMAP;
81  /// See [`render::REFLECTION_NORMAL_FACE`] for details.
82  #[classattr]
83  pub const REFLECTION_NORMAL_FACE: u8 = render::REFLECTION_NORMAL_FACE;
84  /// See [`render::REFLECTION_NORMAL_VERTEX`] for details.
85  #[classattr]
86  pub const REFLECTION_NORMAL_VERTEX: u8 = render::REFLECTION_NORMAL_VERTEX;
87  /// See [`render::REFLECTION_NORMAL_TEXTURE`] for details.
88  #[classattr]
89  pub const REFLECTION_NORMAL_TEXTURE: u8 = render::REFLECTION_NORMAL_TEXTURE;
90  /// See [`render::REFLECTION_DIFFUSE_NONE`] for details.
91  #[classattr]
92  pub const REFLECTION_DIFFUSE_NONE: u8 = render::REFLECTION_DIFFUSE_NONE;
93  /// See [`render::REFLECTION_DIFFUSE_LAMBERTIAN`] for details.
94  #[classattr]
95  pub const REFLECTION_DIFFUSE_LAMBERTIAN: u8 = render::REFLECTION_DIFFUSE_LAMBERTIAN;
96  /// See [`render::REFLECTION_SPECULAR_NONE`] for details.
97  #[classattr]
98  pub const REFLECTION_SPECULAR_NONE: u8 = render::REFLECTION_SPECULAR_NONE;
99  /// See [`render::REFLECTION_SPECULAR_PHONG`] for details.
100  #[classattr]
101  pub const REFLECTION_SPECULAR_PHONG: u8 = render::REFLECTION_SPECULAR_PHONG;
102  /// See [`render::REFLECTION_SPECULAR_BLINN_PHONG`] for details.
103  #[classattr]
104  pub const REFLECTION_SPECULAR_BLINN_PHONG: u8 = render::REFLECTION_SPECULAR_BLINN_PHONG;
105  /// See [`render::REFLECTION_SPECULAR_TORRANCE_SPARROW_PHONG`] for details.
106  #[classattr]
107  pub const REFLECTION_SPECULAR_TORRANCE_SPARROW_PHONG: u8 = render::REFLECTION_SPECULAR_TORRANCE_SPARROW_PHONG;
108  /// See [`render::REFLECTION_SPECULAR_TORRANCE_SPARROW_BLINN_PHONG`] for details.
109  #[classattr]
110  pub const REFLECTION_SPECULAR_TORRANCE_SPARROW_BLINN_PHONG: u8 =
111    render::REFLECTION_SPECULAR_TORRANCE_SPARROW_BLINN_PHONG;
112    /// See [`render::REFLECTION_SPECULAR_TORRANCE_SPARROW_BECKMANN`] for details.
113  #[classattr]
114  pub const REFLECTION_SPECULAR_TORRANCE_SPARROW_BECKMANN: u8 = render::REFLECTION_SPECULAR_TORRANCE_SPARROW_BECKMANN;
115
116  /// Construct `LibreDR` by connecting to LibreDR server.
117  ///
118  /// See [`LibreDR::new`] for details.
119  #[new]
120  pub fn py_new(py: Python, connect: String, unix: bool, tls: bool) -> Result<Self> {
121    let libredr = py.allow_threads(|| {
122      let rt = RUNTIME.get().expect("Initialized in pymodule");
123      rt.block_on(LibreDR::new(connect.to_owned(), unix, tls))
124    })?;
125    Ok(PyLibreDR(libredr, connect, unix, tls))
126  }
127
128  /// To allow pickle [`PyLibreDR`] object by reconnecting to the server.
129  ///
130  /// Unpickled connection has different UUID.
131  pub fn __getnewargs__(&self) -> (String, bool, bool) {
132    (self.1.to_owned(), self.2, self.3)
133  }
134
135  /// Create a [`RequestRayTracingForward`] task and wait for response
136  ///
137  /// # Arguments
138  /// * `ray` - ray parameters
139  ///   * if `camera_space` is `false` 18 * `image_shape`
140  ///     * including ray position 9 * `image_shape`
141  ///     * including ray direction 9 * `image_shape`
142  ///   * if `camera_space` is `true`, add another (1 + 14) channels
143  ///     * including ray depth 1 * `image_shape` (if depth <= 0, treat as hit miss)
144  ///     * including ray material 14 * `image_shape`
145  /// * `texture` - (3 + 3 + 3 + 1 + 3 + 1) * `texture_resolution` * `texture_resolution` (must be square image)
146  ///   * including normal + diffuse + specular + roughness + intensity + window
147  /// * `envmap` - 3 * 6 * `envmap_resolution` * `envmap_resolution`
148  ///   * (must be box unwrapped 6 square images)
149  /// * `sample_per_pixel` - `sample_per_pixel_forward`, (`sample_per_pixel_backward`)
150  ///   * `sample_per_pixel` can be a single integer,
151  ///     * (same value for forward and backward)
152  ///   * or tuple of 2 integers.
153  ///     * (only `sample_per_pixel_backward` number of rays are stored for backward)
154  ///     * (must ensure `sample_per_pixel_forward` >= `sample_per_pixel_backward`)
155  /// * `max_bounce` - `max_bounce_forward`, (`max_bounce_backward`), (`max_bounce_low_discrepancy`), (`skip_bounce`)
156  ///   * `max_bounce` can be a single integer, or tuple of 2-4 integers.
157  ///   * The default value for `max_bounce_backward` is the same as `max_bounce_forward`.
158  ///   * The default value for `max_bounce_low_discrepancy` is `0`.
159  ///   * The default value for `skip_bounce` is `0`.
160  /// * `switches` - tuple of 4 switches to determine hit miss and reflection behavior
161  ///   * render::MISS_* - determine how to deal with ray hit miss
162  ///     * [`common::render::MISS_NONE`]
163  ///     * [`common::render::MISS_ENVMAP`]
164  ///   * render::REFLECTION_NORMAL_* - determine how to get surface normal
165  ///     * [`common::render::REFLECTION_NORMAL_FACE`]
166  ///     * [`common::render::REFLECTION_NORMAL_VERTEX`]
167  ///     * [`common::render::REFLECTION_NORMAL_TEXTURE`]
168  ///   * render::REFLECTION_DIFFUSE_* - determine diffuse reflection model
169  ///     * [`common::render::REFLECTION_DIFFUSE_NONE`]
170  ///     * [`common::render::REFLECTION_DIFFUSE_LAMBERTIAN`]
171  ///   * render::REFLECTION_SPECULAR_* - determine specular reflection model
172  ///     * [`common::render::REFLECTION_SPECULAR_NONE`]
173  ///     * [`common::render::REFLECTION_SPECULAR_PHONG`]
174  ///     * [`common::render::REFLECTION_SPECULAR_BLINN_PHONG`]
175  ///     * [`common::render::REFLECTION_SPECULAR_TORRANCE_SPARROW_PHONG`]
176  ///     * [`common::render::REFLECTION_SPECULAR_TORRANCE_SPARROW_BLINN_PHONG`]
177  ///     * [`common::render::REFLECTION_SPECULAR_TORRANCE_SPARROW_BECKMANN`]
178  /// * `clip_near` - clip near distance for camera
179  ///   * `clip_near` can be a single float number (same for all bounces),
180  ///   * or tuple of 3 float numbers (first bounce, second bounce, and other bounces)
181  /// * `camera_space` - if `true`, the first bounce uses the depth and material given by the ray
182  /// * `requires_grad` - if `true`, worker will save intermediate data, the next task must be `ray_tracing_backward`
183  /// * `srand` - random seed
184  ///   * if srand >= 0, the same random seed is used for every pixel
185  ///   * if srand < 0, use different seed for each pixel
186  /// * `low_discrepancy` - (optional) start id of Halton low discrepancy sequence.
187  ///   * The default value is the same as `sample_per_pixel_forward`.
188  ///   * if combine multiple rendered images to reduce noise, this value can be set to: \
189  ///       1 * `sample_per_pixel_forward`, 2 * `sample_per_pixel_forward`, 3 * `sample_per_pixel_forward`, ...
190  ///
191  /// # Return
192  /// Return shape will be,
193  ///   * if `camera_space` is `true`
194  ///     * render image 3 * `image_shape`
195  ///   * if `camera_space` is `false`, add another
196  ///     * ray texture coordinate 2 * `image_shape`
197  ///     * ray depth (Euclidean distance) 1 * `image_shape`
198  ///     * ray normal 3 * `image_shape`
199  #[pyo3(name = "ray_tracing_forward")]
200  #[allow(clippy::too_many_arguments)]
201  pub fn py_ray_tracing_forward<'py>(
202      &mut self,
203      py: Python<'py>,
204      geometry: &PyGeometry,
205      ray: PyReadonlyArrayDyn<f32>,
206      texture: PyReadonlyArray3<f32>,
207      envmap: PyReadonlyArray4<f32>,
208      sample_per_pixel: Bound<'py, PyAny>,
209      max_bounce: Bound<'py, PyAny>,
210      switches: (u8, u8, u8, u8),
211      clip_near: Bound<'py, PyAny>,
212      camera_space: bool,
213      requires_grad: bool,
214      srand: i32,
215      low_discrepancy: Option<u32>) -> Result<Bound<'py, PyArrayDyn<f32>>> {
216    debug!("py_ray_tracing_forward: enter");
217    let ray = ray.to_owned_array();
218    let texture = texture.to_owned_array();
219    let envmap = envmap.to_owned_array();
220    let sample_per_pixel:(usize, usize) = if let Ok(sample_per_pixel_forward) = sample_per_pixel.extract() {
221      (sample_per_pixel_forward, sample_per_pixel_forward)
222    } else if let Ok(sample_per_pixel) = sample_per_pixel.extract() {
223      sample_per_pixel
224    } else {
225      bail!("py_ray_tracing_forward: invalid type in sample_per_pixel")
226    };
227    let max_bounce:(usize, usize, usize, usize) = if let Ok(max_bounce_forward) = max_bounce.extract() {
228      (max_bounce_forward, max_bounce_forward, 0, 0)
229    } else if let Ok((max_bounce_forward, max_bounce_backward)) = max_bounce.extract() {
230      (max_bounce_forward, max_bounce_backward, 0, 0)
231    } else if let Ok((max_bounce_forward, max_bounce_backward, max_bounce_low_discrepancy)) = max_bounce.extract() {
232      (max_bounce_forward, max_bounce_backward, max_bounce_low_discrepancy, 0)
233    } else if let Ok(max_bounce) = max_bounce.extract() {
234      max_bounce
235    } else {
236      bail!("py_ray_tracing_forward: invalid type in max_bounce")
237    };
238    let clip_near:(f32, f32, f32) = if let Ok(clip_near) = clip_near.extract() {
239      (clip_near, clip_near, clip_near)
240    } else if let Ok(clip_near) = clip_near.extract() {
241      clip_near
242    } else {
243      bail!("py_ray_tracing_forward: invalid type in clip_near")
244    };
245    let low_discrepancy = low_discrepancy.unwrap_or(sample_per_pixel.0.try_into()?);
246    let response = py.allow_threads(|| {
247      let rt = RUNTIME.get().expect("Initialized in pymodule");
248      rt.block_on(self.0.ray_tracing_forward(
249        &geometry.geometry,
250        &geometry.data_cache,
251        ray,
252        texture,
253        envmap,
254        sample_per_pixel,
255        max_bounce,
256        switches,
257        clip_near,
258        camera_space,
259        requires_grad,
260        srand,
261        low_discrepancy
262      ))
263    })?;
264    Ok(response.into_pyarray(py))
265  }
266
267  /// Create a [`RequestRayTracingBackward`] task and wait for response.
268  ///
269  /// Must be called consecutive to a [`RequestRayTracingForward`] task with `requires_grad` set to `true`. \
270  /// To create multiple [`RequestRayTracingForward`] tasks and backward together, multiple client connections are
271  /// required.
272  ///
273  /// # Arguments
274  /// * `d_ray` - gradient of image 3 * `image_shape` (must ensure same `image_shape` as [`RequestRayTracingForward`])
275  ///
276  /// # Return
277  /// Return shape will be,
278  /// * if `camera_space` is `false` for [`RequestRayTracingForward`] task
279  ///   * 1st return value (3 + 3 + 3 + 1 + 3 + 1) * `texture_resolution` * `texture_resolution`
280  ///     * (same `texture_resolution` as [`RequestRayTracingForward`])
281  ///     * including d_normal + d_diffuse + d_specular + d_roughness + d_intensity + d_window
282  ///   * 2nd return value 3 * 6 * `envmap_resolution` * `envmap_resolution`
283  ///     * (same `envmap_resolution` as [`RequestRayTracingForward`])
284  ///     * including d_envmap
285  /// * if `camera_space` is `true` for [`RequestRayTracingForward`] task, add another
286  ///   * 3rd return value 14 * `image_shape` (same shape as [`RequestRayTracingForward`])
287  ///     * including d_ray_texture
288  #[pyo3(name = "ray_tracing_backward")]
289  #[allow(clippy::type_complexity)]
290  pub fn py_ray_tracing_backward<'py>(
291      &mut self,
292      py: Python<'py>,
293      d_ray: PyReadonlyArrayDyn<f32>
294    ) -> Result<(Bound<'py, PyArray3<f32>>, Bound<'py, PyArray4<f32>>, Option<Bound<'py, PyArrayDyn<f32>>>)> {
295    debug!("py_ray_tracing_backward: enter");
296    let d_ray = d_ray.to_owned_array();
297    let response = py.allow_threads(|| {
298      let rt = RUNTIME.get().expect("Initialized in pymodule");
299      rt.block_on(self.0.ray_tracing_backward(d_ray))
300    })?;
301    let d_texture = response.0.into_pyarray(py);
302    let d_envmap = response.1.into_pyarray(py);
303    let d_ray_texture = response.2.map(|d_ray_texture| d_ray_texture.into_pyarray(py));
304    Ok((d_texture, d_envmap, d_ray_texture))
305  }
306}
307
308impl Drop for PyLibreDR {
309  fn drop(&mut self) {
310    let rt = RUNTIME.get().expect("Initialized in pymodule");
311    rt.block_on(async {
312      if let Err(err) = self.0.close().await {
313        error!("PyLibreDR::drop: {}", err);
314      }
315    })
316  }
317}
318
319fn init_static() -> Result<()> {
320  let log_level = env::var("LIBREDR_LOG_LEVEL").unwrap_or(String::from("info"));
321  let worker_threads = env::var("LIBREDR_WORKER_THREADS").unwrap_or(String::from("1")).parse()?;
322  RUNTIME.set(tokio::runtime::Builder::new_multi_thread().worker_threads(worker_threads).enable_all().build()?).ok();
323  let fmt_layer = fmt::layer()
324    .with_target(false);
325  let filter_layer = EnvFilter::try_new(log_level)?;
326  tracing_subscriber::registry()
327    .with(filter_layer)
328    .with(fmt_layer)
329    .init();
330  Ok(())
331}
332
333/// Initialize Python module.
334///
335/// Accept `LIBREDR_LOG_LEVEL` environment variable to set `log_level`.
336/// * (default: info, feasible: debug, info, warn, error)
337///
338/// Accept `LIBREDR_WORKER_THREADS` environment variable to set `worker_threads`
339/// * (default: 1)
340#[pymodule]
341#[pyo3(name = "libredr")]
342pub fn py_libredr<'py>(py: Python<'py>, module: &Bound<'py, PyModule>) -> PyResult<()> {
343  init_static()?;
344  module.add("__author__", "Bohan Yu <ybh1998@protonmail.com>")?;
345  module.add("__version__", format!("LibreDR {}", common::CLAP_LONG_VERSION))?;
346  module.add_class::<PyLibreDR>()?;
347  module.add_class::<PyGeometry>()?;
348  py_camera(py, module)?;
349  py_light_source(py, module)?;
350  Ok(())
351}