Skip to content

Implemented custom SNI server name #1582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/MQTTAsync.c
Original file line number Diff line number Diff line change
Expand Up @@ -609,7 +609,7 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options)
}
if (options->struct_version != 0 && options->ssl) /* check validity of SSL options structure */
{
if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 5)
if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 6)
{
rc = MQTTASYNC_BAD_STRUCTURE;
goto exit;
Expand Down Expand Up @@ -767,6 +767,11 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options)
if (m->c->sslopts->CApath)
free((void*)m->c->sslopts->CApath);
}
if (m->c->sslopts->struct_version >= 6)
{
if (m->c->sslopts->serverName)
free((void*)m->c->sslopts->serverName);
}
free((void*)m->c->sslopts);
m->c->sslopts = NULL;
}
Expand Down Expand Up @@ -816,6 +821,11 @@ int MQTTAsync_connect(MQTTAsync handle, const MQTTAsync_connectOptions* options)
m->c->sslopts->protos = (const unsigned char*)MQTTStrdup((const char*)options->ssl->protos);
m->c->sslopts->protos_len = options->ssl->protos_len;
}
if (m->c->sslopts->struct_version >= 6)
{
if (options->ssl->serverName)
m->c->sslopts->serverName = MQTTStrdup(options->ssl->serverName);
}
}
#else
if (options->struct_version != 0 && options->ssl)
Expand Down
14 changes: 12 additions & 2 deletions src/MQTTAsync.h
Original file line number Diff line number Diff line change
Expand Up @@ -1074,12 +1074,13 @@ typedef struct
/** The eyecatcher for this structure. Must be MQTS */
char struct_id[4];

/** The version number of this structure. Must be 0, 1, 2, 3, 4 or 5.
/** The version number of this structure. Must be 0, 1, 2, 3, 4, 5, or 6.
* 0 means no sslVersion
* 1 means no verify, CApath
* 2 means no ssl_error_context, ssl_error_cb
* 3 means no ssl_psk_cb, ssl_psk_context, disableDefaultTrustStore
* 4 means no protos, protos_len
* 5 means no (SNI) serverName
*/
int struct_version;

Expand Down Expand Up @@ -1178,9 +1179,18 @@ typedef struct
* Exists only if struct_version >= 5
*/
unsigned int protos_len;

/**
* Optional server name for the Server Name Indication (SNI) TLS
* extension. It's the name of the broker/server host, and must be a
* host name, and not an IP address. It can be used by a multi-homed
* server to choose the correct certificate to present to the client.
* Exists only if struct_version >= 6
*/
const char *serverName;
} MQTTAsync_SSLOptions;

#define MQTTAsync_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 5, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT, 0, NULL, NULL, NULL, NULL, NULL, 0, NULL, 0 }
#define MQTTAsync_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 6, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT, 0, NULL, NULL, NULL, NULL, NULL, 0, NULL, 0, NULL }

/** Utility structure where name/value pairs are needed */
typedef struct
Expand Down
17 changes: 10 additions & 7 deletions src/MQTTAsyncUtils.c
Original file line number Diff line number Diff line change
Expand Up @@ -2851,6 +2851,8 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
char* serverURI = m->serverURI;
#if defined(OPENSSL)
int default_port = MQTT_DEFAULT_PORT;
const char* hostname = NULL; // Host name for SNI & verification
size_t hostname_len = 0;
#endif

FUNC_ENTRY;
Expand Down Expand Up @@ -2918,7 +2920,6 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
if (m->ssl)
{
int port;
size_t hostname_len;
int setSocketForSSLrc = 0;

if (m->c->net.https_proxy) {
Expand All @@ -2927,19 +2928,19 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
goto exit;
}

hostname_len = MQTTProtocol_addressPort(serverURI, &port, NULL, default_port);
hostname = SSLSocket_getHostName(serverURI, m->c->sslopts, &hostname_len);
setSocketForSSLrc = SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts,
serverURI, hostname_len);
hostname, hostname_len);

