1use std::collections::hash_map::Entry;
58use std::fmt;
59use std::hash::Hash;
60use std::marker::PhantomData;
61use std::sync::Mutex;
62
63use pyo3::sync::MutexExt;
64use pyo3::{Bound, Py, Python};
65use rustc_hash::FxHashMap;
66
67use crate::dtype::{clone_methods_impl, Element, PyArrayDescr, PyArrayDescrMethods};
68use crate::npyffi::{
69 PyArray_DatetimeDTypeMetaData, PyDataType_C_METADATA, NPY_DATETIMEUNIT, NPY_TYPES,
70};
71
72pub trait Unit: Send + Sync + Clone + Copy + PartialEq + Eq + Hash + PartialOrd + Ord {
76 const UNIT: NPY_DATETIMEUNIT;
80
81 const ABBREV: &'static str;
83}
84
85macro_rules! define_units {
86 ($($(#[$meta:meta])* $struct:ident => $unit:ident $abbrev:literal,)+) => {
87 $(
88
89 $(#[$meta])*
90 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
91 pub struct $struct;
92
93 impl Unit for $struct {
94 const UNIT: NPY_DATETIMEUNIT = NPY_DATETIMEUNIT::$unit;
95
96 const ABBREV: &'static str = $abbrev;
97 }
98
99 )+
100 };
101}
102
103pub mod units {
105 use super::*;
106
107 define_units!(
108 #[doc = "Years, i.e. 12 months"]
109 Years => NPY_FR_Y "a",
110 #[doc = "Months, i.e. 30 days"]
111 Months => NPY_FR_M "mo",
112 #[doc = "Weeks, i.e. 7 days"]
113 Weeks => NPY_FR_W "w",
114 #[doc = "Days, i.e. 24 hours"]
115 Days => NPY_FR_D "d",
116 #[doc = "Hours, i.e. 60 minutes"]
117 Hours => NPY_FR_h "h",
118 #[doc = "Minutes, i.e. 60 seconds"]
119 Minutes => NPY_FR_m "min",
120 #[doc = "Seconds"]
121 Seconds => NPY_FR_s "s",
122 #[doc = "Milliseconds, i.e. 10^-3 seconds"]
123 Milliseconds => NPY_FR_ms "ms",
124 #[doc = "Microseconds, i.e. 10^-6 seconds"]
125 Microseconds => NPY_FR_us "µs",
126 #[doc = "Nanoseconds, i.e. 10^-9 seconds"]
127 Nanoseconds => NPY_FR_ns "ns",
128 #[doc = "Picoseconds, i.e. 10^-12 seconds"]
129 Picoseconds => NPY_FR_ps "ps",
130 #[doc = "Femtoseconds, i.e. 10^-15 seconds"]
131 Femtoseconds => NPY_FR_fs "fs",
132 #[doc = "Attoseconds, i.e. 10^-18 seconds"]
133 Attoseconds => NPY_FR_as "as",
134 );
135}
136
137#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
141#[repr(transparent)]
142pub struct Datetime<U: Unit>(i64, PhantomData<U>);
143
144impl<U: Unit> From<i64> for Datetime<U> {
145 fn from(val: i64) -> Self {
146 Self(val, PhantomData)
147 }
148}
149
150impl<U: Unit> From<Datetime<U>> for i64 {
151 fn from(val: Datetime<U>) -> Self {
152 val.0
153 }
154}
155
156unsafe impl<U: Unit> Element for Datetime<U> {
157 const IS_COPY: bool = true;
158
159 fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
160 static DTYPES: TypeDescriptors = unsafe { TypeDescriptors::new(NPY_TYPES::NPY_DATETIME) };
161
162 DTYPES.from_unit(py, U::UNIT)
163 }
164
165 clone_methods_impl!(Self);
166}
167
168impl<U: Unit> fmt::Debug for Datetime<U> {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 write!(f, "Datetime({} {})", self.0, U::ABBREV)
171 }
172}
173
174#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
178#[repr(transparent)]
179pub struct Timedelta<U: Unit>(i64, PhantomData<U>);
180
181impl<U: Unit> From<i64> for Timedelta<U> {
182 fn from(val: i64) -> Self {
183 Self(val, PhantomData)
184 }
185}
186
187impl<U: Unit> From<Timedelta<U>> for i64 {
188 fn from(val: Timedelta<U>) -> Self {
189 val.0
190 }
191}
192
193unsafe impl<U: Unit> Element for Timedelta<U> {
194 const IS_COPY: bool = true;
195
196 fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
197 static DTYPES: TypeDescriptors = unsafe { TypeDescriptors::new(NPY_TYPES::NPY_TIMEDELTA) };
198
199 DTYPES.from_unit(py, U::UNIT)
200 }
201
202 clone_methods_impl!(Self);
203}
204
205impl<U: Unit> fmt::Debug for Timedelta<U> {
206 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207 write!(f, "Timedelta({} {})", self.0, U::ABBREV)
208 }
209}
210
211struct TypeDescriptors {
212 npy_type: NPY_TYPES,
213 dtypes: Mutex<Option<FxHashMap<NPY_DATETIMEUNIT, Py<PyArrayDescr>>>>,
214}
215
216impl TypeDescriptors {
217 const unsafe fn new(npy_type: NPY_TYPES) -> Self {
219 Self {
220 npy_type,
221 dtypes: Mutex::new(None),
222 }
223 }
224
225 #[allow(clippy::wrong_self_convention)]
226 fn from_unit<'py>(&self, py: Python<'py>, unit: NPY_DATETIMEUNIT) -> Bound<'py, PyArrayDescr> {
227 let mut dtypes = self
228 .dtypes
229 .lock_py_attached(py)
230 .expect("dtype cache poisoned");
231
232 let dtype = match dtypes.get_or_insert_with(Default::default).entry(unit) {
233 Entry::Occupied(entry) => entry.into_mut(),
234 Entry::Vacant(entry) => {
235 let dtype = PyArrayDescr::new_from_npy_type(py, self.npy_type);
236
237 unsafe {
239 let metadata = &mut *(PyDataType_C_METADATA(py, dtype.as_dtype_ptr())
240 as *mut PyArray_DatetimeDTypeMetaData);
241
242 metadata.meta.base = unit;
243 metadata.meta.num = 1;
244 }
245
246 entry.insert(dtype.into())
247 }
248 };
249
250 dtype.bind(py).to_owned()
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257
258 use pyo3::{
259 ffi::c_str,
260 py_run,
261 types::{PyAnyMethods, PyDict, PyModule},
262 };
263
264 use crate::array::{PyArray1, PyArrayMethods};
265
266 #[test]
267 fn from_python_to_rust() {
268 Python::with_gil(|py| {
269 let locals = py
270 .eval(c_str!("{ 'np': __import__('numpy') }"), None, None)
271 .unwrap()
272 .downcast_into::<PyDict>()
273 .unwrap();
274
275 let array = py
276 .eval(
277 c_str!("np.array([np.datetime64('1970-01-01')])"),
278 None,
279 Some(&locals),
280 )
281 .unwrap()
282 .downcast_into::<PyArray1<Datetime<units::Days>>>()
283 .unwrap();
284
285 let value: i64 = array.get_owned(0).unwrap().into();
286 assert_eq!(value, 0);
287 });
288 }
289
290 #[test]
291 fn from_rust_to_python() {
292 Python::with_gil(|py| {
293 let array = PyArray1::<Timedelta<units::Minutes>>::zeros(py, 1, false);
294
295 *array.readwrite().get_mut(0).unwrap() = Timedelta::<units::Minutes>::from(5);
296
297 let np = py
298 .eval(c_str!("__import__('numpy')"), None, None)
299 .unwrap()
300 .downcast_into::<PyModule>()
301 .unwrap();
302
303 py_run!(py, array np, "assert array.dtype == np.dtype('timedelta64[m]')");
304 py_run!(py, array np, "assert array[0] == np.timedelta64(5, 'm')");
305 });
306 }
307
308 #[test]
309 fn debug_formatting() {
310 assert_eq!(
311 format!("{:?}", Datetime::<units::Days>::from(28)),
312 "Datetime(28 d)"
313 );
314
315 assert_eq!(
316 format!("{:?}", Timedelta::<units::Milliseconds>::from(160)),
317 "Timedelta(160 ms)"
318 );
319 }
320
321 #[test]
322 fn unit_conversion() {
323 #[track_caller]
324 fn convert<'py, S: Unit, D: Unit>(py: Python<'py>, expected_value: i64) {
325 let array = PyArray1::<Timedelta<S>>::from_slice(py, &[Timedelta::<S>::from(1)]);
326 let array = array.cast::<Timedelta<D>>(false).unwrap();
327
328 let value: i64 = array.get_owned(0).unwrap().into();
329 assert_eq!(value, expected_value);
330 }
331
332 Python::with_gil(|py| {
333 convert::<units::Years, units::Days>(py, (97 + 400 * 365) / 400);
334 convert::<units::Months, units::Days>(py, (97 + 400 * 365) / 400 / 12);
335
336 convert::<units::Weeks, units::Seconds>(py, 7 * 24 * 60 * 60);
337 convert::<units::Days, units::Seconds>(py, 24 * 60 * 60);
338 convert::<units::Hours, units::Seconds>(py, 60 * 60);
339 convert::<units::Minutes, units::Seconds>(py, 60);
340
341 convert::<units::Seconds, units::Milliseconds>(py, 1_000);
342 convert::<units::Seconds, units::Microseconds>(py, 1_000_000);
343 convert::<units::Seconds, units::Nanoseconds>(py, 1_000_000_000);
344 convert::<units::Seconds, units::Picoseconds>(py, 1_000_000_000_000);
345 convert::<units::Seconds, units::Femtoseconds>(py, 1_000_000_000_000_000);
346
347 convert::<units::Femtoseconds, units::Attoseconds>(py, 1_000);
348 });
349 }
350}