numpy/
strings.rs

1//! Types to support arrays of [ASCII][ascii] and [UCS4][ucs4] strings
2//!
3//! [ascii]: https://numpy.org/doc/stable/reference/c-api/dtype.html#c.NPY_STRING
4//! [ucs4]: https://numpy.org/doc/stable/reference/c-api/dtype.html#c.NPY_UNICODE
5
6use std::collections::hash_map::Entry;
7use std::fmt;
8use std::mem::size_of;
9use std::os::raw::c_char;
10use std::str;
11use std::sync::Mutex;
12
13use pyo3::sync::MutexExt;
14use pyo3::{
15    ffi::{Py_UCS1, Py_UCS4},
16    Bound, Py, Python,
17};
18use rustc_hash::FxHashMap;
19
20use crate::dtype::{clone_methods_impl, Element, PyArrayDescr, PyArrayDescrMethods};
21use crate::npyffi::PyDataType_SET_ELSIZE;
22use crate::npyffi::NPY_TYPES;
23
24/// A newtype wrapper around [`[u8; N]`][Py_UCS1] to handle [`byte` scalars][numpy-bytes] while satisfying coherence.
25///
26/// Note that when creating arrays of ASCII strings without an explicit `dtype`,
27/// NumPy will automatically determine the smallest possible array length at runtime.
28///
29/// For example,
30///
31/// ```python
32/// array = numpy.array([b"foo", b"bar", b"foobar"])
33/// ```
34///
35/// yields `S6` for `array.dtype`.
36///
37/// On the Rust side however, the length `N` of `PyFixedString<N>` must always be given
38/// explicitly and as a compile-time constant. For this work reliably, the Python code
39/// should set the `dtype` explicitly, e.g.
40///
41/// ```python
42/// numpy.array([b"foo", b"bar", b"foobar"], dtype='S12')
43/// ```
44///
45/// always matching `PyArray1<PyFixedString<12>>`.
46///
47/// # Example
48///
49/// ```rust
50/// # use pyo3::Python;
51/// use numpy::{PyArray1, PyUntypedArrayMethods, PyFixedString};
52///
53/// # Python::with_gil(|py| {
54/// let array = PyArray1::<PyFixedString<3>>::from_vec(py, vec![[b'f', b'o', b'o'].into()]);
55///
56/// assert!(array.dtype().to_string().contains("S3"));
57/// # });
58/// ```
59///
60/// [numpy-bytes]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.bytes_
61#[repr(transparent)]
62#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
63pub struct PyFixedString<const N: usize>(pub [Py_UCS1; N]);
64
65impl<const N: usize> fmt::Display for PyFixedString<N> {
66    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
67        fmt.write_str(str::from_utf8(&self.0).unwrap().trim_end_matches('\0'))
68    }
69}
70
71impl<const N: usize> From<[Py_UCS1; N]> for PyFixedString<N> {
72    fn from(val: [Py_UCS1; N]) -> Self {
73        Self(val)
74    }
75}
76
77unsafe impl<const N: usize> Element for PyFixedString<N> {
78    const IS_COPY: bool = true;
79
80    fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
81        static DTYPES: TypeDescriptors = TypeDescriptors::new();
82
83        unsafe { DTYPES.from_size(py, NPY_TYPES::NPY_STRING, b'|' as _, size_of::<Self>()) }
84    }
85
86    clone_methods_impl!(Self);
87}
88
89/// A newtype wrapper around [`[PyUCS4; N]`][Py_UCS4] to handle [`str_` scalars][numpy-str] while satisfying coherence.
90///
91/// Note that when creating arrays of Unicode strings without an explicit `dtype`,
92/// NumPy will automatically determine the smallest possible array length at runtime.
93///
94/// For example,
95///
96/// ```python
97/// numpy.array(["foo🐍", "bar🦀", "foobar"])
98/// ```
99///
100/// yields `U6` for `array.dtype`.
101///
102/// On the Rust side however, the length `N` of `PyFixedUnicode<N>` must always be given
103/// explicitly and as a compile-time constant. For this work reliably, the Python code
104/// should set the `dtype` explicitly, e.g.
105///
106/// ```python
107/// numpy.array(["foo🐍", "bar🦀", "foobar"], dtype='U12')
108/// ```
109///
110/// always matching `PyArray1<PyFixedUnicode<12>>`.
111///
112/// # Example
113///
114/// ```rust
115/// # use pyo3::Python;
116/// use numpy::{PyArray1, PyUntypedArrayMethods, PyFixedUnicode};
117///
118/// # Python::with_gil(|py| {
119/// let array = PyArray1::<PyFixedUnicode<3>>::from_vec(py, vec![[b'b' as _, b'a' as _, b'r' as _].into()]);
120///
121/// assert!(array.dtype().to_string().contains("U3"));
122/// # });
123/// ```
124///
125/// [numpy-str]: https://numpy.org/doc/stable/reference/arrays.scalars.html#numpy.str_
126#[repr(transparent)]
127#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
128pub struct PyFixedUnicode<const N: usize>(pub [Py_UCS4; N]);
129
130impl<const N: usize> fmt::Display for PyFixedUnicode<N> {
131    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
132        for character in self.0 {
133            if character == 0 {
134                break;
135            }
136
137            write!(fmt, "{}", char::from_u32(character).unwrap())?;
138        }
139
140        Ok(())
141    }
142}
143
144impl<const N: usize> From<[Py_UCS4; N]> for PyFixedUnicode<N> {
145    fn from(val: [Py_UCS4; N]) -> Self {
146        Self(val)
147    }
148}
149
150unsafe impl<const N: usize> Element for PyFixedUnicode<N> {
151    const IS_COPY: bool = true;
152
153    fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
154        static DTYPES: TypeDescriptors = TypeDescriptors::new();
155
156        unsafe { DTYPES.from_size(py, NPY_TYPES::NPY_UNICODE, b'=' as _, size_of::<Self>()) }
157    }
158
159    clone_methods_impl!(Self);
160}
161
162struct TypeDescriptors {
163    dtypes: Mutex<Option<FxHashMap<usize, Py<PyArrayDescr>>>>,
164}
165
166impl TypeDescriptors {
167    const fn new() -> Self {
168        Self {
169            dtypes: Mutex::new(None),
170        }
171    }
172
173    /// `npy_type` must be either `NPY_STRING` or `NPY_UNICODE` with matching `byteorder` and `size`
174    #[allow(clippy::wrong_self_convention)]
175    unsafe fn from_size<'py>(
176        &self,
177        py: Python<'py>,
178        npy_type: NPY_TYPES,
179        byteorder: c_char,
180        size: usize,
181    ) -> Bound<'py, PyArrayDescr> {
182        let mut dtypes = self
183            .dtypes
184            .lock_py_attached(py)
185            .expect("dtype cache poisoned");
186
187        let dtype = match dtypes.get_or_insert_with(Default::default).entry(size) {
188            Entry::Occupied(entry) => entry.into_mut(),
189            Entry::Vacant(entry) => {
190                let dtype = PyArrayDescr::new_from_npy_type(py, npy_type);
191
192                let descr = &mut *dtype.as_dtype_ptr();
193                PyDataType_SET_ELSIZE(py, descr, size.try_into().unwrap());
194                descr.byteorder = byteorder;
195
196                entry.insert(dtype.into())
197            }
198        };
199
200        dtype.bind(py).to_owned()
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn format_fixed_string() {
210        assert_eq!(
211            PyFixedString([b'f', b'o', b'o', 0, 0, 0]).to_string(),
212            "foo"
213        );
214        assert_eq!(
215            PyFixedString([b'f', b'o', b'o', b'b', b'a', b'r']).to_string(),
216            "foobar"
217        );
218    }
219
220    #[test]
221    fn format_fixed_unicode() {
222        assert_eq!(
223            PyFixedUnicode([b'f' as _, b'o' as _, b'o' as _, 0, 0, 0]).to_string(),
224            "foo"
225        );
226        assert_eq!(
227            PyFixedUnicode([0x1F980, 0x1F40D, 0, 0, 0, 0]).to_string(),
228            "🦀🐍"
229        );
230        assert_eq!(
231            PyFixedUnicode([b'f' as _, b'o' as _, b'o' as _, b'b' as _, b'a' as _, b'r' as _])
232                .to_string(),
233            "foobar"
234        );
235    }
236}