diff --git a/src/MQTTAsync.c b/src/MQTTAsync.c index c548ae31..dad790d6 100644 --- a/src/MQTTAsync.c +++ b/src/MQTTAsync.c @@ -609,7 +609,7 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options) } if (options->struct_version != 0 && options->ssl) /* check validity of SSL options structure */ { - if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 5) + if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 6) { rc = MQTTASYNC_BAD_STRUCTURE; goto exit; @@ -767,6 +767,11 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options) if (m->c->sslopts->CApath) free((void*)m->c->sslopts->CApath); } + if (m->c->sslopts->struct_version >= 6) + { + if (m->c->sslopts->serverName) + free((void*)m->c->sslopts->serverName); + } free((void*)m->c->sslopts); m->c->sslopts = NULL; } @@ -816,6 +821,11 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options) m->c->sslopts->protos = (const unsigned char*)MQTTStrdup((const char*)options->ssl->protos); m->c->sslopts->protos_len = options->ssl->protos_len; } + if (m->c->sslopts->struct_version >= 6) + { + if (options->ssl->serverName) + m->c->sslopts->serverName = MQTTStrdup(options->ssl->serverName); + } } #else if (options->struct_version != 0 && options->ssl) diff --git a/src/MQTTAsync.h b/src/MQTTAsync.h index 7c5b6214..81898a3e 100644 --- a/src/MQTTAsync.h +++ b/src/MQTTAsync.h @@ -1074,12 +1074,13 @@ typedef struct /** The eyecatcher for this structure. Must be MQTS */ char struct_id[4]; - /** The version number of this structure. Must be 0, 1, 2, 3, 4 or 5. + /** The version number of this structure. Must be 0, 1, 2, 3, 4, 5, or 6. * 0 means no sslVersion * 1 means no verify, CApath * 2 means no ssl_error_context, ssl_error_cb * 3 means no ssl_psk_cb, ssl_psk_context, disableDefaultTrustStore * 4 means no protos, protos_len + * 5 means no (SNI) serverName */ int struct_version; @@ -1178,9 +1179,18 @@ typedef struct * Exists only if struct_version >= 5 */ unsigned int protos_len; + + /** + * Optional server name for the Server Name Indication (SNI) TLS + * extension. It's the name of the broker/server host, and must be a + * host name, and not an IP address. It can be used by a multi-homed + * server to choose the correct certificate to present to the client. + * Exists only if struct_version >= 6 + */ + const char *serverName; } MQTTAsync_SSLOptions; -#define MQTTAsync_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 5, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT, 0, NULL, NULL, NULL, NULL, NULL, 0, NULL, 0 } +#define MQTTAsync_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 6, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT, 0, NULL, NULL, NULL, NULL, NULL, 0, NULL, 0, NULL } /** Utility structure where name/value pairs are needed */ typedef struct diff --git a/src/MQTTAsyncUtils.c b/src/MQTTAsyncUtils.c index c084a11e..abfeb979 100644 --- a/src/MQTTAsyncUtils.c +++ b/src/MQTTAsyncUtils.c @@ -2851,6 +2851,8 @@ static int MQTTAsync_connecting(MQTTAsyncs* m) char* serverURI = m->serverURI; #if defined(OPENSSL) int default_port = MQTT_DEFAULT_PORT; + const char* hostname = NULL; // Host name for SNI & verification + size_t hostname_len = 0; #endif FUNC_ENTRY; @@ -2918,7 +2920,6 @@ static int MQTTAsync_connecting(MQTTAsyncs* m) if (m->ssl) { int port; - size_t hostname_len; int setSocketForSSLrc = 0; if (m->c->net.https_proxy) { @@ -2927,9 +2928,9 @@ static int MQTTAsync_connecting(MQTTAsyncs* m) goto exit; } - hostname_len = MQTTProtocol_addressPort(serverURI, &port, NULL, default_port); + hostname = SSLSocket_getHostName(serverURI, m->c->sslopts, &hostname_len); setSocketForSSLrc = SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts, - serverURI, hostname_len); + hostname, hostname_len); if (setSocketForSSLrc != MQTTASYNC_SUCCESS) { @@ -2937,9 +2938,9 @@ static int MQTTAsync_connecting(MQTTAsyncs* m) if ((rc = SSL_set_session(m->c->net.ssl, m->c->session)) != 1) Log(TRACE_MIN, -1, "Failed to set SSL session with stored data, non critical"); rc = m->c->sslopts->struct_version >= 3 ? - SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI, + SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len, m->c->sslopts->verify, m->c->sslopts->ssl_error_cb, m->c->sslopts->ssl_error_context) : - SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI, + SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len, m->c->sslopts->verify, NULL, NULL); if (rc == TCPSOCKET_INTERRUPTED) { @@ -3009,10 +3010,12 @@ static int MQTTAsync_connecting(MQTTAsyncs* m) #if defined(OPENSSL) else if (m->c->connect_state == SSL_IN_PROGRESS) /* SSL connect sent - wait for completion */ { + hostname = SSLSocket_getHostName(serverURI, m->c->sslopts, &hostname_len); + rc = m->c->sslopts->struct_version >= 3 ? - SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI, + SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len, m->c->sslopts->verify, m->c->sslopts->ssl_error_cb, m->c->sslopts->ssl_error_context) : - SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI, + SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len, m->c->sslopts->verify, NULL, NULL); if (rc != 1) goto exit; diff --git a/src/MQTTClient.c b/src/MQTTClient.c index 32d59c3b..c52a0b13 100644 --- a/src/MQTTClient.c +++ b/src/MQTTClient.c @@ -995,10 +995,14 @@ static thread_return_type WINAPI MQTTClient_run(void* n) #if defined(OPENSSL) else if (m->c->connect_state == SSL_IN_PROGRESS) { + const char* hostname; + size_t hostname_len; + + hostname = SSLSocket_getHostName(m->serverURI, m->c->sslopts, &hostname_len); rc = m->c->sslopts->struct_version >= 3 ? - SSLSocket_connect(m->c->net.ssl, m->c->net.socket, m->serverURI, + SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len, m->c->sslopts->verify, m->c->sslopts->ssl_error_cb, m->c->sslopts->ssl_error_context) : - SSLSocket_connect(m->c->net.ssl, m->c->net.socket, m->serverURI, + SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len, m->c->sslopts->verify, NULL, NULL); if (rc == 1 || rc == SSL_FATAL) { @@ -1282,7 +1286,7 @@ static MQTTResponse MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_c #if defined(OPENSSL) if (m->ssl) { - int port1; + const char* hostname; size_t hostname_len; const char *topic; int setSocketForSSLrc = 0; @@ -1293,9 +1297,10 @@ static MQTTResponse MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_c goto exit; } - hostname_len = MQTTProtocol_addressPort(serverURI, &port1, &topic, MQTT_DEFAULT_PORT); + hostname = SSLSocket_getHostName(serverURI, m->c->sslopts, &hostname_len); + setSocketForSSLrc = SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts, - serverURI, hostname_len); + hostname, hostname_len); if (setSocketForSSLrc != MQTTCLIENT_SUCCESS) { @@ -1303,9 +1308,9 @@ static MQTTResponse MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_c if ((rc = SSL_set_session(m->c->net.ssl, m->c->session)) != 1) Log(TRACE_MIN, -1, "Failed to set SSL session with stored data, non critical"); rc = m->c->sslopts->struct_version >= 3 ? - SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI, + SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len, m->c->sslopts->verify, m->c->sslopts->ssl_error_cb, m->c->sslopts->ssl_error_context) : - SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI, + SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len, m->c->sslopts->verify, NULL, NULL); if (rc == TCPSOCKET_INTERRUPTED) m->c->connect_state = SSL_IN_PROGRESS; /* the connect is still in progress */ @@ -1630,6 +1635,11 @@ static MQTTResponse MQTTClient_connectURI(MQTTClient handle, MQTTClient_connectO if (m->c->sslopts->CApath) free((void*)m->c->sslopts->CApath); } + if (m->c->sslopts->struct_version >= 6) + { + if (m->c->sslopts->serverName) + free((void*)m->c->sslopts->serverName); + } free(m->c->sslopts); m->c->sslopts = NULL; } @@ -1678,6 +1688,11 @@ static MQTTResponse MQTTClient_connectURI(MQTTClient handle, MQTTClient_connectO m->c->sslopts->protos = options->ssl->protos; m->c->sslopts->protos_len = options->ssl->protos_len; } + if (m->c->sslopts->struct_version >= 6) + { + if (options->ssl->serverName) + m->c->sslopts->serverName = MQTTStrdup(options->ssl->serverName); + } } #endif @@ -1830,7 +1845,7 @@ MQTTResponse MQTTClient_connectAll(MQTTClient handle, MQTTClient_connectOptions* #if defined(OPENSSL) if (options->struct_version != 0 && options->ssl) /* check validity of SSL options structure */ { - if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 5) + if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 6) { rc.reasonCode = MQTTCLIENT_BAD_STRUCTURE; goto exit; @@ -2749,11 +2764,15 @@ static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* r #if defined(OPENSSL) else if (m->c->connect_state == SSL_IN_PROGRESS) { + const char* hostname; + size_t hostname_len; + + hostname = SSLSocket_getHostName(m->currentServerURI, m->c->sslopts, &hostname_len); *rc = m->c->sslopts->struct_version >= 3 ? - SSLSocket_connect(m->c->net.ssl, sock, m->currentServerURI, + SSLSocket_connect(m->c->net.ssl, sock, hostname, hostname_len, m->c->sslopts->verify, m->c->sslopts->ssl_error_cb, m->c->sslopts->ssl_error_context) : - SSLSocket_connect(m->c->net.ssl, sock, m->currentServerURI, + SSLSocket_connect(m->c->net.ssl, sock, hostname, hostname_len, m->c->sslopts->verify, NULL, NULL); if (*rc == SSL_FATAL) break; diff --git a/src/MQTTClient.h b/src/MQTTClient.h index c8a4de35..d6066d1a 100644 --- a/src/MQTTClient.h +++ b/src/MQTTClient.h @@ -679,12 +679,13 @@ typedef struct /** The eyecatcher for this structure. Must be MQTS */ char struct_id[4]; - /** The version number of this structure. Must be 0, 1, 2, 3, 4 or 5. + /** The version number of this structure. Must be 0, 1, 2, 3, 4, 5, or 6. * 0 means no sslVersion * 1 means no verify, CApath * 2 means no ssl_error_context, ssl_error_cb * 3 means no ssl_psk_cb, ssl_psk_context, disableDefaultTrustStore * 4 means no protos, protos_len + * 5 means no (SNI) serverName */ int struct_version; @@ -783,9 +784,18 @@ typedef struct * Exists only if struct_version >= 5 */ unsigned int protos_len; + + /** + * Optional server name for the Server Name Indication (SNI) TLS + * extension. It's the name of the broker/server host, and must be a + * host name, and not an IP address. It can be used by a multi-homed + * server to choose the correct certificate to present to the client. + * Exists only if struct_version >= 6 + */ + const char *serverName; } MQTTClient_SSLOptions; -#define MQTTClient_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 5, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT, 0, NULL, NULL, NULL, NULL, NULL, 0, NULL, 0 } +#define MQTTClient_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 6, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT, 0, NULL, NULL, NULL, NULL, NULL, 0, NULL, 0, NULL } /** * MQTTClient_libraryInfo is used to store details relating to the currently used diff --git a/src/MQTTProtocolOut.c b/src/MQTTProtocolOut.c index b53a5e74..db19332b 100644 --- a/src/MQTTProtocolOut.c +++ b/src/MQTTProtocolOut.c @@ -290,16 +290,21 @@ int MQTTProtocol_connect(const char* address, Clients* aClient, int unixsock, in #if defined(OPENSSL) if (ssl) { + const char* hostname; + size_t hostname_len; + if (aClient->net.https_proxy) { aClient->connect_state = PROXY_CONNECT_IN_PROGRESS; rc = Proxy_connect( &aClient->net, 1, address); } - if (rc == 0 && SSLSocket_setSocketForSSL(&aClient->net, aClient->sslopts, address, addr_len) == 1) + + hostname = SSLSocket_getHostName(address, aClient->sslopts, &hostname_len); + if (rc == 0 && SSLSocket_setSocketForSSL(&aClient->net, aClient->sslopts, hostname, hostname_len) == 1) { rc = aClient->sslopts->struct_version >= 3 ? - SSLSocket_connect(aClient->net.ssl, aClient->net.socket, address, + SSLSocket_connect(aClient->net.ssl, aClient->net.socket, hostname, hostname_len, aClient->sslopts->verify, aClient->sslopts->ssl_error_cb, aClient->sslopts->ssl_error_context) : - SSLSocket_connect(aClient->net.ssl, aClient->net.socket, address, + SSLSocket_connect(aClient->net.ssl, aClient->net.socket, hostname, hostname_len, aClient->sslopts->verify, NULL, NULL); if (rc == TCPSOCKET_INTERRUPTED) aClient->connect_state = SSL_IN_PROGRESS; /* SSL connect called - wait for completion */ diff --git a/src/SSLSocket.c b/src/SSLSocket.c index 8cb090c4..c5ed48ab 100644 --- a/src/SSLSocket.c +++ b/src/SSLSocket.c @@ -85,6 +85,32 @@ static int tls_ex_index_ssl_opts; #define snprintf _snprintf #endif +/** + * Gets the hostname to use for SNI and host verification. + * This will use the `serverName` in the SSL options if one is provided, + * otherise will extract the host name from the URI and use that. + * @param serverURI The server URI + * @param opts The SSL options + * @param hostname_len Gets the string length of the returned host name. + * @return The host name to use for SNI and verification. + */ +const char* SSLSocket_getHostName(const char* serverURI, MQTTClient_SSLOptions* opts, size_t* hostname_len) +{ + const char *hostname = NULL; + int port; + + /* If servername is set in the SSL options, use that for the hostname */ + if (opts->struct_version >= 6 && opts->serverName != NULL) { + hostname = opts->serverName; + *hostname_len = strnlen(hostname, MAXHOSTNAMELEN); + } + else { + hostname = serverURI; + *hostname_len = MQTTProtocol_addressPort(serverURI, &port, NULL, 0); + } + return hostname; +} + /** * Gets the specific error corresponding to SOCKET_ERROR * @param aString the function that was being used when the error occurred @@ -744,6 +770,7 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, if (hostname_plus_null) { MQTTStrncpy(hostname_plus_null, hostname, hostname_len + 1u); + Log(TRACE_PROTOCOL, -1, "SNI server/host name is %s", hostname_plus_null); if ((rc = SSL_set_tlsext_host_name(net->ssl, hostname_plus_null)) != 1) { if (opts->struct_version >= 3) SSLSocket_error("SSL_set_tlsext_host_name", NULL, net->socket, rc, opts->ssl_error_cb, opts->ssl_error_context); @@ -763,7 +790,8 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, /* * Return value: 1 - success, TCPSOCKET_INTERRUPTED - try again, anything else is failure */ -int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, int verify, int (*cb)(const char *str, size_t len, void *u), void* u) +int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, size_t hostname_len, + int verify, int (*cb)(const char *str, size_t len, void *u), void* u) { int rc = 0; @@ -784,12 +812,10 @@ int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, int verify, i else if (verify) { char* peername = NULL; - int port; - size_t hostname_len; X509* cert = SSL_get_peer_certificate(ssl); - hostname_len = MQTTProtocol_addressPort(hostname, &port, NULL, MQTT_DEFAULT_PORT); + Log(TRACE_PROTOCOL, -1, "X509_check_host for hostname %s", hostname); rc = X509_check_host(cert, hostname, hostname_len, 0, &peername); if (rc == 1) Log(TRACE_PROTOCOL, -1, "peername from X509_check_host is %s", peername); diff --git a/src/SSLSocket.h b/src/SSLSocket.h index 2e804cec..1d3b0c01 100644 --- a/src/SSLSocket.h +++ b/src/SSLSocket.h @@ -36,16 +36,22 @@ /** if we should handle openssl initialization (bool_value == 1) or depend on it to be initalized externally (bool_value == 0) */ void SSLSocket_handleOpensslInit(int bool_value); +/** Get the host name to use for verification either from the URI or options */ +const char* SSLSocket_getHostName(const char* serverURI, MQTTClient_SSLOptions* opts, + size_t* hostname_len); + int SSLSocket_initialize(void); void SSLSocket_terminate(void); -int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, const char* hostname, size_t hostname_len); +int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, + const char* hostname, size_t hostname_len); int SSLSocket_getch(SSL* ssl, SOCKET socket, char* c); char *SSLSocket_getdata(SSL* ssl, SOCKET socket, size_t bytes, size_t* actual_len, int* rc); int SSLSocket_close(networkHandles* net); int SSLSocket_putdatas(SSL* ssl, SOCKET socket, char* buf0, size_t buf0len, PacketBuffers bufs); -int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, int verify, int (*cb)(const char *str, size_t len, void *u), void* u); +int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, size_t hostname_len, + int verify, int (*cb)(const char *str, size_t len, void *u), void* u); SOCKET SSLSocket_getPendingRead(void); int SSLSocket_continueWrite(pending_writes* pw);