diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index eb556a79ec..dddedf29fd 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -730,6 +730,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/ResolvedServerSpn.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/ResolvedServerSpn.cs new file mode 100644 index 0000000000..bbf161b66e --- /dev/null +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/ResolvedServerSpn.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +#nullable enable + +namespace Microsoft.Data.SqlClient.SNI +{ + /// + /// This is used to hold the ServerSpn for a given connection. Most connection types have a single format, although TCP connections may allow + /// with and without a port. Depending on how the SPN is registered on the server, either one may be the correct name. + /// + /// + /// + /// + /// + /// SQL Server SPN format follows these patterns: + /// + /// + /// Default instance, no port (primary): + /// MSSQLSvc/fully-qualified-domain-name + /// + /// + /// Default instance, default port (secondary): + /// MSSQLSvc/fully-qualified-domain-name:1433 + /// + /// + /// Named instance or custom port: + /// MSSQLSvc/fully-qualified-domain-name:port_or_instance_name + /// + /// + /// For TCP connections to named instances, the port number is used in SPN. + /// For Named Pipe connections to named instances, the instance name is used in SPN. + /// When hostname resolution fails, the user-provided hostname is used instead of FQDN. + /// For default instances with TCP protocol, both forms (with and without port) may be returned. + /// + internal readonly struct ResolvedServerSpn(string primary, string? secondary = null) + { + public string Primary => primary; + + public string? Secondary => secondary; + } +} diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs index 265f80246c..59491bf8b4 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs @@ -3,12 +3,9 @@ // See the LICENSE file in the project root for more information. using System; -using System.Buffers; -using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Net; -using System.Net.Security; using System.Net.Sockets; using System.Text; using Microsoft.Data.ProviderBase; @@ -34,7 +31,7 @@ internal class SNIProxy /// Full server name from connection string /// Timer expiration /// Instance name - /// SPNs + /// SPN /// pre-defined SPN /// Flush packet cache /// Asynchronous connection @@ -51,7 +48,7 @@ internal static SNIHandle CreateConnectionHandle( string fullServerName, TimeoutTimer timeout, out byte[] instanceName, - ref string[] spns, + out ResolvedServerSpn resolvedSpn, string serverSPN, bool flushCache, bool async, @@ -65,6 +62,7 @@ internal static SNIHandle CreateConnectionHandle( string serverCertificateFilename) { instanceName = new byte[1]; + resolvedSpn = default; bool errorWithLocalDBProcessing; string localDBDataSource = GetLocalDBDataSource(fullServerName, out errorWithLocalDBProcessing); @@ -103,7 +101,7 @@ internal static SNIHandle CreateConnectionHandle( { try { - spns = GetSqlServerSPNs(details, serverSPN); + resolvedSpn = GetSqlServerSPNs(details, serverSPN); } catch (Exception e) { @@ -115,12 +113,12 @@ internal static SNIHandle CreateConnectionHandle( return sniHandle; } - private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN) + private static ResolvedServerSpn GetSqlServerSPNs(DataSource dataSource, string serverSPN) { Debug.Assert(!string.IsNullOrWhiteSpace(dataSource.ServerName)); if (!string.IsNullOrWhiteSpace(serverSPN)) { - return new[] { serverSPN }; + return new(serverSPN); } string hostName = dataSource.ServerName; @@ -138,7 +136,7 @@ private static string[] GetSqlServerSPNs(DataSource dataSource, string serverSPN return GetSqlServerSPNs(hostName, postfix, dataSource.ResolvedProtocol); } - private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol) + private static ResolvedServerSpn GetSqlServerSPNs(string hostNameOrAddress, string portOrInstanceName, DataSource.Protocol protocol) { Debug.Assert(!string.IsNullOrWhiteSpace(hostNameOrAddress)); IPHostEntry hostEntry = null; @@ -169,12 +167,12 @@ private static string[] GetSqlServerSPNs(string hostNameOrAddress, string portOr string serverSpnWithDefaultPort = serverSpn + $":{DefaultSqlServerPort}"; // Set both SPNs with and without Port as Port is optional for default instance SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPNs {0} and {1}", serverSpn, serverSpnWithDefaultPort); - return new[] { serverSpn, serverSpnWithDefaultPort }; + return new(serverSpn, serverSpnWithDefaultPort); } // else Named Pipes do not need to valid port SqlClientEventSource.Log.TryAdvancedTraceEvent("SNIProxy.GetSqlServerSPN | Info | ServerSPN {0}", serverSpn); - return new[] { serverSpn }; + return new(serverSpn); } /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index d66984d8b7..7b947e4ce5 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -121,8 +121,6 @@ internal sealed partial class TdsParser private bool _is2022 = false; - private string[] _serverSpn = null; - // SqlStatistics private SqlStatistics _statistics = null; @@ -395,7 +393,6 @@ internal void Connect(ServerInfo serverInfo, } else { - _serverSpn = null; SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | Connection Object Id {0}, Authentication Mode: {1}", _connHandler.ObjectID, authType == SqlAuthenticationMethod.NotSpecified ? SqlAuthenticationMethod.SqlPassword.ToString() : authType.ToString()); } @@ -407,7 +404,6 @@ internal void Connect(ServerInfo serverInfo, SqlClientEventSource.Log.TryTraceEvent(" Encryption will be disabled as target server is a SQL Local DB instance."); } - _serverSpn = null; _authenticationProvider = null; // AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server @@ -446,7 +442,7 @@ internal void Connect(ServerInfo serverInfo, serverInfo.ExtendedServerName, timeout, out instanceName, - ref _serverSpn, + out var resolvedServerSpn, false, true, fParallel, @@ -459,8 +455,6 @@ internal void Connect(ServerInfo serverInfo, hostNameInCertificate, serverCertificateFilename); - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); - if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -546,7 +540,7 @@ internal void Connect(ServerInfo serverInfo, serverInfo.ExtendedServerName, timeout, out instanceName, - ref _serverSpn, + out resolvedServerSpn, true, true, fParallel, @@ -559,8 +553,6 @@ internal void Connect(ServerInfo serverInfo, hostNameInCertificate, serverCertificateFilename); - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); - if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -599,6 +591,11 @@ internal void Connect(ServerInfo serverInfo, } SqlClientEventSource.Log.TryTraceEvent(" Prelogin handshake successful"); + if (_authenticationProvider is { }) + { + _authenticationProvider.Initialize(serverInfo, _physicalStateObj, this, resolvedServerSpn.Primary, resolvedServerSpn.Secondary); + } + if (_fMARS && marsCapable) { // if user explicitly disables mars or mars not supported, don't create the session pool @@ -744,7 +741,7 @@ private void SendPreLoginHandshake( // UNDONE - need to do some length verification to ensure packet does not // get too big!!! Not beyond it's max length! - + for (int option = (int)PreLoginOptions.VERSION; option < (int)PreLoginOptions.NUMOPT; option++) { int optionDataSize = 0; @@ -935,7 +932,7 @@ private PreLoginHandshakeStatus ConsumePreLoginHandshake( string serverCertificateFilename) { // Assign default values - marsCapable = _fMARS; + marsCapable = _fMARS; fedAuthRequired = false; Debug.Assert(_physicalStateObj._syncOverAsync, "Should not attempt pends in a synchronous call"); TdsOperationStatus result = _physicalStateObj.TryReadNetworkPacket(); @@ -2181,7 +2178,7 @@ internal TdsOperationStatus TryRun(RunBehavior runBehavior, SqlCommand cmdHandle dataStream.BrowseModeInfoConsumed = true; } else - { + { // no dataStream result = stateObj.TrySkipBytes(tokenLength); if (result != TdsOperationStatus.Done) @@ -2195,7 +2192,7 @@ internal TdsOperationStatus TryRun(RunBehavior runBehavior, SqlCommand cmdHandle case TdsEnums.SQLDONE: case TdsEnums.SQLDONEPROC: case TdsEnums.SQLDONEINPROC: - { + { // RunBehavior can be modified - see SQL BU DT 269516 & 290090 result = TryProcessDone(cmdHandler, dataStream, ref runBehavior, stateObj); if (result != TdsOperationStatus.Done) @@ -4122,7 +4119,7 @@ internal TdsOperationStatus TryProcessReturnValue(int length, { return result; } - + byte len; result = stateObj.TryReadByte(out len); if (result != TdsOperationStatus.Done) @@ -4321,7 +4318,7 @@ internal TdsOperationStatus TryProcessReturnValue(int length, { return result; } - + if (rec.collation.IsUTF8) { // UTF8 collation rec.encoding = Encoding.UTF8; @@ -4776,13 +4773,13 @@ internal TdsOperationStatus TryProcessAltMetaData(int cColumns, TdsParserStateOb { // internal meta data class _SqlMetaData col = altMetaDataSet[i]; - + result = stateObj.TryReadByte(out _); if (result != TdsOperationStatus.Done) { return result; } - + result = stateObj.TryReadUInt16(out _); if (result != TdsOperationStatus.Done) { @@ -5466,7 +5463,7 @@ private TdsOperationStatus TryProcessColInfo(_SqlMetaDataSet columns, SqlDataRea for (int i = 0; i < columns.Length; i++) { _SqlMetaData col = columns[i]; - + TdsOperationStatus result = stateObj.TryReadByte(out _); if (result != TdsOperationStatus.Done) { @@ -7386,7 +7383,7 @@ private byte[] SerializeSqlMoney(SqlMoney value, int length, TdsParserStateObjec private void WriteSqlMoney(SqlMoney value, int length, TdsParserStateObject stateObj) { - // UNDONE: can I use SqlMoney.ToInt64()? + // UNDONE: can I use SqlMoney.ToInt64()? int[] bits = decimal.GetBits(value.Value); // this decimal should be scaled by 10000 (regardless of what the incoming decimal was scaled by) @@ -9906,7 +9903,7 @@ private Task TDSExecuteRPCAddParameter(TdsParserStateObject stateObj, SqlParamet WriteUDTMetaData(value, names[0], names[1], names[2], stateObj); - // UNDONE - re-org to use code below to write value! + // UNDONE - re-org to use code below to write value! if (!isNull) { WriteUnsignedLong((ulong)udtVal.Length, stateObj); // PLP length @@ -12340,7 +12337,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int case TdsEnums.SQLNVARCHAR: case TdsEnums.SQLNTEXT: case TdsEnums.SQLXMLTYPE: - case TdsEnums.SQLJSON: + case TdsEnums.SQLJSON: { Debug.Assert(!isDataFeed || (value is TextDataFeed || value is XmlDataFeed), "Value must be a TextReader or XmlReader"); Debug.Assert(isDataFeed || (value is string || value is byte[]), "Value is a byte array or string"); @@ -13556,15 +13553,14 @@ private TdsOperationStatus TryProcessUDTMetaData(SqlMetaDataPriv metaData, TdsPa + " _connHandler = {14}\n\t" + " _fMARS = {15}\n\t" + " _sessionPool = {16}\n\t" - + " _sniSpnBuffer = {17}\n\t" - + " _errors = {18}\n\t" - + " _warnings = {19}\n\t" - + " _attentionErrors = {20}\n\t" - + " _attentionWarnings = {21}\n\t" - + " _statistics = {22}\n\t" - + " _statisticsIsInTransaction = {23}\n\t" - + " _fPreserveTransaction = {24}" - + " _fParallel = {25}" + + " _errors = {17}\n\t" + + " _warnings = {18}\n\t" + + " _attentionErrors = {19}\n\t" + + " _attentionWarnings = {20}\n\t" + + " _statistics = {21}\n\t" + + " _statisticsIsInTransaction = {22}\n\t" + + " _fPreserveTransaction = {23}" + + " _fParallel = {24}" ; internal string TraceString() { @@ -13587,7 +13583,6 @@ internal string TraceString() _connHandler == null ? "(null)" : _connHandler.ObjectID.ToString((IFormatProvider)null), _fMARS ? bool.TrueString : bool.FalseString, _sessionPool == null ? "(null)" : _sessionPool.TraceString(), - _serverSpn == null ? "(null)" : _serverSpn.Length.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null), diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs index 5e1e268b06..d0f6225831 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs @@ -11,6 +11,7 @@ using System.Threading.Tasks; using Microsoft.Data.Common; using Microsoft.Data.ProviderBase; +using Microsoft.Data.SqlClient.SNI; namespace Microsoft.Data.SqlClient { @@ -55,7 +56,7 @@ internal TdsParserStateObject(TdsParser parser, TdsParserStateObject physicalCon AddError(parser.ProcessSNIError(this)); ThrowExceptionAndWarning(); } - + // we post a callback that represents the call to dispose; once the // object is disposed, the next callback will cause the GC Handle to // be released. @@ -71,7 +72,7 @@ internal abstract void CreatePhysicalSNIHandle( string serverName, TimeoutTimer timeout, out byte[] instanceName, - ref string[] spns, + out ResolvedServerSpn resolvedSpn, bool flushCache, bool async, bool fParallel, diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs index e6dddc79f9..706096165d 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectManaged.cs @@ -81,7 +81,7 @@ internal override void CreatePhysicalSNIHandle( string serverName, TimeoutTimer timeout, out byte[] instanceName, - ref string[] spns, + out ResolvedServerSpn resolvedSpn, bool flushCache, bool async, bool parallel, @@ -94,7 +94,7 @@ internal override void CreatePhysicalSNIHandle( string hostNameInCertificate, string serverCertificateFilename) { - SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, ref spns, serverSPN, + SNIHandle? sessionHandle = SNIProxy.CreateConnectionHandle(serverName, timeout, out instanceName, out resolvedSpn, serverSPN, flushCache, async, parallel, isIntegratedSecurity, iPAddressPreference, cachedFQDN, ref pendingDNSInfo, tlsFirst, hostNameInCertificate, serverCertificateFilename); diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs index b8d1b6cccb..7bbd000160 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObjectNative.cs @@ -13,6 +13,7 @@ using Interop.Windows.Sni; using Microsoft.Data.Common; using Microsoft.Data.ProviderBase; +using Microsoft.Data.SqlClient.SNI; namespace Microsoft.Data.SqlClient { @@ -144,7 +145,7 @@ internal override void CreatePhysicalSNIHandle( string serverName, TimeoutTimer timeout, out byte[] instanceName, - ref string[] spns, + out ResolvedServerSpn resolvedSpn, bool flushCache, bool async, bool fParallel, @@ -178,7 +179,7 @@ internal override void CreatePhysicalSNIHandle( _sessionHandle = new SNIHandle(myInfo, serverName, ref serverSPN, timeout.MillisecondsRemainingInt, out instanceName, flushCache, !async, fParallel, ipPreference, cachedDNSInfo, hostNameInCertificate); - spns = new[] { serverSPN.TrimEnd() }; + resolvedSpn = new(serverSPN.TrimEnd()); } protected override uint SniPacketGetData(PacketHandle packet, byte[] _inBuff, ref uint dataSize) @@ -423,7 +424,7 @@ internal override uint WaitForSSLHandShakeToComplete(out int protocolVersion) } else if (nativeProtocol.HasFlag(NativeProtocols.SP_PROT_SSL3_CLIENT) || nativeProtocol.HasFlag(NativeProtocols.SP_PROT_SSL3_SERVER)) { -// SSL 2.0 and 3.0 are only referenced to log a warning, not explicitly used for connections + // SSL 2.0 and 3.0 are only referenced to log a warning, not explicitly used for connections #pragma warning disable CS0618, CA5397 protocolVersion = (int)SslProtocols.Ssl3; } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs index 130094dc53..9c4bd4eafe 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -121,8 +121,6 @@ internal sealed partial class TdsParser private bool _is2022 = false; - private string _serverSpn = null; - // SqlStatistics private SqlStatistics _statistics = null; @@ -396,6 +394,8 @@ internal void Connect(ServerInfo serverInfo, Debug.Fail("SNI returned status != success, but no error thrown?"); } + string serverSpn = null; + //Create LocalDB instance if necessary if (connHandler.ConnectionOptions.LocalDBInstance != null) { @@ -415,13 +415,13 @@ internal void Connect(ServerInfo serverInfo, if (!string.IsNullOrEmpty(serverInfo.ServerSPN)) { - _serverSpn = serverInfo.ServerSPN; + serverSpn = serverInfo.ServerSPN; SqlClientEventSource.Log.TryTraceEvent(" Server SPN `{0}` from the connection string is used.", serverInfo.ServerSPN); } else { // Empty signifies to interop layer that SPN needs to be generated - _serverSpn = string.Empty; + serverSpn = string.Empty; } SqlClientEventSource.Log.TryTraceEvent(" SSPI or Active Directory Authentication Library for SQL Server based integrated authentication"); @@ -429,7 +429,6 @@ internal void Connect(ServerInfo serverInfo, else { _authenticationProvider = null; - _serverSpn = null; switch (authType) { @@ -508,7 +507,7 @@ internal void Connect(ServerInfo serverInfo, serverInfo.ExtendedServerName, timeout, out instanceName, - ref _serverSpn, + ref serverSpn, false, true, fParallel, @@ -518,8 +517,6 @@ internal void Connect(ServerInfo serverInfo, FQDNforDNSCache, hostNameInCertificate); - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); - if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -602,7 +599,7 @@ internal void Connect(ServerInfo serverInfo, serverInfo.ExtendedServerName, timeout, out instanceName, - ref _serverSpn, + ref serverSpn, true, true, fParallel, @@ -612,8 +609,6 @@ internal void Connect(ServerInfo serverInfo, serverInfo.ResolvedServerName, hostNameInCertificate); - _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this); - if (TdsEnums.SNI_SUCCESS != _physicalStateObj.Status) { _physicalStateObj.AddError(ProcessSNIError(_physicalStateObj)); @@ -648,6 +643,8 @@ internal void Connect(ServerInfo serverInfo, } SqlClientEventSource.Log.TryTraceEvent(" Prelogin handshake successful"); + _authenticationProvider?.Initialize(serverInfo, _physicalStateObj, this, serverSpn); + if (_fMARS && marsCapable) { // if user explicitly disables mars or mars not supported, don't create the session pool @@ -13669,15 +13666,14 @@ internal ulong PlpBytesTotalLength(TdsParserStateObject stateObj) + " _connHandler = {14}\n\t" + " _fMARS = {15}\n\t" + " _sessionPool = {16}\n\t" - + " _sniSpnBuffer = {17}\n\t" - + " _errors = {18}\n\t" - + " _warnings = {19}\n\t" - + " _attentionErrors = {20}\n\t" - + " _attentionWarnings = {21}\n\t" - + " _statistics = {22}\n\t" - + " _statisticsIsInTransaction = {23}\n\t" - + " _fPreserveTransaction = {24}" - + " _fParallel = {25}" + + " _errors = {17}\n\t" + + " _warnings = {18}\n\t" + + " _attentionErrors = {19}\n\t" + + " _attentionWarnings = {20}\n\t" + + " _statistics = {21}\n\t" + + " _statisticsIsInTransaction = {22}\n\t" + + " _fPreserveTransaction = {23}" + + " _fParallel = {24}" ; internal string TraceString() { @@ -13700,7 +13696,6 @@ internal string TraceString() _connHandler == null ? "(null)" : _connHandler.ObjectID.ToString((IFormatProvider)null), _fMARS ? bool.TrueString : bool.FalseString, _sessionPool == null ? "(null)" : _sessionPool.TraceString(), - _serverSpn == null ? "(null)" : _serverSpn.Length.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.ErrorCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.WarningCount.ToString((IFormatProvider)null), _physicalStateObj != null ? "(null)" : _physicalStateObj.PreAttentionErrorCount.ToString((IFormatProvider)null), diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs index 5dc52010b3..a74651cf2d 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/NegotiateSspiContextProvider.cs @@ -2,26 +2,25 @@ using System; using System.Buffers; +using System.Diagnostics; using System.Net.Security; #nullable enable namespace Microsoft.Data.SqlClient { - internal sealed class NegotiateSspiContextProvider : SspiContextProvider + internal sealed class NegotiateSspiContextProvider : SspiContextProvider, IDisposable { - private NegotiateAuthentication? _negotiateAuth = null; + private NegotiateAuthentication? _negotiateAuth; protected override bool GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams) { - NegotiateAuthenticationStatusCode statusCode = NegotiateAuthenticationStatusCode.UnknownCredentials; - - _negotiateAuth ??= new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = authParams.Resource }); - var sendBuff = _negotiateAuth.GetOutgoingBlob(incomingBlob, out statusCode)!; + var negotiateAuth = GetNegotiateAuthenticationForParams(authParams); + var sendBuff = negotiateAuth.GetOutgoingBlob(incomingBlob, out var statusCode)!; // Log session id, status code and the actual SPN used in the negotiation SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | Session Id {2}, StatusCode={3}, SPN={4}", nameof(NegotiateSspiContextProvider), - nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, statusCode, _negotiateAuth.TargetName); + nameof(GenerateSspiClientContext), _physicalStateObj.SessionId, statusCode, negotiateAuth.TargetName); if (statusCode == NegotiateAuthenticationStatusCode.Completed || statusCode == NegotiateAuthenticationStatusCode.ContinueNeeded) { @@ -31,6 +30,27 @@ protected override bool GenerateSspiClientContext(ReadOnlySpan incomingBlo return false; } + + public void Dispose() + { + _negotiateAuth?.Dispose(); + } + + private NegotiateAuthentication GetNegotiateAuthenticationForParams(SspiAuthenticationParameters authParams) + { + if (_negotiateAuth is { }) + { + if (string.Equals(_negotiateAuth.TargetName, authParams.Resource, StringComparison.Ordinal)) + { + return _negotiateAuth; + } + + // Dispose of it since we're not going to use it now + _negotiateAuth.Dispose(); + } + + return _negotiateAuth = new(new NegotiateAuthenticationClientOptions { Package = "Negotiate", TargetName = authParams.Resource }); + } } } #endif diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs index ff83422f10..e246d58d38 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/SspiContextProvider.cs @@ -10,15 +10,48 @@ internal abstract class SspiContextProvider { private TdsParser _parser = null!; private ServerInfo _serverInfo = null!; + + private SspiAuthenticationParameters? _primaryAuthParams; + private SspiAuthenticationParameters? _secondaryAuthParams; + private protected TdsParserStateObject _physicalStateObj = null!; - internal void Initialize(ServerInfo serverInfo, TdsParserStateObject physicalStateObj, TdsParser parser) +#if NET + /// + /// for details as to what and means and why there are two. + /// +#endif + internal void Initialize( + ServerInfo serverInfo, + TdsParserStateObject physicalStateObj, + TdsParser parser, + string primaryServerSpn, + string? secondaryServerSpn = null + ) { _parser = parser; _physicalStateObj = physicalStateObj; _serverInfo = serverInfo; + var options = parser.Connection.ConnectionOptions; + + SqlClientEventSource.Log.StateDumpEvent(" Initializing provider {0} with SPN={1} and alternate={2}", GetType().FullName, primaryServerSpn, secondaryServerSpn); + + _primaryAuthParams = CreateAuthParams(options, primaryServerSpn); + + if (secondaryServerSpn is { }) + { + _secondaryAuthParams = CreateAuthParams(options, secondaryServerSpn); + } + Initialize(); + + static SspiAuthenticationParameters CreateAuthParams(SqlConnectionString connString, string serverSpn) => new(connString.DataSource, serverSpn) + { + DatabaseName = connString.InitialCatalog, + UserId = connString.UserID, + Password = connString.Password, + }; } private protected virtual void Initialize() @@ -27,46 +60,41 @@ private protected virtual void Initialize() protected abstract bool GenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams); - internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, string serverSpn) + internal void WriteSSPIContext(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter) { using var _ = TrySNIEventScope.Create(nameof(SspiContextProvider)); - if (!RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpn)) + if (_primaryAuthParams is { }) { - // If we've hit here, the SSPI context provider implementation failed to generate the SSPI context. - SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT); - } - } + if (RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, _primaryAuthParams)) + { + return; + } - internal void SSPIData(ReadOnlySpan receivedBuff, IBufferWriter outgoingBlobWriter, ReadOnlySpan serverSpns) - { - using var _ = TrySNIEventScope.Create(nameof(SspiContextProvider)); + // remove _primaryAuth from future attempts as it failed + _primaryAuthParams = null; + } - foreach (var serverSpn in serverSpns) + if (_secondaryAuthParams is { }) { - if (RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, serverSpn)) + if (RunGenerateSspiClientContext(receivedBuff, outgoingBlobWriter, _secondaryAuthParams)) { return; } + + // remove _secondaryAuthParams from future attempts as it failed + _secondaryAuthParams = null; } // If we've hit here, the SSPI context provider implementation failed to generate the SSPI context. SSPIError(SQLMessage.SSPIGenerateError(), TdsEnums.GEN_CLIENT_CONTEXT); } - private bool RunGenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, string serverSpn) + private bool RunGenerateSspiClientContext(ReadOnlySpan incomingBlob, IBufferWriter outgoingBlobWriter, SspiAuthenticationParameters authParams) { - var options = _parser.Connection.ConnectionOptions; - var authParams = new SspiAuthenticationParameters(options.DataSource, serverSpn) - { - DatabaseName = options.InitialCatalog, - UserId = options.UserID, - Password = options.Password, - }; - try { - SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | SPN={1}", GetType().FullName, nameof(GenerateSspiClientContext), serverSpn); + SqlClientEventSource.Log.TryTraceEvent("{0}.{1} | Info | SPN={1}", GetType().FullName, nameof(GenerateSspiClientContext), authParams.Resource); return GenerateSspiClientContext(incomingBlob, outgoingBlobWriter, authParams); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs index 90a69b5670..835ceee88f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlClientEventSource.cs @@ -893,6 +893,12 @@ internal void StateDumpEvent(string message, T0 args0, T1 args1) { StateDump(string.Format(message, args0?.ToString() ?? NullStr, args1?.ToString() ?? NullStr)); } + + [NonEvent] + internal void StateDumpEvent(string message, T0 args0, T1 args1, T2 args2) + { + StateDump(string.Format(message, args0?.ToString() ?? NullStr, args1?.ToString() ?? NullStr, args2?.ToString() ?? NullStr)); + } #endregion #region SNI Trace diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs index ad8226c7fe..675a1483d4 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -35,7 +35,7 @@ internal void ProcessSSPI(int receivedLength) try { // make call for SSPI data - _authenticationProvider!.SSPIData(receivedBuff.AsSpan(0, receivedLength), writer, _serverSpn); + _authenticationProvider!.WriteSSPIContext(receivedBuff.AsSpan(0, receivedLength), writer); // DO NOT SEND LENGTH - TDS DOC INCORRECT! JUST SEND SSPI DATA! _physicalStateObj.WriteByteSpan(writer.WrittenSpan); @@ -175,7 +175,7 @@ internal void TdsLogin( // byte[] buffer and 0 for the int length. Debug.Assert(SniContext.Snix_Login == _physicalStateObj.SniContext, $"Unexpected SniContext. Expecting Snix_Login, actual value is '{_physicalStateObj.SniContext}'"); _physicalStateObj.SniContext = SniContext.Snix_LoginSspi; - _authenticationProvider.SSPIData(ReadOnlySpan.Empty, sspiWriter, _serverSpn); + _authenticationProvider.WriteSSPIContext(ReadOnlySpan.Empty, sspiWriter); _physicalStateObj.SniContext = SniContext.Snix_Login; diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs index 9f6673332c..ee1905e914 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/InstanceNameTest/InstanceNameTest.cs @@ -212,9 +212,9 @@ private static string GetSPNInfo(string dataSource, string inInstanceName) string serverSPN = ""; MethodInfo getSqlServerSPNs = sniProxyObj.GetType().GetMethod("GetSqlServerSPNs", BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic, null, CallingConventions.Any, getSqlServerSPNsTypesArray, null); - string[] result = (string[])getSqlServerSPNs.Invoke(sniProxyObj, new object[] { dataSrcInfo, serverSPN }); + object resolvedSpns = getSqlServerSPNs.Invoke(sniProxyObj, new object[] { dataSrcInfo, serverSPN }); - string spnInfo = result[0]; + string spnInfo = (string)resolvedSpns.GetType().GetProperty("Primary", BindingFlags.Instance | BindingFlags.Public).GetValue(resolvedSpns); return spnInfo; }