From c8b526c3da1e375155a1a0294280b6be1a5ddc97 Mon Sep 17 00:00:00 2001
From: Sebastian Gomez-Gonzalez <sgomez@tue.mpg.de>
Date: Fri, 22 Sep 2017 16:19:11 +0200
Subject: [PATCH] Testing a max blob implementation (Works)

---
 include/ball_tracking/cuda/tracker.hpp | 27 ----------
 src/cuda/tracker.cpp                   |  2 +
 src/tracker.cpp                        | 68 +++++++++++++++++++++++++-
 3 files changed, 69 insertions(+), 28 deletions(-)

diff --git a/include/ball_tracking/cuda/tracker.hpp b/include/ball_tracking/cuda/tracker.hpp
index 970926c..b453a01 100644
--- a/include/ball_tracking/cuda/tracker.hpp
+++ b/include/ball_tracking/cuda/tracker.hpp
@@ -74,33 +74,6 @@ namespace ball_tracking {
         void operator()(const cv::cuda::GpuMat& src, cv::cuda::GpuMat& dst, cv::cuda::Stream& stream = cv::cuda::Stream::Null());
     };
 
-    /**
-     * @brief Find candidate ball positions if any in a likelihood image 
-     */
-    class FindBallBlob {
-      private:
-        class Impl;
-        std::shared_ptr<Impl> _impl;
-      public:
-        FindBallBlob(const nlohmann::json& params);
-        ~FindBallBlob();
-
-        /**
-         * @brief Creates an empty object
-         */
-        FindBallBlob();
-
-        /**
-         * @brief Returns all candidate locations of a ball in a single channel
-         * likelihood image.
-         *
-         * @param[in] src A binary (thresholded) image segmenting the interesting parts
-         * @returns a vector of Key Points with the locations and other optional
-         * properties of the ball blobs found in the image
-         */
-        std::vector<cv::KeyPoint> operator()(cv::InputArray src);
-    };
-
   };
 };
 
diff --git a/src/cuda/tracker.cpp b/src/cuda/tracker.cpp
index 262124f..0befcf5 100644
--- a/src/cuda/tracker.cpp
+++ b/src/cuda/tracker.cpp
@@ -7,6 +7,7 @@
 #include <ball_tracking/cuda/tracker.hpp>
 #include <ball_tracking/utils.hpp>
 #include <opencv2/core.hpp>
+#include <opencv2/cudaarithm.hpp>
 
 using namespace cv::cuda;
 using namespace cv;
@@ -65,6 +66,7 @@ namespace ball_tracking {
             }
           }
       };
+
     };
 
     class BallLogLikelihood::Impl {
diff --git a/src/tracker.cpp b/src/tracker.cpp
index 0129d01..ff134b5 100644
--- a/src/tracker.cpp
+++ b/src/tracker.cpp
@@ -4,6 +4,8 @@
 #include <ball_tracking/utils.hpp>
 #include <algorithm>
 #include <json.hpp>
+#include <queue>
+#include <unordered_set>
 
 #ifdef WITH_CUDA
 #include <ball_tracking/cuda/tracker.hpp>
@@ -80,7 +82,6 @@ namespace ball_tracking {
 
     /**
      * OpenCV Blob detection algorithm
-     * TODO: Continue
      */
     class CV_blob_detect {
       private:
@@ -120,6 +121,69 @@ namespace ball_tracking {
         }
     };
 
+    /**
+     * Finds the blob of highest intensity and returns it only if is above
+     * a given threshold
+     */
+    class Max_blob_detect {
+      private:
+        double high_thresh, low_thresh;
+        int maxArea;
+      public:
+        void set_blob_params(const json& conf) {
+          double phigh = conf.at("high_thresh");
+          double plow = conf.at("low_thresh");
+          high_thresh = log(phigh/(1-phigh));
+          low_thresh = log(plow/(1-plow));
+          maxArea = conf.at("maxArea");
+        }
+
+        Max_blob_detect(const json& conf) {
+          set_blob_params(conf);
+        }
+
+        int h(int rows, const Point& p) {
+          return p.x + p.y*rows;
+        }
+
+        vector<KeyPoint> operator()(cv::InputArray _llh) {
+          Mat llh = _llh.getMat();
+          Point maxLoc;
+          double maxVal;
+          minMaxLoc(llh, 0, &maxVal, 0, &maxLoc);
+          vector<KeyPoint> ans;
+          const vector<Point>& delta{Point(0,1), Point(0,-1), Point(1,0), Point(-1,0)};
+          int rows = llh.rows, cols = llh.cols;
+          if (maxVal > high_thresh) {
+            deque<Point> q{maxLoc};
+            unordered_set<int> s{h(rows,maxLoc)};
+            vector<Point> visited{maxLoc};
+            for (int i=0; i<maxArea && !q.empty(); i++) {
+              Point parent = q.front(); q.pop_front();
+              for (const auto& d : delta) {
+                Point child = parent + d;
+                if (child.x >= 0 && child.x<cols && child.y >= 0 && child.y<rows && 
+                    llh.at<double>(child.y, child.x)>low_thresh && s.count(h(rows, child))==0) {
+                  q.push_back(child);
+                  visited.push_back(child);
+                  s.insert(h(rows, child));
+                }
+              }
+            }
+
+            Point2f center(0,0);
+            int low_x=cols+1, high_x=-1;
+            for (const auto p : visited) {
+              center.x += p.x; center.y += p.y;
+              low_x = min(low_x, p.x);
+              high_x = max(high_x, p.x);
+            }
+            ans.push_back( KeyPoint((1.0/visited.size())*center, high_x-low_x+1) );
+          }
+          return ans;
+        }
+    };
+
   };
 
   class Binarizer::Impl {
@@ -197,6 +261,8 @@ namespace ball_tracking {
         const json& conf = config.at("conf");
         if (type == "cv_blob_detect") {
           bf = CV_blob_detect(conf);
+        } else if (type == "max_blob_detect") {
+          bf = Max_blob_detect(conf);
         } else {
           throw std::logic_error("Type of the blob finder algorithm not recognized");
         }
-- 
GitLab