diff --git a/.gitignore b/.gitignore index 8e242c10d..a9285b01e 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ dist /.eslintcache .vscode/ manually-test-on-heroku.js +.history diff --git a/packages/pg/lib/client.js b/packages/pg/lib/client.js index 9200dded6..b3db96a66 100644 --- a/packages/pg/lib/client.js +++ b/packages/pg/lib/client.js @@ -81,6 +81,9 @@ class Client extends EventEmitter { keepAlive: c.keepAlive || false, keepAliveInitialDelayMillis: c.keepAliveInitialDelayMillis || 0, encoding: this.connectionParameters.client_encoding || 'utf8', + targetSessionAttrs: c.targetSessionAttrs || this.connectionParameters.targetSessionAttrs || null, + trustParameterStatus: c.trustParameterStatus || false, + Promise: this._Promise, }) this._queryQueue = [] this.binary = c.binary || defaults.binary @@ -155,7 +158,7 @@ class Client extends EventEmitter { } } - if (this.host && this.host.indexOf('/') === 0) { + if (!Array.isArray(this.host) && this.host && this.host.indexOf('/') === 0) { con.connect(this.host + '/.s.PGSQL.' + this.port) } else { con.connect(this.port, this.host) @@ -542,7 +545,7 @@ class Client extends EventEmitter { if (client.activeQuery === query) { const con = this.connection - if (this.host && this.host.indexOf('/') === 0) { + if (!Array.isArray(this.host) && this.host && this.host.indexOf('/') === 0) { con.connect(this.host + '/.s.PGSQL.' + this.port) } else { con.connect(this.port, this.host) diff --git a/packages/pg/lib/connection-parameters.js b/packages/pg/lib/connection-parameters.js index c153932bb..50bdf5547 100644 --- a/packages/pg/lib/connection-parameters.js +++ b/packages/pg/lib/connection-parameters.js @@ -67,9 +67,16 @@ class ConnectionParameters { this.database = this.user } - this.port = parseInt(val('port', config), 10) + const rawPort = val('port', config) + this.port = Array.isArray(rawPort) ? rawPort.map((p) => parseInt(p, 10)) : parseInt(rawPort, 10) this.host = val('host', config) + const hosts = Array.isArray(this.host) ? this.host : [this.host] + const ports = Array.isArray(this.port) ? this.port : [this.port] + if (ports.length !== 1 && ports.length !== hosts.length) { + throw new Error(`ports must have either 1 entry or the same number of entries as hosts (${hosts.length})`) + } + // "hiding" the password so it doesn't show up in stack traces // or if the client is console.logged Object.defineProperty(this, 'password', { @@ -111,6 +118,17 @@ class ConnectionParameters { this.idle_in_transaction_session_timeout = val('idle_in_transaction_session_timeout', config, false) this.query_timeout = val('query_timeout', config, false) + this.targetSessionAttrs = val('targetSessionAttrs', config) + + const validTargetSessionAttrs = ['any', 'read-write', 'read-only', 'primary', 'standby', 'prefer-standby'] + if (this.targetSessionAttrs && !validTargetSessionAttrs.includes(this.targetSessionAttrs)) { + throw new Error( + `invalid targetSessionAttrs value: "${this.targetSessionAttrs}". Must be one of: ${validTargetSessionAttrs.join( + ', ' + )}` + ) + } + if (config.connectionTimeoutMillis === undefined) { this.connect_timeout = process.env.PGCONNECT_TIMEOUT || 0 } else { diff --git a/packages/pg/lib/connection.js b/packages/pg/lib/connection.js index 027f93935..3e98a0c93 100644 --- a/packages/pg/lib/connection.js +++ b/packages/pg/lib/connection.js @@ -1,7 +1,7 @@ 'use strict' +const net = require('net') const EventEmitter = require('events').EventEmitter - const { parse, serialize } = require('pg-protocol') const { getStream, getSecureStream } = require('./stream') @@ -9,15 +9,28 @@ const flushBuffer = serialize.flush() const syncBuffer = serialize.sync() const endBuffer = serialize.end() -// TODO(bmc) support binary mode at some point +const PROBE_SHOW_TX_READ_ONLY = serialize.query('SHOW transaction_read_only') +const PROBE_SELECT_RECOVERY = serialize.query('SELECT pg_catalog.pg_is_in_recovery()') + +const PHASE = { + STARTUP: 'startup', + PROBE: 'probe', + DONE: 'done', +} + class Connection extends EventEmitter { constructor(config) { super() config = config || {} - this.stream = config.stream || getStream(config.ssl) - if (typeof this.stream === 'function') { - this.stream = this.stream(config) + if (typeof config.stream === 'function') { + this._streamFactory = config.stream + this._config = config + this.stream = config.stream(config) + } else { + this._streamFactory = null + this._config = null + this.stream = config.stream || getStream(config.ssl) } this._keepAlive = config.keepAlive @@ -26,6 +39,12 @@ class Connection extends EventEmitter { this.ssl = config.ssl || false this._ending = false this._emitMessage = false + this._targetSessionAttrs = config.targetSessionAttrs || null + this._trustParameterStatus = config.trustParameterStatus || false + this.host = null + this.port = null + this._streamErrorHandler = this._onStreamError.bind(this) + const self = this this.on('newListener', function (eventName) { if (eventName === 'message') { @@ -34,78 +53,257 @@ class Connection extends EventEmitter { }) } + _newStream() { + return this._streamFactory ? this._streamFactory(this._config) : getStream(this.ssl) + } + connect(port, host) { const self = this + const hosts = Array.isArray(host) ? host : [host] + const ports = Array.isArray(port) ? port : [port] + let hostIndex = 0 + this._connecting = true - this.stream.setNoDelay(true) - this.stream.connect(port, host) - this.stream.once('connect', function () { - if (self._keepAlive) { - self.stream.setKeepAlive(true, self._keepAliveInitialDelayMillis) + const targetAttrs = this._targetSessionAttrs + const shouldCheck = targetAttrs && targetAttrs !== 'any' + const probeType = shouldCheck ? getProbeType(targetAttrs) : null + + // prefer-standby: two-pass logic (libpq style) + // Pass 1: look for standby among all hosts + // Pass 2: if no standby found, accept any host + let preferStandbyPass = 1 + + const resetStream = () => { + self._send(endBuffer) + self.stream.removeAllListeners() + self.stream.destroy() + self.stream = self._newStream() + } + + const tryNextOrFail = () => { + if (hostIndex + 1 < hosts.length) { + hostIndex++ + resetStream() + attemptConnect() + } else if (targetAttrs === 'prefer-standby' && preferStandbyPass === 1) { + preferStandbyPass = 2 + hostIndex = 0 + resetStream() + attemptConnect() + } else { + self._connecting = false + self.emit('error', new Error('None of the hosts satisfy target_session_attrs="' + targetAttrs + '"')) } - self.emit('connect') - }) + } - const reportStreamError = function (error) { - // errors about disconnections should be ignored during disconnect - if (self._ending && (error.code === 'ECONNRESET' || error.code === 'EPIPE')) { - return + const attemptConnect = () => { + const currentHost = hosts[hostIndex] + const currentPort = ports.length === 1 ? ports[0] : ports[hostIndex] + let connected = false + + self.host = currentHost + self.port = currentPort + + self.stream.setNoDelay(true) + self.stream.connect(currentPort, currentHost) + + self.stream.once('connect', function () { + connected = true + if (self._keepAlive) { + self.stream.setKeepAlive(true, self._keepAliveInitialDelayMillis) + } + if (!self.ssl) { + attachProbeOrPlain() + } + self.emit('connect') + }) + + const onStreamError = function (error) { + if (self._ending && (error.code === 'ECONNRESET' || error.code === 'EPIPE')) { + return + } + if (!connected) { + if (hostIndex + 1 < hosts.length) { + hostIndex++ + resetStream() + attemptConnect() + return + } else if (targetAttrs === 'prefer-standby' && preferStandbyPass === 1) { + preferStandbyPass = 2 + hostIndex = 0 + resetStream() + attemptConnect() + return + } + } + self._connecting = false + self.emit('error', error) + } + + self.stream.on('error', onStreamError) + + const onClose = function () { + self.emit('end') + } + self.stream.on('close', onClose) + + const attachProbeOrPlain = () => { + if (shouldCheck) { + self._runSessionAttrsCheck( + probeType, + targetAttrs, + hostIndex, + hosts, + preferStandbyPass, + tryNextOrFail, + onStreamError + ) + } else { + self._releaseConnectScope(onStreamError) + self.attachListeners(self.stream) + } + } + + if (self.ssl) { + self.stream.once('data', function (buffer) { + const responseCode = buffer.toString('utf8') + switch (responseCode) { + case 'S': + break + case 'N': + self.stream.end() + return self.emit('error', new Error('The server does not support SSL connections')) + default: + self.stream.end() + return self.emit('error', new Error('There was an error establishing an SSL connection')) + } + const options = { + socket: self.stream, + } + + if (self.ssl !== true) { + Object.assign(options, self.ssl) + if ('key' in self.ssl) { + options.key = self.ssl.key + } + } + + if (net.isIP && net.isIP(currentHost) === 0) { + options.servername = currentHost + } + + const tcpStream = self.stream + tcpStream.removeListener('close', onClose) + tcpStream.removeListener('error', onStreamError) + try { + self.stream = getSecureStream(options) + } catch (err) { + return self.emit('error', err) + } + attachProbeOrPlain() + self.stream.on('error', onStreamError) + self.stream.on('close', onClose) + self.emit('sslconnect') + }) } - self.emit('error', error) } - this.stream.on('error', reportStreamError) - this.stream.on('close', function () { - self.emit('end') - }) + attemptConnect() + } - if (!this.ssl) { - return this.attachListeners(this.stream) + _runSessionAttrsCheck(probeType, targetAttrs, hostIndex, hosts, preferStandbyPass, tryNextOrFail, onStreamError) { + const self = this + const trustParams = this._trustParameterStatus + let phase = PHASE.STARTUP + let probeRows = [] + let probeError = false + let backendParams = {} + + const done = (readyMsg) => { + self._releaseConnectScope(onStreamError) + self._connecting = false + backendParams = probeRows = null + phase = PHASE.DONE + if (self._emitMessage) self.emit('message', readyMsg) + self.emit('readyForQuery', readyMsg) } - this.stream.once('data', function (buffer) { - const responseCode = buffer.toString('utf8') - switch (responseCode) { - case 'S': // Server supports SSL connections, continue with a secure connection - break - case 'N': // Server does not support SSL connections - self.stream.end() - return self.emit('error', new Error('The server does not support SSL connections')) - default: - // Any other response byte, including 'E' (ErrorResponse) indicating a server error - self.stream.end() - return self.emit('error', new Error('There was an error establishing an SSL connection')) - } - const options = { - socket: self.stream, + parse(this.stream, function onMessage(msg) { + const eventName = msg.name === 'error' ? 'errorMessage' : msg.name + + if (phase === PHASE.DONE) { + if (self._emitMessage) self.emit('message', msg) + self.emit(eventName, msg) + return } - if (self.ssl !== true) { - Object.assign(options, self.ssl) + if (eventName === 'parameterStatus') { + backendParams[msg.parameterName] = msg.parameterValue + if (self._emitMessage) self.emit('message', msg) + self.emit(eventName, msg) + return + } - if ('key' in self.ssl) { - options.key = self.ssl.key + if (phase === PHASE.STARTUP) { + if (eventName === 'readyForQuery') { + if (trustParams && canDecideFromParams(targetAttrs, backendParams)) { + if (!hostMatches(targetAttrs, backendParams, hostIndex, hosts, preferStandbyPass)) { + tryNextOrFail() + return + } + return done(msg) + } + + phase = PHASE.PROBE + self._send(probeType === 'tx_read_only' ? PROBE_SHOW_TX_READ_ONLY : PROBE_SELECT_RECOVERY) + return } + if (self._emitMessage) self.emit('message', msg) + self.emit(eventName, msg) + return } - const net = require('net') - if (net.isIP && net.isIP(host) === 0) { - options.servername = host + // Probe: intercept response — don't emit to client + if (eventName === 'dataRow') { + probeRows.push(msg) + return } - try { - self.stream = getSecureStream(options) - } catch (err) { - return self.emit('error', err) + if (eventName === 'rowDescription' || eventName === 'commandComplete') { + return } - self.attachListeners(self.stream) - self.stream.on('error', reportStreamError) + if (eventName === 'errorMessage') { + probeError = true + return + } + if (eventName === 'readyForQuery') { + if (!probeError && probeRows.length >= 1) { + parseProbeResult(probeType, probeRows[0], backendParams) + } + + if (probeError || !hostMatches(targetAttrs, backendParams, hostIndex, hosts, preferStandbyPass)) { + tryNextOrFail() + return + } - self.emit('sslconnect') + return done(msg) + } }) } + _onStreamError(error) { + if (this._ending && (error.code === 'ECONNRESET' || error.code === 'EPIPE')) { + return + } + this.emit('error', error) + } + + _releaseConnectScope(reportStreamError) { + this.stream.removeListener('error', reportStreamError) + this.stream.on('error', this._streamErrorHandler) + } + attachListeners(stream) { parse(stream, (msg) => { const eventName = msg.name === 'error' ? 'errorMessage' : msg.name @@ -151,17 +349,14 @@ class Connection extends EventEmitter { this._send(serialize.query(text)) } - // send parse message parse(query) { this._send(serialize.parse(query)) } - // send bind message bind(config) { this._send(serialize.bind(config)) } - // send execute message execute(config) { this._send(serialize.execute(config)) } @@ -186,7 +381,6 @@ class Connection extends EventEmitter { } end() { - // 0x58 = 'X' this._ending = true if (!this._connecting || !this.stream.writable) { this.stream.end() @@ -218,4 +412,61 @@ class Connection extends EventEmitter { } } +function getProbeType(targetAttrs) { + switch (targetAttrs) { + case 'read-write': + case 'read-only': + return 'tx_read_only' + case 'primary': + case 'standby': + case 'prefer-standby': + return 'is_in_recovery' + default: + return null + } +} + +function parseProbeResult(probeType, row, params) { + const val = row.fields[0]?.toString('utf8') ?? null + if (val === null) return + if (probeType === 'tx_read_only') { + params.default_transaction_read_only = val + params.in_hot_standby = val + } else { + params.in_hot_standby = val === 't' ? 'on' : 'off' + } +} + +function canDecideFromParams(targetAttrs, params) { + switch (targetAttrs) { + case 'read-write': + case 'read-only': + return params.in_hot_standby !== undefined && params.default_transaction_read_only !== undefined + case 'primary': + case 'standby': + case 'prefer-standby': + return params.in_hot_standby !== undefined + default: + return false + } +} + +function hostMatches(targetAttrs, params, hostIndex, hosts, preferStandbyPass) { + switch (targetAttrs) { + case 'read-write': + return params.in_hot_standby !== 'on' && params.default_transaction_read_only !== 'on' + case 'read-only': + return params.in_hot_standby === 'on' || params.default_transaction_read_only === 'on' + case 'primary': + return params.in_hot_standby !== 'on' + case 'standby': + return params.in_hot_standby !== 'off' + case 'prefer-standby': + if (preferStandbyPass === 2) return true + return params.in_hot_standby !== 'off' || hostIndex + 1 >= hosts.length + default: + return true + } +} + module.exports = Connection diff --git a/packages/pg/test/unit/client/multihost-tests.js b/packages/pg/test/unit/client/multihost-tests.js new file mode 100644 index 000000000..17aa4ae35 --- /dev/null +++ b/packages/pg/test/unit/client/multihost-tests.js @@ -0,0 +1,80 @@ +'use strict' +const assert = require('assert') +const EventEmitter = require('events') +const helper = require('./test-helper') +const { Client } = helper + +const suite = new helper.Suite() + +function makeFakeConnection() { + const con = new EventEmitter() + con.connectCalls = [] + con.connect = function (port, host) { + con.connectCalls.push({ port, host }) + } + con.on = con.addListener.bind(con) + con.once = EventEmitter.prototype.once.bind(con) + con.removeAllListeners = EventEmitter.prototype.removeAllListeners.bind(con) + con._ending = false + con.requestSsl = function () {} + con.startup = function () {} + con.end = function () {} + return con +} + +suite.test('passes port array to connection.connect', function () { + const con = makeFakeConnection() + const client = new Client({ connection: con, host: ['localhost', '127.0.0.1'], port: [5432, 5433] }) + client._connect(function () {}) + assert.deepStrictEqual(client.port, [5432, 5433]) + assert.deepStrictEqual(con.connectCalls[0].port, [5432, 5433]) +}) + +suite.test('passes host array to connection.connect', function () { + const con = makeFakeConnection() + const client = new Client({ connection: con, host: ['h1', 'h2'], port: 5432 }) + client._connect(function () {}) + assert.deepStrictEqual(client.host, ['h1', 'h2']) + assert.deepStrictEqual(con.connectCalls[0].host, ['h1', 'h2']) +}) + +suite.test('passes host and port arrays together to connection.connect', function () { + const con = makeFakeConnection() + const client = new Client({ connection: con, host: ['h1', 'h2'], port: [5432, 5433] }) + client._connect(function () {}) + assert.deepStrictEqual(con.connectCalls[0], { port: [5432, 5433], host: ['h1', 'h2'] }) +}) + +// --- domain socket path is not broken by the array guard --- + +suite.test('domain socket path still works with single string host', function () { + const con = makeFakeConnection() + con.connect = function (path) { + con.connectCalls.push({ path }) + } + const client = new Client({ connection: con, host: '/tmp/', port: 5432 }) + client._connect(function () {}) + assert.ok(con.connectCalls[0].path.startsWith('/tmp/'), 'should use domain socket path') +}) + +// --- array host does NOT trigger domain socket path --- + +suite.test('array host with leading-slash element does not trigger domain socket', function () { + const con = makeFakeConnection() + const client = new Client({ connection: con, host: ['/tmp/', 'localhost'], port: 5432 }) + client._connect(function () {}) + // connect() must receive (port, host) signature, not a single socket path string + const call = con.connectCalls[0] + assert.ok('port' in call, 'should call connect(port, host) not connect(socketPath)') + assert.ok('host' in call, 'should call connect(port, host) not connect(socketPath)') +}) + +// --- single host / single port unchanged --- + +suite.test('single host and port are still passed as scalars', function () { + const con = makeFakeConnection() + const client = new Client({ connection: con, host: 'localhost', port: 5432 }) + client._connect(function () {}) + assert.strictEqual(con.connectCalls[0].port, 5432) + assert.strictEqual(con.connectCalls[0].host, 'localhost') +}) diff --git a/packages/pg/test/unit/connection-parameters/multihost-tests.js b/packages/pg/test/unit/connection-parameters/multihost-tests.js new file mode 100644 index 000000000..1bd55158c --- /dev/null +++ b/packages/pg/test/unit/connection-parameters/multihost-tests.js @@ -0,0 +1,82 @@ +'use strict' +const assert = require('assert') +const helper = require('../test-helper') +const ConnectionParameters = require('../../../lib/connection-parameters') + +for (const key in process.env) { + delete process.env[key] +} + +const suite = new helper.Suite() + +suite.test('single port as number is parsed to integer', function () { + const subject = new ConnectionParameters({ port: 5432 }) + assert.strictEqual(subject.port, 5432) +}) + +suite.test('single port as string is parsed to integer', function () { + const subject = new ConnectionParameters({ port: '5433' }) + assert.strictEqual(subject.port, 5433) +}) + +suite.test('port array of numbers is preserved as integer array', function () { + const subject = new ConnectionParameters({ host: ['h1', 'h2'], port: [5432, 5433] }) + assert.deepStrictEqual(subject.port, [5432, 5433]) +}) + +suite.test('port array of strings is mapped to integers', function () { + const subject = new ConnectionParameters({ host: ['h1', 'h2', 'h3'], port: ['5432', '5433', '5434'] }) + assert.deepStrictEqual(subject.port, [5432, 5433, 5434]) +}) + +suite.test('port array with single element is preserved as array', function () { + const subject = new ConnectionParameters({ port: [5432] }) + assert.deepStrictEqual(subject.port, [5432]) +}) + +suite.test('single host string is preserved', function () { + const subject = new ConnectionParameters({ host: 'localhost' }) + assert.strictEqual(subject.host, 'localhost') +}) + +suite.test('host array is passed through unchanged', function () { + const subject = new ConnectionParameters({ host: ['host1', 'host2', 'host3'] }) + assert.deepStrictEqual(subject.host, ['host1', 'host2', 'host3']) +}) + +suite.test('host array with single element is preserved as array', function () { + const subject = new ConnectionParameters({ host: ['localhost'] }) + assert.deepStrictEqual(subject.host, ['localhost']) +}) + +suite.test('host and port arrays are both passed through', function () { + const subject = new ConnectionParameters({ host: ['h1', 'h2'], port: [5432, 5433] }) + assert.deepStrictEqual(subject.host, ['h1', 'h2']) + assert.deepStrictEqual(subject.port, [5432, 5433]) +}) + +suite.test('isDomainSocket is false when host is an array', function () { + const subject = new ConnectionParameters({ host: ['/tmp/', 'localhost'] }) + assert.strictEqual(subject.isDomainSocket, false) +}) + +suite.test('invalid targetSessionAttrs throws', function () { + assert.throws( + () => new ConnectionParameters({ targetSessionAttrs: 'read-mostly' }), + /invalid targetSessionAttrs value/ + ) +}) + +suite.test('valid targetSessionAttrs values do not throw', function () { + const valid = ['any', 'read-write', 'read-only', 'primary', 'standby', 'prefer-standby'] + for (const value of valid) { + assert.doesNotThrow(() => new ConnectionParameters({ targetSessionAttrs: value })) + } +}) + +suite.test('mismatched ports and hosts count throws', function () { + assert.throws( + () => new ConnectionParameters({ host: ['h1', 'h2', 'h3'], port: [5432, 5433] }), + /ports must have either 1 entry/ + ) +}) diff --git a/packages/pg/test/unit/connection/multihost-tests.js b/packages/pg/test/unit/connection/multihost-tests.js new file mode 100644 index 000000000..0c4409d45 --- /dev/null +++ b/packages/pg/test/unit/connection/multihost-tests.js @@ -0,0 +1,549 @@ +'use strict' +const helper = require('./test-helper') +const Connection = require('../../../lib/connection') +const assert = require('assert') + +const suite = new helper.Suite() +const { MemoryStream } = helper + +function makeStream() { + const stream = new MemoryStream() + stream.destroy = function () {} + return stream +} + +function makeErrorMessageBuf() { + // 'E' + length + 'S' + 'ERROR\0' + '\0' + const content = Buffer.concat([Buffer.from('SERROR\0'), Buffer.from([0x00])]) + const len = 4 + content.length + const buf = Buffer.allocUnsafe(1 + len) + buf[0] = 0x45 // 'E' + buf.writeUInt32BE(len, 1) + content.copy(buf, 5) + return buf +} + +function makeParameterStatusBuf(name, value) { + const n = Buffer.from(name + '\0') + const v = Buffer.from(value + '\0') + const len = 4 + n.length + v.length + const buf = Buffer.allocUnsafe(1 + len) + buf[0] = 0x53 // 'S' + buf.writeUInt32BE(len, 1) + n.copy(buf, 5) + v.copy(buf, 5 + n.length) + return buf +} + +function makeReadyForQueryBuf() { + return Buffer.from([0x5a, 0x00, 0x00, 0x00, 0x05, 0x49]) // 'Z' len=5 status='I' +} + +function makeDataRowBuf(fields) { + const bufs = fields.map((f) => (Buffer.isBuffer(f) ? f : Buffer.from(f))) + let dataLen = 2 // Int16 field count + for (const f of bufs) dataLen += 4 + f.length // Int32 len + data + const totalLen = 4 + dataLen // Int32 length field includes itself + const buf = Buffer.allocUnsafe(1 + totalLen) + buf[0] = 0x44 // 'D' + buf.writeUInt32BE(totalLen, 1) + buf.writeUInt16BE(bufs.length, 5) + let offset = 7 + for (const f of bufs) { + buf.writeInt32BE(f.length, offset) + offset += 4 + f.copy(buf, offset) + offset += f.length + } + return buf +} + +function makeRowDescriptionBuf() { + // 'T', length=6, field count=0 + return Buffer.from([0x54, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00]) +} + +function makeCommandCompleteBuf() { + const tag = Buffer.from('SELECT 1\0') + const len = 4 + tag.length + const buf = Buffer.allocUnsafe(1 + len) + buf[0] = 0x43 // 'C' + buf.writeUInt32BE(len, 1) + tag.copy(buf, 5) + return buf +} + +function simulateReadyForQuery(stream, params) { + for (const [key, value] of Object.entries(params)) { + stream.emit('data', makeParameterStatusBuf(key, value)) + } + stream.emit('data', makeReadyForQueryBuf()) +} + +suite.test('connects to single host', function (done) { + const stream = makeStream() + let connectPort, connectHost + stream.connect = function (port, host) { + connectPort = port + connectHost = host + } + const con = new Connection({ stream: stream }) + con.once('connect', function () { + assert.equal(connectPort, 5432) + assert.equal(connectHost, 'localhost') + done() + }) + con.connect(5432, 'localhost') + stream.emit('connect') +}) + +suite.test('connects to first host when multiple are given', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const connectCalls = [] + streams.forEach((s) => { + s.connect = function (port, host) { + connectCalls.push({ port, host }) + } + }) + const con = new Connection({ stream: () => streams[streamIndex++] }) + con.once('connect', function () { + assert.equal(connectCalls.length, 1) + assert.equal(connectCalls[0].host, 'host1') + assert.equal(connectCalls[0].port, 5432) + done() + }) + con.connect([5432, 5433], ['host1', 'host2']) + streams[0].emit('connect') +}) + +suite.test('stream factory receives same config on failover streams', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const factoryArgs = [] + const config = { + ssl: false, + stream: function (opts) { + factoryArgs.push(opts) + return streams[streamIndex++] + }, + } + const con = new Connection(config) + con.once('connect', function () { + assert.equal(factoryArgs.length, 2) + assert.strictEqual(factoryArgs[0], config) + assert.strictEqual(factoryArgs[1], config) + done() + }) + con.connect([5432, 5433], ['host1', 'host2']) + const err = new Error('Connection refused') + err.code = 'ECONNREFUSED' + streams[0].emit('error', err) + streams[1].emit('connect') +}) + +suite.test('falls back to second host on connection error', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const connectCalls = [] + streams.forEach((s) => { + s.connect = function (port, host) { + connectCalls.push({ port, host }) + } + }) + const con = new Connection({ stream: () => streams[streamIndex++] }) + con.once('connect', function () { + assert.equal(connectCalls.length, 2) + assert.equal(connectCalls[0].host, 'host1') + assert.equal(connectCalls[1].host, 'host2') + done() + }) + con.connect([5432, 5433], ['host1', 'host2']) + const err = new Error('Connection refused') + err.code = 'ECONNREFUSED' + streams[0].emit('error', err) + streams[1].emit('connect') +}) + +suite.test('uses matching port for each host by index', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const connectCalls = [] + streams.forEach((s) => { + s.connect = function (port, host) { + connectCalls.push({ port, host }) + } + }) + const con = new Connection({ stream: () => streams[streamIndex++] }) + con.once('connect', function () { + assert.equal(connectCalls[0].port, 5432) + assert.equal(connectCalls[1].port, 5433) + done() + }) + con.connect([5432, 5433], ['host1', 'host2']) + const err = new Error('Connection refused') + err.code = 'ECONNREFUSED' + streams[0].emit('error', err) + streams[1].emit('connect') +}) + +suite.test('reuses single port for all hosts when port is not an array', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const connectPorts = [] + streams.forEach((s) => { + s.connect = function (port) { + connectPorts.push(port) + } + }) + const con = new Connection({ stream: () => streams[streamIndex++] }) + con.once('connect', function () { + assert.equal(connectPorts[0], 5432) + assert.equal(connectPorts[1], 5432) + done() + }) + con.connect(5432, ['host1', 'host2']) + const err = new Error('Connection refused') + err.code = 'ECONNREFUSED' + streams[0].emit('error', err) + streams[1].emit('connect') +}) + +suite.test('emits error after all hosts fail', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const con = new Connection({ stream: () => streams[streamIndex++] }) + assert.emits(con, 'error', function () { + done() + }) + con.connect([5432, 5433], ['host1', 'host2']) + const err1 = new Error('Connection refused') + err1.code = 'ECONNREFUSED' + streams[0].emit('error', err1) + const err2 = new Error('Connection refused') + err2.code = 'ECONNREFUSED' + streams[1].emit('error', err2) +}) + +suite.test('does not fall back after successful connect', function (done) { + const stream = makeStream() + const con = new Connection({ stream: stream }) + con.once('connect', function () { + assert.emits(con, 'error', function (err) { + assert.equal(err.code, 'ECONNRESET') + done() + }) + const err = new Error('Connection reset') + err.code = 'ECONNRESET' + stream.emit('error', err) + }) + con.connect([5432, 5433], ['host1', 'host2']) + stream.emit('connect') +}) + +suite.test('targetSessionAttrs=any does not intercept readyForQuery', function (done) { + const stream = makeStream() + const con = new Connection({ targetSessionAttrs: 'any', stream: stream }) + con.once('readyForQuery', function () { + done() + }) + con.connect(5432, 'localhost') + stream.emit('connect') + con.emit('readyForQuery', {}) +}) + +suite.test('targetSessionAttrs=read-write skips hot standby and uses primary', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const con = new Connection({ + targetSessionAttrs: 'read-write', + trustParameterStatus: true, + stream: () => streams[streamIndex++], + }) + con.once('readyForQuery', function () { + assert.equal(streamIndex, 2) + done() + }) + con.connect([5432, 5433], ['standby', 'primary']) + streams[0].emit('connect') + simulateReadyForQuery(streams[0], { in_hot_standby: 'on', default_transaction_read_only: 'off' }) + streams[1].emit('connect') + simulateReadyForQuery(streams[1], { in_hot_standby: 'off', default_transaction_read_only: 'off' }) +}) + +suite.test('targetSessionAttrs=read-write skips read-only and uses writable', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const con = new Connection({ + targetSessionAttrs: 'read-write', + trustParameterStatus: true, + stream: () => streams[streamIndex++], + }) + con.once('readyForQuery', function () { + done() + }) + con.connect([5432, 5433], ['readonly', 'writable']) + streams[0].emit('connect') + simulateReadyForQuery(streams[0], { in_hot_standby: 'off', default_transaction_read_only: 'on' }) + streams[1].emit('connect') + simulateReadyForQuery(streams[1], { in_hot_standby: 'off', default_transaction_read_only: 'off' }) +}) + +suite.test('targetSessionAttrs=read-only skips primary and uses standby', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const con = new Connection({ + targetSessionAttrs: 'read-only', + trustParameterStatus: true, + stream: () => streams[streamIndex++], + }) + con.once('readyForQuery', function () { + done() + }) + con.connect([5432, 5433], ['primary', 'standby']) + streams[0].emit('connect') + simulateReadyForQuery(streams[0], { in_hot_standby: 'off', default_transaction_read_only: 'off' }) + streams[1].emit('connect') + simulateReadyForQuery(streams[1], { in_hot_standby: 'on', default_transaction_read_only: 'on' }) +}) + +suite.test('targetSessionAttrs=primary skips standby and uses primary', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const con = new Connection({ + targetSessionAttrs: 'primary', + trustParameterStatus: true, + stream: () => streams[streamIndex++], + }) + con.once('readyForQuery', function () { + done() + }) + con.connect([5432, 5433], ['standby', 'primary']) + streams[0].emit('connect') + simulateReadyForQuery(streams[0], { in_hot_standby: 'on', default_transaction_read_only: 'on' }) + streams[1].emit('connect') + simulateReadyForQuery(streams[1], { in_hot_standby: 'off', default_transaction_read_only: 'off' }) +}) + +suite.test('targetSessionAttrs=standby skips primary and uses hot standby', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const con = new Connection({ + targetSessionAttrs: 'standby', + trustParameterStatus: true, + stream: () => streams[streamIndex++], + }) + con.once('readyForQuery', function () { + done() + }) + con.connect([5432, 5433], ['primary', 'standby']) + streams[0].emit('connect') + simulateReadyForQuery(streams[0], { in_hot_standby: 'off', default_transaction_read_only: 'off' }) + streams[1].emit('connect') + simulateReadyForQuery(streams[1], { in_hot_standby: 'on', default_transaction_read_only: 'on' }) +}) + +suite.test('targetSessionAttrs=prefer-standby uses standby when available', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const con = new Connection({ + targetSessionAttrs: 'prefer-standby', + trustParameterStatus: true, + stream: () => streams[streamIndex++], + }) + con.once('readyForQuery', function () { + assert.equal(streamIndex, 2) + done() + }) + con.connect([5432, 5433], ['primary', 'standby']) + streams[0].emit('connect') + simulateReadyForQuery(streams[0], { in_hot_standby: 'off', default_transaction_read_only: 'off' }) + streams[1].emit('connect') + simulateReadyForQuery(streams[1], { in_hot_standby: 'on', default_transaction_read_only: 'on' }) +}) + +suite.test('targetSessionAttrs=prefer-standby falls back to primary when no standby available', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const con = new Connection({ + targetSessionAttrs: 'prefer-standby', + trustParameterStatus: true, + stream: () => streams[streamIndex++], + }) + con.once('readyForQuery', function () { + done() + }) + con.connect([5432, 5433], ['primary1', 'primary2']) + streams[0].emit('connect') + simulateReadyForQuery(streams[0], { in_hot_standby: 'off', default_transaction_read_only: 'off' }) + streams[1].emit('connect') + simulateReadyForQuery(streams[1], { in_hot_standby: 'off', default_transaction_read_only: 'off' }) +}) + +suite.test('emits error when no host satisfies targetSessionAttrs', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const con = new Connection({ + targetSessionAttrs: 'read-write', + trustParameterStatus: true, + stream: () => streams[streamIndex++], + }) + assert.emits(con, 'error', function (err) { + assert.ok(err.message.includes('read-write')) + done() + }) + con.connect([5432, 5433], ['standby1', 'standby2']) + streams[0].emit('connect') + simulateReadyForQuery(streams[0], { in_hot_standby: 'on', default_transaction_read_only: 'off' }) + streams[1].emit('connect') + simulateReadyForQuery(streams[1], { in_hot_standby: 'on', default_transaction_read_only: 'off' }) +}) + +suite.test('resets backend params between hosts when checking targetSessionAttrs', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const con = new Connection({ + targetSessionAttrs: 'primary', + trustParameterStatus: true, + stream: () => streams[streamIndex++], + }) + con.once('readyForQuery', function () { + done() + }) + con.connect([5432, 5433], ['standby', 'primary']) + streams[0].emit('connect') + // standby sends in_hot_standby=on → skip + simulateReadyForQuery(streams[0], { in_hot_standby: 'on', default_transaction_read_only: 'on' }) + streams[1].emit('connect') + simulateReadyForQuery(streams[1], { in_hot_standby: 'off', default_transaction_read_only: 'off' }) +}) + +suite.test('fetches session state via SHOW query when not provided in ParameterStatus', function (done) { + const stream = makeStream() + const con = new Connection({ + targetSessionAttrs: 'read-write', + stream: stream, + }) + con.once('readyForQuery', function () { + done() + }) + con.connect(5432, 'localhost') + stream.emit('connect') + stream.emit('data', makeReadyForQueryBuf()) + stream.emit('data', makeDataRowBuf([Buffer.from('off')])) + stream.emit('data', makeReadyForQueryBuf()) +}) + +suite.test('tries next host when SHOW query returns standby state', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const con = new Connection({ + targetSessionAttrs: 'read-write', + trustParameterStatus: true, + stream: () => streams[streamIndex++], + }) + con.once('readyForQuery', function () { + assert.equal(streamIndex, 2) + done() + }) + con.connect([5432, 5433], ['standby', 'primary']) + streams[0].emit('connect') + streams[0].emit('data', makeReadyForQueryBuf()) + streams[0].emit('data', makeDataRowBuf([Buffer.from('on')])) // transaction_read_only=on + streams[0].emit('data', makeReadyForQueryBuf()) + streams[1].emit('connect') + simulateReadyForQuery(streams[1], { in_hot_standby: 'off', default_transaction_read_only: 'off' }) +}) + +suite.test('prefer-standby triggers pass 2 when all hosts fail TCP in pass 1', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream(), makeStream()] + const connectHosts = [] + streams.forEach((s) => { + s.connect = function (_port, host) { + connectHosts.push(host) + } + }) + const con = new Connection({ + targetSessionAttrs: 'prefer-standby', + trustParameterStatus: true, + stream: () => streams[streamIndex++], + }) + con.once('readyForQuery', function () { + // pass 2 reconnects from the beginning of the host list + assert.equal(connectHosts[2], 'host1') + done() + }) + con.connect([5432, 5433], ['host1', 'host2']) + const err1 = new Error('Connection refused') + err1.code = 'ECONNREFUSED' + streams[0].emit('error', err1) + const err2 = new Error('Connection refused') + err2.code = 'ECONNREFUSED' + streams[1].emit('error', err2) + streams[2].emit('connect') + simulateReadyForQuery(streams[2], { in_hot_standby: 'off', default_transaction_read_only: 'off' }) +}) + +suite.test('probe error causes next host to be tried', function (done) { + let streamIndex = 0 + const streams = [makeStream(), makeStream()] + const con = new Connection({ + targetSessionAttrs: 'read-write', + stream: () => streams[streamIndex++], + }) + con.once('readyForQuery', function () { + assert.equal(streamIndex, 2) + done() + }) + con.connect([5432, 5433], ['host1', 'host2']) + streams[0].emit('connect') + streams[0].emit('data', makeReadyForQueryBuf()) + streams[0].emit('data', makeErrorMessageBuf()) + streams[0].emit('data', makeReadyForQueryBuf()) + streams[1].emit('connect') + streams[1].emit('data', makeReadyForQueryBuf()) + streams[1].emit('data', makeDataRowBuf([Buffer.from('off')])) + streams[1].emit('data', makeReadyForQueryBuf()) +}) + +suite.test('read-only host accepted when tx_read_only probe returns on', function (done) { + const stream = makeStream() + const con = new Connection({ + targetSessionAttrs: 'read-only', + stream: stream, + }) + con.once('readyForQuery', function () { + done() + }) + con.connect(5432, 'localhost') + stream.emit('connect') + stream.emit('data', makeReadyForQueryBuf()) + stream.emit('data', makeDataRowBuf([Buffer.from('on')])) + stream.emit('data', makeReadyForQueryBuf()) +}) + +suite.test('swallows rowDescription and commandComplete during SHOW fetch', function (done) { + const stream = makeStream() + const con = new Connection({ + targetSessionAttrs: 'read-write', + stream: stream, + }) + const unexpectedEvents = [] + for (const evt of ['rowDescription', 'commandComplete']) { + con.on(evt, function () { + unexpectedEvents.push(evt) + }) + } + con.once('readyForQuery', function () { + assert.equal(unexpectedEvents.length, 0) + done() + }) + con.connect(5432, 'localhost') + stream.emit('connect') + stream.emit('data', makeReadyForQueryBuf()) + stream.emit('data', makeRowDescriptionBuf()) + stream.emit('data', makeDataRowBuf([Buffer.from('off')])) // transaction_read_only=off + stream.emit('data', makeCommandCompleteBuf()) + stream.emit('data', makeReadyForQueryBuf()) +})