diff --git a/services/webnn/ort/graph_builder_ort.cc b/services/webnn/ort/graph_builder_ort.cc index fe01db84394976..5bd0e1b8279fbe 100644 --- a/services/webnn/ort/graph_builder_ort.cc +++ b/services/webnn/ort/graph_builder_ort.cc @@ -669,7 +669,11 @@ struct GraphFusionInfo { // on that the `operations` in `mojom::GraphInfo` have been in topological // order which means if operation 'j' depends on 'i', 'i' must appear before // 'j'. -GraphFusionInfo GetGraphFusionInfo(const mojom::GraphInfo& graph_info) { +GraphFusionInfo GetGraphFusionInfo( + const mojom::GraphInfo& graph_info, + base::flat_map>& + constant_operands, + const ContextProperties& context_properties) { // If it's disabled, just return empty 'GraphFusionInfo' object which means no // graph fusion will be applied since currently we only enable matmulnbits // fusion. @@ -869,6 +873,35 @@ GraphFusionInfo GetGraphFusionInfo(const mojom::GraphInfo& graph_info) { graph_fusion_info.matmul_input_b_to_fusible_dequantize_map [transpose->get_transpose()->output_operand_id] = operation.get(); + const std::vector& input_shape = + input_operand->descriptor.shape(); + uint32_t input_feature_size = input_shape[0]; + uint32_t quant_num = input_shape[1]; + uint32_t blob_bytes = input_shape[2] / 2; + auto input_constant = std::move( + constant_operands.at(dequantize_linear->input_operand_id)); + std::vector new_input_buffer_shape = {input_feature_size, + quant_num, blob_bytes}; + auto new_input_desc = *OperandDescriptor::Create( + context_properties, OperandDataType::kUint8, new_input_buffer_shape, + dequantize_linear->label); + auto new_input_constant = std::make_unique( + std::move(new_input_desc), input_constant->TakeData()); + constant_operands[dequantize_linear->input_operand_id] = + std::move(new_input_constant); + + auto zero_point_constant = std::move( + constant_operands.at(dequantize_linear->zero_point_operand_id)); + std::vector new_zero_point_buffer_shape = { + input_feature_size * ((quant_num + 1) / 2)}; + auto new_zero_point_desc = *OperandDescriptor::Create( + context_properties, OperandDataType::kUint8, + new_zero_point_buffer_shape, dequantize_linear->label); + + auto new_zero_point_constant = std::make_unique( + std::move(new_zero_point_desc), zero_point_constant->TakeData()); + constant_operands[dequantize_linear->zero_point_operand_id] = + std::move(new_zero_point_constant); break; } default: { @@ -3552,28 +3585,13 @@ GraphBuilderOrt::AddMatMulOperation( uint32_t input_feature_size = input_b_shape[0]; uint32_t quant_num = input_b_shape[1]; - uint32_t blob_bytes = input_b_shape[2] / 2; + // uint32_t blob_bytes = input_b_shape[2] / 2; uint32_t block_size = input_b_shape[2]; uint32_t output_feature_size = quant_num * block_size; - const WebNNConstantOperand& input_constant = - *constant_operands_.at(dequantize_linear->input_operand_id); - std::vector new_input_buffer_shape = {input_feature_size, - quant_num, blob_bytes}; - - ASSIGN_OR_RETURN(input_b, - CreateInitializer(new_input_buffer_shape, - input_constant.ByteSpan())); - - const WebNNConstantOperand& zero_point_constant = - *constant_operands_.at(dequantize_linear->zero_point_operand_id); - std::vector new_zero_point_buffer_shape = {input_feature_size * - ((quant_num + 1) / 2)}; - ASSIGN_OR_RETURN( - std::string zero_point, - CreateInitializer(new_zero_point_buffer_shape, - zero_point_constant.ByteSpan())); - + input_b = GetOperandNameById(dequantize_linear->input_operand_id); + std::string zero_point = + GetOperandNameById(dequantize_linear->zero_point_operand_id); std::string scale = GetOperandNameById(dequantize_linear->scale_operand_id); // Here we insert a reshape since the original reshape has been folded into // scale due to constant folding. @@ -4236,6 +4254,9 @@ GraphBuilderOrt::BuildModel() { AddInput(input_id); } + GraphFusionInfo graph_fusion_info = + GetGraphFusionInfo(*graph_info_, constant_operands_, context_properties_); + // Add initializers. for (const auto& [constant_id, _] : constant_operands_) { RETURN_IF_ERROR(AddInitializer(constant_id)); @@ -4245,8 +4266,6 @@ GraphBuilderOrt::BuildModel() { // Find all the bool operands. FindBoolOperands(); - GraphFusionInfo graph_fusion_info = GetGraphFusionInfo(*graph_info_); - // Add operations. for (const mojom::OperationPtr& operation : graph_info_->operations) { // Skip the operations which are fused into another operation.