1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
- """
4
- This is a demo script showing how to use the
5
- PrithviGeospatialMAE model with vLLM
6
- This script is based on: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/blob/main/inference.py # noqa
7
-
8
- Target model weights: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11/resolve/main/Prithvi-EO-V2-300M-TL-Sen1Floods11.pt # noqa
9
-
10
- The requirements for running this script are:
11
- - Installing [terratorch, albumentations, rasterio] in your python environment
12
- - downloading the model weights in a 'model' folder local to the script
13
- (temporary measure until the proper config.json file is uploaded to HF)
14
- - download an input example image (India_900498_S2Hand.tif) and place it in
15
- the same folder with the script (or specify with the --data_file argument)
16
-
17
- Run the example:
18
- python prithvi_geospatial_mae.py
19
-
20
- """ # noqa: E501
21
-
22
3
import argparse
23
4
import datetime
24
5
import os
34
15
35
16
from vllm import LLM
36
17
18
+ torch .set_default_dtype (torch .float16 )
19
+
37
20
NO_DATA = - 9999
38
21
NO_DATA_FLOAT = 0.0001
39
22
OFFSET = 0
40
23
PERCENTILE = 99
41
24
42
- model_config = """{
43
- "architectures": ["PrithviGeoSpatialMAE"],
44
- "num_classes": 0,
45
- "pretrained_cfg": {
46
- "task_args": {
47
- "task": "SemanticSegmentationTask",
48
- "model_factory": "EncoderDecoderFactory",
49
- "loss": "ce",
50
- "ignore_index": -1,
51
- "lr": 0.001,
52
- "freeze_backbone": false,
53
- "freeze_decoder": false,
54
- "plot_on_val": 10,
55
- "optimizer": "AdamW",
56
- "scheduler": "CosineAnnealingLR"
57
- },
58
- "model_args": {
59
- "backbone_pretrained": false,
60
- "backbone": "prithvi_eo_v2_300_tl",
61
- "decoder": "UperNetDecoder",
62
- "decoder_channels": 256,
63
- "decoder_scale_modules": true,
64
- "num_classes": 2,
65
- "rescale": true,
66
- "backbone_bands": [
67
- "BLUE",
68
- "GREEN",
69
- "RED",
70
- "NIR_NARROW",
71
- "SWIR_1",
72
- "SWIR_2"
73
- ],
74
- "head_dropout": 0.1,
75
- "necks": [
76
- {
77
- "name": "SelectIndices",
78
- "indices": [
79
- 5,
80
- 11,
81
- 17,
82
- 23
83
- ]
84
- },
85
- {
86
- "name": "ReshapeTokensToImage"
87
- }
88
- ]
89
- },
90
- "optimizer_params" : {
91
- "lr": 5.0e-05,
92
- "betas": [0.9, 0.999],
93
- "eps": [1.0e-08],
94
- "weight_decay": 0.05,
95
- "amsgrad": false,
96
- "maximize": false,
97
- "capturable": false,
98
- "differentiable": false
99
- },
100
- "scheduler_params" : {
101
- "T_max": 50,
102
- "eta_min": 0,
103
- "last_epoch": -1,
104
- "verbose": "deprecated"
105
- }
106
- },
107
-
108
-
109
- "torch_dtype": "float32"
110
- }
111
- """
112
-
113
- # Temporarily creating the "config.json" for the model.
114
- # This is going to disappear once the correct config.json is available on HF
115
- with open (
116
- os .path .join (os .path .dirname (__file__ ), "./model/config.json" ), "w"
117
- ) as config_file :
118
- config_file .write (model_config )
119
-
120
25
datamodule_config = {
121
26
"bands" : ["BLUE" , "GREEN" , "RED" , "NIR_NARROW" , "SWIR_1" , "SWIR_2" ],
122
27
"batch_size" : 16 ,
138
43
139
44
140
45
class PrithviMAE :
141
- def __init__ (self ):
142
- print ("Initializing PrithviMAE model" )
143
- self .llm = LLM (
144
- model = os .path .join (os .path .dirname (__file__ ), "./model" ),
145
- skip_tokenizer_init = True ,
146
- dtype = "float32" ,
46
+ def __init__ (self , model ):
47
+ self .model = LLM (
48
+ model = model , skip_tokenizer_init = True , dtype = "float16" , enforce_eager = True
147
49
)
148
50
149
51
def run (self , input_data , location_coords ):
150
- print ("################ Running inference on vLLM ##############" )
151
52
# merge the inputs into one data structure
53
+ if input_data is not None and input_data .dtype == torch .float32 :
54
+ input_data = input_data .to (torch .float16 )
55
+ input_data = input_data [0 ]
56
+
152
57
mm_data = {
153
- "pixel_values" : torch .empty (0 ) if input_data is None else input_data ,
154
- "location_coords" : torch .empty (0 )
155
- if location_coords is None
156
- else location_coords ,
58
+ "pixel_values" : input_data ,
59
+ "location_coords" : location_coords ,
157
60
}
158
61
159
62
prompt = {"prompt_token_ids" : [1 ], "multi_modal_data" : mm_data }
160
-
161
- outputs = self .llm .encode (prompt , use_tqdm = False )
162
- print ("################ Inference done (it took seconds) ##############" )
63
+ outputs = self .model .encode (prompt , use_tqdm = False )
163
64
164
65
return outputs [0 ].outputs .data
165
66
@@ -181,11 +82,12 @@ def process_channel_group(orig_img, channels):
181
82
"""
182
83
Args:
183
84
orig_img: torch.Tensor representing original image (reference)
184
- with shape = (bands, H, W).
85
+ with shape = (bands, H, W).
185
86
channels: list of indices representing RGB channels.
186
87
187
88
Returns:
188
- torch.Tensor with shape (num_channels, height, width) for original image
89
+ torch.Tensor with shape (num_channels, height, width)
90
+ for original image
189
91
"""
190
92
191
93
orig_img = orig_img [channels , ...]
@@ -260,10 +162,10 @@ def load_example(
260
162
261
163
Args:
262
164
file_paths: list of file paths .
263
- mean: list containing mean values for each band in the images
264
- in *file_paths*.
265
- std: list containing std values for each band in the images
266
- in *file_paths*.
165
+ mean: list containing mean values for each band in the
166
+ images in *file_paths*.
167
+ std: list containing std values for each band in the
168
+ images in *file_paths*.
267
169
268
170
Returns:
269
171
np.array containing created example
@@ -308,7 +210,7 @@ def load_example(
308
210
print (f"Could not extract timestamp for { file } ({ e } )" )
309
211
310
212
imgs = np .stack (imgs , axis = 0 ) # num_frames, H, W, C
311
- imgs = np .moveaxis (imgs , - 1 , 0 ).astype ("float32" )
213
+ imgs = np .moveaxis (imgs , - 1 , 0 ).astype ("float32" ) # C, num_frames, H, W
312
214
imgs = np .expand_dims (imgs , axis = 0 ) # add batch di
313
215
314
216
return imgs , temporal_coords , location_coords , metas
@@ -332,8 +234,10 @@ def run_model(
332
234
)
333
235
334
236
# Build sliding window
237
+
335
238
batch_size = 1
336
- batch = torch .tensor (input_data , device = "cpu" )
239
+ # batch = torch.tensor(input_data, device="cpu")
240
+ batch = torch .tensor (input_data )
337
241
windows = batch .unfold (3 , img_size , img_size ).unfold (4 , img_size , img_size )
338
242
h1 , w1 = windows .shape [3 :5 ]
339
243
windows = rearrange (
@@ -344,34 +248,24 @@ def run_model(
344
248
num_batches = windows .shape [0 ] // batch_size if windows .shape [0 ] > batch_size else 1
345
249
windows = torch .tensor_split (windows , num_batches , dim = 0 )
346
250
347
- device = torch .device ("cuda" ) if torch .cuda .is_available () else torch .device ("cpu" )
348
-
349
251
if temporal_coords :
350
- temporal_coords = torch .tensor (temporal_coords , device = device ).unsqueeze (0 )
252
+ temporal_coords = torch .tensor (temporal_coords ).unsqueeze (0 )
351
253
else :
352
254
temporal_coords = None
353
255
if location_coords :
354
- location_coords = torch .tensor (location_coords [0 ], device = device ).unsqueeze (0 )
256
+ location_coords = torch .tensor (location_coords [0 ]).unsqueeze (0 )
355
257
else :
356
258
location_coords = None
357
259
358
- # Run model
260
+ # Run Prithvi-EO-V2-300M-TL-Sen1Floods11
359
261
pred_imgs = []
360
262
for x in windows :
361
263
# Apply standardization
362
264
x = datamodule .test_transform (image = x .squeeze ().numpy ().transpose (1 , 2 , 0 ))
363
265
x = datamodule .aug (x )["image" ]
364
266
365
267
with torch .no_grad ():
366
- x = x .to (device )
367
268
pred = model .run (x , location_coords = location_coords )
368
- if lightning_model :
369
- pred_lightning = lightning_model (
370
- x , temporal_coords = temporal_coords , location_coords = location_coords
371
- )
372
- pred_lightning = pred_lightning .output .detach ().cpu ()
373
- if not torch .equal (pred , pred_lightning ):
374
- print ("Inference output is not equal" )
375
269
y_hat = pred .argmax (dim = 1 )
376
270
377
271
y_hat = torch .nn .functional .interpolate (
@@ -403,52 +297,18 @@ def run_model(
403
297
return pred_imgs
404
298
405
299
406
- def parse_args ():
407
- parser = argparse .ArgumentParser ("MAE run inference" , add_help = False )
408
-
409
- parser .add_argument (
410
- "--data_file" ,
411
- type = str ,
412
- default = "./India_900498_S2Hand.tif" ,
413
- help = "Path to the file." ,
414
- )
415
- parser .add_argument (
416
- "--output_dir" ,
417
- type = str ,
418
- default = "output" ,
419
- help = "Path to the directory where to save outputs." ,
420
- )
421
- parser .add_argument (
422
- "--input_indices" ,
423
- default = [1 , 2 , 3 , 8 , 11 , 12 ],
424
- type = int ,
425
- nargs = "+" ,
426
- help = "0-based indices of the six Prithvi channels to be selected from the "
427
- "input. By default selects [1,2,3,8,11,12] for S2L1C data." ,
428
- )
429
- parser .add_argument (
430
- "--rgb_outputs" ,
431
- action = "store_true" ,
432
- help = "If present, output files will only contain RGB channels. "
433
- "Otherwise, all bands will be saved." ,
434
- )
435
-
436
-
437
300
def main (
438
301
data_file : str ,
302
+ model : str ,
439
303
output_dir : str ,
440
304
rgb_outputs : bool ,
441
305
input_indices : list [int ] = None ,
442
306
):
443
307
os .makedirs (output_dir , exist_ok = True )
444
308
445
- # Load model ---------------------------------------------------------------
446
-
447
- model_obj = PrithviMAE ()
309
+ model_obj = PrithviMAE (model = model )
448
310
datamodule = generate_datamodule ()
449
- img_size = 256 # Size of Sen1Floods11
450
-
451
- # Loading data -------------------------------------------------------------
311
+ img_size = 512 # Size of Sen1Floods11
452
312
453
313
input_data , temporal_coords , location_coords , meta_data = load_example (
454
314
file_paths = [data_file ],
@@ -460,16 +320,13 @@ def main(
460
320
if input_data .mean () > 1 :
461
321
input_data = input_data / 10000 # Convert to range 0-1
462
322
463
- # Running model ------------------------------------------------------------
464
-
465
323
channels = [
466
324
datamodule_config ["bands" ].index (b ) for b in ["RED" , "GREEN" , "BLUE" ]
467
325
] # BGR -> RGB
468
326
469
327
pred = run_model (
470
328
input_data , temporal_coords , location_coords , model_obj , datamodule , img_size
471
329
)
472
-
473
330
# Save pred
474
331
meta_data .update (count = 1 , dtype = "uint8" , compress = "lzw" , nodata = 0 )
475
332
pred_file = os .path .join (
@@ -487,6 +344,7 @@ def main(
487
344
orig_img = torch .Tensor (input_data [0 , :, 0 , ...]),
488
345
channels = channels ,
489
346
)
347
+ rgb_orig = rgb_orig .to (torch .float32 )
490
348
491
349
pred [pred == 0.0 ] = np .nan
492
350
img_pred = rgb_orig * 0.7 + pred * 0.3
@@ -503,9 +361,10 @@ def main(
503
361
504
362
# Save image rgb
505
363
if rgb_outputs :
364
+ name_suffix = os .path .splitext (os .path .basename (data_file ))[0 ]
506
365
rgb_file = os .path .join (
507
366
output_dir ,
508
- f"original_rgb_{ os . path . splitext ( os . path . basename ( data_file ))[ 0 ] } .tiff" ,
367
+ f"original_rgb_{ name_suffix } .tiff" ,
509
368
)
510
369
save_geotiff (
511
370
image = _convert_np_uint8 (rgb_orig ),
@@ -515,6 +374,42 @@ def main(
515
374
516
375
517
376
if __name__ == "__main__" :
518
- args = parse_args ()
377
+ parser = argparse .ArgumentParser ("MAE run inference" , add_help = False )
378
+
379
+ parser .add_argument (
380
+ "--data_file" ,
381
+ type = str ,
382
+ default = "./India_900498_S2Hand.tif" ,
383
+ help = "Path to the file." ,
384
+ )
385
+ parser .add_argument (
386
+ "--model" ,
387
+ type = str ,
388
+ default = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM" ,
389
+ help = "Path to a checkpoint file to load from." ,
390
+ )
391
+ parser .add_argument (
392
+ "--output_dir" ,
393
+ type = str ,
394
+ default = "output" ,
395
+ help = "Path to the directory where to save outputs." ,
396
+ )
397
+ parser .add_argument (
398
+ "--input_indices" ,
399
+ default = [1 , 2 , 3 , 8 , 11 , 12 ],
400
+ type = int ,
401
+ nargs = "+" ,
402
+ help = """
403
+ 0-based indices of the six Prithvi channels to be selected from the input.
404
+ By default selects [1,2,3,8,11,12] for S2L1C data.
405
+ """ ,
406
+ )
407
+ parser .add_argument (
408
+ "--rgb_outputs" ,
409
+ action = "store_true" ,
410
+ help = "If present, output files will only contain RGB channels. "
411
+ "Otherwise, all bands will be saved." ,
412
+ )
413
+ args = parser .parse_args ()
519
414
520
415
main (** vars (args ))
0 commit comments