Skip to content

Commit 6d4192d

Browse files
authored
Merge pull request #6 from kalcohol/main-fix-sample
fix classification
2 parents 916d2d9 + 2ba26e2 commit 6d4192d

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

examples/classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def preprocess_image(image_path, target_size=(256, 256), crop_size=(224, 224)):
5252

5353
def get_top_k_predictions(output, k=5):
5454
# Get top k predictions
55-
top_k_indices = np.argsort(output[0])[-k:][::-1]
56-
top_k_scores = output[0][top_k_indices]
55+
top_k_indices = np.argsort(output[0].flatten())[-k:][::-1]
56+
top_k_scores = output[0].flatten()[top_k_indices]
5757
return top_k_indices, top_k_scores
5858

5959

0 commit comments

Comments
 (0)