Skip to content

Commit 0483409

Browse files
committed
Implement sending plain data over tls socket before handshake.
1 parent 5f34973 commit 0483409

File tree

5 files changed

+43
-30
lines changed

5 files changed

+43
-30
lines changed

include/restc-cpp/url_encode.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
#pragma once
22

3-
#ifndef RESTC_CPP_URL_ENCODE_H_
4-
#define RESTC_CPP_URL_ENCODE_H_
5-
63
#include "restc-cpp.h"
74

85
#include <boost/utility/string_view.hpp>
@@ -13,4 +10,3 @@ std::string url_encode(const boost::string_view& src);
1310

1411
} // namespace
1512

16-
#endif // RESTC_CPP_URL_ENCODE_H_

src/ConnectionPoolImpl.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,9 @@ class ConnectionPoolImpl
365365
}
366366
else {
367367
#ifdef RESTC_CPP_WITH_TLS
368-
socket = make_unique<TlsSocketImpl>(owner_.GetIoService(), owner_.GetTLSContext());
368+
socket = make_unique<TlsSocketImpl>(owner_.GetIoService(), owner_.GetTLSContext(),
369+
/*can_send_over_unupgraded_socket to send plain data over tls socket before handshake*/
370+
properties_->proxy.type == Request::Proxy::Type::HTTPS);
369371
#else
370372
throw NotImplementedException(
371373
"restc_cpp is compiled without TLS support");

src/RequestImpl.cpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,10 @@ void DoProxyConnect(const Connection::ptr_t& connection,
327327

328328
//sets status_code, reason_phrase in response
329329
stream->ReadServerResponse(proxy_response);
330+
stream->ReadHeaderLines(
331+
[](std::string&& name, std::string&& value) {
332+
RESTC_CPP_LOG_TRACE_("Read proxy header: " << name);
333+
});
330334

331335
} catch (const exception& ex) {
332336
RESTC_CPP_LOG_DEBUG_("DoProxyConnect: exception from ReceivingFromProxy: " << ex.what());
@@ -614,8 +618,8 @@ class RequestImpl : public Request {
614618
return request_buffer.str();
615619
}
616620

617-
//returns {protocol_type, host, service} instead of deprecated ip::resolver::query
618-
tuple<Connection::Type, string, string> GetRequestEndpoint() {
621+
//returns {host, service} instead of deprecated ip::resolver::query
622+
tuple<string, string> GetRequestEndpoint() {
619623
const auto proxy_type = properties_->proxy.type;
620624

621625
if (proxy_type == Request::Proxy::Type::SOCKS5) {
@@ -627,11 +631,7 @@ class RequestImpl : public Request {
627631
<< " Proxy at: "
628632
<< host << ':' << port);
629633

630-
// what connection type should we use for SOCKS tunnel?
631-
return { (parsed_url_.GetProtocol() == Url::Protocol::HTTPS)
632-
? Connection::Type::HTTPS
633-
: Connection::Type::HTTP,
634-
host, to_string(port) };
634+
return { host, to_string(port) };
635635
}
636636

637637
if ( (proxy_type == Request::Proxy::Type::HTTP &&
@@ -645,17 +645,11 @@ class RequestImpl : public Request {
645645
<< " Proxy at: "
646646
<< proxy.GetHost() << ':' << proxy.GetPort());
647647

648-
return { (proxy.GetProtocol() == Url::Protocol::HTTPS)
649-
? Connection::Type::HTTPS
650-
: Connection::Type::HTTP,
651-
proxy.GetHost().to_string(),
648+
return { proxy.GetHost().to_string(),
652649
proxy.GetPort().to_string() };
653650
}
654651

655-
return { (parsed_url_.GetProtocol() == Url::Protocol::HTTPS)
656-
? Connection::Type::HTTPS
657-
: Connection::Type::HTTP,
658-
parsed_url_.GetHost().to_string(),
652+
return { parsed_url_.GetHost().to_string(),
659653
parsed_url_.GetPort().to_string() };
660654
}
661655

@@ -756,15 +750,21 @@ class RequestImpl : public Request {
756750

757751
auto prot_filter = GetBindProtocols(properties_->bindToLocalAddress, ctx);
758752

753+
const Connection::Type protocol_type =
754+
(parsed_url_.GetProtocol() == Url::Protocol::HTTPS ||
755+
properties_->proxy.type == Request::Proxy::Type::HTTPS)
756+
? Connection::Type::HTTPS
757+
: Connection::Type::HTTP;
758+
759759
boost::asio::ip::tcp::resolver resolver(owner_.GetIoService());
760760
// Resolve the hostname
761-
const auto ep_tuple = GetRequestEndpoint(); //{protocol_type, host, service=port}
761+
const auto ep_tuple = GetRequestEndpoint(); //{host, service=port}
762762

763-
RESTC_CPP_LOG_TRACE_("Resolving " << get<1>(ep_tuple) << ":"
764-
<< get<2>(ep_tuple));
763+
RESTC_CPP_LOG_TRACE_("Resolving " << get<0>(ep_tuple) << ":"
764+
<< get<1>(ep_tuple));
765765

766-
auto address_it = resolver.async_resolve(/*host*/ get<1>(ep_tuple),
767-
/*port*/ get<2>(ep_tuple),
766+
auto address_it = resolver.async_resolve(/*host*/ get<0>(ep_tuple),
767+
/*port*/ get<1>(ep_tuple),
768768
ctx.GetYield());
769769
const decltype(address_it) addr_end;
770770

@@ -781,7 +781,7 @@ class RequestImpl : public Request {
781781
for(size_t retries = 0; retries < 8; ++retries) {
782782
// Get a connection from the pool
783783
auto connection = owner_.GetConnectionPool()->GetConnection(
784-
endpoint, /*protocol_type*/ get<0>(ep_tuple));
784+
endpoint, protocol_type);
785785

786786
// Connect if the connection is new.
787787
if (connection->GetSocket().IsOpen()) {

src/SocketImpl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ class SocketImpl : public Socket, protected ExceptionWrapper {
5656
}
5757

5858
void AsyncConnect(const boost::asio::ip::tcp::endpoint& ep,
59-
const std::string &host,
59+
const std::string &host,
6060
bool tcpNodelay,
6161
boost::asio::yield_context& yield) override {
6262
return WrapException<void>([&] {

src/TlsSocketImpl.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@ class TlsSocketImpl : public Socket, protected ExceptionWrapper {
2424

2525
using ssl_socket_t = boost::asio::ssl::stream<boost::asio::ip::tcp::socket>;
2626

27-
TlsSocketImpl(boost::asio::io_service& io_service, shared_ptr<boost::asio::ssl::context> ctx)
27+
TlsSocketImpl(boost::asio::io_service& io_service,
28+
shared_ptr<boost::asio::ssl::context> ctx,
29+
bool can_send_over_unupgraded_socket = false)
2830
{
2931
ssl_socket_ = std::make_unique<ssl_socket_t>(io_service, *ctx);
32+
can_send_over_unupgraded_socket_ = can_send_over_unupgraded_socket;
3033
}
3134

3235
boost::asio::ip::tcp::socket& GetSocket() override {
@@ -42,26 +45,36 @@ class TlsSocketImpl : public Socket, protected ExceptionWrapper {
4245
std::size_t AsyncReadSome(boost::asio::mutable_buffers_1 buffers,
4346
boost::asio::yield_context& yield) override {
4447
return WrapException<std::size_t>([&] {
48+
if (can_send_over_unupgraded_socket_)
49+
return ssl_socket_->next_layer().async_read_some(buffers, yield);
4550
return ssl_socket_->async_read_some(buffers, yield);
4651
});
4752
}
4853

4954
std::size_t AsyncRead(boost::asio::mutable_buffers_1 buffers,
5055
boost::asio::yield_context& yield) override {
5156
return WrapException<std::size_t>([&] {
57+
if (can_send_over_unupgraded_socket_)
58+
return boost::asio::async_read(ssl_socket_->next_layer(), buffers, yield);
5259
return boost::asio::async_read(*ssl_socket_, buffers, yield);
5360
});
5461
}
5562

5663
void AsyncWrite(const boost::asio::const_buffers_1& buffers,
5764
boost::asio::yield_context& yield) override {
58-
boost::asio::async_write(*ssl_socket_, buffers, yield);
65+
if (can_send_over_unupgraded_socket_)
66+
boost::asio::async_write(ssl_socket_->next_layer(), buffers, yield);
67+
else
68+
boost::asio::async_write(*ssl_socket_, buffers, yield);
5969
}
6070

6171
void AsyncWrite(const write_buffers_t& buffers,
6272
boost::asio::yield_context& yield) override {
6373
return WrapException<void>([&] {
64-
boost::asio::async_write(*ssl_socket_, buffers, yield);
74+
if (can_send_over_unupgraded_socket_)
75+
boost::asio::async_write(ssl_socket_->next_layer(), buffers, yield);
76+
else
77+
boost::asio::async_write(*ssl_socket_, buffers, yield);
6578
});
6679
}
6780

@@ -90,6 +103,7 @@ class TlsSocketImpl : public Socket, protected ExceptionWrapper {
90103
RESTC_CPP_LOG_TRACE_("AsyncConnect - Calling async_handshake");
91104
ssl_socket_->async_handshake(boost::asio::ssl::stream_base::client,
92105
yield);
106+
can_send_over_unupgraded_socket_ = false;
93107

94108
RESTC_CPP_LOG_TRACE_("AsyncConnect - Done");
95109
});
@@ -135,6 +149,7 @@ class TlsSocketImpl : public Socket, protected ExceptionWrapper {
135149

136150
private:
137151
std::unique_ptr<ssl_socket_t> ssl_socket_;
152+
bool can_send_over_unupgraded_socket_;
138153
};
139154

140155
} // restc_cpp

0 commit comments

Comments
 (0)