Skip to content

Commit 84ed120

Browse files
lisa0314shiyi9801
authored andcommitted
implement resample2d (otcshare#88)
1 parent 566ac79 commit 84ed120

File tree

3 files changed

+71
-2
lines changed

3 files changed

+71
-2
lines changed

services/webnn/ort/context_impl_ort.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ ContextProperties ContextImplOrt::GetContextProperties() {
137137
/*reduce_sum_input=*/DataTypeConstraint::kFloat16To32Ints32To64,
138138
/*reduce_sum_square_input=*/DataTypeConstraint::kFloat16To32Ints32To64,
139139
/*relu_input=*/DataTypeConstraint::kFloat16To32Int8To32,
140-
/*resample2d_input=*/{},
140+
/*resample2d_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
141141
/*reshape_input=*/DataTypeConstraint::kAllDataTypesAtLeast8bits,
142142
/*reverse_input=*/{},
143143
/*scatter_elements_input=*/{},

services/webnn/ort/graph_builder_ort.cc

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ constexpr char kOpTypeReduceSum[] = "ReduceSum";
9292
constexpr char kOpTypeReduceSumSquare[] = "ReduceSumSquare";
9393

9494
constexpr char kOpTypeRelu[] = "Relu";
95+
constexpr char kOpTypeResample2d[] = "Resize";
9596
constexpr char kOpTypeReshape[] = "Reshape";
9697
constexpr char kOpTypeSigmoid[] = "Sigmoid";
9798
constexpr char kOpTypeSlice[] = "Slice";
@@ -1302,6 +1303,70 @@ void GraphBuilderOrt::AddReduceOperation(const mojom::Reduce& reduce) {
13021303
attributes);
13031304
}
13041305

1306+
void GraphBuilderOrt::AddResample2dOperation(
1307+
const mojom::Resample2d& resample2d) {
1308+
const std::string node_name = GetNodeName(resample2d.label);
1309+
const std::string input_name = GetOperandName(resample2d.input_operand_id);
1310+
const std::string output_name = GetOperandName(resample2d.output_operand_id);
1311+
const std::vector<uint32_t>& output_shape =
1312+
GetOperand(resample2d.output_operand_id).descriptor.shape();
1313+
std::vector<const char*> input_names = {input_name.c_str()};
1314+
1315+
// ROI only takes effect when ONNX Resize op's attribute
1316+
// coordinate_transformation_mode is “tf_crop_and_resize” and the default
1317+
// value of coordinate_transformation_mode is "half_pixel". Currently, WebNN
1318+
// only supports "half_pixel".
1319+
const std::string roi_name = "";
1320+
input_names.push_back(roi_name.c_str());
1321+
1322+
CHECK_EQ(resample2d.axes.size(), 2u);
1323+
std::string scales_name;
1324+
std::string sizes_name;
1325+
if (resample2d.scales) {
1326+
// The number of elements of scales should be the same as the rank of axes
1327+
// if provided.
1328+
std::array<float, 2> scales_data = {resample2d.scales->at(0),
1329+
resample2d.scales->at(1)};
1330+
scales_name = CreateInitializer<float>({2}, scales_data);
1331+
sizes_name = "";
1332+
} else {
1333+
// The number of elements of sizes should be the same as the length of axes
1334+
// if provided.
1335+
std::array<int64_t, 2> sizes_data = {
1336+
base::checked_cast<int64_t>(output_shape[resample2d.axes[0]]),
1337+
base::checked_cast<int64_t>(output_shape[resample2d.axes[1]])};
1338+
sizes_name = CreateInitializer<int64_t>({2}, sizes_data);
1339+
scales_name = "";
1340+
}
1341+
input_names.push_back(scales_name.c_str());
1342+
input_names.push_back(sizes_name.c_str());
1343+
1344+
std::array<int64_t, 2> axes = {
1345+
base::checked_cast<int64_t>(resample2d.axes[0]),
1346+
base::checked_cast<int64_t>(resample2d.axes[1])};
1347+
ScopedOrtOpAttrPtr attr_axes =
1348+
model_builder_.CreateAttribute(/*name=*/"axes", axes);
1349+
1350+
std::string mode;
1351+
switch (resample2d.mode) {
1352+
case mojom::Resample2d::InterpolationMode::kLinear:
1353+
mode = "linear";
1354+
break;
1355+
case mojom::Resample2d::InterpolationMode::kNearestNeighbor:
1356+
mode = "nearest";
1357+
break;
1358+
}
1359+
ScopedOrtOpAttrPtr attr_mode =
1360+
model_builder_.CreateAttribute(/*name=*/"mode", mode);
1361+
std::array<OrtOpAttr*, 2> attributes = {attr_axes.Release(),
1362+
attr_mode.Release()};
1363+
1364+
std::array<const char*, 1> output_names = {output_name.c_str()};
1365+
1366+
model_builder_.AddNode(kOpTypeResample2d, node_name, input_names,
1367+
output_names, attributes);
1368+
}
1369+
13051370
void GraphBuilderOrt::AddReshapeOperation(const mojom::Reshape& reshape) {
13061371
const std::string node_name = GetNodeName(reshape.label);
13071372
const std::string input_name = GetOperandName(reshape.input_operand_id);
@@ -1524,6 +1589,10 @@ GraphBuilderOrt::BuildModel() {
15241589
AddUnaryOperation(*operation->get_relu(), kOpTypeRelu);
15251590
break;
15261591
}
1592+
case mojom::Operation::Tag::kResample2d: {
1593+
AddResample2dOperation(*operation->get_resample2d());
1594+
break;
1595+
}
15271596
case mojom::Operation::Tag::kReshape: {
15281597
AddReshapeOperation(*operation->get_reshape());
15291598
break;
@@ -1565,7 +1634,6 @@ GraphBuilderOrt::BuildModel() {
15651634
case mojom::Operation::Tag::kLstmCell:
15661635
case mojom::Operation::Tag::kPrelu:
15671636
case mojom::Operation::Tag::kQuantizeLinear:
1568-
case mojom::Operation::Tag::kResample2d:
15691637
case mojom::Operation::Tag::kReverse:
15701638
case mojom::Operation::Tag::kScatterElements:
15711639
case mojom::Operation::Tag::kScatterNd:

services/webnn/ort/graph_builder_ort.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ class GraphBuilderOrt {
185185
const mojom::Pad& pad);
186186
void AddPool2dOperation(const mojom::Pool2d& pool2d);
187187
void AddReduceOperation(const mojom::Reduce& reduce);
188+
void AddResample2dOperation(const mojom::Resample2d& resample2d);
188189
void AddReshapeOperation(const mojom::Reshape& reshape);
189190
void AddSliceOperation(const mojom::Slice& slice);
190191
void AddSoftmaxOperation(const mojom::Softmax& softmax);

0 commit comments

Comments
 (0)