Skip to content

Commit a569e0c

Browse files
committed
fix: move more parts to sklearnex
1 parent 3f917d4 commit a569e0c

File tree

4 files changed

+50
-27
lines changed

4 files changed

+50
-27
lines changed

sklearnex/neighbors/common.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,32 @@ def _compute_weights(self, distances, weights_param):
250250
"'distance', or a callable function"
251251
)
252252

253+
def _process_self_neighbors(self, results, X, n_neighbors, return_distance):
254+
"""Process results to remove self-neighbors for training data queries."""
255+
256+
if return_distance:
257+
neigh_dist, neigh_ind = results
258+
else:
259+
neigh_ind = results
260+
261+
n_queries = X.shape[0] if X is not None else self._fit_X.shape[0]
262+
sample_range = np.arange(n_queries)[:, None]
263+
sample_mask = neigh_ind != sample_range
264+
265+
# Corner case: When the number of duplicates are more
266+
# than the number of neighbors, the first NN will not
267+
# be the sample, but a duplicate.
268+
# In that case mask the first duplicate.
269+
dup_gr_nbrs = np.all(sample_mask, axis=1)
270+
sample_mask[:, 0][dup_gr_nbrs] = False
271+
272+
neigh_ind = np.reshape(neigh_ind[sample_mask], (n_queries, n_neighbors))
273+
274+
if return_distance:
275+
neigh_dist = np.reshape(neigh_dist[sample_mask], (n_queries, n_neighbors))
276+
return neigh_dist, neigh_ind
277+
return neigh_ind
278+
253279
def _validate_feature_count(self, X, method_name=""):
254280
n_features = getattr(self, "n_features_in_", None)
255281
shape = getattr(X, "shape", None)

sklearnex/neighbors/knn_classification.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,14 +175,13 @@ def score(self, X, y, sample_weight=None):
175175
def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
176176
check_is_fitted(self)
177177

178-
# Handle X=None case at sklearnex level
179-
if X is None:
180-
query_is_train = True
178+
# Handle X=None case and determine if we need self-neighbor processing at sklearnex level
179+
query_is_train = X is None
180+
if query_is_train:
181181
X = self._fit_X
182-
# Include an extra neighbor to account for the sample itself being returned
182+
# For training data queries, we need to get extra neighbors to remove self-matches
183183
effective_n_neighbors = (n_neighbors if n_neighbors is not None else self.n_neighbors) + 1
184184
else:
185-
query_is_train = False
186185
effective_n_neighbors = n_neighbors if n_neighbors is not None else self.n_neighbors
187186
check_feature_names(self, X, reset=False)
188187
# Perform preprocessing at sklearnex level
@@ -193,7 +192,7 @@ def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
193192
# Validate kneighbors parameters
194193
self._validate_kneighbors_params(n_neighbors, X)
195194

196-
# Call oneDAL with the effective n_neighbors
195+
# Call oneDAL with the effective n_neighbors (all preprocessing done)
197196
result = dispatch(
198197
self,
199198
"kneighbors",
@@ -204,14 +203,14 @@ def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
204203
"sklearn": _sklearn_KNeighborsClassifier.kneighbors,
205204
},
206205
X,
207-
n_neighbors=n_neighbors,
206+
n_neighbors=effective_n_neighbors, # Pass effective_n_neighbors to sklearn too
208207
return_distance=return_distance,
209208
)
210209

211-
# Process results at sklearnex level
210+
# If query was training data, post-process to remove self-neighbors at sklearnex level
212211
if query_is_train:
213212
final_n_neighbors = n_neighbors if n_neighbors is not None else self.n_neighbors
214-
result = self._process_kneighbors_results(result, X, final_n_neighbors, return_distance, query_is_train)
213+
result = self._process_self_neighbors(result, X, final_n_neighbors, return_distance)
215214

216215
return result
217216

sklearnex/neighbors/knn_regression.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,13 @@ def score(self, X, y, sample_weight=None):
153153
def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
154154
check_is_fitted(self)
155155

