Skip to content

Commit edf5f76

Browse files
committed
#1580 Also using optional SSL serverName for certificate verification.
1 parent 98f3ca9 commit edf5f76

File tree

5 files changed

+73
-30
lines changed

5 files changed

+73
-30
lines changed

src/MQTTAsyncUtils.c

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,6 +2851,8 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
28512851
char* serverURI = m->serverURI;
28522852
#if defined(OPENSSL)
28532853
int default_port = MQTT_DEFAULT_PORT;
2854+
const char* hostname = NULL; // Host name for SNI & verification
2855+
size_t hostname_len = 0;
28542856
#endif
28552857

28562858
FUNC_ENTRY;
@@ -2918,7 +2920,6 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
29182920
if (m->ssl)
29192921
{
29202922
int port;
2921-
size_t hostname_len;
29222923
int setSocketForSSLrc = 0;
29232924

29242925
if (m->c->net.https_proxy) {
@@ -2927,19 +2928,19 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
29272928
goto exit;
29282929
}
29292930

2930-
hostname_len = MQTTProtocol_addressPort(serverURI, &port, NULL, default_port);
2931+
hostname = SSLSocket_getHostName(serverURI, m->c->sslopts, &hostname_len);
29312932
setSocketForSSLrc = SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts,
2932-
serverURI, hostname_len);
2933+
hostname, hostname_len);
29332934

29342935
if (setSocketForSSLrc != MQTTASYNC_SUCCESS)
29352936
{
29362937
if (m->c->session != NULL)
29372938
if ((rc = SSL_set_session(m->c->net.ssl, m->c->session)) != 1)
29382939
Log(TRACE_MIN, -1, "Failed to set SSL session with stored data, non critical");
29392940
rc = m->c->sslopts->struct_version >= 3 ?
2940-
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI,
2941+
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
29412942
m->c->sslopts->verify, m->c->sslopts->ssl_error_cb, m->c->sslopts->ssl_error_context) :
2942-
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI,
2943+
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
29432944
m->c->sslopts->verify, NULL, NULL);
29442945
if (rc == TCPSOCKET_INTERRUPTED)
29452946
{
@@ -3009,10 +3010,12 @@ static int MQTTAsync_connecting(MQTTAsyncs* m)
30093010
#if defined(OPENSSL)
30103011
else if (m->c->connect_state == SSL_IN_PROGRESS) /* SSL connect sent - wait for completion */
30113012
{
3013+
hostname = SSLSocket_getHostName(serverURI, m->c->sslopts, &hostname_len);
3014+
30123015
rc = m->c->sslopts->struct_version >= 3 ?
3013-
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI,
3016+
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
30143017
m->c->sslopts->verify, m->c->sslopts->ssl_error_cb, m->c->sslopts->ssl_error_context) :
3015-
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI,
3018+
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
30163019
m->c->sslopts->verify, NULL, NULL);
30173020
if (rc != 1)
30183021
goto exit;

