Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 41 additions & 22 deletions services/webnn/ort/graph_builder_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,11 @@ struct GraphFusionInfo {
// on that the `operations` in `mojom::GraphInfo` have been in topological
// order which means if operation 'j' depends on 'i', 'i' must appear before
// 'j'.
GraphFusionInfo GetGraphFusionInfo(const mojom::GraphInfo& graph_info) {
GraphFusionInfo GetGraphFusionInfo(
const mojom::GraphInfo& graph_info,
base::flat_map<uint64_t, std::unique_ptr<WebNNConstantOperand>>&
constant_operands,
const ContextProperties& context_properties) {
// If it's disabled, just return empty 'GraphFusionInfo' object which means no
// graph fusion will be applied since currently we only enable matmulnbits
// fusion.
Expand Down Expand Up @@ -869,6 +873,35 @@ GraphFusionInfo GetGraphFusionInfo(const mojom::GraphInfo& graph_info) {
graph_fusion_info.matmul_input_b_to_fusible_dequantize_map
[transpose->get_transpose()->output_operand_id] = operation.get();

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth adding a comment explaining why it needs to change the data type.

const std::vector<uint32_t>& input_shape =
input_operand->descriptor.shape();
uint32_t input_feature_size = input_shape[0];
uint32_t quant_num = input_shape[1];
uint32_t blob_bytes = input_shape[2] / 2;
auto input_constant = std::move(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::move is unnecessary? I guess std::move a reference would result in a copying.

constant_operands.at(dequantize_linear->input_operand_id));
std::vector<uint32_t> new_input_buffer_shape = {input_feature_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this line right after declaring blob_bytes

quant_num, blob_bytes};
auto new_input_desc = *OperandDescriptor::Create(
context_properties, OperandDataType::kUint8, new_input_buffer_shape,
dequantize_linear->label);
auto new_input_constant = std::make_unique<WebNNConstantOperand>(
std::move(new_input_desc), input_constant->TakeData());
constant_operands[dequantize_linear->input_operand_id] =
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You may want to check this constant is only used by dequantizeLinear before replacing it.

std::move(new_input_constant);

auto zero_point_constant = std::move(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, std::move seems to be unnecessary.

constant_operands.at(dequantize_linear->zero_point_operand_id));
std::vector<uint32_t> new_zero_point_buffer_shape = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
std::vector<uint32_t> new_zero_point_buffer_shape = {
std::vector<uint32_t> new_zero_point_shape = {

input_feature_size * ((quant_num + 1) / 2)};
auto new_zero_point_desc = *OperandDescriptor::Create(
context_properties, OperandDataType::kUint8,
new_zero_point_buffer_shape, dequantize_linear->label);

auto new_zero_point_constant = std::make_unique<WebNNConstantOperand>(
std::move(new_zero_point_desc), zero_point_constant->TakeData());
constant_operands[dequantize_linear->zero_point_operand_id] =
std::move(new_zero_point_constant);
break;
}
default: {
Expand Down Expand Up @@ -3552,28 +3585,13 @@ GraphBuilderOrt::AddMatMulOperation(

uint32_t input_feature_size = input_b_shape[0];
uint32_t quant_num = input_b_shape[1];
uint32_t blob_bytes = input_b_shape[2] / 2;
// uint32_t blob_bytes = input_b_shape[2] / 2;
uint32_t block_size = input_b_shape[2];
uint32_t output_feature_size = quant_num * block_size;

const WebNNConstantOperand& input_constant =
*constant_operands_.at(dequantize_linear->input_operand_id);
std::vector<uint32_t> new_input_buffer_shape = {input_feature_size,
quant_num, blob_bytes};

ASSIGN_OR_RETURN(input_b,
CreateInitializer<uint8_t>(new_input_buffer_shape,
input_constant.ByteSpan()));

const WebNNConstantOperand& zero_point_constant =
*constant_operands_.at(dequantize_linear->zero_point_operand_id);
std::vector<uint32_t> new_zero_point_buffer_shape = {input_feature_size *
((quant_num + 1) / 2)};
ASSIGN_OR_RETURN(
std::string zero_point,
CreateInitializer<uint8_t>(new_zero_point_buffer_shape,
zero_point_constant.ByteSpan()));

input_b = GetOperandNameById(dequantize_linear->input_operand_id);
std::string zero_point =
GetOperandNameById(dequantize_linear->zero_point_operand_id);
std::string scale = GetOperandNameById(dequantize_linear->scale_operand_id);
// Here we insert a reshape since the original reshape has been folded into
// scale due to constant folding.
Expand Down Expand Up @@ -4236,6 +4254,9 @@ GraphBuilderOrt::BuildModel() {
AddInput(input_id);
}

GraphFusionInfo graph_fusion_info =
GetGraphFusionInfo(*graph_info_, constant_operands_, context_properties_);

// Add initializers.
for (const auto& [constant_id, _] : constant_operands_) {
RETURN_IF_ERROR(AddInitializer(constant_id));
Expand All @@ -4245,8 +4266,6 @@ GraphBuilderOrt::BuildModel() {
// Find all the bool operands.
FindBoolOperands();

GraphFusionInfo graph_fusion_info = GetGraphFusionInfo(*graph_info_);

// Add operations.
for (const mojom::OperationPtr& operation : graph_info_->operations) {
// Skip the operations which are fused into another operation.
Expand Down