From 379c2e8746efc50b24d7f08cc06e094c51ae66e8 Mon Sep 17 00:00:00 2001
From: Sebastian Gomez-Gonzalez <sebastian@robot-learning.de>
Date: Thu, 21 Sep 2017 16:07:30 +0200
Subject: [PATCH] Creating functionality to seemless running CPU or GPU
 implementation. Not tested yet on the GPU

---
 CMakeLists.txt                         |  32 +++++--
 examples/CMakeLists.txt                |  33 +++----
 examples/tracking/cpu_track_conf.json  |  60 +++++++++++++
 examples/tracking/track.cpp            |   4 +-
 examples/tracking/track_conf.json      |   5 +-
 include/ball_tracking/cuda/tracker.hpp |  21 ++++-
 include/ball_tracking/tracker.hpp      |  44 +++++++++-
 src/cuda/tracker.cpp                   |  12 ++-
 src/tracker.cpp                        | 117 +++++++++++++++++++++++--
 9 files changed, 288 insertions(+), 40 deletions(-)
 create mode 100644 examples/tracking/cpu_track_conf.json

diff --git a/CMakeLists.txt b/CMakeLists.txt
index c025cf8..8951782 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -2,28 +2,42 @@ cmake_minimum_required(VERSION 3.5)
 project(ball_tracking)
 
 find_package(OpenCV REQUIRED)
-find_package(CUDA 8.0 REQUIRED)
+find_package(CUDA 8.0)
+find_library(ZMQPP NAMES zmqpp)
+find_path(ZMQPP_INCLUDES NAMES zmqpp/zmqpp.hpp)
 
-option (PYLIB "Create a Python Module with interface to some of the C++ implementations" ON)
+if (CUDA_FOUND)
+  option (WITH_CUDA "Compile the library with GPU implementations" ON)
+endif (CUDA_FOUND)
+
+if (WITH_CUDA)
+  include_directories(${CUDA_INCLUDE_DIRS})
+  set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -std=c++11 -arch=sm_30" )
+  cuda_add_library(cu_ball_track SHARED
+    src/cuda/img_proc.cu
+  )
+  set(GPU_CPP_SRC
+    src/cuda/tracker.cpp
+    )
+  set (GPU_BT_LIB
+    cu_ball_track
+    )
+  add_definitions(-DWITH_CUDA)
+endif(WITH_CUDA)
 
 include_directories(include
   ${OpenCV_INCLUDES}
-  ${CUDA_INCLUDE_DIRS})
-
-set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -std=c++11 -arch=sm_30" )
-cuda_add_library(cu_ball_track SHARED
-  src/cuda/img_proc.cu
   )
 
 add_library(ball_tracking SHARED
   src/img_proc.cpp
   src/utils.cpp
   src/tracker.cpp
-  src/cuda/tracker.cpp
+  ${GPU_CPP_SRC}
   )
 target_link_libraries(ball_tracking
   ${OpenCV_LIBS}
-  cu_ball_track
+  ${GPU_BT_LIB}
   )
 
 #Compile with C++11 support only
diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt
index 2d1e5a7..5583d2a 100644
--- a/examples/CMakeLists.txt
+++ b/examples/CMakeLists.txt
@@ -11,14 +11,6 @@ target_link_libraries(color_approach
   ${Boost_LIBRARIES}
   )
 
-add_executable(log_reg_approach
-  img_proc/log_reg_approach.cpp
-  )
-target_link_libraries(log_reg_approach
-  ball_tracking
-  ${Boost_LIBRARIES}
-  )
-
 add_executable(ball_log_lh
   img_proc/ball_log_lh.cpp
   )
@@ -35,10 +27,21 @@ target_link_libraries(track
   ${Boost_LIBRARIES}
   )
 
