From 1e9da3ceea4fb5be458356d78724bcdd45580762 Mon Sep 17 00:00:00 2001 From: wooseokRo <43750676+wooseokRo@users.noreply.github.com> Date: Tue, 4 Aug 2020 11:23:51 +0900 Subject: [PATCH 1/4] Add files via upload --- CMakeLists.txt | 2 +- Makefile | 23 +++++++-- Makefile.config | 121 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 140 insertions(+), 6 deletions(-) create mode 100644 Makefile.config diff --git a/CMakeLists.txt b/CMakeLists.txt index 27d172f900b..6a543a495a3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,7 +32,7 @@ caffe_option(USE_CUDNN "Build Caffe with cuDNN library support" ON IF NOT CPU_ON caffe_option(USE_NCCL "Build Caffe with NCCL library support" OFF) caffe_option(BUILD_SHARED_LIBS "Build shared libraries" ON) caffe_option(BUILD_python "Build Python wrapper" ON) -set(python_version "2" CACHE STRING "Specify which Python version to use") +set(python_version "3" CACHE STRING "Specify which Python version to use") caffe_option(BUILD_matlab "Build Matlab wrapper" OFF IF UNIX OR APPLE) caffe_option(BUILD_docs "Build documentation" ON IF UNIX OR APPLE) caffe_option(BUILD_python_layer "Build the Caffe Python layer" ON) diff --git a/Makefile b/Makefile index b7660e852d6..35fb4feee3e 100644 --- a/Makefile +++ b/Makefile @@ -186,6 +186,7 @@ USE_LMDB ?= 1 # This code is taken from https://github.com/sh1r0/caffe-android-lib USE_HDF5 ?= 1 USE_OPENCV ?= 1 +USE_PKG_CONFIG ?= 0 ifeq ($(USE_LEVELDB), 1) LIBRARIES += leveldb snappy @@ -200,12 +201,21 @@ endif ifeq ($(USE_OPENCV), 1) LIBRARIES += opencv_core opencv_highgui opencv_imgproc - ifeq ($(OPENCV_VERSION), 3) + ifeq ($(OPENCV_VERSION), $(filter $(OPENCV_VERSION), 3 4)) LIBRARIES += opencv_imgcodecs endif + ifeq ($(OPENCV_VERSION), 4) + ifeq ($(USE_PKG_CONFIG), 1) + INCLUDE_DIRS += $(shell pkg-config opencv4 --cflags-only-I | sed 's/-I//g') + else + INCLUDE_DIRS += /usr/include/opencv4 /usr/local/include/opencv4 + INCLUDE_DIRS += /usr/include/opencv4/opencv /usr/local/include/opencv4/opencv + endif + endif + endif -PYTHON_LIBRARIES ?= boost_python python2.7 +PYTHON_LIBRARIES ?= boost_python python3.5 WARNINGS := -Wall -Wno-sign-compare ############################## @@ -427,9 +437,12 @@ NVCCFLAGS += -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS) MATLAB_CXXFLAGS := $(CXXFLAGS) -Wno-uninitialized LINKFLAGS += -pthread -fPIC $(COMMON_FLAGS) $(WARNINGS) -USE_PKG_CONFIG ?= 0 ifeq ($(USE_PKG_CONFIG), 1) - PKG_CONFIG := $(shell pkg-config opencv --libs) + ifeq ($(OPENCV_VERSION), 4) + PKG_CONFIG := $(shell pkg-config opencv4 --libs) + else + PKG_CONFIG := $(shell pkg-config opencv --libs) + endif else PKG_CONFIG := endif @@ -543,7 +556,7 @@ runtest: $(TEST_ALL_BIN) $(TEST_ALL_BIN) $(TEST_GPUID) --gtest_shuffle $(TEST_FILTER) pytest: py - cd python; python -m unittest discover -s caffe/test + cd python; python3 -m unittest discover -s caffe/test mattest: mat cd matlab; $(MATLAB_DIR)/bin/matlab -nodisplay -r 'caffe.run_tests(), exit()' diff --git a/Makefile.config b/Makefile.config new file mode 100644 index 00000000000..35be8b56623 --- /dev/null +++ b/Makefile.config @@ -0,0 +1,121 @@ +## Refer to http://caffe.berkeleyvision.org/installation.html +# Contributions simplifying and improving our build system are welcome! + +# cuDNN acceleration switch (uncomment to build with cuDNN). +USE_CUDNN := 1 + +# CPU-only switch (uncomment to build without GPU support). +# CPU_ONLY := 1 + +# uncomment to disable IO dependencies and corresponding data layers +# USE_OPENCV := 0 +# USE_LEVELDB := 0 +# USE_LMDB := 0 +# This code is taken from https://github.com/sh1r0/caffe-android-lib +# USE_HDF5 := 0 + +# uncomment to allow MDB_NOLOCK when reading LMDB files (only if necessary) +# You should not set this flag if you will be reading LMDBs with any +# possibility of simultaneous read and write +# ALLOW_LMDB_NOLOCK := 1 + +# Uncomment if you're using OpenCV 3 or 4 +OPENCV_VERSION := 4 +USE_PKG_CONFIG := 1 + +# To customize your choice of compiler, uncomment and set the following. +# N.B. the default for Linux is g++ and the default for OSX is clang++ +# CUSTOM_CXX := g++ + +# CUDA directory contains bin/ and lib/ directories that we need. +CUDA_DIR := /usr/local/cuda +# On Ubuntu 14.04, if cuda tools are installed via +# "sudo apt-get install nvidia-cuda-toolkit" then use this instead: +# CUDA_DIR := /usr + +# CUDA architecture setting: going with all of them. +# For CUDA < 6.0, comment the *_50 through *_61 lines for compatibility. +# For CUDA < 8.0, comment the *_60 and *_61 lines for compatibility. +# For CUDA >= 9.0, comment the *_20 and *_21 lines for compatibility. +CUDA_ARCH := #-gencode arch=compute_20,code=sm_20 \ + #-gencode arch=compute_20,code=sm_21 \ + #-gencode arch=compute_30,code=sm_30 \ + #-gencode arch=compute_35,code=sm_35 \ + -gencode arch=compute_50,code=sm_50 \ + -gencode arch=compute_52,code=sm_52 \ + -gencode arch=compute_60,code=sm_60 \ + -gencode arch=compute_61,code=sm_61 \ + -gencode arch=compute_75,code=sm_75 \ + -gencode arch=compute_75,code=compute_75 + +# BLAS choice: +# atlas for ATLAS (default) +# mkl for MKL +# open for OpenBlas +BLAS := open +# Custom (MKL/ATLAS/OpenBLAS) include and lib directories. +# Leave commented to accept the defaults for your choice of BLAS +# (which should work)! +# BLAS_INCLUDE := /path/to/your/blas +# BLAS_LIB := /path/to/your/blas + +# Homebrew puts openblas in a directory that is not on the standard search path +# BLAS_INCLUDE := $(shell brew --prefix openblas)/include +# BLAS_LIB := $(shell brew --prefix openblas)/lib + +# This is required only if you will compile the matlab interface. +# MATLAB directory should contain the mex binary in /bin. +MATLAB_DIR := /usr/local/MATLAB/R2019b +# MATLAB_DIR := /Applications/MATLAB_R2012b.app + +# NOTE: this is required only if you will compile the python interface. +# We need to be able to find Python.h and numpy/arrayobject.h. +# PYTHON_INCLUDE := /usr/include/python2.7 \ +# /usr/lib/python2.7/dist-packages/numpy/core/include +# Anaconda Python distribution is quite popular. Include path: +# Verify anaconda location, sometimes it's in root. +ANACONDA_HOME := $(HOME)/anaconda +PYTHON_INCLUDE := $(ANACONDA_HOME)/include \ + $(ANACONDA_HOME)/include/python3.5m \ + $(ANACONDA_HOME)/lib/python3.5/site-packages/numpy/core/include + +# Uncomment to use Python 3 (Python 2 is not supported after 2019) +PYTHON_LIBRARIES := boost_python-py35 python3.5m +#PYTHON_INCLUDE := /usr/include/python3.5m \ +# /usr/lib/python3/dist-packages/numpy/core/include + +# We need to be able to find libpythonX.X.so or .dylib. +PYTHON_LIB := /usr/lib +PYTHON_LIB := $(ANACONDA_HOME)/lib + +# Homebrew installs numpy in a non standard path (keg only) +# PYTHON_INCLUDE += $(dir $(shell python -c 'import numpy.core; print(numpy.core.__file__)'))/include +# PYTHON_LIB += $(shell brew --prefix numpy)/lib + +# Uncomment to support layers written in Python (will link against Python libs) +WITH_PYTHON_LAYER := 1 + +# Whatever else you find you need goes here. +INCLUDE_DIRS := $(PYTHON_INCLUDE) /usr/local/include /usr/include/hdf5/serial/ +LIBRARY_DIRS := $(PYTHON_LIB) /usr/local/lib /usr/lib /usr/lib/x86_64-linux-gnu/hdf5/serial/ + +# If Homebrew is installed at a non standard location (for example your home directory) and you use it for general dependencies +# INCLUDE_DIRS += $(shell brew --prefix)/include +# LIBRARY_DIRS += $(shell brew --prefix)/lib + +# NCCL acceleration switch (uncomment to build with NCCL) +# https://github.com/NVIDIA/nccl (last tested version: v1.2.3-1+cuda8.0) +# USE_NCCL := 1 + +# N.B. both build and distribute dirs are cleared on `make clean` +BUILD_DIR := build +DISTRIBUTE_DIR := distribute + +# Uncomment for debugging. Does not work on OSX due to https://github.com/BVLC/caffe/issues/171 +# DEBUG := 1 + +# The ID of the GPU that 'make runtest' will use to run unit tests. +TEST_GPUID := 0 + +# enable pretty build (comment to see full commands) +Q ?= @ From c48b67ad876b2ea9abcd61b6834403fcf5423ff6 Mon Sep 17 00:00:00 2001 From: wooseokRo <43750676+wooseokRo@users.noreply.github.com> Date: Tue, 4 Aug 2020 11:27:05 +0900 Subject: [PATCH 2/4] Add files via upload --- src/caffe/layers/heatmaps_from_vec_layer.cpp | 116 +++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 src/caffe/layers/heatmaps_from_vec_layer.cpp diff --git a/src/caffe/layers/heatmaps_from_vec_layer.cpp b/src/caffe/layers/heatmaps_from_vec_layer.cpp new file mode 100644 index 00000000000..2f56211ccdb --- /dev/null +++ b/src/caffe/layers/heatmaps_from_vec_layer.cpp @@ -0,0 +1,116 @@ +#include + +#include "caffe/layers/heatmaps_from_vec_layer.hpp" + +namespace caffe { + +template +void HeatmapsFromVecLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + range_ = 1.5f; + heatmap_size_ = 32; + kernel_size_ = 3; + gradient_fact_ = ((float)(heatmap_size_ - 1)) / (2.f * kernel_size_); + gaussian_.resize((kernel_size_+1)*(kernel_size_+1)); // bottom-right quarter Gaussian values + + // un-normalized Gaussian (bottom-right quarter) in pixel space + for (int k_r = 0; k_r <= kernel_size_; k_r++) + { + for (int k_c = 0; k_c <= kernel_size_; k_c++) + { + int linID = k_r * (kernel_size_+1) + k_c; + gaussian_[linID] = exp(-0.5 * (k_r*k_r + k_c*k_c)); + } + } +} + +template +void HeatmapsFromVecLayer::Reshape(const vector*>& bottom, + const vector*>& top) { // bottom has N x numJoints x 1 x 3 + std::vector shape = bottom[0]->shape(); + num_vecs_ = shape[1]; // number of vectors corresponds to channels + proj_vecs_.resize(2*num_vecs_); + shape[2] = heatmap_size_; + shape[3] = heatmap_size_; + top[0]->Reshape(shape); +} + +template +void HeatmapsFromVecLayer::Forward_cpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->cpu_data(); + Dtype* top_data = top[0]->mutable_cpu_data(); + caffe_set(top[0]->count(), (Dtype)0.0, top_data); + int example_size = top[0]->count(1); // this is C x H x W + int num = top[0]->shape(0); // batch_size + + int u_id, v_id; // top-left corner is 0,0 + // x -> left to right, y -> top to bottom, assuming ortographic camera for the moment + int top_step = heatmap_size_ * heatmap_size_; + + for (int n = 0; n < num; n++) + { + for (int j = 0; j < num_vecs_; j++) + { + u_id = (int)round((bottom_data[n * (3 * num_vecs_) + j * 3] + range_) / (2 * range_) * (heatmap_size_ - 1)); + v_id = (int)round((bottom_data[n * (3 * num_vecs_) + j * 3 + 1] + range_) / (2 * range_) * (heatmap_size_ - 1)); + proj_vecs_[j*2] = u_id; + proj_vecs_[j*2 + 1] = v_id; + + for (int w = -kernel_size_; w <= kernel_size_; w++) + { + for (int h = -kernel_size_; h <= kernel_size_; h++) + { + if (u_id + w >= 0 && u_id + w < heatmap_size_ && v_id + h >= 0 && v_id + h < heatmap_size_) + { // in this case, put Gaussian in top blob centered at u_id, v_id + top_data[n * example_size + j * top_step + (v_id + h) * heatmap_size_ + (u_id + w)] = gaussian_[abs(h) * (kernel_size_+1) + abs(w)]; + } + } + } + } + } +} + +template +void HeatmapsFromVecLayer::Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + if (propagate_down[0]) + { + const Dtype* top_data = top[0]->cpu_data(); + const Dtype* top_diff = top[0]->cpu_diff(); + Dtype* bottom_diff = bottom[0]->mutable_cpu_diff(); + caffe_set(bottom[0]->count(), (Dtype)0.0, bottom_diff); // init to 0 to enable accumulation + + int example_size = top[0]->count(1); // this is C x H x W + int top_step = heatmap_size_ * heatmap_size_; + int num = top[0]->shape(0); // batch_size + int top_id; + + for (int n = 0; n < num; n++) + { + for (int j = 0; j < num_vecs_; j++) + { + for (int w = -kernel_size_; w <= kernel_size_; w++) // all gradients outside the Gaussian are 0 + { + for (int h = -kernel_size_; h <= kernel_size_; h++) + { + if (proj_vecs_[j*2] + w >= 0 && proj_vecs_[j*2] + w < heatmap_size_ && proj_vecs_[j*2+1] + h >= 0 && proj_vecs_[j*2+1] + h < heatmap_size_) + { // in this case, accumulate gradient + top_id = n * example_size + j * top_step + (proj_vecs_[j*2+1] + h) * heatmap_size_ + (proj_vecs_[j*2] + w); + // x gradient + bottom_diff[n * (3 * num_vecs_) + j * 3] += top_diff[top_id] * top_data[top_id] * w * gradient_fact_; + // y gradient + bottom_diff[n * (3 * num_vecs_) + j * 3 + 1] += top_diff[top_id] * top_data[top_id] * h * gradient_fact_; + } + } + } + } + } + + } +} + +INSTANTIATE_CLASS(HeatmapsFromVecLayer); +REGISTER_LAYER_CLASS(HeatmapsFromVec); + +} // namespace caffe From 9230f759dd9d6f65faf26077315337b80782a023 Mon Sep 17 00:00:00 2001 From: wooseokRo <43750676+wooseokRo@users.noreply.github.com> Date: Tue, 4 Aug 2020 11:27:47 +0900 Subject: [PATCH 3/4] Add files via upload --- .../caffe/layers/heatmaps_from_vec_layer.hpp | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 include/caffe/layers/heatmaps_from_vec_layer.hpp diff --git a/include/caffe/layers/heatmaps_from_vec_layer.hpp b/include/caffe/layers/heatmaps_from_vec_layer.hpp new file mode 100644 index 00000000000..65fdc9383a8 --- /dev/null +++ b/include/caffe/layers/heatmaps_from_vec_layer.hpp @@ -0,0 +1,55 @@ +#ifndef CAFFE_HEATMAPS_FROM_VEC_LAYER_HPP_ +#define CAFFE_HEATMAPS_FROM_VEC_LAYER_HPP_ + +#include + +#include "caffe/blob.hpp" +#include "caffe/layer.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +/** + * @brief Projects 3D positions onto the image plane of a virtual camera and creates heatmaps. + */ +template +class HeatmapsFromVecLayer : public Layer { + public: + /** + * @param param provides options: + * - num_iter. The number of IK iterations to best fit the 3D vector to the input heatmaps. + */ + explicit HeatmapsFromVecLayer(const LayerParameter& param) + : Layer(param) {} + + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + + virtual inline const char* type() const { return "HeatmapsFromVec"; } + virtual inline int ExactNumBottomBlobs() const { return 1; } + virtual inline int ExactNumTopBlobs() const { return 1; } + + protected: + virtual void Forward_cpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_cpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + float fx_; // focal length x + float fy_; // focal length y + float ux_; // principal point x + float uy_; // principal point y + int num_vecs_; // number of 3D vectors to be transformed to heatmaps + int heatmap_size_; // resolution of heatmaps (always square) + int kernel_size_; // size of Gaussian kernel to put is 2*kernel_size_+1 + float range_; // range -val to val corresponds to heatmap width and height + std::vector gaussian_; // Gaussian kernel to be put around every projected 2D location (linearized) + std::vector proj_vecs_; // save projection of 3D vecs for gradient computation + float gradient_fact_; // constant factor in the Jacobi matrix of projection (for orthographic camera) +}; + +} // namespace caffe + +#endif // CAFFE_HEATMAPS_FROM_VEC_LAYER_HPP_ From d12d62e2e521b6dd1cab1ab8040c0e1e34155e80 Mon Sep 17 00:00:00 2001 From: wooseokRo <43750676+wooseokRo@users.noreply.github.com> Date: Tue, 4 Aug 2020 11:28:15 +0900 Subject: [PATCH 4/4] Add files via upload --- src/caffe/proto/caffe.proto | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 3dcad697f6d..62ea48bcc6b 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -322,7 +322,7 @@ message ParamSpec { // NOTE // Update the next available ID when you add a new LayerParameter field. // -// LayerParameter next available layer-specific ID: 149 (last added: clip_param) +// LayerParameter next available layer-specific ID: 150 (last added: heatmaps_from_vec_layer) message LayerParameter { optional string name = 1; // the layer name optional string type = 2; // the layer type @@ -393,6 +393,7 @@ message LayerParameter { optional FlattenParameter flatten_param = 135; optional HDF5DataParameter hdf5_data_param = 112; optional HDF5OutputParameter hdf5_output_param = 113; + optional HeatmapsFromVecParameter heatmaps_from_vec_layer = 149; optional HingeLossParameter hinge_loss_param = 114; optional ImageDataParameter image_data_param = 115; optional InfogainLossParameter infogain_loss_param = 116; @@ -507,6 +508,7 @@ message ArgMaxParameter { } // Message that stores parameters used by ClipLayer + message ClipParameter { required float min = 1; required float max = 2; @@ -1447,3 +1449,7 @@ message PReLUParameter { // Whether or not slope parameters are shared across channels. optional bool channel_shared = 2 [default = false]; } + +message HeatmapsFromVecParameter { + optional uint32 heatmap_size = 1 [default = 32]; +}