1+ /* *
2+ * @file augmentation.hpp
3+ * @author Kartik Dutt
4+ *
5+ * Definition of Augmentation class for augmenting data.
6+ *
7+ * mlpack is free software; you may redistribute it and/or modify it under the
8+ * terms of the 3-clause BSD license. You should have received a copy of the
9+ * 3-clause BSD license along with mlpack. If not, see
10+ * http://www.opensource.org/licenses/BSD-3-Clause for more information.
11+ */
12+
13+ #include < mlpack/methods/ann/layer/bilinear_interpolation.hpp>
14+ #include < boost/regex.hpp>
15+
16+ #ifndef MODELS_AUGMENTATION_HPP
17+ #define MODELS_AUGMENTATION_HPP
18+
19+ /* *
20+ * Augmentation class used to perform augmentations / transform the data.
21+ * For the list of supported augmentation, take a look at our wiki page.
22+ *
23+ * @code
24+ * Augmentation<> augmentation({"horizontal-flip", "resize = (224, 224)"}, 0.2);
25+ * augmentation.Transform(dataloader.TrainFeatures);
26+ * @endcode
27+ *
28+ * @tparam DatasetX Datatype on which augmentation will be done.
29+ */
30+ class Augmentation
31+ {
32+ public:
33+ // ! Create the augmenation class object.
34+ Augmentation ();
35+
36+ /* *
37+ * Constructor for augmentation class.
38+ *
39+ * @param augmentation List of strings containing one of the supported
40+ * augmentation.
41+ * @param augmentationProbability Probability of applying augmentation on
42+ * the dataset.
43+ * NOTE : This doesn't apply to augmentations
44+ * such as resize.
45+ * @param batches Boolean to determine if input is a single data point or
46+ * a batch. Defaults to true.
47+ * NOTE : If true, each data point must be represented as a
48+ * seperate column.
49+ */
50+ Augmentation (const std::vector<std::string>& augmentation,
51+ const double augmentationProbability);
52+
53+ /* *
54+ */
55+ template <typename DatasetType = arma::mat>
56+ void Transform (DatasetType& dataset);
57+
58+ template <typename DatasetType = arma::mat>
59+ void ResizeTransform (DatasetType& dataset);
60+
61+ template <typename DatasetType = arma::mat>
62+ void HorizontalFlipTransform (DatasetType &dataset);
63+
64+ template <typename DatasetType = arma::mat>
65+ void VerticalFlipTransform (DatasetType& dataset);
66+
67+
68+ private:
69+ /* *
70+ * Function to determine if augmentation has Resize function.
71+ */
72+ bool HasResizeParam ()
73+ {
74+ return augmentations.size () <= 0 ? false :
75+ augmentations[0 ].find (" resize" ) != std::string::npos ;
76+ }
77+
78+ /* *
79+ * Sets size of output width and output height of the new data.
80+ *
81+ * @param outWidth Output width of resized data point.
82+ * @param outHeight Output height of resized data point.
83+ */
84+ void GetResizeParam (size_t & outWidth, size_t & outHeight)
85+ {
86+ if (!HasResizeParam ())
87+ {
88+ return ;
89+ }
90+
91+ outWidth = -1 ;
92+ outHeight = -1 ;
93+
94+ // Use regex to find one / two numbers. If only one provided
95+ // set output width equal to output height.
96+ boost::regex regex{" [0-9]+" };
97+
98+ // Create an iterator to find matches.
99+ boost::sregex_token_iterator matches (augmentations[0 ].begin (),
100+ augmentations[0 ].end (), regex, 0 ), end;
101+
102+ size_t matchesCount = std::distance (matches, end);
103+
104+ if (matchesCount == 0 )
105+ {
106+ mlpack::Log::Fatal << " Invalid size / shape in " <<
107+ augmentations[0 ] << std::endl;
108+ }
109+
110+ if (matchesCount == 1 )
111+ {
112+ outWidth = std::stoi (*matches);
113+ outHeight = outWidth;
114+ }
115+ else
116+ {
117+ outWidth = std::stoi (*matches);
118+ matches++;
119+ outHeight = std::stoi (*matches);
120+ }
121+ }
122+
123+ // ! Locally held augmentations / transforms that need to be applied.
124+ std::vector<std::string> augmentations;
125+
126+ // ! Locally held value of augmentation probability.
127+ double augmentationProbability;
128+ };
129+
130+ #endif
0 commit comments