156-
# Handle X=None case at sklearnex level
157-
if X is None:
158-
query_is_train = True
156+
# Handle X=None case and determine if we need self-neighbor processing at sklearnex level
157+
query_is_train = X is None
158+
if query_is_train:
159159
X = self._fit_X
160-
# Include an extra neighbor to account for the sample itself being returned
160+
# For training data queries, we need to get extra neighbors to remove self-matches
161161
effective_n_neighbors = (n_neighbors if n_neighbors is not None else self.n_neighbors) + 1
162162
else:
163-
query_is_train = False
164163
effective_n_neighbors = n_neighbors if n_neighbors is not None else self.n_neighbors
165164
check_feature_names(self, X, reset=False)
166165
# Perform preprocessing at sklearnex level
@@ -171,7 +170,7 @@ def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
171170
# Validate kneighbors parameters
172171
self._validate_kneighbors_params(n_neighbors, X)
173172

174-
# Call oneDAL with the effective n_neighbors
173+
# Call oneDAL with the effective n_neighbors (all preprocessing done)
175174
result = dispatch(
176175
self,
177176
"kneighbors",
@@ -182,14 +181,14 @@ def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
182181
"sklearn": _sklearn_KNeighborsRegressor.kneighbors,
183182
},
184183
X,
185-
n_neighbors=n_neighbors,
184+
n_neighbors=effective_n_neighbors, # Pass effective_n_neighbors to sklearn too
186185
return_distance=return_distance,
187186
)
188187

189-
# Process results at sklearnex level
188+
# If query was training data, post-process to remove self-neighbors at sklearnex level
190189
if query_is_train:
191190
final_n_neighbors = n_neighbors if n_neighbors is not None else self.n_neighbors
192-
result = self._process_kneighbors_results(result, X, final_n_neighbors, return_distance, query_is_train)
191+
result = self._process_self_neighbors(result, X, final_n_neighbors, return_distance)
193192

194193
return result
195194

sklearnex/neighbors/knn_unsupervised.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,13 @@ def fit(self, X, y=None):
7676
def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
7777
check_is_fitted(self)
7878

79-
# Handle X=None case at sklearnex level
80-
if X is None:
81-
query_is_train = True
79+
# Handle X=None case and determine if we need self-neighbor processing at sklearnex level
80+
query_is_train = X is None
81+
if query_is_train:
8282
X = self._fit_X
83-
# Include an extra neighbor to account for the sample itself being returned
83+
# For training data queries, we need to get extra neighbors to remove self-matches
8484
effective_n_neighbors = (n_neighbors if n_neighbors is not None else self.n_neighbors) + 1
8585
else:
86-
query_is_train = False
8786
effective_n_neighbors = n_neighbors if n_neighbors is not None else self.n_neighbors
8887
check_feature_names(self, X, reset=False)
8988
# Perform preprocessing at sklearnex level
@@ -94,7 +93,7 @@ def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
9493
# Validate kneighbors parameters
9594
self._validate_kneighbors_params(n_neighbors, X)
9695

97-
# Call oneDAL with the effective n_neighbors
96+
# Call oneDAL with the effective n_neighbors (all preprocessing done)
9897
result = dispatch(
9998
self,
10099
"kneighbors",
@@ -105,14 +104,14 @@ def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
105104
"sklearn": _sklearn_NearestNeighbors.kneighbors,
106105
},
107106
X,
108-
n_neighbors=n_neighbors,
107+
n_neighbors=effective_n_neighbors, # Pass effective_n_neighbors to sklearn too
109108
return_distance=return_distance,
110109
)
111110

112-
# Process results at sklearnex level
111+
# If query was training data, post-process to remove self-neighbors at sklearnex level
113112
if query_is_train:
114113
final_n_neighbors = n_neighbors if n_neighbors is not None else self.n_neighbors
115-
result = self._process_kneighbors_results(result, X, final_n_neighbors, return_distance, query_is_train)
114+
result = self._process_self_neighbors(result, X, final_n_neighbors, return_distance)
116115

117116
return result
118117

0 commit comments

Comments
 (0)