Skip to content

Commit

Permalink
feat: add support for connection cookie if db is 23c or higher
Browse files Browse the repository at this point in the history
  • Loading branch information
lovetodream committed Sep 12, 2023
1 parent bce8742 commit 5f2230b
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct AuthenticationStateMachine {
}

enum Action {
case sendAuthenticationPhaseOne(AuthContext)
case sendAuthenticationPhaseOne(AuthContext, ConnectionCookie?)
case sendAuthenticationPhaseTwo(
AuthContext,
OracleBackendMessage.Parameter
Expand All @@ -24,10 +24,12 @@ struct AuthenticationStateMachine {
}

let authContext: AuthContext
let cookie: ConnectionCookie?
var state: State

init(authContext: AuthContext) {
init(authContext: AuthContext, cookie: ConnectionCookie?) {
self.authContext = authContext
self.cookie = cookie
self.state = .initialized
}

Expand All @@ -36,7 +38,7 @@ struct AuthenticationStateMachine {
preconditionFailure("Unexpected state")
}
self.state = .authenticationPhaseOneSent
return .sendAuthenticationPhaseOne(self.authContext)
return .sendAuthenticationPhaseOne(self.authContext, self.cookie)
}

mutating func parameterReceived(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ struct ConnectionStateMachine {
case sendDataTypes

// Authentication Actions
case provideAuthenticationContext
case sendAuthenticationPhaseOne(AuthContext)
case provideAuthenticationContext(ConnectionCookie?)
case sendAuthenticationPhaseOne(AuthContext, ConnectionCookie?)
case sendAuthenticationPhaseTwo(
AuthContext, OracleBackendMessage.Parameter
)
Expand Down Expand Up @@ -126,9 +126,9 @@ struct ConnectionStateMachine {
}

mutating func provideAuthenticationContext(
_ authContext: AuthContext
_ authContext: AuthContext, cookie: ConnectionCookie?
) -> ConnectionAction {
self.startAuthentication(authContext)
self.startAuthentication(authContext, cookie: cookie)
}

mutating func close(
Expand Down Expand Up @@ -214,15 +214,15 @@ struct ConnectionStateMachine {
case .notQuiescing:
switch self.state {
case .initialized,
.connectMessageSent,
.protocolMessageSent,
.dataTypesMessageSent,
.waitingToStartAuthentication,
.authenticating,
.extendedQuery,
.ping,
.commit,
.rollback:
.connectMessageSent,
.protocolMessageSent,
.dataTypesMessageSent,
.waitingToStartAuthentication,
.authenticating,
.extendedQuery,
.ping,
.commit,
.rollback:
self.taskQueue.append(task)
return .wait

Expand Down Expand Up @@ -313,10 +313,35 @@ struct ConnectionStateMachine {
}
}

mutating func acceptReceived() -> ConnectionAction {
mutating func acceptReceived(
_ accept: OracleBackendMessage.Accept, description: Description
) -> ConnectionAction {
guard case .connectMessageSent = state else {
preconditionFailure()
}

let capabilities = accept.newCapabilities

if capabilities.protocolVersion < Constants.TNS_VERSION_MIN_ACCEPTED {
return self.errorHappened(.serverVersionNotSupported)
}

if
capabilities.supportsOOB && capabilities.protocolVersion >=
Constants.TNS_VERSION_MIN_OOB_CHECK
{
// TODO: Perform OOB Check
// send OUT_OF_BAND + reset marker message through socket
}

// Starting in 23c, a cookie can be sent along with the protocol, data
// types and authorization messages without waiting for the server to
// respond to each of the messages in turn
if let dbUUID = accept.dbCookieUUID, let cookie = ConnectionCookieManager.shared.get(by: dbUUID, description: description) {
self.state = .waitingToStartAuthentication
return .provideAuthenticationContext(cookie)
}

self.state = .protocolMessageSent
return .sendProtocol
}
Expand Down Expand Up @@ -375,18 +400,18 @@ struct ConnectionStateMachine {
preconditionFailure()
}
self.state = .waitingToStartAuthentication
return .provideAuthenticationContext
return .provideAuthenticationContext(nil)
}

mutating func parameterReceived(
parameters: OracleBackendMessage.Parameter
) -> ConnectionAction {
switch self.state {
case .initialized,
.connectMessageSent,
.protocolMessageSent,
.dataTypesMessageSent,
.waitingToStartAuthentication:
.connectMessageSent,
.protocolMessageSent,
.dataTypesMessageSent,
.waitingToStartAuthentication:
preconditionFailure()

case .authenticating(var authState):
Expand All @@ -407,11 +432,11 @@ struct ConnectionStateMachine {
mutating func markerReceived() -> ConnectionAction {
switch self.state {
case .initialized,
.waitingToStartAuthentication,
.readyForQuery,
.closed:
.waitingToStartAuthentication,
.readyForQuery,
.closed:
preconditionFailure()
case .connectMessageSent,
case .connectMessageSent,
.protocolMessageSent,
.dataTypesMessageSent,
.authenticating,
Expand Down Expand Up @@ -645,15 +670,19 @@ struct ConnectionStateMachine {

// MARK: - Private Methods -

private mutating func startAuthentication(_ authContext: AuthContext) -> ConnectionAction {
private mutating func startAuthentication(
_ authContext: AuthContext, cookie: ConnectionCookie?
) -> ConnectionAction {
guard case .waitingToStartAuthentication = state else {
preconditionFailure(
"Can only start authentication after connection is established"
)
}

return self.avoidingStateMachineCoW { machine in
var authState = AuthenticationStateMachine(authContext: authContext)
var authState = AuthenticationStateMachine(
authContext: authContext, cookie: cookie
)
let action = authState.start()
machine.state = .authenticating(authState)
return machine.modify(with: action)
Expand Down Expand Up @@ -813,6 +842,7 @@ extension ConnectionStateMachine {
case .connectionError,
.messageDecodingFailure,
.unexpectedBackendMessage,
.serverVersionNotSupported,
.uncleanShutdown:
return true
case .queryCancelled, .nationalCharsetNotSupported:
Expand Down Expand Up @@ -911,8 +941,8 @@ extension ConnectionStateMachine {
with action: AuthenticationStateMachine.Action
) -> ConnectionAction {
switch action {
case .sendAuthenticationPhaseOne(let authContext):
return .sendAuthenticationPhaseOne(authContext)
case .sendAuthenticationPhaseOne(let authContext, let cookie):
return .sendAuthenticationPhaseOne(authContext, cookie)
case .sendAuthenticationPhaseTwo(let authContext, let parameters):
return .sendAuthenticationPhaseTwo(authContext, parameters)
case .wait:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ extension OracleConnection {

/// The service name of the database.
var serviceName: String
/// The system identifier (SID) of the database.
///
/// - Note: Using a ``serviceName`` instead is recommended by Oracle.
var sid: String?

/// Authorization mode to use.
var mode: AuthenticationMode = .default
Expand Down Expand Up @@ -181,7 +185,7 @@ extension OracleConnection {
value: defaultUsername()
)
private static func defaultUsername() -> String {
#if os(iOS) || os(tvOS)
#if os(iOS) || os(tvOS) || os(watchOS)
return "unknown"
#else
return ProcessInfo.processInfo.userName
Expand Down
33 changes: 33 additions & 0 deletions Sources/OracleNIO/ConnectionCookieManager.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import NIOCore
import struct Foundation.UUID

struct ConnectionCookieManager: Sendable {
static var shared = ConnectionCookieManager()

private var store: [String: ConnectionCookie] = [:]

private init() { }

func get(
by uuid: UUID, description: Description
) -> ConnectionCookie? {
let key = uuid.uuidString +
(description.serviceName ?? description.sid ?? "")
return store[key]
}

func set() {
// TOOD
}
}

struct ConnectionCookie: Sendable {
var protocolVersion: UInt8
var serverBanner: ByteBuffer
var charsetID: UInt16
var nationalCharsetID: UInt16
var flags: UInt8
var compileCapabilities: ByteBuffer
var runtimeCapabilities: ByteBuffer
var populated: Bool
}
1 change: 1 addition & 0 deletions Sources/OracleNIO/Constants.swift
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ enum Constants {
static let TNS_VERSION_MIN_ACCEPTED = 315 // 12.1
static let TNS_VERSION_MIN_LARGE_SDU = 315
static let TNS_VERSION_MIN_OOB_CHECK = 318
static let TNS_VERSION_MIN_UUID = 319

// MARK: Control packet types
static let TNS_CONTROL_TYPE_INBAND_NOTIFICATION = 8
Expand Down
6 changes: 6 additions & 0 deletions Sources/OracleNIO/MessageType.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ enum MessageType: UInt8, CustomStringConvertible {
case serverSidePiggyback = 23
case onewayFN = 26
case implicitResultset = 27
case renegotiate = 28
case cookie = 30

var description: String {
switch self {
Expand Down Expand Up @@ -57,6 +59,10 @@ enum MessageType: UInt8, CustomStringConvertible {
return "ONEWAY_FN"
case .implicitResultset:
return "IMPLICIT_RESULTSET"
case .renegotiate:
return "RENEGOTIATE"
case .cookie:
return "COOKIE"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,25 +110,57 @@ struct OracleFrontendMessageEncoder {
self.endRequest(packetType: isConnectStringToLong ? .data : .connect)
}

mutating func `protocol`() {
mutating func cookie(_ cookie: ConnectionCookie, authContext: AuthContext) {
self.clearIfNeeded()

self.startRequest()

self.buffer.writeInteger(MessageType.protocol.rawValue)
buffer.writeInteger(UInt8(6)) // protocol version (8.1 and higher)
buffer.writeInteger(UInt8(0)) // `array` terminator
buffer.writeString(Constants.DRIVER_NAME)
buffer.writeInteger(UInt8(0)) // `NULL` terminator
self.protocol0()

self.buffer.writeMultipleIntegers(
MessageType.cookie.rawValue,
UInt8(1), // cookie version
cookie.protocolVersion
)
self.buffer.writeInteger(cookie.charsetID, endianness: .little)
self.buffer.writeInteger(cookie.flags)
self.buffer.writeInteger(cookie.nationalCharsetID, endianness: .little)
cookie.serverBanner._encodeRaw(into: &self.buffer, context: .default)
cookie.compileCapabilities
._encodeRaw(into: &self.buffer, context: .default)
cookie.runtimeCapabilities
._encodeRaw(into: &self.buffer, context: .default)
self.dataTypes0()
self.authenticationPhaseOne0(authContext: authContext)

self.endRequest()
}

mutating func `protocol`() {
self.clearIfNeeded()

self.startRequest()
self.protocol0()
self.endRequest()
}

private mutating func protocol0() {
self.buffer.writeInteger(MessageType.protocol.rawValue)
self.buffer.writeInteger(UInt8(6)) // protocol version (8.1 and higher)
self.buffer.writeInteger(UInt8(0)) // `array` terminator
self.buffer.writeString(Constants.DRIVER_NAME)
self.buffer.writeInteger(UInt8(0)) // `NULL` terminator
}

mutating func dataTypes() {
self.clearIfNeeded()

self.startRequest()
self.dataTypes0()
self.endRequest()
}

private mutating func dataTypes0() {
self.buffer.writeInteger(MessageType.dataTypes.rawValue, as: UInt8.self)
self.buffer.writeInteger(Constants.TNS_CHARSET_UTF8, endianness: .little)
self.buffer.writeInteger(Constants.TNS_CHARSET_UTF8, endianness: .little)
Expand All @@ -150,13 +182,17 @@ struct OracleFrontendMessageEncoder {
}

self.buffer.writeInteger(UInt16(0))

self.endRequest()
}

mutating func authenticationPhaseOne(authContext: AuthContext) {
self.clearIfNeeded()

self.startRequest()
self.authenticationPhaseOne0(authContext: authContext)
self.endRequest()
}

mutating func authenticationPhaseOne0(authContext: AuthContext) {
// 1. Setup

let newPassword = authContext.newPassword
Expand All @@ -176,8 +212,6 @@ struct OracleFrontendMessageEncoder {

let numberOfPairs: UInt32 = 5

self.startRequest()

self.writeBasicAuthData(
authContext: authContext,
authPhase: .authPhaseOne,
Expand Down Expand Up @@ -211,8 +245,6 @@ struct OracleFrontendMessageEncoder {
value: authContext.username,
out: &self.buffer
)

self.endRequest()
}

mutating func authenticationPhaseTwo(
Expand Down
Loading

0 comments on commit 5f2230b

Please sign in to comment.