@@ -92,6 +92,7 @@ constexpr char kOpTypeReduceSum[] = "ReduceSum";
92
92
constexpr char kOpTypeReduceSumSquare [] = " ReduceSumSquare" ;
93
93
94
94
constexpr char kOpTypeRelu [] = " Relu" ;
95
+ constexpr char kOpTypeResample2d [] = " Resize" ;
95
96
constexpr char kOpTypeReshape [] = " Reshape" ;
96
97
constexpr char kOpTypeSigmoid [] = " Sigmoid" ;
97
98
constexpr char kOpTypeSlice [] = " Slice" ;
@@ -1302,6 +1303,70 @@ void GraphBuilderOrt::AddReduceOperation(const mojom::Reduce& reduce) {
1302
1303
attributes);
1303
1304
}
1304
1305
1306
+ void GraphBuilderOrt::AddResample2dOperation (
1307
+ const mojom::Resample2d& resample2d) {
1308
+ const std::string node_name = GetNodeName (resample2d.label );
1309
+ const std::string input_name = GetOperandName (resample2d.input_operand_id );
1310
+ const std::string output_name = GetOperandName (resample2d.output_operand_id );
1311
+ const std::vector<uint32_t >& output_shape =
1312
+ GetOperand (resample2d.output_operand_id ).descriptor .shape ();
1313
+ std::vector<const char *> input_names = {input_name.c_str ()};
1314
+
1315
+ // ROI only takes effect when ONNX Resize op's attribute
1316
+ // coordinate_transformation_mode is “tf_crop_and_resize” and the default
1317
+ // value of coordinate_transformation_mode is "half_pixel". Currently, WebNN
1318
+ // only supports "half_pixel".
1319
+ const std::string roi_name = " " ;
1320
+ input_names.push_back (roi_name.c_str ());
1321
+
1322
+ CHECK_EQ (resample2d.axes .size (), 2u );
1323
+ std::string scales_name;
1324
+ std::string sizes_name;
1325
+ if (resample2d.scales ) {
1326
+ // The number of elements of scales should be the same as the rank of axes
1327
+ // if provided.
1328
+ std::array<float , 2 > scales_data = {resample2d.scales ->at (0 ),
1329
+ resample2d.scales ->at (1 )};
1330
+ scales_name = CreateInitializer<float >({2 }, scales_data);
1331
+ sizes_name = " " ;
1332
+ } else {
1333
+ // The number of elements of sizes should be the same as the length of axes
1334
+ // if provided.
1335
+ std::array<int64_t , 2 > sizes_data = {
1336
+ base::checked_cast<int64_t >(output_shape[resample2d.axes [0 ]]),
1337
+ base::checked_cast<int64_t >(output_shape[resample2d.axes [1 ]])};
1338
+ sizes_name = CreateInitializer<int64_t >({2 }, sizes_data);
1339
+ scales_name = " " ;
1340
+ }
1341
+ input_names.push_back (scales_name.c_str ());
1342
+ input_names.push_back (sizes_name.c_str ());
1343
+
1344
+ std::array<int64_t , 2 > axes = {
1345
+ base::checked_cast<int64_t >(resample2d.axes [0 ]),
1346
+ base::checked_cast<int64_t >(resample2d.axes [1 ])};
1347
+ ScopedOrtOpAttrPtr attr_axes =
1348
+ model_builder_.CreateAttribute (/* name=*/ " axes" , axes);
1349
+
1350
+ std::string mode;
1351
+ switch (resample2d.mode ) {
1352
+ case mojom::Resample2d::InterpolationMode::kLinear :
1353
+ mode = " linear" ;
1354
+ break ;
1355
+ case mojom::Resample2d::InterpolationMode::kNearestNeighbor :
1356
+ mode = " nearest" ;
1357
+ break ;
1358
+ }
1359
+ ScopedOrtOpAttrPtr attr_mode =
1360
+ model_builder_.CreateAttribute (/* name=*/ " mode" , mode);
1361
+ std::array<OrtOpAttr*, 2 > attributes = {attr_axes.Release (),
1362
+ attr_mode.Release ()};
1363
+
1364
+ std::array<const char *, 1 > output_names = {output_name.c_str ()};
1365
+
1366
+ model_builder_.AddNode (kOpTypeResample2d , node_name, input_names,
1367
+ output_names, attributes);
1368
+ }
1369
+
1305
1370
void GraphBuilderOrt::AddReshapeOperation (const mojom::Reshape& reshape) {
1306
1371
const std::string node_name = GetNodeName (reshape.label );
1307
1372
const std::string input_name = GetOperandName (reshape.input_operand_id );
@@ -1524,6 +1589,10 @@ GraphBuilderOrt::BuildModel() {
1524
1589
AddUnaryOperation (*operation->get_relu (), kOpTypeRelu );
1525
1590
break ;
1526
1591
}
1592
+ case mojom::Operation::Tag::kResample2d : {
1593
+ AddResample2dOperation (*operation->get_resample2d ());
1594
+ break ;
1595
+ }
1527
1596
case mojom::Operation::Tag::kReshape : {
1528
1597
AddReshapeOperation (*operation->get_reshape ());
1529
1598
break ;
@@ -1565,7 +1634,6 @@ GraphBuilderOrt::BuildModel() {
1565
1634
case mojom::Operation::Tag::kLstmCell :
1566
1635
case mojom::Operation::Tag::kPrelu :
1567
1636
case mojom::Operation::Tag::kQuantizeLinear :
1568
- case mojom::Operation::Tag::kResample2d :
1569
1637
case mojom::Operation::Tag::kReverse :
1570
1638
case mojom::Operation::Tag::kScatterElements :
1571
1639
case mojom::Operation::Tag::kScatterNd :
0 commit comments