-add_executable(gpu_track
-  tracking/gpu_track.cpp
-  )
-target_link_libraries(gpu_track
-  ball_tracking
-  ${Boost_LIBRARIES}
-  )
+if (WITH_CUDA)
+  add_executable(log_reg_approach
+    img_proc/log_reg_approach.cpp
+    )
+  target_link_libraries(log_reg_approach
+    ball_tracking
+    ${Boost_LIBRARIES}
+    )
+
+  add_executable(gpu_track
+    tracking/gpu_track.cpp
+    )
+  target_link_libraries(gpu_track
+    ball_tracking
+    ${Boost_LIBRARIES}
+    )
+endif (WITH_CUDA)
+
diff --git a/examples/tracking/cpu_track_conf.json b/examples/tracking/cpu_track_conf.json
new file mode 100644
index 0000000..55ca9d5
--- /dev/null
+++ b/examples/tracking/cpu_track_conf.json
@@ -0,0 +1,60 @@
+{
+  "tracker": {
+    "type": "cpu",
+    "conf": {
+      "ball_log_lh": {
+        "type": "cb_log_reg",
+        "conf": { 
+          "weights": [
+            -4.3645153397769523,
+            -1.7528550846385467,
+            -0.94643386271480279,
+            -2.1172356652966022,
+            -1.7266645284508131,
+            -4.0142794641079673,
+            -3.8619400850106795,
+            1.7698900417473211,
+            2.2392411493959243,
+            1.6657076794425114,
+            1.9180643323474467,
+            -1.8974895205433868,
+            2.6587407284968458,
+            2.2491171019131664,
+            2.6281082938750533,
+            2.4812839914967717,
+            -2.2706336207474567,
+            2.5778824379463727,
+            1.2787179559340944,
+            1.9590196639347666,
+            -1.1268150111657678,
+            3.278679714617708,
+            2.232487133767151,
+            -1.6055001267916118,
+            3.5447611950425828,
+            -5.4301792573145677,
+            -4.7388290757015756,
+            -5.739125161010926
+          ],
+          "gauss_smooth": {
+            "size": 5,
+            "sigma": 0
+          }
+        }
+      },
+      "blob_detection": {
+        "type": "cv_blob_detect",
+        "conf": {
+          "binarizer": {
+            "p_thresh": 0.05
+          },
+          "filterByArea": true,
+          "minArea": 8,
+          "maxArea": 500,
+          "filterByCircularity": true,
+          "minCircularity": 0.75,
+          "maxCircularity": 1.0
+        }
+      }
+    }
+  }
+}
diff --git a/examples/tracking/track.cpp b/examples/tracking/track.cpp
index 2f25de1..5ad4dd9 100644
--- a/examples/tracking/track.cpp
+++ b/examples/tracking/track.cpp
@@ -107,7 +107,7 @@ int main(int argc, char** argv) {
       end_t = std::chrono::steady_clock::now();
       llh_time.push_back(std::chrono::duration_cast<std::chrono::microseconds>(end_t - start_t).count());
 
-      //2.1) Binarize
+      //2.1) Binarize (Just for debugging purposes)
       start_t = std::chrono::steady_clock::now();
       auto bin_img = bin(l_lh);
       end_t = std::chrono::steady_clock::now();
@@ -115,7 +115,7 @@ int main(int argc, char** argv) {
 
       //3) Run blob detection algorithm
       start_t = std::chrono::steady_clock::now();
-      auto key_pts = bf(bin_img);
+      auto key_pts = bf(l_lh);
       end_t = std::chrono::steady_clock::now();
       blob_time.push_back(std::chrono::duration_cast<std::chrono::microseconds>(end_t - start_t).count());
     
diff --git a/examples/tracking/track_conf.json b/examples/tracking/track_conf.json
index 7040981..cd3f6eb 100644
--- a/examples/tracking/track_conf.json
+++ b/examples/tracking/track_conf.json
@@ -40,10 +40,13 @@
   },
   "binarizer": {
     "p_thresh": 0.05
-  }, 
+  },
   "blob_detection": {
     "type": "cv_blob_detect",
     "conf": {
+      "binarizer": {
+        "p_thresh": 0.05
+      },
       "filterByArea": true,
       "minArea": 8,
       "maxArea": 500,
diff --git a/include/ball_tracking/cuda/tracker.hpp b/include/ball_tracking/cuda/tracker.hpp
index 3daab03..970926c 100644
--- a/include/ball_tracking/cuda/tracker.hpp
+++ b/include/ball_tracking/cuda/tracker.hpp
@@ -22,7 +22,7 @@ namespace ball_tracking {
     class BallLogLikelihood {
       private:
         class Impl;
-        std::unique_ptr<Impl> _impl;
+        std::shared_ptr<Impl> _impl;
       public:
         /**
          * @brief Creates an object with given configuration parameters
@@ -30,6 +30,11 @@ namespace ball_tracking {
         BallLogLikelihood(const nlohmann::json& params);
         ~BallLogLikelihood();
 
+        /**
+         * @brief Creates an empty object
+         */
+        BallLogLikelihood();
+
         /**
          * @brief Produces a single channel image representing the log likelihood of each
          * pixel being part of the ball
@@ -49,11 +54,16 @@ namespace ball_tracking {
     class Binarizer {
       private:
         class Impl;
-        std::unique_ptr<Impl> _impl;
+        std::shared_ptr<Impl> _impl;
       public:
         Binarizer(const nlohmann::json& params);
         ~Binarizer();
 
+        /**
+         * @brief Creates an empty object
+         */
+        Binarizer();
+
         /**
          * @brief Returns an 8 bit single channel binarized image from a given single channel
          * real valued image. The image contains only 0 or 255 values in each pixel.
@@ -70,11 +80,16 @@ namespace ball_tracking {
     class FindBallBlob {
       private:
         class Impl;
-        std::unique_ptr<Impl> _impl;
+        std::shared_ptr<Impl> _impl;
       public:
         FindBallBlob(const nlohmann::json& params);
         ~FindBallBlob();
 
+        /**
+         * @brief Creates an empty object
+         */
+        FindBallBlob();
+
         /**
          * @brief Returns all candidate locations of a ball in a single channel
          * likelihood image.
diff --git a/include/ball_tracking/tracker.hpp b/include/ball_tracking/tracker.hpp
index 6f564bf..6bbaa88 100644
--- a/include/ball_tracking/tracker.hpp
+++ b/include/ball_tracking/tracker.hpp
@@ -25,12 +25,21 @@ namespace ball_tracking {
   class BallLogLikelihood {
     private:
       class Impl;
-      std::unique_ptr<Impl> _impl;
+      std::shared_ptr<Impl> _impl;
     public:
       /**
        * @brief Creates an object with given configuration parameters
        */
       BallLogLikelihood(const nlohmann::json& params);
+
+      /**
+       * @brief Creates an empty object. 
+       *
+       * Using an object created with this default constructor will result 
+       * in undefined behaviour.
+       */
+      BallLogLikelihood();
+
       ~BallLogLikelihood();
 
       /**
@@ -51,9 +60,18 @@ namespace ball_tracking {
   class Binarizer {
     private:
       class Impl;
-      std::unique_ptr<Impl> _impl;
+      std::shared_ptr<Impl> _impl;
     public:
+      /**
+       * @brief Creates a binarizer with the given configuration
+       */
       Binarizer(const nlohmann::json& params);
+
+      /**
+       * @brief Creates an empty object
+       */
+      Binarizer();
+
       ~Binarizer();
 
       /**
@@ -72,11 +90,16 @@ namespace ball_tracking {
   class FindBallBlob {
     private:
       class Impl;
-      std::unique_ptr<Impl> _impl;
+      std::shared_ptr<Impl> _impl;
     public:
       FindBallBlob(const nlohmann::json& params);
       ~FindBallBlob();
 
+      /**
+       * @brief Creates an empty object
+       */
+      FindBallBlob();
+
       /**
        * @brief Returns all candidate locations of a ball in a single channel
        * likelihood image.
@@ -87,6 +110,21 @@ namespace ball_tracking {
        */
       std::vector<cv::KeyPoint> operator()(cv::InputArray src);
   };
+
+  /**
+   * @brief Returns the object position if any from the original image
+   */
+  class Tracker {
+    private:
+      class Impl;
+      std::shared_ptr<Impl> _impl;
+    public:
+      Tracker(const nlohmann::json& params);
+      Tracker();
+      ~Tracker();
+      std::vector<cv::KeyPoint> operator()(cv::InputArray src);
+  };
+
 };
 
 #endif
diff --git a/src/cuda/tracker.cpp b/src/cuda/tracker.cpp
index 638fb6b..262124f 100644
--- a/src/cuda/tracker.cpp
+++ b/src/cuda/tracker.cpp
@@ -83,7 +83,11 @@ namespace ball_tracking {
     };
 
     BallLogLikelihood::BallLogLikelihood(const nlohmann::json& params) {
-      _impl = unique_ptr<Impl>(new Impl(params));
+      _impl = shared_ptr<Impl>(new Impl(params));
+    }
+
+    BallLogLikelihood::BallLogLikelihood() {
+      _impl = nullptr;
     }
         
     BallLogLikelihood::~BallLogLikelihood() = default;
@@ -114,7 +118,11 @@ namespace ball_tracking {
     };
 
     Binarizer::Binarizer(const nlohmann::json& params) {
-      _impl = unique_ptr<Impl>(new Impl(params));
+      _impl = shared_ptr<Impl>(new Impl(params));
+    }
+
+    Binarizer::Binarizer() {
+      _impl = nullptr;
     }
 
     Binarizer::~Binarizer() = default;
diff --git a/src/tracker.cpp b/src/tracker.cpp
index d199c6c..90c2e9b 100644
--- a/src/tracker.cpp
+++ b/src/tracker.cpp
@@ -28,7 +28,7 @@ namespace ball_tracking {
 
         cv::Mat operator()(cv::InputArray _src) {
           Mat dst;
-          GaussianBlur(_src, dst, Size(size,size), sigma, sigma);
+          cv::GaussianBlur(_src, dst, Size(size,size), sigma, sigma);
           return dst;
         }
     };
@@ -82,6 +82,7 @@ namespace ball_tracking {
         double thresh;
         SimpleBlobDetector::Params params;
         Ptr<SimpleBlobDetector> detector;
+        Binarizer bin;
       public:
         void set_blob_params(const json& conf) {
           if (conf.count("filterByArea")) {
@@ -94,6 +95,7 @@ namespace ball_tracking {
             params.minCircularity = conf.at("minCircularity");
             params.maxCircularity = conf.at("maxCircularity");
           }
+          bin = Binarizer(conf.at("binarizer"));
         }
 
         CV_blob_detect(const json& conf) {
@@ -105,8 +107,9 @@ namespace ball_tracking {
 #endif
         }
 
-        vector<KeyPoint> operator()(cv::InputArray bin_img) {
+        vector<KeyPoint> operator()(cv::InputArray llh) {
           vector<KeyPoint> ans;
+          Mat bin_img = bin(llh);
           detector->detect(bin_img, ans);
           return ans;
         }
@@ -136,7 +139,11 @@ namespace ball_tracking {
   };
 
   Binarizer::Binarizer(const nlohmann::json& params) {
-    _impl = unique_ptr<Impl>(new Impl(params));
+    _impl = shared_ptr<Impl>(new Impl(params));
+  }
+
+  Binarizer::Binarizer() {
+    _impl = nullptr;
   }
 
   Binarizer::~Binarizer() = default;
@@ -162,7 +169,11 @@ namespace ball_tracking {
   };
 
   BallLogLikelihood::BallLogLikelihood(const nlohmann::json& params) {
-    _impl = unique_ptr<Impl>(new Impl(params));
+    _impl = shared_ptr<Impl>(new Impl(params));
+  }
+
+  BallLogLikelihood::BallLogLikelihood() {
+    _impl = nullptr;
   }
       
   BallLogLikelihood::~BallLogLikelihood() = default;
@@ -188,7 +199,11 @@ namespace ball_tracking {
   };
 
   FindBallBlob::FindBallBlob(const nlohmann::json& params) {
-    _impl = unique_ptr<Impl>(new Impl(params));
+    _impl = shared_ptr<Impl>(new Impl(params));
+  }
+
+  FindBallBlob::FindBallBlob() {
+    _impl = nullptr;
   }
 
   FindBallBlob::~FindBallBlob() = default;
@@ -198,4 +213,96 @@ namespace ball_tracking {
   }
 
 
+  /**
+   * Implementation of the tracker class
+   */
+  namespace {
+
+    class CPUTracker {
+      private:
+        BallLogLikelihood llh;
+        FindBallBlob blob_detect;
+      public:
+        CPUTracker(const json& conf) {
+          llh = BallLogLikelihood(conf.at("ball_log_lh"));
+          blob_detect = FindBallBlob(conf.at("blob_detection"));
+        }
+
+        vector<KeyPoint> operator()(cv::InputArray img) {
+          Mat llh_img = llh(img);
+          return blob_detect(llh_img);
+        }
+    };
+
+#ifdef WITH_CUDA
+#include <ball_tracking/cuda/tracker.hpp>
+#include <opencv2/core.hpp>
+    class GPUTracker {
+      private:
+        ball_tracking::cuda::BallLogLikelihood llh; //!< Fully implemented in GPU
+        FindBallBlob blob_detect; //!< For the moment only implemented in CPU
+        cv::cuda::Stream stream;
+      public:
+        GPUTracker(const json& conf) {
+          llh = ball_tracking::cuda::BallLogLikelihood(conf.at("ball_log_lh"));
+          blob_detect = FindBallBlob(conf.at("blob_detection"));
+        }
+
+        vector<KeyPoint> operator()(cv::InputArray _img) {
+          Mat img = _img.getMat();
+          GpuMat gpu_img(img, stream);
+          GpuMat gpu_llh_img;
+          llh(gpu_img, gpu_llh_img, stream);
+
+          Mat llh_img;
+          gpu_llh_img.download(llh_img, stream);
+          stream.waitForCompletion();
+
+          return blob_detect(llh_img);
+        }
+    };
+#else
+    class GPUTracker {
+      public:
+        GPUTracker(const json& conf) {
+          throw new std::logic_error("Calling a GPU function on a CPU only compiled tracker");
+        }
+
+        vector<KeyPoint> operator()(cv::InputArray _img) {
+          throw new std::logic_error("Calling a GPU function on a CPU only compiled tracker");
+        }
+    };
+#endif
+  };
+
+  class Tracker::Impl {
+    public:
+      blob_finder tracker;
+
+      Impl(const json& conf) {
+        string type = conf.at("type");
+        if (type == "cpu") {
+          tracker = CPUTracker(conf.at("conf"));
+        } else if (type == "gpu") {
+          tracker = GPUTracker(conf.at("gpu"));
+        } else {
+          throw std::logic_error("Type of tracker selected in the configuration is not recognized");
+        }
+      }
+  };
+
+  Tracker::Tracker(const nlohmann::json& params) {
+    _impl = shared_ptr<Impl>(new Impl(params));
+  }
+
+  Tracker::Tracker() {
+    _impl = nullptr;
+  }
+
+  Tracker::~Tracker() = default;
+      
+  std::vector<cv::KeyPoint> Tracker::operator()(cv::InputArray src) {
+    return _impl->tracker(src);
+  }
+
 };
-- 
GitLab