Skip to content

Commit e51aaa9

Browse files
committed
Add basic definition of augmentation class
1 parent 6ae432e commit e51aaa9

File tree

5 files changed

+180
-0
lines changed

5 files changed

+180
-0
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ find_package(Boost 1.49
164164
COMPONENTS
165165
filesystem
166166
system
167+
regex
167168
program_options
168169
serialization
169170
unit_test_framework

augmentation/CMakeLists.txt

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
cmake_minimum_required(VERSION 3.1.0 FATAL_ERROR)
2+
project(augmentation)
3+
4+
option(DEBUG "DEBUG" OFF)
5+
6+
set(DIR_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/)
7+
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../../")
8+
9+
set(SOURCES
10+
augmentation.hpp
11+
)
12+
13+
foreach(file ${SOURCES})
14+
set(DIR_SRCS ${DIR_SRCS} ${CMAKE_CURRENT_SOURCE_DIR}/${file})
15+
endforeach()
16+
17+
# Append sources (with directory name) to list of all models sources (used at
18+
# the parent scope).
19+
set(DIRS ${DIRS} ${DIR_SRCS} PARENT_SCOPE)

augmentation/augmentation.hpp

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/**
2+
* @file augmentation.hpp
3+
* @author Kartik Dutt
4+
*
5+
* Definition of Augmentation class for augmenting data.
6+
*
7+
* mlpack is free software; you may redistribute it and/or modify it under the
8+
* terms of the 3-clause BSD license. You should have received a copy of the
9+
* 3-clause BSD license along with mlpack. If not, see
10+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
11+
*/
12+
13+
#include <mlpack/methods/ann/layer/bilinear_interpolation.hpp>
14+
#include <boost/regex.hpp>
15+
16+
#ifndef MODELS_AUGMENTATION_HPP
17+
#define MODELS_AUGMENTATION_HPP
18+
19+
/**
20+
* Augmentation class used to perform augmentations / transform the data.
21+
* For the list of supported augmentation, take a look at our wiki page.
22+
*
23+
* @code
24+
* Augmentation<> augmentation({"horizontal-flip", "resize = (224, 224)"}, 0.2);
25+
* augmentation.Transform(dataloader.TrainFeatures);
26+
* @endcode
27+
*
28+
* @tparam DatasetX Datatype on which augmentation will be done.
29+
*/
30+
class Augmentation
31+
{
32+
public:
33+
//! Create the augmenation class object.
34+
Augmentation();
35+
36+
/**
37+
* Constructor for augmentation class.
38+
*
39+
* @param augmentation List of strings containing one of the supported
40+
* augmentation.
41+
* @param augmentationProbability Probability of applying augmentation on
42+
* the dataset.
43+
* NOTE : This doesn't apply to augmentations
44+
* such as resize.
45+
* @param batches Boolean to determine if input is a single data point or
46+
* a batch. Defaults to true.
47+
* NOTE : If true, each data point must be represented as a
48+
* seperate column.
49+
*/
50+
Augmentation(const std::vector<std::string>& augmentation,
51+
const double augmentationProbability);
52+
53+
/**
54+
*/
55+
template<typename DatasetType = arma::mat>
56+
void Transform(DatasetType& dataset);
57+
58+
template<typename DatasetType = arma::mat>
59+
void ResizeTransform(DatasetType& dataset);
60+
61+
template <typename DatasetType = arma::mat>
62+
void HorizontalFlipTransform(DatasetType &dataset);
63+
64+
template<typename DatasetType = arma::mat>
65+
void VerticalFlipTransform(DatasetType& dataset);
66+
67+
68+
private:
69+
/**
70+
* Function to determine if augmentation has Resize function.
71+
*/
72+
bool HasResizeParam()
73+
{
74+
return augmentations.size() <= 0 ? false :
75+
augmentations[0].find("resize") != std::string::npos ;
76+
}
77+
78+
/**
79+
* Sets size of output width and output height of the new data.
80+
*
81+
* @param outWidth Output width of resized data point.
82+
* @param outHeight Output height of resized data point.
83+
*/
84+
void GetResizeParam(size_t& outWidth, size_t& outHeight)
85+
{
86+
if (!HasResizeParam())
87+
{
88+
return;
89+
}
90+
91+
outWidth = -1;
92+
outHeight = -1;
93+
94+
// Use regex to find one / two numbers. If only one provided
95+
// set output width equal to output height.
96+
boost::regex regex{"[0-9]+"};
97+
98+
// Create an iterator to find matches.
99+
boost::sregex_token_iterator matches(augmentations[0].begin(),
100+
augmentations[0].end(), regex, 0), end;
101+
102+
size_t matchesCount = std::distance(matches, end);
103+
104+
if (matchesCount == 0)
105+
{
106+
mlpack::Log::Fatal << "Invalid size / shape in " <<
107+
augmentations[0] << std::endl;
108+
}
109+
110+
if (matchesCount == 1)
111+
{
112+
outWidth = std::stoi(*matches);
113+
outHeight = outWidth;
114+
}
115+
else
116+
{
117+
outWidth = std::stoi(*matches);
118+
matches++;
119+
outHeight = std::stoi(*matches);
120+
}
121+
}
122+
123+
//! Locally held augmentations / transforms that need to be applied.
124+
std::vector<std::string> augmentations;
125+
126+
//! Locally held value of augmentation probability.
127+
double augmentationProbability;
128+
};
129+
130+
#endif

tests/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ set(MODEL_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/)
77
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/../")
88

99
add_executable(models_test
10+
augmentation_tests.cpp
1011
dataloader_tests.cpp
1112
utils_tests.cpp
1213
)
@@ -19,6 +20,7 @@ target_link_libraries(models_test
1920
${Boost_UNIT_TEST_FRAMEWORK_LIBRARY}
2021
${Boost_SYSTEM_LIBRARY}
2122
${Boost_SERIALIZATION_LIBRARY}
23+
${Boost_REGEX_LIBRARY}
2224
${MLPACK_LIBRARIES}
2325
)
2426

tests/augmentation_tests.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
/**
2+
* @file augmentation.cpp
3+
* @author Kartik Dutt
4+
*
5+
* Tests for various functionalities of utils.
6+
*
7+
* mlpack is free software; you may redistribute it and/or modify it under the
8+
* terms of the 3-clause BSD license. You should have received a copy of the
9+
* 3-clause BSD license along with mlpack. If not, see
10+
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
11+
*/
12+
#define BOOST_TEST_DYN_LINK
13+
#include <boost/regex.hpp>
14+
#include <boost/test/unit_test.hpp>
15+
using namespace boost::unit_test;
16+
17+
BOOST_AUTO_TEST_SUITE(AugmentationTest);
18+
19+
BOOST_AUTO_TEST_CASE(REGEXTest)
20+
{
21+
std::string s = " resize = { 19, 112 }, resize : 133,442, resize = [12 213]";
22+
boost::regex expr{"[0-9]+"};
23+
boost::smatch what;
24+
boost::sregex_token_iterator iter(s.begin(), s.end(), expr, 0);
25+
boost::sregex_token_iterator end;
26+
}
27+
28+
BOOST_AUTO_TEST_SUITE_END();

0 commit comments

Comments
 (0)