Skip to content

Commit f8a02e2

Browse files
committed
reimplement setitem_long_long
1 parent dcee323 commit f8a02e2

File tree

2 files changed

+55
-45
lines changed

2 files changed

+55
-45
lines changed

pypy/module/mamba/helper_funcs.py

Lines changed: 50 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -381,72 +381,77 @@ def setitem_long_long_helper( value, other, start, stop ):
381381
return setitem_long_int_helper( value, other.digit(0), start, stop )
382382

383383
if other.sign < 0:
384-
slice_nbits = stop - start
385-
other = other.and_( get_long_mask(slice_nbits) )
384+
other = other.and_( get_long_mask(stop-start) )
386385
if other.numdigits() == 1:
387386
return setitem_long_int_helper( value, other.digit(0), start, stop )
388387

389-
# After the two above checks, we have made sure other has more than one
390-
# digit and wordstart must < wordstop
391388
vsize = value.numdigits()
392389
other = other.lshift( start ) # lshift first to align two rbigints
393390
osize = other.numdigits()
394391

395-
# Now other must be long, wordstart must < wordstop
392+
# After the two above checks, we have made sure other has more than one digit
393+
# assert osize >= 2
394+
# assert wordstart < wordstop
395+
396+
# Also, the caller must have already checked if bitwidth exceeds the slice
397+
# assert wordstart <= osize - 1 <= wordstop
398+
396399
wordstart = start / SHIFT
397400

398401
# 1. vsize <= wordstart < wordstop, concatenate
399402
if vsize <= wordstart:
400403
return rbigint(value._digits[:vsize] + other._digits[vsize:], 1, osize )
401404

402-
wordstop = stop / SHIFT
405+
bitstart = start - wordstart*SHIFT
403406

404-
# 2. wordstart < wordstop < vsize
405-
if wordstop < vsize:
406-
ret = rbigint( value._digits[:vsize], 1, vsize )
407+
wordstop = stop / SHIFT
407408

408-
# do start
409-
bitstart = start - wordstart*SHIFT
410-
tmpstart = other.digit( wordstart ) | (ret.digit(wordstart) & get_int_mask(bitstart))
411-
# if bitstart:
412-
# tmpstart |= ret.digit(wordstart) & get_int_mask(bitstart) # lo
413-
ret.setdigit( wordstart, tmpstart )
414-
415-
i = wordstart+1
416-
417-
# wordstart < osize <= wordstop < vsize
418-
if osize <= wordstop:
419-
while i < osize:
420-
ret.setdigit( i, other.digit(i) )
421-
i += 1
422-
while i < wordstop:
423-
ret._digits[i] = NULLDIGIT
424-
i += 1
425-
# wordstart < wordstop < osize < vsize
426-
else:
427-
while i < wordstop:
428-
ret.setdigit( i, other.digit(i) )
429-
i += 1
409+
# 2. wordstart < vsize <= wordstop, merge wordstart and concatenate
410+
if vsize <= wordstop:
411+
assert wordstart >= 0
412+
ret = rbigint( value._digits[:wordstart] + \
413+
other._digits[wordstart:osize], 1, osize )
430414

431-
# do stop
432-
bitstop = stop - wordstop*SHIFT
433-
if bitstop:
434-
masked_val = ret.digit(wordstop) & ~get_int_mask(bitstop) #hi
435-
ret.setdigit( wordstop, other.digit(wordstop) | masked_val ) # lo|hi
415+
# union start, there is no value in lower bits of other.digit(wordstart)
416+
if bitstart:
417+
value_lo = value.digit(wordstart) & get_int_mask(bitstart) # lo
418+
ret.setdigit( wordstart, value_lo | ret.digit(wordstart) ) # lo | hi
436419

437420
return ret
438421

439-
assert wordstart >= 0
440-
# wordstart < vsize <= wordstop
441-
ret = rbigint( value._digits[:wordstart] + \
442-
other._digits[wordstart:osize], 1, osize )
422+
# 3. wordstart < wordstop < vsize, handle both sides
423+
ret = rbigint( value._digits[:], 1, vsize )
443424

444-
# do start
445-
bitstart = start - wordstart*SHIFT
446-
if bitstart:
447-
masked_val = value.digit(wordstart) & get_int_mask(bitstart) # lo
448-
ret.setdigit( wordstart, masked_val | ret.digit(wordstart) ) # lo | hi
425+
# union start, there is no value in lower bits of other.digit(wordstart)
426+
value_lo = ret.digit(wordstart) & get_int_mask(bitstart) # lo
427+
ret.setdigit( wordstart, value_lo | other.digit( wordstart ) ) # lo | hi
428+
429+
# put other into
430+
i = wordstart + 1
431+
432+
inv_maskstop = ~get_int_mask( stop - wordstop*SHIFT )
433+
# wordstop == osize - 1 means other's last word is wordstop
434+
if wordstop == osize - 1:
435+
while i < wordstop:
436+
ret.setdigit(i, other.digit(i) )
437+
i += 1
438+
# union stop
439+
value_hi = ret.digit(wordstop) & inv_maskstop # hi
440+
ret.setdigit( wordstop, other.digit(wordstop) | value_hi ) # lo|hi
449441

442+
# wordstop > osize - 1, other is shorter
443+
else:
444+
while i < osize:
445+
ret.setdigit(i, other.digit(i) )
446+
i += 1
447+
while i < wordstop:
448+
ret._digits[i] = NULLDIGIT
449+
i += 1
450+
451+
# clear stop
452+
ret.setdigit( wordstop, ret.digit(wordstop) & inv_maskstop )
453+
454+
ret._normalize()
450455
return ret
451456

452457
@jit.elidable

pypy/module/mamba/test/test_bits.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ def make_long(x): return x + 2 ** 100 - 2 ** 100
158158
with raises(ValueError):
159159
b[0:80] = mamba.Bits(100, 0)
160160

161+
a = mamba.Bits(146, 0x00000c26283b3f1402373002002b700000293)
162+
a[0:128] = mamba.Bits(128, 0x00001317134282930000129740710133)
163+
print(a, a == mamba.Bits(146, 0x0000000001317134282930000129740710133), repr(a) == repr(mamba.Bits(146, 0x0000000001317134282930000129740710133)) )
164+
assert a == mamba.Bits(146, 0x0000000001317134282930000129740710133)
165+
161166
def test_setitem_crash(self):
162167
from mamba import Bits
163168
input = Bits(465, 0x00095700000000000000003f950000000000000000000000000000000000000000000000000000000000000000000000000000000000000000f5d )

0 commit comments

Comments
 (0)