Commit c6b46562 authored by Baptiste Esteban's avatar Baptiste Esteban
Browse files

Better error message, C contiguous check and remove error flag + tests

parent 72b069e5
Pipeline #26739 passed with stage
in 16 minutes and 5 seconds
find_package(fmt 6.0 REQUIRED)
add_library(Pylene-numpy)
add_library(Pylene::Pylene-numpy ALIAS Pylene-numpy)
target_include_directories(Pylene-numpy PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
......@@ -10,6 +12,7 @@ target_sources(Pylene-numpy PRIVATE src/core/image_cast.cpp
src/core/numpy_format.cpp)
# REPLACE PYTHON_LIBRARIES BY pybind11::pybind11 WHEN cmake_find_package FOR PYBIND11 WILL BE FIXED
target_link_libraries(Pylene-numpy PUBLIC Pylene ${PYTHON_LIBRARIES})
target_link_libraries(Pylene-numpy PRIVATE fmt::fmt)
pybind11_add_module(pylena)
target_link_libraries(pylena PRIVATE Pylene Pylene-numpy ${PYTHON_LIBRARIES})
......
......@@ -4,29 +4,34 @@
#include <pybind11/cast.h>
#include <fmt/format.h>
#include <vector>
namespace pln
{
mln::ndbuffer_image from_numpy(pybind11::array arr)
{
if (!pybind11::detail::check_flags(arr.ptr(), pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_))
throw std::invalid_argument("Array should be C contiguous");
auto base = arr.base();
const auto info = arr.request();
mln::sample_type_id type = get_sample_type(info.format);
if (type == mln::sample_type_id::OTHER)
throw std::invalid_argument("Invalid dtype argument");
throw std::invalid_argument(fmt::format(
"Invalid dtype argument (Got dtype format {} expected types: [u]int[8, 16, 32, 64], float, double or bool)",
info.format));
const bool is_rgb8 = info.ndim == 3 && info.shape[2] == 3 && type == mln::sample_type_id::UINT8;
const auto pdim = info.ndim - (is_rgb8 ? 1 : 0);
if (pdim > mln::PYLENE_NDBUFFER_DEFAULT_DIM)
throw std::invalid_argument("Invalid number of dimension (should be < 5)");
throw std::invalid_argument(
fmt::format("Invalid number of dimension from numpy array (Got {} but should be < 5)", pdim));
int size[mln::PYLENE_NDBUFFER_DEFAULT_DIM] = {0};
std::ptrdiff_t strides[mln::PYLENE_NDBUFFER_DEFAULT_DIM] = {0};
for (auto d = 0; d < pdim; d++)
{
size[d] = info.shape[pdim - d - 1];
strides[d] = info.strides[pdim - d - 1];
if (d > 0 && strides[d] < strides[d - 1])
throw std::invalid_argument("Array should be C contiguous");
}
const auto sample_type = is_rgb8 ? mln::sample_type_id::RGB8 : type;
auto res =
......@@ -38,14 +43,11 @@ namespace pln
pybind11::object to_numpy(mln::ndbuffer_image img)
{
auto& api = pybind11::detail::npy_api::get();
const auto& api = pybind11::detail::npy_api::get();
pybind11::object data = pybind11::none();
int flags = pybind11::detail::npy_api::NPY_ARRAY_OWNDATA_ | pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_;
int flags = pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_;
if (img.__data())
{
data = pybind11::reinterpret_borrow<pybind11::object>(pybind11::cast(img.__data()));
flags |= pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_;
}
/* For the moment, restrict RGB8 image to 2D image */
const bool is_rgb8 = img.pdim() == 2 && img.sample_type() == mln::sample_type_id::RGB8;
......@@ -73,8 +75,7 @@ namespace pln
void init_pylena_numpy(pybind11::module& m)
{
if (!pybind11::detail::get_global_type_info(typeid(mln::internal::ndbuffer_image_data)) &&
!pybind11::detail::get_global_type_info(typeid(std::shared_ptr<mln::internal::ndbuffer_image_data>)))
if (!pybind11::detail::get_global_type_info(typeid(mln::internal::ndbuffer_image_data)))
{
pybind11::class_<mln::internal::ndbuffer_image_data, std::shared_ptr<mln::internal::ndbuffer_image_data>>(
m, "ndbuffer_image_data");
......
......@@ -72,19 +72,15 @@ class TestNumpyImage(unittest.TestCase):
self.assertTrue(np.all(res2 == expected2))
def test_incorrect_type(self):
ERROR_MSG = "Invalid dtype argument"
img = np.zeros((10, 10), dtype=str)
with self.assertRaises(ValueError) as context:
with self.assertRaises(ValueError, msg="Invalid dtype argument (Got dtype format 1w, expected types: [u]int[8, 16, 32, 64], float, double or bool)"):
pln.id(img)
self.assertTrue(ERROR_MSG in str(context.exception))
class WrongType:
pass
img = np.zeros((10, 10), dtype=WrongType)
with self.assertRaises(ValueError) as context:
with self.assertRaises(ValueError, msg="Invalid dtype argument (Got dtype format 0, expected types: [u]int[8, 16, 32, 64], float, double or bool)"):
pln.id(img)
self.assertTrue(ERROR_MSG in str(context.exception))
def test_memory(self):
import gc
......@@ -130,6 +126,11 @@ class TestNumpyImage(unittest.TestCase):
del img2
self.assertTrue(sys.getrefcount(base) == 2)
def test_invalid_dim(self):
img = np.zeros(np.arange(5))
with self.assertRaises(ValueError, msg="Invalid number of dimension from numpy array (Got 5 but should be < 5)"):
pln.id(img)
if __name__ == "__main__":
unittest.main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment