Commit c6b46562 authored by Baptiste Esteban's avatar Baptiste Esteban
Browse files

Better error message, C contiguous check and remove error flag + tests

parent 72b069e5
Pipeline #26739 passed with stage
in 16 minutes and 5 seconds
find_package(fmt 6.0 REQUIRED)
add_library(Pylene-numpy) add_library(Pylene-numpy)
add_library(Pylene::Pylene-numpy ALIAS Pylene-numpy) add_library(Pylene::Pylene-numpy ALIAS Pylene-numpy)
target_include_directories(Pylene-numpy PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include> target_include_directories(Pylene-numpy PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
...@@ -10,6 +12,7 @@ target_sources(Pylene-numpy PRIVATE src/core/image_cast.cpp ...@@ -10,6 +12,7 @@ target_sources(Pylene-numpy PRIVATE src/core/image_cast.cpp
src/core/numpy_format.cpp) src/core/numpy_format.cpp)
# REPLACE PYTHON_LIBRARIES BY pybind11::pybind11 WHEN cmake_find_package FOR PYBIND11 WILL BE FIXED # REPLACE PYTHON_LIBRARIES BY pybind11::pybind11 WHEN cmake_find_package FOR PYBIND11 WILL BE FIXED
target_link_libraries(Pylene-numpy PUBLIC Pylene ${PYTHON_LIBRARIES}) target_link_libraries(Pylene-numpy PUBLIC Pylene ${PYTHON_LIBRARIES})
target_link_libraries(Pylene-numpy PRIVATE fmt::fmt)
pybind11_add_module(pylena) pybind11_add_module(pylena)
target_link_libraries(pylena PRIVATE Pylene Pylene-numpy ${PYTHON_LIBRARIES}) target_link_libraries(pylena PRIVATE Pylene Pylene-numpy ${PYTHON_LIBRARIES})
......
...@@ -4,29 +4,34 @@ ...@@ -4,29 +4,34 @@
#include <pybind11/cast.h> #include <pybind11/cast.h>
#include <fmt/format.h>
#include <vector> #include <vector>
namespace pln namespace pln
{ {
mln::ndbuffer_image from_numpy(pybind11::array arr) mln::ndbuffer_image from_numpy(pybind11::array arr)
{ {
if (!pybind11::detail::check_flags(arr.ptr(), pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_))
throw std::invalid_argument("Array should be C contiguous");
auto base = arr.base(); auto base = arr.base();
const auto info = arr.request(); const auto info = arr.request();
mln::sample_type_id type = get_sample_type(info.format); mln::sample_type_id type = get_sample_type(info.format);
if (type == mln::sample_type_id::OTHER) if (type == mln::sample_type_id::OTHER)
throw std::invalid_argument("Invalid dtype argument"); throw std::invalid_argument(fmt::format(
"Invalid dtype argument (Got dtype format {} expected types: [u]int[8, 16, 32, 64], float, double or bool)",
info.format));
const bool is_rgb8 = info.ndim == 3 && info.shape[2] == 3 && type == mln::sample_type_id::UINT8; const bool is_rgb8 = info.ndim == 3 && info.shape[2] == 3 && type == mln::sample_type_id::UINT8;
const auto pdim = info.ndim - (is_rgb8 ? 1 : 0); const auto pdim = info.ndim - (is_rgb8 ? 1 : 0);
if (pdim > mln::PYLENE_NDBUFFER_DEFAULT_DIM) if (pdim > mln::PYLENE_NDBUFFER_DEFAULT_DIM)
throw std::invalid_argument("Invalid number of dimension (should be < 5)"); throw std::invalid_argument(
fmt::format("Invalid number of dimension from numpy array (Got {} but should be < 5)", pdim));
int size[mln::PYLENE_NDBUFFER_DEFAULT_DIM] = {0}; int size[mln::PYLENE_NDBUFFER_DEFAULT_DIM] = {0};
std::ptrdiff_t strides[mln::PYLENE_NDBUFFER_DEFAULT_DIM] = {0}; std::ptrdiff_t strides[mln::PYLENE_NDBUFFER_DEFAULT_DIM] = {0};
for (auto d = 0; d < pdim; d++) for (auto d = 0; d < pdim; d++)
{ {
size[d] = info.shape[pdim - d - 1]; size[d] = info.shape[pdim - d - 1];
strides[d] = info.strides[pdim - d - 1]; strides[d] = info.strides[pdim - d - 1];
if (d > 0 && strides[d] < strides[d - 1])
throw std::invalid_argument("Array should be C contiguous");
} }
const auto sample_type = is_rgb8 ? mln::sample_type_id::RGB8 : type; const auto sample_type = is_rgb8 ? mln::sample_type_id::RGB8 : type;
auto res = auto res =
...@@ -38,14 +43,11 @@ namespace pln ...@@ -38,14 +43,11 @@ namespace pln
pybind11::object to_numpy(mln::ndbuffer_image img) pybind11::object to_numpy(mln::ndbuffer_image img)
{ {
auto& api = pybind11::detail::npy_api::get(); const auto& api = pybind11::detail::npy_api::get();
pybind11::object data = pybind11::none(); pybind11::object data = pybind11::none();
int flags = pybind11::detail::npy_api::NPY_ARRAY_OWNDATA_ | pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_; int flags = pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_;
if (img.__data()) if (img.__data())
{
data = pybind11::reinterpret_borrow<pybind11::object>(pybind11::cast(img.__data())); data = pybind11::reinterpret_borrow<pybind11::object>(pybind11::cast(img.__data()));
flags |= pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_;
}
/* For the moment, restrict RGB8 image to 2D image */ /* For the moment, restrict RGB8 image to 2D image */
const bool is_rgb8 = img.pdim() == 2 && img.sample_type() == mln::sample_type_id::RGB8; const bool is_rgb8 = img.pdim() == 2 && img.sample_type() == mln::sample_type_id::RGB8;
...@@ -73,8 +75,7 @@ namespace pln ...@@ -73,8 +75,7 @@ namespace pln
void init_pylena_numpy(pybind11::module& m) void init_pylena_numpy(pybind11::module& m)
{ {
if (!pybind11::detail::get_global_type_info(typeid(mln::internal::ndbuffer_image_data)) && if (!pybind11::detail::get_global_type_info(typeid(mln::internal::ndbuffer_image_data)))
!pybind11::detail::get_global_type_info(typeid(std::shared_ptr<mln::internal::ndbuffer_image_data>)))
{ {
pybind11::class_<mln::internal::ndbuffer_image_data, std::shared_ptr<mln::internal::ndbuffer_image_data>>( pybind11::class_<mln::internal::ndbuffer_image_data, std::shared_ptr<mln::internal::ndbuffer_image_data>>(
m, "ndbuffer_image_data"); m, "ndbuffer_image_data");
......
...@@ -72,19 +72,15 @@ class TestNumpyImage(unittest.TestCase): ...@@ -72,19 +72,15 @@ class TestNumpyImage(unittest.TestCase):
self.assertTrue(np.all(res2 == expected2)) self.assertTrue(np.all(res2 == expected2))
def test_incorrect_type(self): def test_incorrect_type(self):
ERROR_MSG = "Invalid dtype argument"
img = np.zeros((10, 10), dtype=str) img = np.zeros((10, 10), dtype=str)
with self.assertRaises(ValueError) as context: with self.assertRaises(ValueError, msg="Invalid dtype argument (Got dtype format 1w, expected types: [u]int[8, 16, 32, 64], float, double or bool)"):
pln.id(img) pln.id(img)
self.assertTrue(ERROR_MSG in str(context.exception))
class WrongType: class WrongType:
pass pass
img = np.zeros((10, 10), dtype=WrongType) img = np.zeros((10, 10), dtype=WrongType)
with self.assertRaises(ValueError) as context: with self.assertRaises(ValueError, msg="Invalid dtype argument (Got dtype format 0, expected types: [u]int[8, 16, 32, 64], float, double or bool)"):
pln.id(img) pln.id(img)
self.assertTrue(ERROR_MSG in str(context.exception))
def test_memory(self): def test_memory(self):
import gc import gc
...@@ -130,6 +126,11 @@ class TestNumpyImage(unittest.TestCase): ...@@ -130,6 +126,11 @@ class TestNumpyImage(unittest.TestCase):
del img2 del img2
self.assertTrue(sys.getrefcount(base) == 2) self.assertTrue(sys.getrefcount(base) == 2)
def test_invalid_dim(self):
img = np.zeros(np.arange(5))
with self.assertRaises(ValueError, msg="Invalid number of dimension from numpy array (Got 5 but should be < 5)"):
pln.id(img)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment