Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions export_sift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import argparse
from typing import List

import torch

from lightglue_onnx import DISK, LightGlue, LightGlueEnd2End, SuperPoint, SIFT
from lightglue_onnx.end2end import normalize_keypoints
from lightglue_onnx.utils import load_image, rgb_to_grayscale


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--img_size",
nargs="+",
type=int,
default=512,
required=False,
help="Sample image size for ONNX tracing. If a single integer is given, resize the longer side of the image to this value. Otherwise, please provide two integers (height width).",
)
parser.add_argument(
"--extractor_type",
type=str,
default="sift",
choices=["superpoint", "disk", "sift"],
required=False,
help="Type of feature extractor. Supported extractors are 'superpoint' and 'disk'. Defaults to 'superpoint'.",
)
parser.add_argument(
"--extractor_path",
type=str,
default=None,
required=False,
help="Path to save the feature extractor ONNX model.",
)
parser.add_argument(
"--lightglue_path",
type=str,
default=None,
required=False,
help="Path to save the LightGlue ONNX model.",
)
parser.add_argument(
"--end2end",
action="store_true",
help="Whether to export an end-to-end pipeline instead of individual models.",
)
parser.add_argument(
"--dynamic", action="store_true", help="Whether to allow dynamic image sizes."
)

# Extractor-specific args:
parser.add_argument(
"--max_num_keypoints",
type=int,
default=None,
required=False,
help="Maximum number of keypoints outputted by the extractor.",
)

return parser.parse_args()


def export_onnx(
img_size=512,
extractor_type="superpoint",
extractor_path=None,
lightglue_path=None,
img0_path="assets/sacre_coeur1.jpg",
img1_path="assets/sacre_coeur2.jpg",
end2end=False,
dynamic=False,
max_num_keypoints=128,#max_num_keypoints
):
# Handle args
if isinstance(img_size, List) and len(img_size) == 1:
img_size = img_size[0]

if extractor_path is not None and end2end:
raise ValueError(
"Extractor will be combined with LightGlue when exporting end-to-end model."
)
if extractor_path is None:
extractor_path = f"weights/{extractor_type}.onnx"
if max_num_keypoints is not None:
extractor_path = extractor_path.replace(
".onnx", f"_{max_num_keypoints}.onnx"
)

if lightglue_path is None:
lightglue_path = (
f"weights/{extractor_type}_lightglue"
f"{'_end2end' if end2end else ''}"
".onnx"
)

# Sample images for tracing
image0, scales0 = load_image(img0_path, resize=img_size)
image1, scales1 = load_image(img1_path, resize=img_size)
# Models
extractor_type = extractor_type.lower()
if extractor_type == "superpoint":
# SuperPoint works on grayscale images.
image0 = rgb_to_grayscale(image0)
image1 = rgb_to_grayscale(image1)
extractor = SuperPoint(max_num_keypoints=max_num_keypoints).eval()
lightglue = LightGlue(extractor_type).eval()
elif extractor_type == "disk":
extractor = DISK(max_num_keypoints=max_num_keypoints).eval()
lightglue = LightGlue(extractor_type).eval()
elif extractor_type == "sift":
extractor = SIFT(max_num_keypoints=128).eval()
lightglue = LightGlue(extractor_type).eval()
else:
raise NotImplementedError(
f"LightGlue has not been trained on {extractor_type} features."
)

# ONNX Export
if end2end:
pipeline = LightGlueEnd2End(extractor, lightglue).eval()

dynamic_axes = {
"kpts0": {1: "num_keypoints0"},
"kpts1": {1: "num_keypoints1"},
"matches0": {0: "num_matches0"},
"mscores0": {0: "num_matches0"},
}
if dynamic:
dynamic_axes.update(
{
"image0": {2: "height0", 3: "width0"},
"image1": {2: "height1", 3: "width1"},
}
)

torch.onnx.export(
pipeline,
(image0[None], image1[None]),
lightglue_path,
input_names=["image0", "image1"],
output_names=[
"kpts0",
"kpts1",
"matches0",
"mscores0",
],
opset_version=17,
dynamic_axes=dynamic_axes,
)
else:
# Export Extractor
dynamic_axes = {
"keypoints": {1: "num_keypoints"},
"scores": {1: "num_keypoints"},
"descriptors": {1: "num_keypoints"},
}
if dynamic:
dynamic_axes.update({"image": {2: "height", 3: "width"}})
else:
print(
f"WARNING: Exporting without --dynamic implies that the {extractor_type} extractor's input image size will be locked to {image0.shape[-2:]}"
)
extractor_path = extractor_path.replace(
".onnx", f"_{image0.shape[-2]}x{image0.shape[-1]}.onnx"
)

'''torch.onnx.export(
extractor,
image0[None],
extractor_path,
input_names=["image"],
output_names=["keypoints", "scores", "descriptors"],
opset_version=17,
dynamic_axes=dynamic_axes,
)'''
feats0 = extractor.extract(image0[None])
feats1 = extractor.extract(image1[None])

# Export LightGlue
#feats0, feats1 = extractor(image0[None]), extractor(image1[None])
kpts0, scores0, desc0 = feats0["keypoints"],feats0["keypoint_scores"],feats0["descriptors"]
kpts1, scores1, desc1 = feats1["keypoints"],feats1["keypoint_scores"],feats1["descriptors"]
kpts0 = normalize_keypoints(kpts0, image0.shape[1], image0.shape[2])
kpts1 = normalize_keypoints(kpts1, image1.shape[1], image1.shape[2])
kpts0 = torch.cat(
[kpts0] + [feats0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
)
kpts1 = torch.cat(
[kpts1] + [feats1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
)


torch.onnx.export(
lightglue,
(kpts0, kpts1, desc0, desc1),
lightglue_path,
input_names=["kpts0", "kpts1", "desc0", "desc1"],
output_names=["matches0", "mscores0"],
opset_version=17,
dynamic_axes={
"kpts0": {1: "num_keypoints0"},
"kpts1": {1: "num_keypoints1"},
"desc0": {1: "num_keypoints0"},
"desc1": {1: "num_keypoints1"},
"matches0": {0: "num_matches0"},
"mscores0": {0: "num_matches0"},
},
)


if __name__ == "__main__":
args = parse_args()
export_onnx(**vars(args))
144 changes: 144 additions & 0 deletions infer_sift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import argparse
from typing import List
import time
from onnx_runner import LightGlueRunner, load_image, rgb_to_grayscale, viz2d,LightGlueRunner_SIFT, load_image_sift


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--img_paths",
nargs=2,
default=["assets/DSC_0410.JPG", "assets/DSC_0411.JPG"],
required=False,
type=str,
)
parser.add_argument(
"--lightglue_path",
type=str,
default="weights/sift_lightglue_fused_cpu.onnx",
required=False,
help="Path to the LightGlue ONNX model or end-to-end LightGlue pipeline.",
)
parser.add_argument(
"--extractor_type",
type=str,
default="sift",
choices=["superpoint", "disk","sift"],
required=False,
help="Type of feature extractor. Supported extractors are 'superpoint' and 'disk'.",
)
parser.add_argument(
"--extractor_path",
type=str,
default=" ",
required=False,
help="Path to the feature extractor ONNX model. If this argument is not provided, it is assumed that lightglue_path refers to an end-to-end model.",
)
parser.add_argument(
"--img_size",
nargs="+",
type=int,
default=512,
required=False,
help="Sample image size for ONNX tracing. If a single integer is given, resize the longer side of the images to this value. Otherwise, please provide two integers (height width) to resize both images to this size, or four integers (height width height width).",
)
parser.add_argument(
"--trt",
action="store_true",
help="Whether to use TensorRT (experimental).",
)
parser.add_argument(
"--viz", action="store_true", default=True,help="Whether to visualize the results."
)
return parser.parse_args()


def infer(
img_paths: List[str],
lightglue_path: str,
extractor_type: str,
extractor_path=None,
img_size=512,
trt=False,
viz=False,
):
# Handle args
img0_path = img_paths[0]
img1_path = img_paths[1]
if isinstance(img_size, List):
if len(img_size) == 1:
size0 = size1 = img_size[0]
elif len(img_size) == 2:
size0 = size1 = img_size
elif len(img_size) == 4:
size0, size1 = img_size[:2], img_size[2:]
else:
raise ValueError("Invalid img_size. Please provide 1, 2, or 4 integers.")
else:
size0 = size1 = img_size

image0, scales0 = load_image(img0_path, resize=size0)
image1, scales1 = load_image(img1_path, resize=size1)

extractor_type = extractor_type.lower()
if extractor_type == "superpoint":
image0 = rgb_to_grayscale(image0)
image1 = rgb_to_grayscale(image1)
elif extractor_type == "disk":
pass
elif extractor_type=="sift":
image0, scales0 = load_image_sift(img0_path, resize=size0)
image1, scales1 = load_image_sift(img1_path, resize=size1)
else:
raise NotImplementedError(
f"Unsupported feature extractor type: {extractor_type}."
)

# Load ONNX models
providers = ["CPUExecutionProvider"] #"CUDAExecutionProvider",
if trt:
providers = [
(
"TensorrtExecutionProvider",
{
"trt_fp16_enable": True,
"trt_engine_cache_enable": True,
"trt_engine_cache_path": "weights/cache",
},
)
] + providers

runner = LightGlueRunner_SIFT(
extractor_path=extractor_path,
lightglue_path=lightglue_path,
providers=providers,
)

# Run inference
start_time = time.time()

m_kpts0, m_kpts1 = runner.run(image0, image1, scales0, scales1)

end_time = time.time()
elapsed_time = end_time - start_time
print(f"程序运行时间: {elapsed_time} 秒")

# Visualisation
if viz:
orig_image0, _ = load_image(img0_path)
orig_image1, _ = load_image(img1_path)
viz2d.plot_images(
[orig_image0[0].transpose(1, 2, 0), orig_image1[0].transpose(1, 2, 0)]
)

viz2d.plot_matches(m_kpts0, m_kpts1, color="lime", lw=0.2)
viz2d.plt.show()

return m_kpts0, m_kpts1


if __name__ == "__main__":
args = parse_args()
m_kpts0, m_kpts1 = infer(**vars(args))
#print(m_kpts0, m_kpts1)
7 changes: 4 additions & 3 deletions lightglue/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .lightglue import LightGlue
from .superpoint import SuperPoint
from .disk import DISK
from .lightglue import LightGlue
from .superpoint import SuperPoint
from .disk import DISK
from .sift import SIFT # noqa
from .utils import match_pair
Loading