src/MQTTClient.c

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -995,10 +995,14 @@ static thread_return_type WINAPI MQTTClient_run(void* n)
995995
#if defined(OPENSSL)
996996
else if (m->c->connect_state == SSL_IN_PROGRESS)
997997
{
998+
const char* hostname;
999+
size_t hostname_len;
1000+
1001+
hostname = SSLSocket_getHostName(m->serverURI, m->c->sslopts, &hostname_len);
9981002
rc = m->c->sslopts->struct_version >= 3 ?
999-
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, m->serverURI,
1003+
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
10001004
m->c->sslopts->verify, m->c->sslopts->ssl_error_cb, m->c->sslopts->ssl_error_context) :
1001-
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, m->serverURI,
1005+
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
10021006
m->c->sslopts->verify, NULL, NULL);
10031007
if (rc == 1 || rc == SSL_FATAL)
10041008
{
@@ -1282,7 +1286,7 @@ static MQTTResponse MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_c
12821286
#if defined(OPENSSL)
12831287
if (m->ssl)
12841288
{
1285-
int port1;
1289+
const char* hostname;
12861290
size_t hostname_len;
12871291
const char *topic;
12881292
int setSocketForSSLrc = 0;
@@ -1293,19 +1297,20 @@ static MQTTResponse MQTTClient_connectURIVersion(MQTTClient handle, MQTTClient_c
12931297
goto exit;
12941298
}
12951299

1296-
hostname_len = MQTTProtocol_addressPort(serverURI, &port1, &topic, MQTT_DEFAULT_PORT);
1300+
hostname = SSLSocket_getHostName(serverURI, m->c->sslopts, &hostname_len);
1301+
12971302
setSocketForSSLrc = SSLSocket_setSocketForSSL(&m->c->net, m->c->sslopts,
1298-
serverURI, hostname_len);
1303+
hostname, hostname_len);
12991304

13001305
if (setSocketForSSLrc != MQTTCLIENT_SUCCESS)
13011306
{
13021307
if (m->c->session != NULL)
13031308
if ((rc = SSL_set_session(m->c->net.ssl, m->c->session)) != 1)
13041309
Log(TRACE_MIN, -1, "Failed to set SSL session with stored data, non critical");
13051310
rc = m->c->sslopts->struct_version >= 3 ?
1306-
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI,
1311+
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
13071312
m->c->sslopts->verify, m->c->sslopts->ssl_error_cb, m->c->sslopts->ssl_error_context) :
1308-
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, serverURI,
1313+
SSLSocket_connect(m->c->net.ssl, m->c->net.socket, hostname, hostname_len,
13091314
m->c->sslopts->verify, NULL, NULL);
13101315
if (rc == TCPSOCKET_INTERRUPTED)
13111316
m->c->connect_state = SSL_IN_PROGRESS; /* the connect is still in progress */
@@ -2759,11 +2764,15 @@ static MQTTPacket* MQTTClient_waitfor(MQTTClient handle, int packet_type, int* r
27592764
#if defined(OPENSSL)
27602765
else if (m->c->connect_state == SSL_IN_PROGRESS)
27612766
{
2767+
const char* hostname;
2768+
size_t hostname_len;
2769+
2770+
hostname = SSLSocket_getHostName(m->currentServerURI, m->c->sslopts, &hostname_len);
27622771

27632772
*rc = m->c->sslopts->struct_version >= 3 ?
2764-
SSLSocket_connect(m->c->net.ssl, sock, m->currentServerURI,
2773+
SSLSocket_connect(m->c->net.ssl, sock, hostname, hostname_len,
27652774
m->c->sslopts->verify, m->c->sslopts->ssl_error_cb, m->c->sslopts->ssl_error_context) :
2766-
SSLSocket_connect(m->c->net.ssl, sock, m->currentServerURI,
2775+
SSLSocket_connect(m->c->net.ssl, sock, hostname, hostname_len,
27672776
m->c->sslopts->verify, NULL, NULL);
27682777
if (*rc == SSL_FATAL)
27692778
break;

src/MQTTProtocolOut.c

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,16 +290,21 @@ int MQTTProtocol_connect(const char* address, Clients* aClient, int unixsock, in
290290
#if defined(OPENSSL)
291291
if (ssl)
292292
{
293+
const char* hostname;
294+
size_t hostname_len;
295+
293296
if (aClient->net.https_proxy) {
294297
aClient->connect_state = PROXY_CONNECT_IN_PROGRESS;
295298
rc = Proxy_connect( &aClient->net, 1, address);
296299
}
297-
if (rc == 0 && SSLSocket_setSocketForSSL(&aClient->net, aClient->sslopts, address, addr_len) == 1)
300+
301+
hostname = SSLSocket_getHostName(address, aClient->sslopts, &hostname_len);
302+
if (rc == 0 && SSLSocket_setSocketForSSL(&aClient->net, aClient->sslopts, hostname, hostname_len) == 1)
298303
{
299304
rc = aClient->sslopts->struct_version >= 3 ?
300-
SSLSocket_connect(aClient->net.ssl, aClient->net.socket, address,
305+
SSLSocket_connect(aClient->net.ssl, aClient->net.socket, hostname, hostname_len,
301306
aClient->sslopts->verify, aClient->sslopts->ssl_error_cb, aClient->sslopts->ssl_error_context) :
302-
SSLSocket_connect(aClient->net.ssl, aClient->net.socket, address,
307+
SSLSocket_connect(aClient->net.ssl, aClient->net.socket, hostname, hostname_len,
303308
aClient->sslopts->verify, NULL, NULL);
304309
if (rc == TCPSOCKET_INTERRUPTED)
305310
aClient->connect_state = SSL_IN_PROGRESS; /* SSL connect called - wait for completion */

src/SSLSocket.c

Lines changed: 29 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,32 @@ static int tls_ex_index_ssl_opts;
8585
#define snprintf _snprintf
8686
#endif
8787

88+
/**
89+
* Gets the hostname to use for SNI and host verification.
90+
* This will use the `serverName` in the SSL options if one is provided,
91+
* otherise will extract the host name from the URI and use that.
92+
* @param serverURI The server URI
93+
* @param opts The SSL options
94+
* @param hostname_len Gets the string length of the returned host name.
95+
* @return The host name to use for SNI and verification.
96+
*/
97+
const char* SSLSocket_getHostName(const char* serverURI, MQTTClient_SSLOptions* opts, size_t* hostname_len)
98+
{
99+
const char *hostname = NULL;
100+
int port;
101+
102+
/* If servername is set in the SSL options, use that for the hostname */
103+
if (opts->struct_version >= 6 && opts->serverName != NULL) {
104+
hostname = opts->serverName;
105+
*hostname_len = strnlen(hostname, MAXHOSTNAMELEN);
106+
}
107+
else {
108+
hostname = serverURI;
109+
*hostname_len = MQTTProtocol_addressPort(serverURI, &port, NULL, 0);
110+
}
111+
return hostname;
112+
}
113+
88114
/**
89115
* Gets the specific error corresponding to SOCKET_ERROR
90116
* @param aString the function that was being used when the error occurred
@@ -740,11 +766,6 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
740766
else
741767
SSLSocket_error("SSL_set_fd", net->ssl, net->socket, rc, NULL, NULL);
742768
}
743-
/* If servername is set in the options, use that for the hostname */
744-
if (opts->struct_version >= 6 && opts->serverName != NULL) {
745-
hostname = opts->serverName;
746-
hostname_len = strnlen(hostname, MAXHOSTNAMELEN);
747-
}
748769
hostname_plus_null = malloc(hostname_len + 1u );
749770
if (hostname_plus_null)
750771
{
@@ -769,7 +790,8 @@ int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
769790
/*
770791
* Return value: 1 - success, TCPSOCKET_INTERRUPTED - try again, anything else is failure
771792
*/
772-
int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, int verify, int (*cb)(const char *str, size_t len, void *u), void* u)
793+
int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, size_t hostname_len,
794+
int verify, int (*cb)(const char *str, size_t len, void *u), void* u)
773795
{
774796
int rc = 0;
775797

@@ -790,12 +812,10 @@ int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, int verify, i
790812
else if (verify)
791813
{
792814
char* peername = NULL;
793-
int port;
794-
size_t hostname_len;
795815

796816
X509* cert = SSL_get_peer_certificate(ssl);
797-
hostname_len = MQTTProtocol_addressPort(hostname, &port, NULL, MQTT_DEFAULT_PORT);
798817

818+
Log(TRACE_PROTOCOL, -1, "X509_check_host for hostname %s", hostname);
799819
rc = X509_check_host(cert, hostname, hostname_len, 0, &peername);
800820
if (rc == 1)
801821
Log(TRACE_PROTOCOL, -1, "peername from X509_check_host is %s", peername);

src/SSLSocket.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,22 @@
3636
/** if we should handle openssl initialization (bool_value == 1) or depend on it to be initalized externally (bool_value == 0) */
3737
void SSLSocket_handleOpensslInit(int bool_value);
3838

39+
/** Get the host name to use for verification either from the URI or options */
40+
const char* SSLSocket_getHostName(const char* serverURI, MQTTClient_SSLOptions* opts,
41+
size_t* hostname_len);
42+
3943
int SSLSocket_initialize(void);
4044
void SSLSocket_terminate(void);
41-
int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts, const char* hostname, size_t hostname_len);
45+
int SSLSocket_setSocketForSSL(networkHandles* net, MQTTClient_SSLOptions* opts,
46+
const char* hostname, size_t hostname_len);
4247

4348
int SSLSocket_getch(SSL* ssl, SOCKET socket, char* c);
4449
char *SSLSocket_getdata(SSL* ssl, SOCKET socket, size_t bytes, size_t* actual_len, int* rc);
4550

4651
int SSLSocket_close(networkHandles* net);
4752
int SSLSocket_putdatas(SSL* ssl, SOCKET socket, char* buf0, size_t buf0len, PacketBuffers bufs);
48-
int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, int verify, int (*cb)(const char *str, size_t len, void *u), void* u);
53+
int SSLSocket_connect(SSL* ssl, SOCKET sock, const char* hostname, size_t hostname_len,
54+
int verify, int (*cb)(const char *str, size_t len, void *u), void* u);
4955

5056
SOCKET SSLSocket_getPendingRead(void);
5157
int SSLSocket_continueWrite(pending_writes* pw);

0 commit comments

Comments
 (0)