|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | +import pandas as pd |
| 4 | +from torch.utils.data import DataLoader |
| 5 | + |
| 6 | +from pytorch_widedeep.models import TabResnet # noqa: F401 |
| 7 | +from pytorch_widedeep.models import TabMlp, Vision, BasicRNN, WideDeep |
| 8 | +from pytorch_widedeep.training import TrainerFromFolder |
| 9 | +from pytorch_widedeep.callbacks import EarlyStopping, ModelCheckpoint |
| 10 | +from pytorch_widedeep.preprocessing import ( |
| 11 | + ImagePreprocessor, |
| 12 | + ChunkTabPreprocessor, |
| 13 | + ChunkTextPreprocessor, |
| 14 | +) |
| 15 | +from pytorch_widedeep.load_from_folder import ( |
| 16 | + TabFromFolder, |
| 17 | + TextFromFolder, |
| 18 | + ImageFromFolder, |
| 19 | + WideDeepDatasetFromFolder, |
| 20 | +) |
| 21 | + |
| 22 | +use_cuda = torch.cuda.is_available() |
| 23 | + |
| 24 | +if __name__ == "__main__": |
| 25 | + # The airbnb dataset, which you could get from here: |
| 26 | + # http://insideairbnb.com/get-the-data.html, is too big to be included in |
| 27 | + # our datasets module (when including images). Therefore, go there, |
| 28 | + # download it, and use the download_images.py script to get the images |
| 29 | + # and the airbnb_data_processing.py to process the data. We'll find |
| 30 | + # better datasets in the future ;). Note that here we are only using a |
| 31 | + # small sample to illustrate the use, so PLEASE ignore the results, just |
| 32 | + # focus on usage |
| 33 | + |
| 34 | + # For this exercise, we use a small sample of the airbnb dataset, |
| 35 | + # comprised of tabular data with a text column ('description') and an |
| 36 | + # image column ('id') that point to the images of the properties listed |
| 37 | + # in Airbnb. We know the size of the sample before hand (1001) so we set |
| 38 | + # a series of parameters accordingly |
| 39 | + train_size = 800 |
| 40 | + eval_size = 100 |
| 41 | + test_size = 101 |
| 42 | + chunksize = 100 |
| 43 | + n_chunks = int(np.ceil(train_size / chunksize)) |
| 44 | + |
| 45 | + data_path = "../tmp_data/airbnb/" |
| 46 | + train_fname = "airbnb_sample_train.csv" |
| 47 | + eval_fname = "airbnb_sample_eval.csv" |
| 48 | + test_fname = "airbnb_sample_test.csv" |
| 49 | + |
| 50 | + # the images are stored in the 'property_picture' while the text is a |
| 51 | + # column in the 'airbnb_sample' dataframe. Let's then define the dir and |
| 52 | + # file variables |
| 53 | + img_path = "../tmp_data/airbnb/property_picture/" |
| 54 | + img_col = "id" |
| 55 | + text_col = "description" |
| 56 | + target_col = "yield" |
| 57 | + cat_embed_cols = [ |
| 58 | + "host_listings_count", |
| 59 | + "neighbourhood_cleansed", |
| 60 | + "is_location_exact", |
| 61 | + "property_type", |
| 62 | + "room_type", |
| 63 | + "accommodates", |
| 64 | + "bathrooms", |
| 65 | + "bedrooms", |
| 66 | + "beds", |
| 67 | + "guests_included", |
| 68 | + "minimum_nights", |
| 69 | + "instant_bookable", |
| 70 | + "cancellation_policy", |
| 71 | + "has_house_rules", |
| 72 | + "host_gender", |
| 73 | + "accommodates_catg", |
| 74 | + "guests_included_catg", |
| 75 | + "minimum_nights_catg", |
| 76 | + "host_listings_count_catg", |
| 77 | + "bathrooms_catg", |
| 78 | + "bedrooms_catg", |
| 79 | + "beds_catg", |
| 80 | + "security_deposit", |
| 81 | + "extra_people", |
| 82 | + ] |
| 83 | + cont_cols = ["latitude", "longitude"] |
| 84 | + |
| 85 | + # Now, processing the data from here on can be done in two ways: |
| 86 | + # 1. The tabular data itsel fits in memory as is only the images that do |
| 87 | + # not: in this case you could use the 'standard' Preprocessors and off |
| 88 | + # you go, move directly to the '[...]FromFolder' functionalities |
| 89 | + |
| 90 | + # 2. The tabular data is also very large and does not fit in memory, so we |
| 91 | + # have to process it in chuncks. For this second case I have created the |
| 92 | + # Chunk Processors (Wide, Tab and Text). Note that at the moment ONLY csv |
| 93 | + # format is allowed for the tabular file. More formats will be supported |
| 94 | + # in the future. |
| 95 | + |
| 96 | + # For the following I will assume (simply for illustration purposes) that |
| 97 | + # we are in the second case. Nonetheless, the process, whether 1 or 2, |
| 98 | + # can be summarised as follows: |
| 99 | + # 1. Process the data |
| 100 | + # 2. Define the loaders from folder |
| 101 | + # 3. Define the datasets and dataloaders |
| 102 | + # 4. Define the model and the Trainer |
| 103 | + # 5. Fit the model and Predict |
| 104 | + |
| 105 | + # Process the data in Chunks |
| 106 | + tab_preprocessor = ChunkTabPreprocessor( |
| 107 | + embed_cols=cat_embed_cols, |
| 108 | + continuous_cols=cont_cols, |
| 109 | + n_chunks=n_chunks, |
| 110 | + default_embed_dim=8, |
| 111 | + verbose=0, |
| 112 | + ) |
| 113 | + |
| 114 | + text_preprocessor = ChunkTextPreprocessor( |
| 115 | + n_chunks=n_chunks, |
| 116 | + text_col=text_col, |
| 117 | + n_cpus=1, |
| 118 | + ) |
| 119 | + |
| 120 | + # Note that all the (pre)processing of the images will occur 'on the fly', |
| 121 | + # as they are loaded from disk. Therefore, the flow for the image dataset |
| 122 | + # and for the tabular and text data modes is not entirely the same. |
| 123 | + # Tabular and text data uses Chunk processors while such processing |
| 124 | + # approach is not needed for the images |
| 125 | + img_preprocessor = ImagePreprocessor( |
| 126 | + img_col=img_col, |
| 127 | + img_path=img_path, |
| 128 | + ) |
| 129 | + |
| 130 | + for i, chunk in enumerate( |
| 131 | + pd.read_csv("/".join([data_path, train_fname]), chunksize=chunksize) |
| 132 | + ): |
| 133 | + print(f"chunk in loop: {i}") |
| 134 | + tab_preprocessor.fit(chunk) |
| 135 | + text_preprocessor.fit(chunk) |
| 136 | + |
| 137 | + # Instantiate the loaders from folder: again here some explanation is |
| 138 | + # required. As I mentioned earlier the "[...]FromFolder" functionalities |
| 139 | + # are thought for the case when we have tabular and text and/or image |
| 140 | + # datasets and the latter do not fit in memory, so they have to be loaded |
| 141 | + # from disk. With this in mind, the tabular data is the reference, and |
| 142 | + # must have columns that point to the image files and to the text files |
| 143 | + # (in case these exists instead of a column with the texts). Since the |
| 144 | + # tabular data is used as a reference, is the one that has to be splitted |
| 145 | + # in train/validation/test. The test and image 'FromFolder' objects only |
| 146 | + # point to the corresponding column or files, and therefore, we do not |
| 147 | + # need to create a separate instance per train/validation/test dataset |
| 148 | + train_tab_folder = TabFromFolder( |
| 149 | + fname=train_fname, |
| 150 | + directory=data_path, |
| 151 | + target_col=target_col, |
| 152 | + preprocessor=tab_preprocessor, |
| 153 | + text_col=text_col, |
| 154 | + img_col=img_col, |
| 155 | + ) |
| 156 | + eval_tab_folder = TabFromFolder(fname=eval_fname, reference=train_tab_folder) # type: ignore[arg-type] |
| 157 | + test_tab_folder = TabFromFolder( |
| 158 | + fname=test_fname, reference=train_tab_folder, ignore_target=True # type: ignore[arg-type] |
| 159 | + ) |
| 160 | + |
| 161 | + text_folder = TextFromFolder( |
| 162 | + preprocessor=text_preprocessor, |
| 163 | + ) |
| 164 | + |
| 165 | + img_folder = ImageFromFolder(preprocessor=img_preprocessor) |
| 166 | + |
| 167 | + # Following 'standard' pytorch approaches, we define the datasets and then |
| 168 | + # the dataloaders |
| 169 | + train_dataset_folder = WideDeepDatasetFromFolder( |
| 170 | + n_samples=train_size, |
| 171 | + tab_from_folder=train_tab_folder, |
| 172 | + text_from_folder=text_folder, |
| 173 | + img_from_folder=img_folder, |
| 174 | + ) |
| 175 | + eval_dataset_folder = WideDeepDatasetFromFolder( |
| 176 | + n_samples=eval_size, |
| 177 | + tab_from_folder=eval_tab_folder, |
| 178 | + reference=train_dataset_folder, |
| 179 | + ) |
| 180 | + test_dataset_folder = WideDeepDatasetFromFolder( |
| 181 | + n_samples=test_size, |
| 182 | + tab_from_folder=test_tab_folder, |
| 183 | + reference=train_dataset_folder, |
| 184 | + ) |
| 185 | + train_loader = DataLoader(train_dataset_folder, batch_size=16, num_workers=1) |
| 186 | + eval_loader = DataLoader(eval_dataset_folder, batch_size=16, num_workers=1) |
| 187 | + test_loader = DataLoader(test_dataset_folder, batch_size=16, num_workers=1) |
| 188 | + |
| 189 | + # And from here on, is all pretty standard within the library |
| 190 | + basic_rnn = BasicRNN( |
| 191 | + vocab_size=len(text_preprocessor.vocab.itos), |
| 192 | + embed_dim=32, |
| 193 | + hidden_dim=64, |
| 194 | + n_layers=2, |
| 195 | + head_hidden_dims=[100, 50], |
| 196 | + ) |
| 197 | + |
| 198 | + deepimage = Vision( |
| 199 | + pretrained_model_name="resnet18", n_trainable=0, head_hidden_dims=[200, 100] |
| 200 | + ) |
| 201 | + |
| 202 | + deepdense = TabMlp( |
| 203 | + mlp_hidden_dims=[64, 32], |
| 204 | + column_idx=tab_preprocessor.column_idx, |
| 205 | + cat_embed_input=tab_preprocessor.cat_embed_input, |
| 206 | + continuous_cols=cont_cols, |
| 207 | + ) |
| 208 | + |
| 209 | + model = WideDeep( |
| 210 | + deeptabular=deepdense, |
| 211 | + deeptext=basic_rnn, |
| 212 | + deepimage=deepimage, |
| 213 | + ) |
| 214 | + |
| 215 | + callbacks = [EarlyStopping, ModelCheckpoint(filepath="model_weights/wd_out.pt")] |
| 216 | + |
| 217 | + trainer = TrainerFromFolder( |
| 218 | + model, |
| 219 | + objective="regression", |
| 220 | + callbacks=callbacks, |
| 221 | + ) |
| 222 | + |
| 223 | + trainer.fit( |
| 224 | + train_loader=train_loader, |
| 225 | + eval_loader=eval_loader, |
| 226 | + finetune=True, |
| 227 | + finetune_epochs=1, |
| 228 | + ) |
| 229 | + preds = trainer.predict(test_loader=test_loader) |
0 commit comments