From c8b526c3da1e375155a1a0294280b6be1a5ddc97 Mon Sep 17 00:00:00 2001 From: Sebastian Gomez-Gonzalez <sgomez@tue.mpg.de> Date: Fri, 22 Sep 2017 16:19:11 +0200 Subject: [PATCH] Testing a max blob implementation (Works) --- include/ball_tracking/cuda/tracker.hpp | 27 ---------- src/cuda/tracker.cpp | 2 + src/tracker.cpp | 68 +++++++++++++++++++++++++- 3 files changed, 69 insertions(+), 28 deletions(-) diff --git a/include/ball_tracking/cuda/tracker.hpp b/include/ball_tracking/cuda/tracker.hpp index 970926c..b453a01 100644 --- a/include/ball_tracking/cuda/tracker.hpp +++ b/include/ball_tracking/cuda/tracker.hpp @@ -74,33 +74,6 @@ namespace ball_tracking { void operator()(const cv::cuda::GpuMat& src, cv::cuda::GpuMat& dst, cv::cuda::Stream& stream = cv::cuda::Stream::Null()); }; - /** - * @brief Find candidate ball positions if any in a likelihood image - */ - class FindBallBlob { - private: - class Impl; - std::shared_ptr<Impl> _impl; - public: - FindBallBlob(const nlohmann::json& params); - ~FindBallBlob(); - - /** - * @brief Creates an empty object - */ - FindBallBlob(); - - /** - * @brief Returns all candidate locations of a ball in a single channel - * likelihood image. - * - * @param[in] src A binary (thresholded) image segmenting the interesting parts - * @returns a vector of Key Points with the locations and other optional - * properties of the ball blobs found in the image - */ - std::vector<cv::KeyPoint> operator()(cv::InputArray src); - }; - }; }; diff --git a/src/cuda/tracker.cpp b/src/cuda/tracker.cpp index 262124f..0befcf5 100644 --- a/src/cuda/tracker.cpp +++ b/src/cuda/tracker.cpp @@ -7,6 +7,7 @@ #include <ball_tracking/cuda/tracker.hpp> #include <ball_tracking/utils.hpp> #include <opencv2/core.hpp> +#include <opencv2/cudaarithm.hpp> using namespace cv::cuda; using namespace cv; @@ -65,6 +66,7 @@ namespace ball_tracking { } } }; + }; class BallLogLikelihood::Impl { diff --git a/src/tracker.cpp b/src/tracker.cpp index 0129d01..ff134b5 100644 --- a/src/tracker.cpp +++ b/src/tracker.cpp @@ -4,6 +4,8 @@ #include <ball_tracking/utils.hpp> #include <algorithm> #include <json.hpp> +#include <queue> +#include <unordered_set> #ifdef WITH_CUDA #include <ball_tracking/cuda/tracker.hpp> @@ -80,7 +82,6 @@ namespace ball_tracking { /** * OpenCV Blob detection algorithm - * TODO: Continue */ class CV_blob_detect { private: @@ -120,6 +121,69 @@ namespace ball_tracking { } }; + /** + * Finds the blob of highest intensity and returns it only if is above + * a given threshold + */ + class Max_blob_detect { + private: + double high_thresh, low_thresh; + int maxArea; + public: + void set_blob_params(const json& conf) { + double phigh = conf.at("high_thresh"); + double plow = conf.at("low_thresh"); + high_thresh = log(phigh/(1-phigh)); + low_thresh = log(plow/(1-plow)); + maxArea = conf.at("maxArea"); + } + + Max_blob_detect(const json& conf) { + set_blob_params(conf); + } + + int h(int rows, const Point& p) { + return p.x + p.y*rows; + } + + vector<KeyPoint> operator()(cv::InputArray _llh) { + Mat llh = _llh.getMat(); + Point maxLoc; + double maxVal; + minMaxLoc(llh, 0, &maxVal, 0, &maxLoc); + vector<KeyPoint> ans; + const vector<Point>& delta{Point(0,1), Point(0,-1), Point(1,0), Point(-1,0)}; + int rows = llh.rows, cols = llh.cols; + if (maxVal > high_thresh) { + deque<Point> q{maxLoc}; + unordered_set<int> s{h(rows,maxLoc)}; + vector<Point> visited{maxLoc}; + for (int i=0; i<maxArea && !q.empty(); i++) { + Point parent = q.front(); q.pop_front(); + for (const auto& d : delta) { + Point child = parent + d; + if (child.x >= 0 && child.x<cols && child.y >= 0 && child.y<rows && + llh.at<double>(child.y, child.x)>low_thresh && s.count(h(rows, child))==0) { + q.push_back(child); + visited.push_back(child); + s.insert(h(rows, child)); + } + } + } + + Point2f center(0,0); + int low_x=cols+1, high_x=-1; + for (const auto p : visited) { + center.x += p.x; center.y += p.y; + low_x = min(low_x, p.x); + high_x = max(high_x, p.x); + } + ans.push_back( KeyPoint((1.0/visited.size())*center, high_x-low_x+1) ); + } + return ans; + } + }; + }; class Binarizer::Impl { @@ -197,6 +261,8 @@ namespace ball_tracking { const json& conf = config.at("conf"); if (type == "cv_blob_detect") { bf = CV_blob_detect(conf); + } else if (type == "max_blob_detect") { + bf = Max_blob_detect(conf); } else { throw std::logic_error("Type of the blob finder algorithm not recognized"); } -- GitLab