diff options
-rw-r--r-- | src/mscorlib/shared/System.Private.CoreLib.Shared.projitems | 1 | ||||
-rw-r--r-- | src/mscorlib/shared/System/Convert.Base64.cs | 217 | ||||
-rw-r--r-- | src/mscorlib/shared/System/Convert.cs | 334 |
3 files changed, 327 insertions, 225 deletions
diff --git a/src/mscorlib/shared/System.Private.CoreLib.Shared.projitems b/src/mscorlib/shared/System.Private.CoreLib.Shared.projitems index 2bc6464e97..270d06d3c4 100644 --- a/src/mscorlib/shared/System.Private.CoreLib.Shared.projitems +++ b/src/mscorlib/shared/System.Private.CoreLib.Shared.projitems @@ -94,6 +94,7 @@ <Compile Include="$(MSBuildThisFileDirectory)System\Configuration\Assemblies\AssemblyHashAlgorithm.cs" /> <Compile Include="$(MSBuildThisFileDirectory)System\Configuration\Assemblies\AssemblyVersionCompatibility.cs" /> <Compile Include="$(MSBuildThisFileDirectory)System\Convert.cs" /> + <Compile Include="$(MSBuildThisFileDirectory)System\Convert.Base64.cs" /> <Compile Include="$(MSBuildThisFileDirectory)System\CurrentSystemTimeZone.cs" /> <Compile Include="$(MSBuildThisFileDirectory)System\DataMisalignedException.cs" /> <Compile Include="$(MSBuildThisFileDirectory)System\DateTime.cs" /> diff --git a/src/mscorlib/shared/System/Convert.Base64.cs b/src/mscorlib/shared/System/Convert.Base64.cs new file mode 100644 index 0000000000..7e2aee31b2 --- /dev/null +++ b/src/mscorlib/shared/System/Convert.Base64.cs @@ -0,0 +1,217 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using Internal.Runtime.CompilerServices; + +namespace System +{ + public static partial class Convert + { + /// <summary> + /// Decode the span of UTF-16 encoded text represented as base 64 into binary data. + /// If the input is not a multiple of 4, or contains illegal characters, it will decode as much as it can, to the largest possible multiple of 4. + /// This invariant allows continuation of the parse with a slower, whitespace-tolerant algorithm. + /// + /// <param name="utf16">The input span which contains UTF-16 encoded text in base 64 that needs to be decoded.</param> + /// <param name="bytes">The output span which contains the result of the operation, i.e. the decoded binary data.</param> + /// <param name="consumed">The number of input bytes consumed during the operation. This can be used to slice the input for subsequent calls, if necessary.</param> + /// <param name="written">The number of bytes written into the output span. This can be used to slice the output for subsequent calls, if necessary.</param> + /// <returns>Returns: + /// - true - The entire input span was successfully parsed. + /// - false - Only a part of the input span was successfully parsed. Failure causes may include embedded or trailing whitespace, + /// other illegal Base64 characters, trailing characters after an encoding pad ('='), an input span whose length is not divisible by 4 + /// or a destination span that's too small. <paramref name="consumed"/> and <paramref name="written"/> are set so that + /// parsing can continue with a slower whitespace-tolerant algorithm. + /// + /// Note: This is a cut down version of the implementation of Base64.DecodeFromUtf8(), modified the accept UTF16 chars and act as a fast-path + /// helper for the Convert routines when the input string contains no whitespace. + /// + /// </summary> + private static bool TryDecodeFromUtf16(ReadOnlySpan<char> utf16, Span<byte> bytes, out int consumed, out int written) + { + ref char srcChars = ref MemoryMarshal.GetReference(utf16); + ref byte destBytes = ref MemoryMarshal.GetReference(bytes); + + int srcLength = utf16.Length & ~0x3; // only decode input up to the closest multiple of 4. + int destLength = bytes.Length; + + int sourceIndex = 0; + int destIndex = 0; + + if (utf16.Length == 0) + goto DoneExit; + + ref sbyte decodingMap = ref s_decodingMap[0]; + + // Last bytes could have padding characters, so process them separately and treat them as valid. + const int skipLastChunk = 4; + + int maxSrcLength; + if (destLength >= (srcLength >> 2) * 3) + { + maxSrcLength = srcLength - skipLastChunk; + } + else + { + // This should never overflow since destLength here is less than int.MaxValue / 4 * 3 (i.e. 1610612733) + // Therefore, (destLength / 3) * 4 will always be less than 2147483641 + maxSrcLength = (destLength / 3) * 4; + } + + while (sourceIndex < maxSrcLength) + { + int result = Decode(ref Unsafe.Add(ref srcChars, sourceIndex), ref decodingMap); + if (result < 0) + goto InvalidExit; + WriteThreeLowOrderBytes(ref Unsafe.Add(ref destBytes, destIndex), result); + destIndex += 3; + sourceIndex += 4; + } + + if (maxSrcLength != srcLength - skipLastChunk) + goto InvalidExit; + + // If input is less than 4 bytes, srcLength == sourceIndex == 0 + // If input is not a multiple of 4, sourceIndex == srcLength != 0 + if (sourceIndex == srcLength) + { + goto InvalidExit; + } + + int i0 = Unsafe.Add(ref srcChars, srcLength - 4); + int i1 = Unsafe.Add(ref srcChars, srcLength - 3); + int i2 = Unsafe.Add(ref srcChars, srcLength - 2); + int i3 = Unsafe.Add(ref srcChars, srcLength - 1); + if (((i0 | i1 | i2 | i3) & 0xffffff00) != 0) + goto InvalidExit; + + i0 = Unsafe.Add(ref decodingMap, i0); + i1 = Unsafe.Add(ref decodingMap, i1); + + i0 <<= 18; + i1 <<= 12; + + i0 |= i1; + + if (i3 != EncodingPad) + { + i2 = Unsafe.Add(ref decodingMap, i2); + i3 = Unsafe.Add(ref decodingMap, i3); + + i2 <<= 6; + + i0 |= i3; + i0 |= i2; + + if (i0 < 0) + goto InvalidExit; + if (destIndex > destLength - 3) + goto InvalidExit; + WriteThreeLowOrderBytes(ref Unsafe.Add(ref destBytes, destIndex), i0); + destIndex += 3; + } + else if (i2 != EncodingPad) + { + i2 = Unsafe.Add(ref decodingMap, i2); + + i2 <<= 6; + + i0 |= i2; + + if (i0 < 0) + goto InvalidExit; + if (destIndex > destLength - 2) + goto InvalidExit; + Unsafe.Add(ref destBytes, destIndex) = (byte)(i0 >> 16); + Unsafe.Add(ref destBytes, destIndex + 1) = (byte)(i0 >> 8); + destIndex += 2; + } + else + { + if (i0 < 0) + goto InvalidExit; + if (destIndex > destLength - 1) + goto InvalidExit; + Unsafe.Add(ref destBytes, destIndex) = (byte)(i0 >> 16); + destIndex += 1; + } + + sourceIndex += 4; + + if (srcLength != utf16.Length) + goto InvalidExit; + + DoneExit: + consumed = sourceIndex; + written = destIndex; + return true; + + InvalidExit: + consumed = sourceIndex; + written = destIndex; + Debug.Assert((consumed % 4) == 0); + return false; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static int Decode(ref char encodedChars, ref sbyte decodingMap) + { + int i0 = encodedChars; + int i1 = Unsafe.Add(ref encodedChars, 1); + int i2 = Unsafe.Add(ref encodedChars, 2); + int i3 = Unsafe.Add(ref encodedChars, 3); + + if (((i0 | i1 | i2 | i3) & 0xffffff00) != 0) + return -1; // One or more chars falls outside the 00..ff range. This cannot be a valid Base64 character. + + i0 = Unsafe.Add(ref decodingMap, i0); + i1 = Unsafe.Add(ref decodingMap, i1); + i2 = Unsafe.Add(ref decodingMap, i2); + i3 = Unsafe.Add(ref decodingMap, i3); + + i0 <<= 18; + i1 <<= 12; + i2 <<= 6; + + i0 |= i3; + i1 |= i2; + + i0 |= i1; + return i0; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static void WriteThreeLowOrderBytes(ref byte destination, int value) + { + destination = (byte)(value >> 16); + Unsafe.Add(ref destination, 1) = (byte)(value >> 8); + Unsafe.Add(ref destination, 2) = (byte)value; + } + + // Pre-computing this table using a custom string(s_characters) and GenerateDecodingMapAndVerify (found in tests) + private static readonly sbyte[] s_decodingMap = { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63, //62 is placed at index 43 (for +), 63 at index 47 (for /) + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, -1, -1, -1, -1, -1, -1, //52-61 are placed at index 48-57 (for 0-9), 64 at index 61 (for =) + -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, -1, -1, -1, -1, -1, //0-25 are placed at index 65-90 (for A-Z) + -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, + 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1, //26-51 are placed at index 97-122 (for a-z) + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Bytes over 122 ('z') are invalid and cannot be decoded + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, // Hence, padding the map with 255, which indicates invalid input + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + }; + + private const byte EncodingPad = (byte)'='; // '=', for padding + } +} diff --git a/src/mscorlib/shared/System/Convert.cs b/src/mscorlib/shared/System/Convert.cs index 756bf17fc5..63342ad000 100644 --- a/src/mscorlib/shared/System/Convert.cs +++ b/src/mscorlib/shared/System/Convert.cs @@ -96,7 +96,7 @@ namespace System // When passed Value.DBNull, the Value.ToXXX() methods all throw an // InvalidCastException. - public static class Convert + public static partial class Convert { //A typeof operation is fairly expensive (does a system call), so we'll cache these here //statically. These are exactly lined up with the TypeCode, eg. ConvertType[TypeCode.Int16] @@ -2656,44 +2656,124 @@ namespace System return TryFromBase64Chars(s.AsSpan(), bytes, out bytesWritten); } - public static unsafe bool TryFromBase64Chars(ReadOnlySpan<char> chars, Span<byte> bytes, out int bytesWritten) + public static bool TryFromBase64Chars(ReadOnlySpan<char> chars, Span<byte> bytes, out int bytesWritten) { - if (chars.Length == 0) - { - bytesWritten = 0; - return true; - } + // This is actually local to one of the nested blocks but is being declared at the top as we don't want multiple stackallocs + // for each iteraton of the loop. + Span<char> tempBuffer = stackalloc char[4]; // Note: The tempBuffer size could be made larger than 4 but the size must be a multiple of 4. - // We need to get rid of any trailing white spaces. - // Otherwise we would be rejecting input such as "abc= ": - while (chars.Length > 0) + bytesWritten = 0; + + while (chars.Length != 0) { - char lastChar = chars[chars.Length - 1]; - if (lastChar != ' ' && lastChar != '\n' && lastChar != '\r' && lastChar != '\t') + // Attempt to decode a segment that doesn't contain whitespace. + bool complete = TryDecodeFromUtf16(chars, bytes, out int consumedInThisIteration, out int bytesWrittenInThisIteration); + bytesWritten += bytesWrittenInThisIteration; + if (complete) + return true; + + chars = chars.Slice(consumedInThisIteration); + bytes = bytes.Slice(bytesWrittenInThisIteration); + + Debug.Assert(chars.Length != 0); // If TryDecodeFromUtf16() consumed the entire buffer, it could not have returned false. + if (chars[0].IsSpace()) { - break; - } - chars = chars.Slice(0, chars.Length - 1); - } + // If we got here, the very first character not consumed was a whitespace. We can skip past any consecutive whitespace, then continue decoding. - fixed (char* charsPtr = &MemoryMarshal.GetReference(chars)) - { - int resultLength = FromBase64_ComputeResultLength(charsPtr, chars.Length); - Debug.Assert(resultLength >= 0); - if (resultLength > bytes.Length) + int indexOfFirstNonSpace = 1; + for (; ; ) + { + if (indexOfFirstNonSpace == chars.Length) + break; + if (!chars[indexOfFirstNonSpace].IsSpace()) + break; + indexOfFirstNonSpace++; + } + + chars = chars.Slice(indexOfFirstNonSpace); + + if ((bytesWrittenInThisIteration % 3) != 0 && chars.Length != 0) + { + // If we got here, the last successfully decoded block encountered an end-marker, yet we have trailing non-whitespace characters. + // That is not allowed. + bytesWritten = default; + return false; + } + + // We now loop again to decode the next run of non-space characters. + } + else { - bytesWritten = 0; - return false; + Debug.Assert(chars.Length != 0 && !chars[0].IsSpace()); + + // If we got here, it is possible that there is whitespace that occurred in the middle of a 4-byte chunk. That is, we still have + // up to three Base64 characters that were left undecoded by the fast-path helper because they didn't form a complete 4-byte chunk. + // This is hopefully the rare case (multiline-formatted base64 message with a non-space character width that's not a multiple of 4.) + // We'll filter out whitespace and copy the remaining characters into a temporary buffer. + CopyToTempBufferWithoutWhiteSpace(chars, tempBuffer, out int consumedFromChars, out int charsWritten); + if ((charsWritten & 0x3) != 0) + { + // Even after stripping out whitespace, the number of characters is not divisible by 4. This cannot be a legal Base64 string. + bytesWritten = default; + return false; + } + + tempBuffer = tempBuffer.Slice(0, charsWritten); + if (!TryDecodeFromUtf16(tempBuffer, bytes, out int consumedFromTempBuffer, out int bytesWrittenFromTempBuffer)) + { + bytesWritten = default; + return false; + } + bytesWritten += bytesWrittenFromTempBuffer; + chars = chars.Slice(consumedFromChars); + bytes = bytes.Slice(bytesWrittenFromTempBuffer); + + if ((bytesWrittenFromTempBuffer % 3) != 0) + { + // If we got here, this decode contained one or more padding characters ('='). We can accept trailing whitespace after this + // but nothing else. + for (int i = 0; i < chars.Length; i++) + { + if (!chars[i].IsSpace()) + { + bytesWritten = default; + return false; + } + } + return true; + } + + // We now loop again to decode the next run of non-space characters. } + } - fixed (byte* bytesPtr = &MemoryMarshal.GetReference(bytes)) + return true; + } + + private static void CopyToTempBufferWithoutWhiteSpace(ReadOnlySpan<char> chars, Span<char> tempBuffer, out int consumed, out int charsWritten) + { + Debug.Assert(tempBuffer.Length != 0); // We only bound-check after writing a character to the tempBuffer. + + charsWritten = 0; + for (int i = 0; i < chars.Length; i++) + { + char c = chars[i]; + if (!c.IsSpace()) { - bytesWritten = FromBase64_Decode(charsPtr, chars.Length, bytesPtr, bytes.Length); - return true; + tempBuffer[charsWritten++] = c; + if (charsWritten == tempBuffer.Length) + { + consumed = i + 1; + return; + } } } + consumed = chars.Length; } + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool IsSpace(this char c) => c == ' ' || c == '\t' || c == '\r' || c == '\n'; + /// <summary> /// Converts the specified range of a Char array, which encodes binary data as Base64 digits, to the equivalent byte array. /// </summary> @@ -2730,8 +2810,6 @@ namespace System } } - - /// <summary> /// Convert Base64 encoding characters to bytes: /// - Compute result length exactly by actually walking the input; @@ -2769,11 +2847,10 @@ namespace System Byte[] decodedBytes = new Byte[resultLength]; // Convert Base64 chars into bytes: - Int32 actualResultLength; - fixed (Byte* decodedBytesPtr = decodedBytes) - actualResultLength = FromBase64_Decode(inputPtr, inputLength, decodedBytesPtr, resultLength); + if (!TryFromBase64Chars(new ReadOnlySpan<char>(inputPtr, inputLength), decodedBytes, out int _)) + throw new FormatException(SR.Format_BadBase64Char); - // Note that actualResultLength can differ from resultLength if the caller is modifying the array + // Note that the number of bytes written can differ from resultLength if the caller is modifying the array // as it is being converted. Silently ignore the failure. // Consider throwing exception in an non in-place release. @@ -2781,199 +2858,6 @@ namespace System return decodedBytes; } - - /// <summary> - /// Decode characters representing a Base64 encoding into bytes: - /// Walk the input. Every time 4 chars are read, convert them to the 3 corresponding output bytes. - /// This method is a bit lengthy on purpose. We are trying to avoid jumps to helpers in the loop - /// to aid performance. - /// </summary> - /// <param name="inputPtr">Pointer to first input char</param> - /// <param name="inputLength">Number of input chars</param> - /// <param name="destPtr">Pointer to location for the first result byte</param> - /// <param name="destLength">Max length of the preallocated result buffer</param> - /// <returns>If the result buffer was not large enough to write all result bytes, return -1; - /// Otherwise return the number of result bytes actually produced.</returns> - private static unsafe Int32 FromBase64_Decode(Char* startInputPtr, Int32 inputLength, Byte* startDestPtr, Int32 destLength) - { - // You may find this method weird to look at. It's written for performance, not aesthetics. - // You will find unrolled loops label jumps and bit manipulations. - - const UInt32 intA = (UInt32)'A'; - const UInt32 inta = (UInt32)'a'; - const UInt32 int0 = (UInt32)'0'; - const UInt32 intEq = (UInt32)'='; - const UInt32 intPlus = (UInt32)'+'; - const UInt32 intSlash = (UInt32)'/'; - const UInt32 intSpace = (UInt32)' '; - const UInt32 intTab = (UInt32)'\t'; - const UInt32 intNLn = (UInt32)'\n'; - const UInt32 intCRt = (UInt32)'\r'; - const UInt32 intAtoZ = (UInt32)('Z' - 'A'); // = ('z' - 'a') - const UInt32 int0to9 = (UInt32)('9' - '0'); - - Char* inputPtr = startInputPtr; - Byte* destPtr = startDestPtr; - - // Pointers to the end of input and output: - Char* endInputPtr = inputPtr + inputLength; - Byte* endDestPtr = destPtr + destLength; - - // Current char code/value: - UInt32 currCode; - - // This 4-byte integer will contain the 4 codes of the current 4-char group. - // Eeach char codes for 6 bits = 24 bits. - // The remaining byte will be FF, we use it as a marker when 4 chars have been processed. - UInt32 currBlockCodes = 0x000000FFu; - - unchecked - { - while (true) - { - // break when done: - if (inputPtr >= endInputPtr) - goto _AllInputConsumed; - - // Get current char: - currCode = (UInt32)(*inputPtr); - inputPtr++; - - // Determine current char code: - - if (currCode - intA <= intAtoZ) - currCode -= intA; - - else if (currCode - inta <= intAtoZ) - currCode -= (inta - 26u); - - else if (currCode - int0 <= int0to9) - currCode -= (int0 - 52u); - - else - { - // Use the slower switch for less common cases: - switch (currCode) - { - // Significant chars: - case intPlus: - currCode = 62u; - break; - - case intSlash: - currCode = 63u; - break; - - // Legal no-value chars (we ignore these): - case intCRt: - case intNLn: - case intSpace: - case intTab: - continue; - - // The equality char is only legal at the end of the input. - // Jump after the loop to make it easier for the JIT register predictor to do a good job for the loop itself: - case intEq: - goto _EqualityCharEncountered; - - // Other chars are illegal: - default: - throw new FormatException(SR.Format_BadBase64Char); - } - } - - // Ok, we got the code. Save it: - currBlockCodes = (currBlockCodes << 6) | currCode; - - // Last bit in currBlockCodes will be on after in shifted right 4 times: - if ((currBlockCodes & 0x80000000u) != 0u) - { - if ((Int32)(endDestPtr - destPtr) < 3) - return -1; - - *(destPtr) = (Byte)(currBlockCodes >> 16); - *(destPtr + 1) = (Byte)(currBlockCodes >> 8); - *(destPtr + 2) = (Byte)(currBlockCodes); - destPtr += 3; - - currBlockCodes = 0x000000FFu; - } - } - } // unchecked while - - // 'd be nice to have an assert that we never get here, but CS0162: Unreachable code detected. - // Debug.Fail("We only leave the above loop by jumping; should never get here."); - - // We jump here out of the loop if we hit an '=': - _EqualityCharEncountered: - - Debug.Assert(currCode == intEq); - - // Recall that inputPtr is now one position past where '=' was read. - // '=' can only be at the last input pos: - if (inputPtr == endInputPtr) - { - // Code is zero for trailing '=': - currBlockCodes <<= 6; - - // The '=' did not complete a 4-group. The input must be bad: - if ((currBlockCodes & 0x80000000u) == 0u) - throw new FormatException(SR.Format_BadBase64CharArrayLength); - - if ((int)(endDestPtr - destPtr) < 2) // Autch! We underestimated the output length! - return -1; - - // We are good, store bytes form this past group. We had a single "=", so we take two bytes: - *(destPtr++) = (Byte)(currBlockCodes >> 16); - *(destPtr++) = (Byte)(currBlockCodes >> 8); - - currBlockCodes = 0x000000FFu; - } - else - { // '=' can also be at the pre-last position iff the last is also a '=' excluding the white spaces: - // We need to get rid of any intermediate white spaces. - // Otherwise we would be rejecting input such as "abc= =": - while (inputPtr < (endInputPtr - 1)) - { - Int32 lastChar = *(inputPtr); - if (lastChar != (Int32)' ' && lastChar != (Int32)'\n' && lastChar != (Int32)'\r' && lastChar != (Int32)'\t') - break; - inputPtr++; - } - - if (inputPtr == (endInputPtr - 1) && *(inputPtr) == '=') - { - // Code is zero for each of the two '=': - currBlockCodes <<= 12; - - // The '=' did not complete a 4-group. The input must be bad: - if ((currBlockCodes & 0x80000000u) == 0u) - throw new FormatException(SR.Format_BadBase64CharArrayLength); - - if ((Int32)(endDestPtr - destPtr) < 1) // Autch! We underestimated the output length! - return -1; - - // We are good, store bytes form this past group. We had a "==", so we take only one byte: - *(destPtr++) = (Byte)(currBlockCodes >> 16); - - currBlockCodes = 0x000000FFu; - } - else // '=' is not ok at places other than the end: - throw new FormatException(SR.Format_BadBase64Char); - } - - // We get here either from above or by jumping out of the loop: - _AllInputConsumed: - - // The last block of chars has less than 4 items - if (currBlockCodes != 0x000000FFu) - throw new FormatException(SR.Format_BadBase64CharArrayLength); - - // Return how many bytes were actually recovered: - return (Int32)(destPtr - startDestPtr); - } // Int32 FromBase64_Decode(...) - - /// <summary> /// Compute the number of bytes encoded in the specified Base 64 char array: /// Walk the entire input counting white spaces and padding chars, then compute result length |