if (setSocketForSSLrc != MQTTASYNC_SUCCESS)
{
if (m->c->session != NULL)
if ((rc = SSL_set_session(m->c->net.ssl, m->c->session)) != 1)
Log(TRACE_MIN, -1, "Failed to set SSL session with stored data, non critical");
rc = m->c->sslopts->struct_version >= 3 ?
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI,
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
m->c->sslopts->verify, m->c->sslopts->ssl_error_cb, m->c->sslopts->ssl_error_context) :
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI,
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
m->c->sslopts->verify, NULL, NULL);
if (rc == TCPSOCKET_INTERRUPTED)
{
Expand Down Expand Up @@ -3009,10 +3010,12 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
#if defined(OPENSSL)
else if (m->c->connect_state == SSL_IN_PROGRESS) /* SSL connect sent - wait for completion */
{
hostname = SSLSocket_getHostName(serverURI, m->c->sslopts, &hostname_len);

rc = m->c->sslopts->struct_version >= 3 ?
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI,
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
m->c->sslopts->verify, m->c->sslopts->ssl_error_cb, m->c->sslopts->ssl_error_context) :
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI,
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
m->c->sslopts->verify, NULL, NULL);
if (rc != 1)
goto exit;
Expand Down
39 changes: 29 additions & 10 deletions src/MQTTClient.c
Original file line number Diff line number Diff line change
Expand Up @@ -995,10 +995,14 @@ static thread_return_type WINAPI MQTTClient_run(void* n)
#if defined(OPENSSL)
else if (m->c->connect_state == SSL_IN_PROGRESS)
{
const char* hostname;
size_t hostname_len;

hostname = SSLSocket_getHostName(m->serverURI, m->c->sslopts, &hostname_len);
rc = m->c->sslopts->struct_version >= 3 ?
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, m->serverURI,
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
m->c->sslopts->verify, m->c->sslopts->ssl_error_cb, m->c->sslopts->ssl_error_context) :
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, m->serverURI,
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
m->c->sslopts->verify, NULL, NULL);
if (rc == 1 || rc == SSL_FATAL)
{
Expand Down Expand Up @@ -1282,7 +1286,7 @@ static MQTTResponse MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_c
#if defined(OPENSSL)
if (m->ssl)
{
int port1;
const char* hostname;
size_t hostname_len;
const char *topic;
int setSocketForSSLrc = 0;
Expand All @@ -1293,19 +1297,20 @@ static MQTTResponse MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_c
goto exit;
}

hostname_len = MQTTProtocol_addressPort(serverURI, &port1, &topic, MQTT_DEFAULT_PORT);
hostname = SSLSocket_getHostName(serverURI, m->c->sslopts, &hostname_len);

setSocketForSSLrc = SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts,
serverURI, hostname_len);
hostname, hostname_len);

if (setSocketForSSLrc != MQTTCLIENT_SUCCESS)
{
if (m->c->session != NULL)
if ((rc = SSL_set_session(m->c->net.ssl, m->c->session)) != 1)
Log(TRACE_MIN, -1, "Failed to set SSL session with stored data, non critical");
rc = m->c->sslopts->struct_version >= 3 ?
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI,
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
m->c->sslopts->verify, m->c->sslopts->ssl_error_cb, m->c->sslopts->ssl_error_context) :
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI,
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
m->c->sslopts->verify, NULL, NULL);
if (rc == TCPSOCKET_INTERRUPTED)
m->c->connect_state = SSL_IN_PROGRESS; /* the connect is still in progress */
Expand Down Expand Up @@ -1630,6 +1635,11 @@ static MQTTResponse MQTTClient_connectURI(MQTTClient handle, MQTTClient_connectO
if (m->c->sslopts->CApath)
free((void*)m->c->sslopts->CApath);
}
if (m->c->sslopts->struct_version >= 6)
{
if (m->c->sslopts->serverName)
free((void*)m->c->sslopts->serverName);
}
free(m->c->sslopts);
m->c->sslopts = NULL;
}
Expand Down Expand Up @@ -1678,6 +1688,11 @@ static MQTTResponse MQTTClient_connectURI(MQTTClient handle, MQTTClient_connectO
m->c->sslopts->protos = options->ssl->protos;
m->c->sslopts->protos_len = options->ssl->protos_len;
}
if (m->c->sslopts->struct_version >= 6)
{
if (options->ssl->serverName)
m->c->sslopts->serverName = MQTTStrdup(options->ssl->serverName);
}
}
#endif

Expand Down Expand Up @@ -1830,7 +1845,7 @@ MQTTResponse MQTTClient_connectAll(MQTTClient handle, MQTTClient_connectOptions*
#if defined(OPENSSL)
if (options->struct_version != 0 && options->ssl) /* check validity of SSL options structure */
{
if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 5)
if (strncmp(options->ssl->struct_id, "MQTS", 4) != 0 || options->ssl->struct_version < 0 || options->ssl->struct_version > 6)
{
rc.reasonCode = MQTTCLIENT_BAD_STRUCTURE;
goto exit;
Expand Down Expand Up @@ -2749,11 +2764,15 @@ static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* r
#if defined(OPENSSL)
else if (m->c->connect_state == SSL_IN_PROGRESS)
{
const char* hostname;
size_t hostname_len;

hostname = SSLSocket_getHostName(m->currentServerURI, m->c->sslopts, &hostname_len);

*rc = m->c->sslopts->struct_version >= 3 ?
SSLSocket_connect(m->c->net.ssl, sock, m->currentServerURI,
SSLSocket_connect(m->c->net.ssl, sock, hostname, hostname_len,
m->c->sslopts->verify, m->c->sslopts->ssl_error_cb, m->c->sslopts->ssl_error_context) :
SSLSocket_connect(m->c->net.ssl, sock, m->currentServerURI,
SSLSocket_connect(m->c->net.ssl, sock, hostname, hostname_len,
m->c->sslopts->verify, NULL, NULL);
if (*rc == SSL_FATAL)
break;
Expand Down
14 changes: 12 additions & 2 deletions src/MQTTClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -679,12 +679,13 @@ typedef struct
/** The eyecatcher for this structure. Must be MQTS */
char struct_id[4];

/** The version number of this structure. Must be 0, 1, 2, 3, 4 or 5.
/** The version number of this structure. Must be 0, 1, 2, 3, 4, 5, or 6.
* 0 means no sslVersion
* 1 means no verify, CApath
* 2 means no ssl_error_context, ssl_error_cb
* 3 means no ssl_psk_cb, ssl_psk_context, disableDefaultTrustStore
* 4 means no protos, protos_len
* 5 means no (SNI) serverName
*/
int struct_version;

Expand Down Expand Up @@ -783,9 +784,18 @@ typedef struct
* Exists only if struct_version >= 5
*/
unsigned int protos_len;

/**
* Optional server name for the Server Name Indication (SNI) TLS
* extension. It's the name of the broker/server host, and must be a
* host name, and not an IP address. It can be used by a multi-homed
* server to choose the correct certificate to present to the client.
* Exists only if struct_version >= 6
*/
const char *serverName;
} MQTTClient_SSLOptions;

#define MQTTClient_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 5, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT, 0, NULL, NULL, NULL, NULL, NULL, 0, NULL, 0 }
#define MQTTClient_SSLOptions_initializer { {'M', 'Q', 'T', 'S'}, 6, NULL, NULL, NULL, NULL, NULL, 1, MQTT_SSL_VERSION_DEFAULT, 0, NULL, NULL, NULL, NULL, NULL, 0, NULL, 0, NULL }

/**
* MQTTClient_libraryInfo is used to store details relating to the currently used
Expand Down
11 changes: 8 additions & 3 deletions src/MQTTProtocolOut.c
Original file line number Diff line number Diff line change
Expand Up @@ -290,16 +290,21 @@ int MQTTProtocol_connect(const char* address, Clients* aClient, int unixsock, in
#if defined(OPENSSL)
if (ssl)
{
const char* hostname;
size_t hostname_len;

if (aClient->net.https_proxy) {
aClient->connect_state = PROXY_CONNECT_IN_PROGRESS;
rc = Proxy_connect( &aClient->net, 1, address);
}
if (rc == 0 && SSLSocket_setSocketForSSL(&aClient->net, aClient->sslopts, address, addr_len) == 1)

hostname = SSLSocket_getHostName(address, aClient->sslopts, &hostname_len);
if (rc == 0 && SSLSocket_setSocketForSSL(&aClient->net, aClient->sslopts, hostname, hostname_len) == 1)
{
rc = aClient->sslopts->struct_version >= 3 ?
SSLSocket_connect(aClient->net.ssl, aClient->net.socket, address,
SSLSocket_connect(aClient->net.ssl, aClient->net.socket, hostname, hostname_len,
aClient->sslopts->verify, aClient->sslopts->ssl_error_cb, aClient->sslopts->ssl_error_context) :
SSLSocket_connect(aClient->net.ssl, aClient->net.socket, address,
SSLSocket_connect(aClient->net.ssl, aClient->net.socket, hostname, hostname_len,
aClient->sslopts->verify, NULL, NULL);
if (rc == TCPSOCKET_INTERRUPTED)
aClient->connect_state = SSL_IN_PROGRESS; /* SSL connect called - wait for completion */
Expand Down
34 changes: 30 additions & 4 deletions src/SSLSocket.c
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,32 @@ static int tls_ex_index_ssl_opts;
#define snprintf _snprintf
#endif

/**
* Gets the hostname to use for SNI and host verification.
* This will use the `serverName` in the SSL options if one is provided,
* otherise will extract the host name from the URI and use that.
* @param serverURI The server URI
* @param opts The SSL options
* @param hostname_len Gets the string length of the returned host name.
* @return The host name to use for SNI and verification.
*/
const char* SSLSocket_getHostName(const char* serverURI, MQTTClient_SSLOptions* opts, size_t* hostname_len)
{
const char *hostname = NULL;
int port;

/* If servername is set in the SSL options, use that for the hostname */
if (opts->struct_version >= 6 && opts->serverName != NULL) {
hostname = opts->serverName;
*hostname_len = strnlen(hostname, MAXHOSTNAMELEN);
}
else {
hostname = serverURI;
*hostname_len = MQTTProtocol_addressPort(serverURI, &port, NULL, 0);
}
return hostname;
}

/**
* Gets the specific error corresponding to SOCKET_ERROR
* @param aString the function that was being used when the error occurred
Expand Down Expand Up @@ -744,6 +770,7 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
if (hostname_plus_null)
{
MQTTStrncpy(hostname_plus_null, hostname, hostname_len + 1u);
Log(TRACE_PROTOCOL, -1, "SNI server/host name is %s", hostname_plus_null);
if ((rc = SSL_set_tlsext_host_name(net->ssl, hostname_plus_null)) != 1) {
if (opts->struct_version >= 3)
SSLSocket_error("SSL_set_tlsext_host_name", NULL, net->socket, rc, opts->ssl_error_cb, opts->ssl_error_context);
Expand All @@ -763,7 +790,8 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
/*
* Return value: 1 - success, TCPSOCKET_INTERRUPTED - try again, anything else is failure
*/
int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, int verify, int (*cb)(const char *str, size_t len, void *u), void* u)
int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, size_t hostname_len,
int verify, int (*cb)(const char *str, size_t len, void *u), void* u)
{
int rc = 0;

Expand All @@ -784,12 +812,10 @@ int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, int verify, i
else if (verify)
{
char* peername = NULL;
int port;
size_t hostname_len;

X509* cert = SSL_get_peer_certificate(ssl);
hostname_len = MQTTProtocol_addressPort(hostname, &port, NULL, MQTT_DEFAULT_PORT);

Log(TRACE_PROTOCOL, -1, "X509_check_host for hostname %s", hostname);
rc = X509_check_host(cert, hostname, hostname_len, 0, &peername);
if (rc == 1)
Log(TRACE_PROTOCOL, -1, "peername from X509_check_host is %s", peername);
Expand Down
10 changes: 8 additions & 2 deletions src/SSLSocket.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,22 @@
/** if we should handle openssl initialization (bool_value == 1) or depend on it to be initalized externally (bool_value == 0) */
void SSLSocket_handleOpensslInit(int bool_value);

/** Get the host name to use for verification either from the URI or options */
const char* SSLSocket_getHostName(const char* serverURI, MQTTClient_SSLOptions* opts,
size_t* hostname_len);

int SSLSocket_initialize(void);
void SSLSocket_terminate(void);
int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, const char* hostname, size_t hostname_len);
int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
const char* hostname, size_t hostname_len);

int SSLSocket_getch(SSL* ssl, SOCKET socket, char* c);
char *SSLSocket_getdata(SSL* ssl, SOCKET socket, size_t bytes, size_t* actual_len, int* rc);

int SSLSocket_close(networkHandles* net);
int SSLSocket_putdatas(SSL* ssl, SOCKET socket, char* buf0, size_t buf0len, PacketBuffers bufs);
int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, int verify, int (*cb)(const char *str, size_t len, void *u), void* u);
int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, size_t hostname_len,
int verify, int (*cb)(const char *str, size_t len, void *u), void* u);

SOCKET SSLSocket_getPendingRead(void);
int SSLSocket_continueWrite(pending_writes* pw);
Expand Down
Loading