Commit c8ff69f1 authored by Quentin Kaci's avatar Quentin Kaci

Factorize some parts with accumulators

parent d0664f29
Pipeline #27405 failed with stages
in 8 minutes and 36 seconds
#pragma once
#include "mln/hierarchies/accumulators/hierarchy_accumulator_base.hpp"
namespace mln
{
class MaxAccumulator : public HierarchyAccumulatorBase<int>
{
public:
explicit MaxAccumulator()
// Neutral element
: acc_(std::numeric_limits<int>::min())
{
}
~MaxAccumulator() override = default;
inline void init(int n) override { acc_ = n; }
inline void invalidate() override { acc_ = -1; }
inline void merge(MaxAccumulator& other) { acc_ = std::max(acc_, other.get_value()); }
inline int get_value() const override { return acc_; }
private:
using HierarchyAccumulatorBase<int>::merge;
int acc_;
};
} // namespace mln
\ No newline at end of file
#pragma once
#include "mln/hierarchies/accumulators/hierarchy_accumulator_base.hpp"
namespace mln
{
class MinAccumulator : public HierarchyAccumulatorBase<int>
{
public:
explicit MinAccumulator()
// Neutral element
: acc_(std::numeric_limits<int>::max())
{
}
~MinAccumulator() override = default;
inline void init(int n) override { acc_ = n; }
inline void invalidate() override { acc_ = -1; }
inline void merge(MinAccumulator& other) { acc_ = std::min(acc_, other.get_value()); }
inline int get_value() const override { return acc_; }
private:
using HierarchyAccumulatorBase<int>::merge;
int acc_;
};
} // namespace mln
\ No newline at end of file
......@@ -4,29 +4,30 @@
namespace mln
{
class SumAccumulator : public HierarchyAccumulatorBase<int>
template <typename T>
class SumAccumulator : public HierarchyAccumulatorBase<T>
{
public:
explicit SumAccumulator()
// Neutral element
: acc_(0)
: acc_()
{
}
~SumAccumulator() override = default;
inline void init(int n) override { acc_ = n; }
inline void init(T n) override { acc_ = n; }
inline void invalidate() override { acc_ = -1; }
inline void merge(SumAccumulator& other) { acc_ += other.get_value(); }
inline int get_value() const override { return acc_; }
inline T get_value() const override { return acc_; }
private:
using HierarchyAccumulatorBase<int>::merge;
using HierarchyAccumulatorBase<T>::merge;
int acc_;
T acc_;
};
} // namespace mln
\ No newline at end of file
......@@ -49,7 +49,7 @@ namespace mln
std::function<int(int)> diff_altitude_;
SumAccumulator sum_;
SumAccumulator<int> sum_;
};
} // namespace mln
\ No newline at end of file
#pragma once
#include "mln/hierarchies/accumulators/hierarchy_accumulator_base.hpp"
#include "mln/hierarchies/graph.hpp"
#include "mln/hierarchies/hierarchy_tree.hpp"
#include <numeric>
#include <vector>
namespace mln
......@@ -18,12 +20,70 @@ namespace mln
bool exclude_leaves = false;
};
std::vector<int> hierarchy_traversal(const HierarchyTree& tree, const HierarchyTraversal& traversal);
inline std::vector<int> hierarchy_traversal(const HierarchyTree& tree, const HierarchyTraversal& traversal)
{
// Exclude root
std::vector<int> res((traversal.exclude_leaves ? tree.get_nb_vertices() / 2 : tree.get_nb_vertices()) - 1);
std::iota(res.begin(), res.end(), traversal.exclude_leaves ? tree.leaf_graph.get_nb_vertices() : 0);
if (traversal.order == HierarchyTraversal::ROOT_TO_LEAVES)
std::reverse(res.begin(), res.end());
return res;
}
template <typename T, Accumulator<T> AccumulatorType>
std::vector<T> compute_attribute_from_accumulator(const HierarchyTree& tree, const AccumulatorType& acc,
const HierarchyTraversal& traversal,
std::vector<T> init_list = std::vector<T>());
inline std::vector<T> compute_attribute_from_accumulator(const HierarchyTree& tree, const AccumulatorType& acc,
const HierarchyTraversal& traversal,
std::vector<T> init_list = std::vector<T>())
{
int tree_nb_vertices = tree.get_nb_vertices();
std::vector<AccumulatorType> accumulators(tree_nb_vertices, acc);
for (int i = 0; i < tree_nb_vertices; ++i)
accumulators[i].set_associated_node(i);
for (size_t i = 0; i < init_list.size(); ++i)
accumulators[i].init(init_list[i]);
std::vector<int> traversal_vector = hierarchy_traversal(tree, traversal);
if (traversal.order == HierarchyTraversal::LEAVES_TO_ROOT)
{
for (const auto& node : traversal_vector)
{
int parent_node = tree.get_parent(node);
if (parent_node == -1)
{
accumulators[node].invalidate();
continue;
}
accumulators[parent_node].merge(accumulators[node]);
}
}
else if (traversal.order == HierarchyTraversal::ROOT_TO_LEAVES)
{
for (const auto& node : traversal_vector)
{
int parent_node = tree.get_parent(node);
if (parent_node == -1)
{
accumulators[node].invalidate();
continue;
}
accumulators[node].merge(accumulators[parent_node]);
}
}
std::vector<T> attribute(tree_nb_vertices);
for (int i = 0; i < tree_nb_vertices; ++i)
attribute[i] = accumulators[i].get_value();
return attribute;
}
std::vector<int> depth_attribute(const HierarchyTree& tree);
......
#include "mln/hierarchies/attributes.hpp"
#include "mln/hierarchies/graph.hpp"
#include "mln/hierarchies/accumulators/deepest_altitude_accumulator.hpp"
#include "mln/hierarchies/accumulators/depth_accumulator.hpp"
......@@ -8,75 +7,8 @@
#include "mln/hierarchies/accumulators/sum_accumulator.hpp"
#include "mln/hierarchies/accumulators/volume_accumulator.hpp"
#include <numeric>
namespace mln
{
std::vector<int> hierarchy_traversal(const HierarchyTree& tree, const HierarchyTraversal& traversal)
{
// Exclude root
std::vector<int> res((traversal.exclude_leaves ? tree.get_nb_vertices() / 2 : tree.get_nb_vertices()) - 1);
std::iota(res.begin(), res.end(), traversal.exclude_leaves ? tree.leaf_graph.get_nb_vertices() : 0);
if (traversal.order == HierarchyTraversal::ROOT_TO_LEAVES)
std::reverse(res.begin(), res.end());
return res;
}
template <typename T, Accumulator<T> AccumulatorType>
std::vector<T> compute_attribute_from_accumulator(const HierarchyTree& tree, const AccumulatorType& acc,
const HierarchyTraversal& traversal, std::vector<T> init_list)
{
int tree_nb_vertices = tree.get_nb_vertices();
std::vector<AccumulatorType> accumulators(tree_nb_vertices, acc);
for (int i = 0; i < tree_nb_vertices; ++i)
accumulators[i].set_associated_node(i);
for (size_t i = 0; i < init_list.size(); ++i)
accumulators[i].init(init_list[i]);
std::vector<int> traversal_vector = hierarchy_traversal(tree, traversal);
if (traversal.order == HierarchyTraversal::LEAVES_TO_ROOT)
{
for (const auto& node : traversal_vector)
{
int parent_node = tree.get_parent(node);
if (parent_node == -1)
{
accumulators[node].invalidate();
continue;
}
accumulators[parent_node].merge(accumulators[node]);
}
}
else if (traversal.order == HierarchyTraversal::ROOT_TO_LEAVES)
{
for (const auto& node : traversal_vector)
{
int parent_node = tree.get_parent(node);
if (parent_node == -1)
{
accumulators[node].invalidate();
continue;
}
accumulators[node].merge(accumulators[parent_node]);
}
}
std::vector<T> attribute(tree_nb_vertices);
for (int i = 0; i < tree_nb_vertices; ++i)
attribute[i] = accumulators[i].get_value();
return attribute;
}
std::vector<int> depth_attribute(const HierarchyTree& tree)
{
return compute_attribute_from_accumulator<int>(tree, DepthAccumulator(),
......@@ -86,7 +18,7 @@ namespace mln
std::vector<int> area_attribute(const HierarchyTree& tree)
{
std::vector<int> one_leaves(tree.leaf_graph.get_nb_vertices(), 1);
return compute_attribute_from_accumulator<int>(tree, SumAccumulator(), HierarchyTraversal{}, one_leaves);
return compute_attribute_from_accumulator<int>(tree, SumAccumulator<int>(), HierarchyTraversal{}, one_leaves);
}
std::vector<int> volume_attribute(const HierarchyTree& tree)
......
#include "mln/hierarchies/segmentation.hpp"
#include "mln/hierarchies/accumulators/max_accumulator.hpp"
#include "mln/hierarchies/accumulators/sum_accumulator.hpp"
#include "mln/hierarchies/attributes.hpp"
namespace mln
{
static std::vector<int> get_qbt_computed_attribute(const Graph& leaf_graph, const QBT& qbt,
......@@ -11,22 +12,18 @@ namespace mln
int qbt_nb_vertices = qbt.get_nb_vertices();
int qbt_root = qbt_nb_vertices - 1;
std::vector<int> qbt_computed_attribute(qbt_nb_vertices, std::numeric_limits<int>::min());
std::fill_n(qbt_computed_attribute.begin(), leaf_graph.get_nb_vertices(), 0);
for (int node = 0; node < qbt_root; ++node)
std::vector<int> different_i_node_altitude(qbt_nb_vertices, std::numeric_limits<int>::min());
std::fill_n(different_i_node_altitude.begin(), leaf_graph.get_nb_vertices(), 0);
for (int i_node = leaf_graph.get_nb_vertices(); i_node < qbt_nb_vertices - 1; ++i_node)
{
int parent_node = qbt.get_parent(node);
if (node >= leaf_graph.get_nb_vertices() && leaf_graph.weight_node(parent_node) != leaf_graph.weight_node(node))
qbt_computed_attribute[node] = attribute[node];
qbt_computed_attribute[parent_node] = std::max(qbt_computed_attribute[parent_node], qbt_computed_attribute[node]);
if (leaf_graph.weight_node(qbt.get_parent(i_node)) != leaf_graph.weight_node(i_node))
different_i_node_altitude[i_node] = attribute[i_node];
}
qbt_computed_attribute[qbt_root] = attribute[qbt_root];
different_i_node_altitude[qbt_root] = attribute[qbt_root];
return qbt_computed_attribute;
return compute_attribute_from_accumulator<int>(qbt, MaxAccumulator(), HierarchyTraversal{},
different_i_node_altitude);
}
Graph watershed_graph(Graph& graph, const std::function<std::vector<int>(const HierarchyTree&)>& attribute_func)
......@@ -94,19 +91,12 @@ namespace mln
int width = image.width();
std::vector<rgb<int>> mean_color(tree_nb_vertices, rgb<int>{0, 0, 0});
std::vector<rgb<int>> colors(tree_nb_vertices, rgb<int>{0, 0, 0});
for (int leaf = 0; leaf < leaf_graph.get_nb_vertices(); ++leaf)
mean_color[leaf] = image({leaf % width, leaf / width});
colors[leaf] = image({leaf % width, leaf / width});
for (int node = 0; node < tree_nb_vertices - 1; ++node)
{
int parent_node = tree.get_parent(node);
if (parent_node == -1)
continue;
mean_color[parent_node] += mean_color[node];
}
std::vector<rgb<int>> mean_color = compute_attribute_from_accumulator<rgb<int>>(tree, SumAccumulator<rgb<int>>(), HierarchyTraversal{}, colors);
std::vector<int> area = area_attribute(tree);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment