-
Notifications
You must be signed in to change notification settings - Fork 2
Move the graph fusion before AddInitial() for constant operands #238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: ort_backend
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -669,7 +669,11 @@ struct GraphFusionInfo { | |||||
// on that the `operations` in `mojom::GraphInfo` have been in topological | ||||||
// order which means if operation 'j' depends on 'i', 'i' must appear before | ||||||
// 'j'. | ||||||
GraphFusionInfo GetGraphFusionInfo(const mojom::GraphInfo& graph_info) { | ||||||
GraphFusionInfo GetGraphFusionInfo( | ||||||
const mojom::GraphInfo& graph_info, | ||||||
base::flat_map<uint64_t, std::unique_ptr<WebNNConstantOperand>>& | ||||||
constant_operands, | ||||||
const ContextProperties& context_properties) { | ||||||
// If it's disabled, just return empty 'GraphFusionInfo' object which means no | ||||||
// graph fusion will be applied since currently we only enable matmulnbits | ||||||
// fusion. | ||||||
|
@@ -869,6 +873,35 @@ GraphFusionInfo GetGraphFusionInfo(const mojom::GraphInfo& graph_info) { | |||||
graph_fusion_info.matmul_input_b_to_fusible_dequantize_map | ||||||
[transpose->get_transpose()->output_operand_id] = operation.get(); | ||||||
|
||||||
const std::vector<uint32_t>& input_shape = | ||||||
input_operand->descriptor.shape(); | ||||||
uint32_t input_feature_size = input_shape[0]; | ||||||
uint32_t quant_num = input_shape[1]; | ||||||
uint32_t blob_bytes = input_shape[2] / 2; | ||||||
auto input_constant = std::move( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||||||
constant_operands.at(dequantize_linear->input_operand_id)); | ||||||
std::vector<uint32_t> new_input_buffer_shape = {input_feature_size, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move this line right after declaring |
||||||
quant_num, blob_bytes}; | ||||||
auto new_input_desc = *OperandDescriptor::Create( | ||||||
context_properties, OperandDataType::kUint8, new_input_buffer_shape, | ||||||
dequantize_linear->label); | ||||||
auto new_input_constant = std::make_unique<WebNNConstantOperand>( | ||||||
std::move(new_input_desc), input_constant->TakeData()); | ||||||
constant_operands[dequantize_linear->input_operand_id] = | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You may want to check this constant is only used by dequantizeLinear before replacing it. |
||||||
std::move(new_input_constant); | ||||||
|
||||||
auto zero_point_constant = std::move( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, |
||||||
constant_operands.at(dequantize_linear->zero_point_operand_id)); | ||||||
std::vector<uint32_t> new_zero_point_buffer_shape = { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
input_feature_size * ((quant_num + 1) / 2)}; | ||||||
auto new_zero_point_desc = *OperandDescriptor::Create( | ||||||
context_properties, OperandDataType::kUint8, | ||||||
new_zero_point_buffer_shape, dequantize_linear->label); | ||||||
|
||||||
auto new_zero_point_constant = std::make_unique<WebNNConstantOperand>( | ||||||
std::move(new_zero_point_desc), zero_point_constant->TakeData()); | ||||||
constant_operands[dequantize_linear->zero_point_operand_id] = | ||||||
std::move(new_zero_point_constant); | ||||||
break; | ||||||
} | ||||||
default: { | ||||||
|
@@ -3552,28 +3585,13 @@ GraphBuilderOrt::AddMatMulOperation( | |||||
|
||||||
uint32_t input_feature_size = input_b_shape[0]; | ||||||
uint32_t quant_num = input_b_shape[1]; | ||||||
uint32_t blob_bytes = input_b_shape[2] / 2; | ||||||
// uint32_t blob_bytes = input_b_shape[2] / 2; | ||||||
uint32_t block_size = input_b_shape[2]; | ||||||
uint32_t output_feature_size = quant_num * block_size; | ||||||
|
||||||
const WebNNConstantOperand& input_constant = | ||||||
*constant_operands_.at(dequantize_linear->input_operand_id); | ||||||
std::vector<uint32_t> new_input_buffer_shape = {input_feature_size, | ||||||
quant_num, blob_bytes}; | ||||||
|
||||||
ASSIGN_OR_RETURN(input_b, | ||||||
CreateInitializer<uint8_t>(new_input_buffer_shape, | ||||||
input_constant.ByteSpan())); | ||||||
|
||||||
const WebNNConstantOperand& zero_point_constant = | ||||||
*constant_operands_.at(dequantize_linear->zero_point_operand_id); | ||||||
std::vector<uint32_t> new_zero_point_buffer_shape = {input_feature_size * | ||||||
((quant_num + 1) / 2)}; | ||||||
ASSIGN_OR_RETURN( | ||||||
std::string zero_point, | ||||||
CreateInitializer<uint8_t>(new_zero_point_buffer_shape, | ||||||
zero_point_constant.ByteSpan())); | ||||||
|
||||||
input_b = GetOperandNameById(dequantize_linear->input_operand_id); | ||||||
std::string zero_point = | ||||||
GetOperandNameById(dequantize_linear->zero_point_operand_id); | ||||||
std::string scale = GetOperandNameById(dequantize_linear->scale_operand_id); | ||||||
// Here we insert a reshape since the original reshape has been folded into | ||||||
// scale due to constant folding. | ||||||
|
@@ -4236,6 +4254,9 @@ GraphBuilderOrt::BuildModel() { | |||||
AddInput(input_id); | ||||||
} | ||||||
|
||||||
GraphFusionInfo graph_fusion_info = | ||||||
GetGraphFusionInfo(*graph_info_, constant_operands_, context_properties_); | ||||||
|
||||||
// Add initializers. | ||||||
for (const auto& [constant_id, _] : constant_operands_) { | ||||||
RETURN_IF_ERROR(AddInitializer(constant_id)); | ||||||
|
@@ -4245,8 +4266,6 @@ GraphBuilderOrt::BuildModel() { | |||||
// Find all the bool operands. | ||||||
FindBoolOperands(); | ||||||
|
||||||
GraphFusionInfo graph_fusion_info = GetGraphFusionInfo(*graph_info_); | ||||||
|
||||||
// Add operations. | ||||||
for (const mojom::OperationPtr& operation : graph_info_->operations) { | ||||||
// Skip the operations which are fused into another operation. | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth adding a comment explaining why it needs to change the data type.