Skip to content

Commit 383cec5

Browse files
committed
Fix Numpy typing
1 parent 3a81a76 commit 383cec5

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

bench_runner/hpt.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,20 +178,20 @@ def prepare_one_row(
178178
rank, rep = get_rank(por_x)
179179
wl = get_ranksum(rank[:n], rep[:n])
180180
wr = get_ranksum(rank[n:], rep[n:])
181-
ml = np.median(por_x[:n])
182-
mr = np.median(por_x[n:])
181+
ml = np.float64(np.median(por_x[:n]))
182+
mr = np.float64(np.median(por_x[n:]))
183183

184184
return wl, wr, ml, mr
185185

186186

187-
def unibench(ub_x: NDArray[np.float64], alpha: float) -> np.float64 | None:
187+
def unibench(ub_x: NDArray[np.float64], alpha: float) -> np.float64:
188188
wl, _, ml, mr = prepare_one_row(ub_x)
189189
target = float(wl)
190190

191191
rst_lower, rst_upper = ranksum_table(len(ub_x) // 2, alpha)
192192
if target <= rst_lower or target >= rst_upper:
193193
return np.subtract(ml, mr)
194-
return None
194+
return np.float64(np.nan)
195195

196196

197197
def crossbench(cb_x: NDArray[np.float64]) -> tuple[float, float, float]:
@@ -230,7 +230,7 @@ def hpt_basic(
230230
meddiff = np.zeros((len(mtx_a),), float)
231231

232232
for i, bm in enumerate(mtx_a.keys()):
233-
hpt_x = np.hstack((multi * mtx_a[bm], mtx_b[bm]))
233+
hpt_x = np.hstack((multi * mtx_a[bm], mtx_b[bm]), dtype=np.float64)
234234
meddiff[i] = unibench(hpt_x, alpha)
235235

236236
return crossbench(meddiff)

0 commit comments

Comments
 (0)