diff --git a/src/libraries/System.Collections/src/System/Collections/BitArray.cs b/src/libraries/System.Collections/src/System/Collections/BitArray.cs index fa515728e085d4..cf9712cfdfa134 100644 --- a/src/libraries/System.Collections/src/System/Collections/BitArray.cs +++ b/src/libraries/System.Collections/src/System/Collections/BitArray.cs @@ -6,7 +6,6 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Runtime.Intrinsics; -using System.Runtime.Intrinsics.Arm; using System.Runtime.Intrinsics.X86; namespace System.Collections @@ -85,31 +84,38 @@ public BitArray(byte[] bytes) m_array = new int[GetInt32ArrayLengthFromByteLength(bytes.Length)]; m_length = bytes.Length * BitsPerByte; - uint totalCount = (uint)bytes.Length / 4; - - ReadOnlySpan byteSpan = bytes; - for (int i = 0; i < totalCount; i++) + if (BitConverter.IsLittleEndian) { - m_array[i] = BinaryPrimitives.ReadInt32LittleEndian(byteSpan); - byteSpan = byteSpan.Slice(4); + bytes.CopyTo(MemoryMarshal.AsBytes(m_array.AsSpan())); } + else + { + int totalCount = bytes.Length / 4; - Debug.Assert(byteSpan.Length >= 0 && byteSpan.Length < 4); + ReadOnlySpan byteSpan = bytes; + for (int i = 0; i < totalCount; i++) + { + m_array[i] = BinaryPrimitives.ReadInt32LittleEndian(byteSpan); + byteSpan = byteSpan.Slice(4); + } - int last = 0; - switch (byteSpan.Length) - { - case 3: - last = byteSpan[2] << 16; - goto case 2; - // fall through - case 2: - last |= byteSpan[1] << 8; - goto case 1; - // fall through - case 1: - m_array[totalCount] = last | byteSpan[0]; - break; + Debug.Assert(byteSpan.Length >= 0 && byteSpan.Length < 4); + + int last = 0; + switch (byteSpan.Length) + { + case 3: + last = byteSpan[2] << 16; + goto case 2; + // fall through + case 2: + last |= byteSpan[1] << 8; + goto case 1; + // fall through + case 1: + m_array[totalCount] = last | byteSpan[0]; + break; + } } _version = 0; @@ -122,12 +128,7 @@ public BitArray(bool[] values) m_array = new int[GetInt32ArrayLengthFromBitLength(values.Length)]; m_length = values.Length; - uint i = 0; - - if (values.Length < Vector256.Count) - { - goto LessThan32; - } + int i = 0; // Comparing with 1s would get rid of the final negation, however this would not work for some CLR bools // (true for any non-zero values, false for 0) - any values between 2-255 will be interpreted as false. @@ -136,49 +137,47 @@ public BitArray(bool[] values) ref byte value = ref Unsafe.As(ref MemoryMarshal.GetArrayDataReference(values)); if (Vector512.IsHardwareAccelerated) { - for (; i <= (uint)values.Length - Vector512.Count; i += (uint)Vector512.Count) + for (; i <= values.Length - Vector512.Count; i += Vector512.Count) { - Vector512 vector = Vector512.LoadUnsafe(ref value, i); + Vector512 vector = Vector512.LoadUnsafe(ref value, (uint)i); Vector512 isFalse = Vector512.Equals(vector, Vector512.Zero); - ulong result = isFalse.ExtractMostSignificantBits(); - m_array[i / 32u] = (int)(~result & 0x00000000FFFFFFFF); - m_array[(i / 32u) + 1] = (int)((~result >> 32) & 0x00000000FFFFFFFF); + ulong result = ~(isFalse.ExtractMostSignificantBits()); + m_array[i / 32] = (int)result; + m_array[(i / 32) + 1] = (int)(result >> 32); } } - else if (Vector256.IsHardwareAccelerated) + if (Vector256.IsHardwareAccelerated) { - for (; i <= (uint)values.Length - Vector256.Count; i += (uint)Vector256.Count) + for (; i <= values.Length - Vector256.Count; i += Vector256.Count) { - Vector256 vector = Vector256.LoadUnsafe(ref value, i); + Vector256 vector = Vector256.LoadUnsafe(ref value, (uint)i); Vector256 isFalse = Vector256.Equals(vector, Vector256.Zero); - uint result = isFalse.ExtractMostSignificantBits(); - m_array[i / 32u] = (int)(~result); + m_array[i / 32] = (int)~(isFalse.ExtractMostSignificantBits()); } } else if (Vector128.IsHardwareAccelerated) { - for (; i <= (uint)values.Length - Vector128.Count * 2u; i += (uint)Vector128.Count * 2u) + for (; i <= values.Length - Vector128.Count * 2; i += Vector128.Count * 2) { - Vector128 lowerVector = Vector128.LoadUnsafe(ref value, i); + Vector128 lowerVector = Vector128.LoadUnsafe(ref value, (uint)i); Vector128 lowerIsFalse = Vector128.Equals(lowerVector, Vector128.Zero); - uint lowerResult = lowerIsFalse.ExtractMostSignificantBits(); + uint lowerResultNot = lowerIsFalse.ExtractMostSignificantBits(); - Vector128 upperVector = Vector128.LoadUnsafe(ref value, i + (uint)Vector128.Count); + Vector128 upperVector = Vector128.LoadUnsafe(ref value, (uint)(i + Vector128.Count)); Vector128 upperIsFalse = Vector128.Equals(upperVector, Vector128.Zero); - uint upperResult = upperIsFalse.ExtractMostSignificantBits(); + uint upperResultNot = upperIsFalse.ExtractMostSignificantBits(); - m_array[i / 32u] = (int)(~((upperResult << 16) | lowerResult)); + m_array[i / 32] = (int)(~((upperResultNot << 16) | lowerResultNot)); } } - LessThan32: - for (; i < (uint)values.Length; i++) + for (; i < values.Length; i++) { if (values[i]) { - int elementIndex = Div32Rem((int)i, out int extraBits); + int elementIndex = Div32Rem(i, out int extraBits); m_array[elementIndex] |= 1 << extraBits; } } @@ -342,36 +341,36 @@ public BitArray And(BitArray value) case 0: goto Done; } - uint i = 0; + int i = 0; ref int left = ref MemoryMarshal.GetArrayDataReference(thisArray); ref int right = ref MemoryMarshal.GetArrayDataReference(valueArray); - if (Vector512.IsHardwareAccelerated && (uint)count >= Vector512.Count) + if (Vector512.IsHardwareAccelerated) { - for (; i < (uint)count - (Vector512.Count - 1u); i += (uint)Vector512.Count) + for (; i <= count - Vector512.Count; i += Vector512.Count) { - Vector512 result = Vector512.LoadUnsafe(ref left, i) & Vector512.LoadUnsafe(ref right, i); - result.StoreUnsafe(ref left, i); + Vector512 result = Vector512.LoadUnsafe(ref left, (uint)i) & Vector512.LoadUnsafe(ref right, (uint)i); + result.StoreUnsafe(ref left, (uint)i); } } - else if (Vector256.IsHardwareAccelerated && (uint)count >= Vector256.Count) + if (Vector256.IsHardwareAccelerated) { - for (; i < (uint)count - (Vector256.Count - 1u); i += (uint)Vector256.Count) + for (; i <= count - Vector256.Count; i += Vector256.Count) { - Vector256 result = Vector256.LoadUnsafe(ref left, i) & Vector256.LoadUnsafe(ref right, i); - result.StoreUnsafe(ref left, i); + Vector256 result = Vector256.LoadUnsafe(ref left, (uint)i) & Vector256.LoadUnsafe(ref right, (uint)i); + result.StoreUnsafe(ref left, (uint)i); } } - else if (Vector128.IsHardwareAccelerated && (uint)count >= Vector128.Count) + if (Vector128.IsHardwareAccelerated) { - for (; i < (uint)count - (Vector128.Count - 1u); i += (uint)Vector128.Count) + for (; i <= count - Vector128.Count; i += Vector128.Count) { - Vector128 result = Vector128.LoadUnsafe(ref left, i) & Vector128.LoadUnsafe(ref right, i); - result.StoreUnsafe(ref left, i); + Vector128 result = Vector128.LoadUnsafe(ref left, (uint)i) & Vector128.LoadUnsafe(ref right, (uint)i); + result.StoreUnsafe(ref left, (uint)i); } } - for (; i < (uint)count; i++) + for (; i < count; i++) thisArray[i] &= valueArray[i]; Done: @@ -415,36 +414,36 @@ public BitArray Or(BitArray value) case 0: goto Done; } - uint i = 0; + int i = 0; ref int left = ref MemoryMarshal.GetArrayDataReference(thisArray); ref int right = ref MemoryMarshal.GetArrayDataReference(valueArray); - if (Vector512.IsHardwareAccelerated && (uint)count >= Vector512.Count) + if (Vector512.IsHardwareAccelerated) { - for (; i < (uint)count - (Vector512.Count - 1u); i += (uint)Vector512.Count) + for (; i <= count - Vector512.Count; i += Vector512.Count) { - Vector512 result = Vector512.LoadUnsafe(ref left, i) | Vector512.LoadUnsafe(ref right, i); - result.StoreUnsafe(ref left, i); + Vector512 result = Vector512.LoadUnsafe(ref left, (uint)i) | Vector512.LoadUnsafe(ref right, (uint)i); + result.StoreUnsafe(ref left, (uint)i); } } - else if (Vector256.IsHardwareAccelerated && (uint)count >= Vector256.Count) + if (Vector256.IsHardwareAccelerated) { - for (; i < (uint)count - (Vector256.Count - 1u); i += (uint)Vector256.Count) + for (; i <= count - Vector256.Count; i += Vector256.Count) { - Vector256 result = Vector256.LoadUnsafe(ref left, i) | Vector256.LoadUnsafe(ref right, i); - result.StoreUnsafe(ref left, i); + Vector256 result = Vector256.LoadUnsafe(ref left, (uint)i) | Vector256.LoadUnsafe(ref right, (uint)i); + result.StoreUnsafe(ref left, (uint)i); } } - else if (Vector128.IsHardwareAccelerated && (uint)count >= Vector128.Count) + if (Vector128.IsHardwareAccelerated) { - for (; i < (uint)count - (Vector128.Count - 1u); i += (uint)Vector128.Count) + for (; i <= count - Vector128.Count; i += Vector128.Count) { - Vector128 result = Vector128.LoadUnsafe(ref left, i) | Vector128.LoadUnsafe(ref right, i); - result.StoreUnsafe(ref left, i); + Vector128 result = Vector128.LoadUnsafe(ref left, (uint)i) | Vector128.LoadUnsafe(ref right, (uint)i); + result.StoreUnsafe(ref left, (uint)i); } } - for (; i < (uint)count; i++) + for (; i < count; i++) thisArray[i] |= valueArray[i]; Done: @@ -488,37 +487,37 @@ public BitArray Xor(BitArray value) case 0: goto Done; } - uint i = 0; + int i = 0; ref int left = ref MemoryMarshal.GetArrayDataReference(thisArray); ref int right = ref MemoryMarshal.GetArrayDataReference(valueArray); - if (Vector512.IsHardwareAccelerated && (uint)count >= Vector512.Count) + if (Vector512.IsHardwareAccelerated) { - for (; i < (uint)count - (Vector512.Count - 1u); i += (uint)Vector512.Count) + for (; i <= count - Vector512.Count; i += Vector512.Count) { - Vector512 result = Vector512.LoadUnsafe(ref left, i) ^ Vector512.LoadUnsafe(ref right, i); - result.StoreUnsafe(ref left, i); + Vector512 result = Vector512.LoadUnsafe(ref left, (uint)i) ^ Vector512.LoadUnsafe(ref right, (uint)i); + result.StoreUnsafe(ref left, (uint)i); } } - else if (Vector256.IsHardwareAccelerated && (uint)count >= Vector256.Count) + if (Vector256.IsHardwareAccelerated) { - for (; i < (uint)count - (Vector256.Count - 1u); i += (uint)Vector256.Count) + for (; i <= count - Vector256.Count; i += Vector256.Count) { - Vector256 result = Vector256.LoadUnsafe(ref left, i) ^ Vector256.LoadUnsafe(ref right, i); - result.StoreUnsafe(ref left, i); + Vector256 result = Vector256.LoadUnsafe(ref left, (uint)i) ^ Vector256.LoadUnsafe(ref right, (uint)i); + result.StoreUnsafe(ref left, (uint)i); } } - else if (Vector128.IsHardwareAccelerated && (uint)count >= Vector128.Count) + if (Vector128.IsHardwareAccelerated) { - for (; i < (uint)count - (Vector128.Count - 1u); i += (uint)Vector128.Count) + for (; i <= count - Vector128.Count; i += Vector128.Count) { - Vector128 result = Vector128.LoadUnsafe(ref left, i) ^ Vector128.LoadUnsafe(ref right, i); - result.StoreUnsafe(ref left, i); + Vector128 result = Vector128.LoadUnsafe(ref left, (uint)i) ^ Vector128.LoadUnsafe(ref right, (uint)i); + result.StoreUnsafe(ref left, (uint)i); } } - for (; i < (uint)count; i++) + for (; i < count; i++) thisArray[i] ^= valueArray[i]; Done: @@ -541,6 +540,8 @@ public BitArray Not() int[] thisArray = m_array; int count = GetInt32ArrayLengthFromBitLength(Length); + if ((uint)count > (uint)thisArray.Length) + ThrowHelper.ThrowConcurrentOperation(); // Unroll loop for count less than Vector256 size. switch (count) @@ -555,35 +556,35 @@ public BitArray Not() case 0: goto Done; } - uint i = 0; + int i = 0; ref int value = ref MemoryMarshal.GetArrayDataReference(thisArray); - if (Vector512.IsHardwareAccelerated && (uint)count >= Vector512.Count) + if (Vector512.IsHardwareAccelerated) { - for (; i < (uint)count - (Vector512.Count - 1u); i += (uint)Vector512.Count) + for (; i <= count - Vector512.Count; i += Vector512.Count) { - Vector512 result = ~Vector512.LoadUnsafe(ref value, i); - result.StoreUnsafe(ref value, i); + Vector512 result = ~Vector512.LoadUnsafe(ref value, (uint)i); + result.StoreUnsafe(ref value, (uint)i); } } - else if (Vector256.IsHardwareAccelerated && (uint)count >= Vector256.Count) + if (Vector256.IsHardwareAccelerated) { - for (; i < (uint)count - (Vector256.Count - 1u); i += (uint)Vector256.Count) + for (; i <= count - Vector256.Count; i += Vector256.Count) { - Vector256 result = ~Vector256.LoadUnsafe(ref value, i); - result.StoreUnsafe(ref value, i); + Vector256 result = ~Vector256.LoadUnsafe(ref value, (uint)i); + result.StoreUnsafe(ref value, (uint)i); } } - else if (Vector128.IsHardwareAccelerated && (uint)count >= Vector128.Count) + if (Vector128.IsHardwareAccelerated) { - for (; i < (uint)count - (Vector128.Count - 1u); i += (uint)Vector128.Count) + for (; i <= count - Vector128.Count; i += Vector128.Count) { - Vector128 result = ~Vector128.LoadUnsafe(ref value, i); - result.StoreUnsafe(ref value, i); + Vector128 result = ~Vector128.LoadUnsafe(ref value, (uint)i); + result.StoreUnsafe(ref value, (uint)i); } } - for (; i < (uint)count; i++) + for (; i < count; i++) thisArray[i] = ~thisArray[i]; Done: @@ -780,7 +781,7 @@ public unsafe void CopyTo(Array array, int index) } // equivalent to m_length % BitsPerByte, since BitsPerByte is a power of 2 - uint extraBits = (uint)m_length & (BitsPerByte - 1); + int extraBits = m_length & (BitsPerByte - 1); if (extraBits > 0) { // last byte is not aligned, we will directly copy one less byte @@ -790,178 +791,220 @@ public unsafe void CopyTo(Array array, int index) Span span = byteArray.AsSpan(index); int quotient = Div4Rem(arrayLength, out int remainder); - for (int i = 0; i < quotient; i++) + if (BitConverter.IsLittleEndian) { - BinaryPrimitives.WriteInt32LittleEndian(span, m_array[i]); - span = span.Slice(4); + MemoryMarshal.AsBytes(m_array).Slice(0, arrayLength).CopyTo(span); + span = span.Slice(quotient * 4); } - - if (extraBits > 0) + else { - Debug.Assert(span.Length > 0); - Debug.Assert(m_array.Length > quotient); - // mask the final byte - span[remainder] = (byte)((m_array[quotient] >> (remainder * 8)) & ((1 << (int)extraBits) - 1)); + for (int i = 0; i < quotient; i++) + { + BinaryPrimitives.WriteInt32LittleEndian(span, m_array[i]); + span = span.Slice(4); + } + + switch (remainder) + { + case 3: + span[2] = (byte)(m_array[quotient] >> 16); + goto case 2; + // fall through + case 2: + span[1] = (byte)(m_array[quotient] >> 8); + goto case 1; + // fall through + case 1: + span[0] = (byte)m_array[quotient]; + break; + } } - switch (remainder) + if (extraBits > 0) { - case 3: - span[2] = (byte)(m_array[quotient] >> 16); - goto case 2; - // fall through - case 2: - span[1] = (byte)(m_array[quotient] >> 8); - goto case 1; - // fall through - case 1: - span[0] = (byte)m_array[quotient]; - break; + // mask the final byte + span[remainder] = (byte)((m_array[quotient] >> (remainder * 8)) & ((1 << extraBits) - 1)); } } else if (array is bool[] boolArray) { - if (array.Length - index < m_length) + // This method uses unsafe code to manipulate data in the BitArray. To avoid issues with + // buggy code concurrently mutating this instance in a way that could cause memory corruption, + // we snapshot the array then operate only on this snapshot. We don't care about such code + // corrupting the BitArray data in a way that produces incorrect answers, since BitArray is not meant + // to be thread-safe; we only care about avoiding buffer overruns. + int[] thisArray = m_array; + int thisLength = m_length; + if (thisLength < 0 || thisArray.Length < GetInt32ArrayLengthFromBitLength(thisLength)) + { + ThrowHelper.ThrowConcurrentOperation(); + } + if (array.Length - index < thisLength) { throw new ArgumentException(SR.Argument_InvalidOffLen); } - uint i = 0; - - if (m_length < BitsPerInt32) - goto LessThan32; + int i = 0; - // The mask used when shuffling a single int into Vector128/256/512. - // On little endian machines, the lower 8 bits of int belong in the first byte, next lower 8 in the second and so on. - // We place the bytes that contain the bits to its respective byte so that we can mask out only the relevant bits later. - Vector128 lowerShuffleMask_CopyToBoolArray = Vector128.Create(0, 0x01010101_01010101).AsByte(); - Vector128 upperShuffleMask_CopyToBoolArray = Vector128.Create(0x02020202_02020202, 0x03030303_03030303).AsByte(); + ref int thisRef = ref MemoryMarshal.GetArrayDataReference(thisArray); + ref bool boolRef = ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(boolArray), index); - if (Avx512BW.IsSupported && (uint)m_length >= Vector512.Count) + if (Vector512.IsHardwareAccelerated && thisLength >= Vector512.Count) { - Vector256 upperShuffleMask_CopyToBoolArray256 = Vector256.Create(0x04040404_04040404, 0x05050505_05050505, - 0x06060606_06060606, 0x07070707_07070707).AsByte(); - Vector256 lowerShuffleMask_CopyToBoolArray256 = Vector256.Create(lowerShuffleMask_CopyToBoolArray, upperShuffleMask_CopyToBoolArray); - Vector512 shuffleMask = Vector512.Create(lowerShuffleMask_CopyToBoolArray256, upperShuffleMask_CopyToBoolArray256); - Vector512 bitMask = Vector512.Create(0x80402010_08040201).AsByte(); + // With AVX512, we could take advantage of the instruction "kmovq k, m64" + // and do something like + // Vector512.ConditionalSelect(Avx512F.CreateMask(Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref Unsafe.As(ref thisRef), i >> BitShiftPerByte))), + // Vector512.Create((byte)1), + // Vector512.Zero) + // Unfortunately Avx512F.CreateMask does not exist. See dotnet/runtime#87097 + + Vector512 bitMask; Vector512 ones = Vector512.Create((byte)1); + if (BitConverter.IsLittleEndian) + { + bitMask = Vector512.Create(0x80402010_08040201).AsByte(); + } + else + { + bitMask = Vector512.Create(0x01020408_10204080).AsByte(); + } - fixed (bool* destination = &boolArray[index]) + for (; i <= thisLength - Vector512.Count; i += Vector512.Count) { - for (; (i + Vector512.Count) <= (uint)m_length; i += (uint)Vector512.Count) + Vector512 bits = Vector512.Create(Unsafe.ReadUnaligned(ref Unsafe.AddByteOffset(ref Unsafe.As(ref thisRef), i >> BitShiftPerByte))); + + Vector512 shuffled; + // Feed the shuffle indices directly to Vector512.Shuffle() for the best codegen. + // See dotnet/runtime#115078 + if (BitConverter.IsLittleEndian) { - ulong bits = (ulong)(uint)m_array[i / (uint)BitsPerInt32] + ((ulong)m_array[(i / (uint)BitsPerInt32) + 1] << BitsPerInt32); - Vector512 scalar = Vector512.Create(bits); - Vector512 shuffled = Avx512BW.Shuffle(scalar.AsByte(), shuffleMask); - Vector512 extracted = Avx512F.And(shuffled, bitMask); - - // The extracted bits can be anywhere between 0 and 255, so we normalise the value to either 0 or 1 - // to ensure compatibility with "C# bool" (0 for false, 1 for true, rest undefined) - Vector512 normalized = Avx512BW.Min(extracted, ones); - Avx512F.Store((byte*)destination + i, normalized); + shuffled = Vector512.Shuffle(bits.AsByte(), + // The shuffle indices should be chosen so that with AVX512, vpshufb can be emitted instead of the slower, less avaliable(AVX512_VBMI) vpermb. + Vector512.Create(0x00000000_00000000, 0x09090909_09090909, 0x12121212_12121212, 0x1B1B1B1B_1B1B1B1B, + 0x24242424_24242424, 0x2D2D2D2D_2D2D2D2D, 0x36363636_36363636, 0x3F3F3F3F_3F3F3F3F).AsByte()); } + else + { + shuffled = Vector512.Shuffle(bits.AsByte(), + Vector512.Create(0x03030303_03030303, 0x0A0A0A0A_0A0A0A0A, 0x11111111_11111111, 0x18181818_18181818, + 0x27272727_27272727, 0x2E2E2E2E_2E2E2E2E, 0x35353535_35353535, 0x3C3C3C3C_3C3C3C3C).AsByte()); + } + + Vector512 extracted = shuffled & bitMask; + + // The extracted bits can be anywhere between 0 and 255, so we normalise the value to either 0 or 1 + // to ensure compatibility with "C# bool" (0 for false, 1 for true, rest undefined) + Vector512 normalized = Vector512.Min(extracted, ones); + normalized.StoreUnsafe(ref Unsafe.As(ref boolRef), (uint)i); } } - else if (Avx2.IsSupported && (uint)m_length >= Vector256.Count) + if (Vector256.IsHardwareAccelerated && thisLength >= Vector256.Count) { - Vector256 shuffleMask = Vector256.Create(lowerShuffleMask_CopyToBoolArray, upperShuffleMask_CopyToBoolArray); - Vector256 bitMask = Vector256.Create(0x80402010_08040201).AsByte(); - //Internal.Console.WriteLine(bitMask); + Vector256 bitMask; Vector256 ones = Vector256.Create((byte)1); + if (BitConverter.IsLittleEndian) + { + bitMask = Vector256.Create(0x80402010_08040201).AsByte(); + } + else + { + bitMask = Vector256.Create(0x01020408_10204080).AsByte(); + } - fixed (bool* destination = &boolArray[index]) + for (; i <= thisLength - Vector256.Count; i += Vector256.Count) { - for (; (i + Vector256.Count) <= (uint)m_length; i += (uint)Vector256.Count) + Vector256 bits = Vector256.Create(Unsafe.AddByteOffset(ref thisRef, i >> BitShiftPerByte)); + + Vector256 shuffled; + // Feed the shuffle indices directly to Vector256.Shuffle() for the best codegen. + // See dotnet/runtime#115078 + if (BitConverter.IsLittleEndian) { - int bits = m_array[i / (uint)BitsPerInt32]; - Vector256 scalar = Vector256.Create(bits); - Vector256 shuffled = Avx2.Shuffle(scalar.AsByte(), shuffleMask); - Vector256 extracted = Avx2.And(shuffled, bitMask); - - // The extracted bits can be anywhere between 0 and 255, so we normalise the value to either 0 or 1 - // to ensure compatibility with "C# bool" (0 for false, 1 for true, rest undefined) - Vector256 normalized = Avx2.Min(extracted, ones); - Avx.Store((byte*)destination + i, normalized); + shuffled = Vector256.Shuffle(bits.AsByte(), + // The shuffle indices should be chosen so that with AVX2, vpshufb can be emitted. + Vector256.Create(0x00000000_00000000, 0x09090909_09090909, 0x12121212_12121212, 0x1B1B1B1B_1B1B1B1B).AsByte()); } + else + { + shuffled = Vector256.Shuffle(bits.AsByte(), + Vector256.Create(0x03030303_03030303, 0x0A0A0A0A_0A0A0A0A, 0x11111111_11111111, 0x18181818_18181818).AsByte()); + } + + Vector256 extracted = shuffled & bitMask; + + // The extracted bits can be anywhere between 0 and 255, so we normalise the value to either 0 or 1 + // to ensure compatibility with "C# bool" (0 for false, 1 for true, rest undefined) + Vector256 normalized = Vector256.Min(extracted, ones); + normalized.StoreUnsafe(ref Unsafe.As(ref boolRef), (uint)i); } } - else if (Ssse3.IsSupported && ((uint)m_length >= Vector128.Count * 2u)) + else if (Vector128.IsHardwareAccelerated + && (Ssse3.IsSupported || !Sse.IsSupported)// We need SSSE3 for pshufb + && thisLength >= Vector128.Count * 2) { - Vector128 lowerShuffleMask = lowerShuffleMask_CopyToBoolArray; - Vector128 upperShuffleMask = upperShuffleMask_CopyToBoolArray; + Vector128 bitMask; Vector128 ones = Vector128.Create((byte)1); - Vector128 bitMask128 = BitConverter.IsLittleEndian ? - Vector128.Create(0x80402010_08040201).AsByte() : - Vector128.Create(0x01020408_10204080).AsByte(); + if (BitConverter.IsLittleEndian) + { + bitMask = Vector128.Create(0x80402010_08040201).AsByte(); + } + else + { + bitMask = Vector128.Create(0x01020408_10204080).AsByte(); + } - fixed (bool* destination = &boolArray[index]) + while (i <= thisLength - Vector128.Count * 2) { - for (; (i + Vector128.Count * 2u) <= (uint)m_length; i += (uint)Vector128.Count * 2u) + Vector128 bits = Vector128.CreateScalarUnsafe(Unsafe.AddByteOffset(ref thisRef, i >> BitShiftPerByte)); + + + Vector128 shuffledLower; + // Feed the shuffle indices directly to Vector128.Shuffle() for the best codegen. + // See dotnet/runtime#115078 + if (BitConverter.IsLittleEndian) { - int bits = m_array[i / (uint)BitsPerInt32]; - Vector128 scalar = Vector128.CreateScalarUnsafe(bits); - - Vector128 shuffledLower = Ssse3.Shuffle(scalar.AsByte(), lowerShuffleMask); - Vector128 extractedLower = Sse2.And(shuffledLower, bitMask128); - Vector128 normalizedLower = Sse2.Min(extractedLower, ones); - Sse2.Store((byte*)destination + i, normalizedLower); - - Vector128 shuffledHigher = Ssse3.Shuffle(scalar.AsByte(), upperShuffleMask); - Vector128 extractedHigher = Sse2.And(shuffledHigher, bitMask128); - Vector128 normalizedHigher = Sse2.Min(extractedHigher, ones); - Sse2.Store((byte*)destination + i + Vector128.Count, normalizedHigher); + shuffledLower = Vector128.Shuffle(bits.AsByte(), Vector128.Create(0x00000000_00000000, 0x01010101_01010101).AsByte()); + } + else + { + shuffledLower = Vector128.Shuffle(bits.AsByte(), Vector128.Create(0x03030303_03030303, 0x02020202_02020202).AsByte()); } - } - } - else if (AdvSimd.Arm64.IsSupported) - { - Vector128 ones = Vector128.Create((byte)1); - Vector128 bitMask128 = BitConverter.IsLittleEndian ? - Vector128.Create(0x80402010_08040201).AsByte() : - Vector128.Create(0x01020408_10204080).AsByte(); - fixed (bool* destination = &boolArray[index]) - { - for (; (i + Vector128.Count * 2u) <= (uint)m_length; i += (uint)Vector128.Count * 2u) + Vector128 extractedLower = shuffledLower & bitMask; + + // The extracted bits can be anywhere between 0 and 255, so we normalise the value to either 0 or 1 + // to ensure compatibility with "C# bool" (0 for false, 1 for true, rest undefined) + Vector128 normalizedLower = Vector128.Min(extractedLower, ones); + normalizedLower.StoreUnsafe(ref Unsafe.As(ref boolRef), (uint)i); + i += Vector128.Count; + + + Vector128 shuffledUpper; + // Feed the shuffle indices directly to Vector128.Shuffle() for the best codegen. + // See dotnet/runtime#115078 + if (BitConverter.IsLittleEndian) + { + shuffledUpper = Vector128.Shuffle(bits.AsByte(), Vector128.Create(0x02020202_02020202, 0x03030303_03030303).AsByte()); + } + else { - int bits = m_array[i / (uint)BitsPerInt32]; - // Same logic as SSSE3 path, except we do not have Shuffle instruction. - // (TableVectorLookup could be an alternative - dotnet/runtime#1277) - // Instead we use chained ZIP1/2 instructions: - // (A0 is the byte containing LSB, A3 is the byte containing MSB) - // bits (on Big endian) - A3 A2 A1 A0 - // bits (Little endian) / Byte reversal - A0 A1 A2 A3 - // v1 = Vector128.Create - A0 A1 A2 A3 A0 A1 A2 A3 A0 A1 A2 A3 A0 A1 A2 A3 - // v2 = ZipLow(v1, v1) - A0 A0 A1 A1 A2 A2 A3 A3 A0 A0 A1 A1 A2 A2 A3 A3 - // v3 = ZipLow(v2, v2) - A0 A0 A0 A0 A1 A1 A1 A1 A2 A2 A2 A2 A3 A3 A3 A3 - // shuffledLower = ZipLow(v3, v3) - A0 A0 A0 A0 A0 A0 A0 A0 A1 A1 A1 A1 A1 A1 A1 A1 - // shuffledHigher = ZipHigh(v3, v3) - A2 A2 A2 A2 A2 A2 A2 A2 A3 A3 A3 A3 A3 A3 A3 A3 - if (!BitConverter.IsLittleEndian) - { - bits = BinaryPrimitives.ReverseEndianness(bits); - } - Vector128 vector = Vector128.Create(bits).AsByte(); - vector = AdvSimd.Arm64.ZipLow(vector, vector); - vector = AdvSimd.Arm64.ZipLow(vector, vector); - - Vector128 shuffledLower = AdvSimd.Arm64.ZipLow(vector, vector); - Vector128 extractedLower = AdvSimd.And(shuffledLower, bitMask128); - Vector128 normalizedLower = AdvSimd.Min(extractedLower, ones); - - Vector128 shuffledHigher = AdvSimd.Arm64.ZipHigh(vector, vector); - Vector128 extractedHigher = AdvSimd.And(shuffledHigher, bitMask128); - Vector128 normalizedHigher = AdvSimd.Min(extractedHigher, ones); - - AdvSimd.Arm64.StorePair((byte*)destination + i, normalizedLower, normalizedHigher); + shuffledUpper = Vector128.Shuffle(bits.AsByte(), Vector128.Create(0x01010101_01010101, 0x00000000_00000000).AsByte()); } + + Vector128 extractedUpper = shuffledUpper & bitMask; + + // The extracted bits can be anywhere between 0 and 255, so we normalise the value to either 0 or 1 + // to ensure compatibility with "C# bool" (0 for false, 1 for true, rest undefined) + Vector128 normalizedUpper = Vector128.Min(extractedUpper, ones); + normalizedUpper.StoreUnsafe(ref Unsafe.As(ref boolRef), (uint)i); + i += Vector128.Count; } } - LessThan32: - for (; i < (uint)m_length; i++) + for (; i < thisLength; i++) { - int elementIndex = Div32Rem((int)i, out int extraBits); - boolArray[(uint)index + i] = ((m_array[elementIndex] >> extraBits) & 0x00000001) != 0; + int elementIndex = Div32Rem(i, out int extraBits); + boolArray[index + i] = ((thisArray[elementIndex] >> extraBits) & 0x00000001) != 0; } } else @@ -1052,20 +1095,14 @@ public bool HasAnySet() /// /// Used for conversion between different representations of bit array. - /// Returns (n + (32 - 1)) / 32, rearranged to avoid arithmetic overflow. - /// For example, in the bit to int case, the straightforward calc would - /// be (n + 31) / 32, but that would cause overflow. So instead it's - /// rearranged to ((n - 1) / 32) + 1. - /// Due to sign extension, we don't need to special case for n == 0, if we use - /// bitwise operations (since ((n - 1) >> 5) + 1 = 0). - /// This doesn't hold true for ((n - 1) / 32) + 1, which equals 1. + /// Returns (n + (32 - 1)) / 32, using unsigned arithmetics to avoid overflow. /// /// Usage: /// GetArrayLength(77): returns how many ints must be /// allocated to store 77 bits. /// /// - /// how many ints are required to store n bytes + /// how many ints are required to store n bits private static int GetInt32ArrayLengthFromBitLength(int n) { Debug.Assert(n >= 0); @@ -1075,21 +1112,19 @@ private static int GetInt32ArrayLengthFromBitLength(int n) private static int GetInt32ArrayLengthFromByteLength(int n) { Debug.Assert(n >= 0); - // Due to sign extension, we don't need to special case for n == 0, since ((n - 1) >> 2) + 1 = 0 - // This doesn't hold true for ((n - 1) / 4) + 1, which equals 1. return (int)((uint)(n - 1 + (1 << BitShiftForBytesPerInt32)) >> BitShiftForBytesPerInt32); } private static int GetByteArrayLengthFromBitLength(int n) { Debug.Assert(n >= 0); - // Due to sign extension, we don't need to special case for n == 0, since ((n - 1) >> 3) + 1 = 0 - // This doesn't hold true for ((n - 1) / 8) + 1, which equals 1. return (int)((uint)(n - 1 + (1 << BitShiftPerByte)) >> BitShiftPerByte); } private static int Div32Rem(int number, out int remainder) { + Debug.Assert(number >= 0); + uint quotient = (uint)number / 32; remainder = number & (32 - 1); // equivalent to number % 32, since 32 is a power of 2 return (int)quotient; @@ -1097,6 +1132,8 @@ private static int Div32Rem(int number, out int remainder) private static int Div4Rem(int number, out int remainder) { + Debug.Assert(number >= 0); + uint quotient = (uint)number / 4; remainder = number & (4 - 1); // equivalent to number % 4, since 4 is a power of 2 return (int)quotient;