Skip to content

Commit fa3137d

Browse files
authored
Support additional connection parameters (#361)
1 parent e60e495 commit fa3137d

File tree

9 files changed

+117
-25
lines changed

9 files changed

+117
-25
lines changed

Sources/PostgresNIO/Connection/PostgresConnection+Configuration.swift

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,19 @@ extension PostgresConnection {
8585
/// This property is provided for compatibility with Amazon RDS Proxy, which requires it to be `false`.
8686
/// If you are not using Amazon RDS Proxy, you should leave this set to `true` (the default).
8787
public var requireBackendKeyData: Bool
88-
88+
89+
/// Additional parameters to send to the server on startup. The name value pairs are added to the initial
90+
/// startup message that the client sends to the server.
91+
public var additionalStartupParameters: [(String, String)]
92+
8993
/// Create an options structure with default values.
9094
///
9195
/// Most users should not need to adjust the defaults.
9296
public init() {
9397
self.connectTimeout = .seconds(10)
9498
self.tlsServerName = nil
9599
self.requireBackendKeyData = true
100+
self.additionalStartupParameters = []
96101
}
97102
}
98103

Sources/PostgresNIO/New/Connection State Machine/ConnectionStateMachine.swift

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,11 +1113,19 @@ struct SendPrepareStatement {
11131113
let query: String
11141114
}
11151115

1116-
struct AuthContext: Equatable, CustomDebugStringConvertible {
1117-
let username: String
1118-
let password: String?
1119-
let database: String?
1120-
1116+
struct AuthContext: CustomDebugStringConvertible {
1117+
var username: String
1118+
var password: String?
1119+
var database: String?
1120+
var additionalParameters: [(String, String)]
1121+
1122+
init(username: String, password: String? = nil, database: String? = nil, additionalParameters: [(String, String)] = []) {
1123+
self.username = username
1124+
self.password = password
1125+
self.database = database
1126+
self.additionalParameters = additionalParameters
1127+
}
1128+
11211129
var debugDescription: String {
11221130
"""
11231131
AuthContext(username: \(String(reflecting: self.username)), \
@@ -1127,6 +1135,22 @@ struct AuthContext: Equatable, CustomDebugStringConvertible {
11271135
}
11281136
}
11291137

1138+
extension AuthContext: Equatable {
1139+
static func ==(lhs: Self, rhs: Self) -> Bool {
1140+
guard lhs.username == rhs.username
1141+
&& lhs.password == rhs.password
1142+
&& lhs.database == rhs.database
1143+
&& lhs.additionalParameters.count == rhs.additionalParameters.count
1144+
else {
1145+
return false
1146+
}
1147+
1148+
return lhs.additionalParameters.elementsEqual(rhs.additionalParameters) { lhs, rhs in
1149+
lhs.0 == rhs.0 && lhs.1 == rhs.1
1150+
}
1151+
}
1152+
}
1153+
11301154
enum PasswordAuthencationMode: Equatable {
11311155
case cleartext
11321156
case md5(salt: UInt32)

Sources/PostgresNIO/New/PostgresChannelHandler.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
328328
case .wait:
329329
break
330330
case .sendStartupMessage(let authContext):
331-
self.encoder.startup(user: authContext.username, database: authContext.database)
331+
self.encoder.startup(user: authContext.username, database: authContext.database, options: authContext.additionalParameters)
332332
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil)
333333
case .sendSSLRequest:
334334
self.encoder.ssl()

Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ struct PostgresFrontendMessageEncoder {
2525
self.buffer = buffer
2626
}
2727

28-
mutating func startup(user: String, database: String?) {
28+
mutating func startup(user: String, database: String?, options: [(String, String)]) {
2929
self.clearIfNeeded()
3030
self.buffer.psqlLengthPrefixed { buffer in
3131
buffer.writeInteger(Self.startupVersionThree)
@@ -37,6 +37,13 @@ struct PostgresFrontendMessageEncoder {
3737
buffer.writeNullTerminatedString(database)
3838
}
3939

40+
// we don't send replication parameters, as the default is false and this is what we
41+
// need for a client
42+
for (key, value) in options {
43+
buffer.writeNullTerminatedString(key)
44+
buffer.writeNullTerminatedString(value)
45+
}
46+
4047
buffer.writeInteger(UInt8(0))
4148
}
4249
}

Tests/PostgresNIOTests/New/Extensions/PSQLFrontendMessageDecoder.swift

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder {
3939
case 196608:
4040
var user: String?
4141
var database: String?
42-
var options: String?
43-
42+
var options = [(String, String)]()
43+
4444
while let name = messageSlice.readNullTerminatedString(), messageSlice.readerIndex < finalIndex {
4545
let value = messageSlice.readNullTerminatedString()
4646

@@ -51,11 +51,10 @@ struct PSQLFrontendMessageDecoder: NIOSingleStepByteToMessageDecoder {
5151
case "database":
5252
database = value
5353

54-
case "options":
55-
options = value
56-
5754
default:
58-
break
55+
if let value = value {
56+
options.append((name, value))
57+
}
5958
}
6059
}
6160

Tests/PostgresNIOTests/New/Extensions/PostgresFrontendMessage.swift

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ enum PostgresFrontendMessage: Equatable {
103103
static let requestCode: Int32 = 80877103
104104
}
105105

106-
struct Startup: Hashable {
106+
struct Startup: Equatable {
107107
static let versionThree: Int32 = 0x00_03_00_00
108108

109109
/// Creates a `Startup` with "3.0" as the protocol version.
@@ -119,7 +119,7 @@ enum PostgresFrontendMessage: Equatable {
119119
/// The protocol version number is followed by one or more pairs of parameter
120120
/// name and value strings. A zero byte is required as a terminator after
121121
/// the last name/value pair. `user` is required, others are optional.
122-
struct Parameters: Hashable {
122+
struct Parameters: Equatable {
123123
enum Replication {
124124
case `true`
125125
case `false`
@@ -136,12 +136,33 @@ enum PostgresFrontendMessage: Equatable {
136136
/// of setting individual run-time parameters.) Spaces within this string are
137137
/// considered to separate arguments, unless escaped with a
138138
/// backslash (\); write \\ to represent a literal backslash.
139-
var options: String?
139+
var options: [(String, String)]
140140

141141
/// Used to connect in streaming replication mode, where a small set of
142142
/// replication commands can be issued instead of SQL statements. Value
143143
/// can be true, false, or database, and the default is false.
144144
var replication: Replication
145+
146+
static func ==(lhs: Self, rhs: Self) -> Bool {
147+
guard lhs.user == rhs.user
148+
&& lhs.database == rhs.database
149+
&& lhs.replication == rhs.replication
150+
&& lhs.options.count == rhs.options.count
151+
else {
152+
return false
153+
}
154+
155+
var lhsIterator = lhs.options.makeIterator()
156+
var rhsIterator = rhs.options.makeIterator()
157+
158+
while let lhsNext = lhsIterator.next(), let rhsNext = rhsIterator.next() {
159+
guard lhsNext.0 == rhsNext.0 && lhsNext.1 == rhsNext.1 else {
160+
return false
161+
}
162+
}
163+
return true
164+
}
165+
145166
}
146167

147168
var parameters: Parameters

Tests/PostgresNIOTests/New/Messages/StartupTests.swift

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class StartupTests: XCTestCase {
1111
let user = "test"
1212
let database = "abc123"
1313

14-
encoder.startup(user: user, database: database)
14+
encoder.startup(user: user, database: database, options: [])
1515
byteBuffer = encoder.flushBuffer()
1616

1717
let byteBufferLength = Int32(byteBuffer.readableBytes)
@@ -32,7 +32,7 @@ class StartupTests: XCTestCase {
3232

3333
let user = "test"
3434

35-
encoder.startup(user: user, database: nil)
35+
encoder.startup(user: user, database: nil, options: [])
3636
byteBuffer = encoder.flushBuffer()
3737

3838
let byteBufferLength = Int32(byteBuffer.readableBytes)
@@ -44,4 +44,41 @@ class StartupTests: XCTestCase {
4444

4545
XCTAssertEqual(byteBuffer.readableBytes, 0)
4646
}
47+
48+
func testStartupMessageWithAdditionalOptions() {
49+
var encoder = PostgresFrontendMessageEncoder(buffer: .init())
50+
var byteBuffer = ByteBuffer()
51+
52+
let user = "test"
53+
let database = "abc123"
54+
55+
encoder.startup(user: user, database: database, options: [("some", "options")])
56+
byteBuffer = encoder.flushBuffer()
57+
58+
let byteBufferLength = Int32(byteBuffer.readableBytes)
59+
XCTAssertEqual(byteBufferLength, byteBuffer.readInteger())
60+
XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger())
61+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user")
62+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test")
63+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database")
64+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123")
65+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some")
66+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options")
67+
XCTAssertEqual(byteBuffer.readInteger(), UInt8(0))
68+
69+
XCTAssertEqual(byteBuffer.readableBytes, 0)
70+
}
71+
}
72+
73+
extension PostgresFrontendMessage.Startup.Parameters.Replication {
74+
var stringValue: String {
75+
switch self {
76+
case .true:
77+
return "true"
78+
case .false:
79+
return "false"
80+
case .database:
81+
return "replication"
82+
}
83+
}
4784
}

Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@ class PostgresChannelHandlerTests: XCTestCase {
3737

3838
XCTAssertEqual(startup.parameters.user, config.username)
3939
XCTAssertEqual(startup.parameters.database, config.database)
40-
XCTAssertEqual(startup.parameters.options, nil)
41-
XCTAssertEqual(startup.parameters.replication, .false)
42-
40+
XCTAssert(startup.parameters.options.isEmpty)
41+
4342
XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.authentication(.ok)))
4443
XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678))))
4544
XCTAssertNoThrow(try embedded.writeInbound(PostgresBackendMessage.readyForQuery(.idle)))
@@ -209,7 +208,7 @@ class PostgresChannelHandlerTests: XCTestCase {
209208

210209
XCTAssertEqual(startup.parameters.user, config.username)
211210
XCTAssertEqual(startup.parameters.database, config.database)
212-
XCTAssertEqual(startup.parameters.options, nil)
211+
XCTAssert(startup.parameters.options.isEmpty)
213212
XCTAssertEqual(startup.parameters.replication, .false)
214213

215214
var buffer = ByteBuffer()
@@ -282,7 +281,7 @@ extension AuthContext {
282281
PostgresFrontendMessage.Startup.Parameters(
283282
user: self.username,
284283
database: self.database,
285-
options: nil,
284+
options: self.additionalParameters,
286285
replication: .false
287286
)
288287
}

Tests/PostgresNIOTests/New/PostgresConnectionTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ class PostgresConnectionTests: XCTestCase {
602602

603603
async let connectionPromise = PostgresConnection.connect(on: eventLoop, configuration: configuration, id: 1, logger: self.logger)
604604
let message = try await channel.waitForOutboundWrite(as: PostgresFrontendMessage.self)
605-
XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", replication: .false))))
605+
XCTAssertEqual(message, .startup(.versionThree(parameters: .init(user: "username", database: "database", options: [], replication: .false))))
606606
try await channel.writeInbound(PostgresBackendMessage.authentication(.ok))
607607
try await channel.writeInbound(PostgresBackendMessage.backendKeyData(.init(processID: 1234, secretKey: 5678)))
608608
try await channel.writeInbound(PostgresBackendMessage.readyForQuery(.idle))

0 commit comments

Comments
 (0)