diff --git a/Mavenfile b/Mavenfile index 1581b9c1..9a724fba 100644 --- a/Mavenfile +++ b/Mavenfile @@ -82,9 +82,12 @@ plugin :clean do 'failOnError' => 'false' ) end -jar 'org.jruby:jruby-core', '9.2.0.0', :scope => :provided +jruby_compile_compat = '9.2.1.0' # due load_ext can use 9.2.0.0 +jar 'org.jruby:jruby-core', jruby_compile_compat, :scope => :provided # for invoker generated classes we need to add javax.annotation when on Java > 8 jar 'javax.annotation:javax.annotation-api', '1.3.1', :scope => :compile +# a test dependency to provide digest and other stdlib bits, needed when loading OpenSSL in Java unit tests +jar 'org.jruby:jruby-stdlib', jruby_compile_compat, :scope => :test jar 'junit:junit', '[4.13.1,)', :scope => :test # NOTE: to build on Java 11 - installing gems fails (due old jossl) with: diff --git a/pom.xml b/pom.xml index 3b8009dc..7e85e21c 100644 --- a/pom.xml +++ b/pom.xml @@ -98,7 +98,7 @@ DO NOT MODIFY - GENERATED CODE org.jruby jruby-core - 9.2.0.0 + 9.2.1.0 provided @@ -107,6 +107,12 @@ DO NOT MODIFY - GENERATED CODE 1.3.1 compile + + org.jruby + jruby-stdlib + 9.2.1.0 + test + junit junit @@ -275,6 +281,7 @@ DO NOT MODIFY - GENERATED CODE 1.8 1.8 + 8 UTF-8 true true diff --git a/src/main/java/org/jruby/ext/openssl/SSLSocket.java b/src/main/java/org/jruby/ext/openssl/SSLSocket.java index ebc9cb03..5d71cc87 100644 --- a/src/main/java/org/jruby/ext/openssl/SSLSocket.java +++ b/src/main/java/org/jruby/ext/openssl/SSLSocket.java @@ -141,14 +141,15 @@ private static CallSite callSite(final CallSite[] sites, final CallSiteIndex ind return sites[ index.ordinal() ]; } - private SSLContext sslContext; + private static final ByteBuffer EMPTY_DATA = ByteBuffer.allocate(0).asReadOnlyBuffer(); + + SSLContext sslContext; private SSLEngine engine; private RubyIO io; - private ByteBuffer appReadData; - private ByteBuffer netReadData; - private ByteBuffer netWriteData; - private final ByteBuffer dummy = ByteBuffer.allocate(0); // could be static + ByteBuffer appReadData; + ByteBuffer netReadData; + ByteBuffer netWriteData; private boolean initialHandshake = false; private transient long initializeTime; @@ -209,7 +210,7 @@ private IRubyObject fallback_set_io_nonblock_checked(ThreadContext context, Ruby private static final String SESSION_SOCKET_ID = "socket_id"; - private SSLEngine ossl_ssl_setup(final ThreadContext context, final boolean server) { + SSLEngine ossl_ssl_setup(final ThreadContext context, final boolean server) { SSLEngine engine = this.engine; if ( engine != null ) return engine; @@ -553,10 +554,6 @@ private static void writeWouldBlock(final Ruby runtime, final boolean exception, result[0] = WRITE_WOULD_BLOCK_RESULT; } - private void doHandshake(final boolean blocking) throws IOException { - doHandshake(blocking, true); - } - // might return :wait_readable | :wait_writable in case (true, false) private IRubyObject doHandshake(final boolean blocking, final boolean exception) throws IOException { while (true) { @@ -578,7 +575,7 @@ private IRubyObject doHandshake(final boolean blocking, final boolean exception) doTasks(); break; case NEED_UNWRAP: - if (readAndUnwrap(blocking) == -1 && handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED) { + if (readAndUnwrap(blocking, exception) == -1 && handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED) { throw new SSLHandshakeException("Socket closed"); } // during initialHandshake, calling readAndUnwrap that results UNDERFLOW does not mean writable. @@ -614,7 +611,7 @@ private IRubyObject doHandshake(final boolean blocking, final boolean exception) private void doWrap(final boolean blocking) throws IOException { netWriteData.clear(); - SSLEngineResult result = engine.wrap(dummy, netWriteData); + SSLEngineResult result = engine.wrap(EMPTY_DATA.duplicate(), netWriteData); netWriteData.flip(); handshakeStatus = result.getHandshakeStatus(); status = result.getStatus(); @@ -689,7 +686,9 @@ public int write(ByteBuffer src, boolean blocking) throws SSLException, IOExcept if ( netWriteData.hasRemaining() ) { flushData(blocking); } - netWriteData.clear(); + // compact() to preserve encrypted bytes flushData could not send (non-blocking partial write) + // clear() would discard them, corrupting the TLS record stream: + netWriteData.compact(); final SSLEngineResult result = engine.wrap(src, netWriteData); if ( result.getStatus() == SSLEngineResult.Status.CLOSED ) { throw getRuntime().newIOError("closed SSL engine"); @@ -704,11 +703,15 @@ public int write(ByteBuffer src, boolean blocking) throws SSLException, IOExcept } public int read(final ByteBuffer dst, final boolean blocking) throws IOException { + return read(dst, blocking, true); + } + + private int read(final ByteBuffer dst, final boolean blocking, final boolean exception) throws IOException { if ( initialHandshake ) return 0; if ( engine.isInboundDone() ) return -1; if ( ! appReadData.hasRemaining() ) { - int appBytesProduced = readAndUnwrap(blocking); + int appBytesProduced = readAndUnwrap(blocking, exception); if (appBytesProduced == -1 || appBytesProduced == 0) { return appBytesProduced; } @@ -719,7 +722,7 @@ public int read(final ByteBuffer dst, final boolean blocking) throws IOException return limit; } - private int readAndUnwrap(final boolean blocking) throws IOException { + private int readAndUnwrap(final boolean blocking, final boolean exception) throws IOException { final int bytesRead = socketChannelImpl().read(netReadData); if ( bytesRead == -1 ) { if ( ! netReadData.hasRemaining() || @@ -727,9 +730,8 @@ private int readAndUnwrap(final boolean blocking) throws IOException { closeInbound(); return -1; } - // inbound channel has been already closed but closeInbound() must - // be defered till the last engine.unwrap() call. - // peerNetData could not be empty. + // inbound channel has been already closed but closeInbound() must be defered till + // the last engine.unwrap() call; peerNetData could not be empty } appReadData.clear(); netReadData.flip(); @@ -768,7 +770,7 @@ private int readAndUnwrap(final boolean blocking) throws IOException { handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_TASK || handshakeStatus == SSLEngineResult.HandshakeStatus.NEED_WRAP || handshakeStatus == SSLEngineResult.HandshakeStatus.FINISHED ) ) { - doHandshake(blocking); + doHandshake(blocking, exception); } return appReadData.remaining(); } @@ -793,7 +795,7 @@ private void doShutdown() throws IOException { } netWriteData.clear(); try { - engine.wrap(dummy, netWriteData); // send close (after sslEngine.closeOutbound) + engine.wrap(EMPTY_DATA.duplicate(), netWriteData); // send close (after sslEngine.closeOutbound) } catch (SSLException e) { debug(getRuntime(), "SSLSocket.doShutdown", e); @@ -808,10 +810,10 @@ private void doShutdown() throws IOException { } /** - * @return the (@link RubyString} buffer or :wait_readable / :wait_writeable {@link RubySymbol} + * @return the {@link RubyString} buffer or :wait_readable / :wait_writeable {@link RubySymbol} */ - private IRubyObject sysreadImpl(final ThreadContext context, final IRubyObject len, final IRubyObject buff, - final boolean blocking, final boolean exception) { + private IRubyObject sysreadImpl(final ThreadContext context, + final IRubyObject len, final IRubyObject buff, final boolean blocking, final boolean exception) { final Ruby runtime = context.runtime; final int length = RubyNumeric.fix2int(len); @@ -831,6 +833,14 @@ private IRubyObject sysreadImpl(final ThreadContext context, final IRubyObject l } try { + // Flush pending write data before reading (after write_nonblock encrypted bytes may still be buffered) + if ( engine != null && netWriteData.hasRemaining() ) { + if ( flushData(blocking) && ! blocking ) { + if ( exception ) throw newSSLErrorWaitWritable(runtime, "write would block"); + return runtime.newSymbol("wait_writable"); + } + } + // So we need to make sure to only block when there is no data left to process if ( engine == null || ! ( appReadData.hasRemaining() || netReadData.position() > 0 ) ) { final Object ex = waitSelect(SelectionKey.OP_READ, blocking, exception); @@ -839,12 +849,12 @@ private IRubyObject sysreadImpl(final ThreadContext context, final IRubyObject l final ByteBuffer dst = ByteBuffer.allocate(length); int read = -1; - // ensure >0 bytes read; sysread is blocking read. + // ensure > 0 bytes read; sysread is blocking read while ( read <= 0 ) { if ( engine == null ) { read = socketChannelImpl().read(dst); } else { - read = read(dst, blocking); + read = read(dst, blocking, exception); } if ( read == -1 ) { @@ -1226,7 +1236,7 @@ public IRubyObject ssl_version(ThreadContext context) { return context.runtime.newString( engine.getSession().getProtocol() ); } - private transient SocketChannelImpl socketChannel; + transient SocketChannelImpl socketChannel; private SocketChannelImpl socketChannelImpl() { if ( socketChannel != null ) return socketChannel; @@ -1241,7 +1251,7 @@ private SocketChannelImpl socketChannelImpl() { throw new IllegalStateException("unknow channel impl: " + channel + " of type " + channel.getClass().getName()); } - private interface SocketChannelImpl { + interface SocketChannelImpl { boolean isOpen() ; diff --git a/src/test/java/org/jruby/ext/openssl/OpenSSLHelper.java b/src/test/java/org/jruby/ext/openssl/OpenSSLHelper.java new file mode 100644 index 00000000..5c2bf19d --- /dev/null +++ b/src/test/java/org/jruby/ext/openssl/OpenSSLHelper.java @@ -0,0 +1,70 @@ +package org.jruby.ext.openssl; + +import org.jruby.Ruby; +import org.jruby.runtime.ThreadContext; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +abstract class OpenSSLHelper { + + protected Ruby runtime; + + void setUpRuntime() throws ClassNotFoundException { + runtime = Ruby.newInstance(); + loadOpenSSL(runtime); + } + + void tearDownRuntime() { + if (runtime != null) runtime.tearDown(false); + } + + protected void loadOpenSSL(final Ruby runtime) throws ClassNotFoundException { + // prepend lib/ so openssl.rb + jopenssl/ are loaded instead of bundled OpenSSL in jruby-stdlib + final String libDir = new File("lib").getAbsolutePath(); + runtime.evalScriptlet("$LOAD_PATH.unshift '" + libDir + "'"); + runtime.evalScriptlet("require 'openssl'"); + + // sanity: verify openssl was loaded from the project, not jruby-stdlib : + final String versionFile = new File(libDir, "jopenssl/version.rb").getAbsolutePath(); + final String expectedVersion = runtime.evalScriptlet( + "File.read('" + versionFile + "').match( /.*\\sVERSION\\s*=\\s*['\"](.*)['\"]/ )[1]") + .toString(); + final String loadedVersion = runtime.evalScriptlet("JOpenSSL::VERSION").toString(); + assertEquals("OpenSSL must be loaded from project (got version " + loadedVersion + + "), not from jruby-stdlib", expectedVersion, loadedVersion); + + // Also check the Java extension classes were resolved from the project, not jruby-stdlib : + final String classOrigin = runtime.getJRubyClassLoader() + .loadClass("org.jruby.ext.openssl.OpenSSL") + .getProtectionDomain().getCodeSource().getLocation().toString(); + assertTrue("OpenSSL.class (via JRuby classloader) come from project, got: " + classOrigin, + classOrigin.endsWith("/pkg/classes/")); + } + + // HELPERS + + public ThreadContext currentContext() { + return runtime.getCurrentContext(); + } + + public static String readResource(final String resource) { + int n; + try (InputStream in = SSLSocketTest.class.getResourceAsStream(resource)) { + if (in == null) throw new IllegalArgumentException(resource + " not found on classpath"); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + byte[] buf = new byte[8192]; + while ((n = in.read(buf)) != -1) out.write(buf, 0, n); + return new String(out.toByteArray(), StandardCharsets.UTF_8); + } catch (IOException e) { + throw new IllegalStateException("failed to load" + resource, e); + } + } +} diff --git a/src/test/java/org/jruby/ext/openssl/SSLSocketTest.java b/src/test/java/org/jruby/ext/openssl/SSLSocketTest.java new file mode 100644 index 00000000..2c7c29ed --- /dev/null +++ b/src/test/java/org/jruby/ext/openssl/SSLSocketTest.java @@ -0,0 +1,251 @@ +package org.jruby.ext.openssl; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import javax.net.ssl.SSLEngine; + +import org.jruby.Ruby; +import org.jruby.RubyArray; +import org.jruby.RubyFixnum; +import org.jruby.RubyHash; +import org.jruby.RubyInteger; +import org.jruby.RubyString; +import org.jruby.exceptions.RaiseException; +import org.jruby.runtime.builtin.IRubyObject; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class SSLSocketTest extends OpenSSLHelper { + + /** Loads the ssl_pair.rb script that creates a connected SSL socket pair. */ + private String start_ssl_server_rb() { return readResource("/start_ssl_server.rb"); } + + @Before + public void setUp() throws Exception { + setUpRuntime(); + } + + @After + public void tearDown() { + tearDownRuntime(); + } + + /** + * Real-world scenario: {@code gem push} sends a large POST body via {@code syswrite_nonblock}, + * then reads the HTTP response via {@code sysread}. + * + * Approximates the {@code gem push} scenario: + *
    + *
  1. Write 256KB via {@code syswrite_nonblock} in a loop (the net/http POST pattern)
  2. + *
  3. Server reads via {@code sysread} and counts bytes
  4. + *
  5. Assert: server received exactly what client sent
  6. + *
+ * + * With the old {@code clear()} bug, encrypted bytes were silently + * discarded during partial non-blocking writes, so the server would + * receive fewer bytes than sent. + */ + @Test + public void syswriteNonblockDataIntegrity() throws Exception { + final RubyArray pair = (RubyArray) runtime.evalScriptlet(start_ssl_server_rb()); + SSLSocket client = (SSLSocket) pair.entry(0).toJava(SSLSocket.class); + SSLSocket server = (SSLSocket) pair.entry(1).toJava(SSLSocket.class); + + try { + // Server: read all data in a background thread, counting bytes + final long[] serverReceived = { 0 }; + Thread serverReader = startServerReader(server, serverReceived); + + // Client: write 256KB in 4KB chunks via syswrite_nonblock + byte[] chunk = new byte[4096]; + java.util.Arrays.fill(chunk, (byte) 'P'); // P for POST body + RubyString payload = RubyString.newString(runtime, chunk); + + long totalSent = 0; + for (int i = 0; i < 64; i++) { // 64 * 4KB = 256KB + try { + IRubyObject written = client.syswrite_nonblock(currentContext(), payload); + totalSent += ((RubyInteger) written).getLongValue(); + } catch (RaiseException e) { + if ("OpenSSL::SSL::SSLErrorWaitWritable".equals(e.getException().getMetaClass().getName())) { + System.out.println("syswrite_nonblock expected: " + e.getMessage()); + // Expected: non-blocking write would block — retry as blocking + IRubyObject written = client.syswrite(currentContext(), payload); + totalSent += ((RubyInteger) written).getLongValue(); + } else { + System.err.println("syswrite_nonblock unexpected: " + e.getMessage()); + throw e; + } + } + } + assertTrue("should have sent data", totalSent > 0); + + // Close client to signal EOF, let server finish reading + client.callMethod(currentContext(), "close"); + serverReader.join(10_000); + + assertEquals( + "server must receive exactly what client sent — mismatch means encrypted bytes were lost!", + totalSent, serverReceived[0] + ); + } finally { + closeQuietly(pair); + } + } + + private Thread startServerReader(final SSLSocket server, final long[] serverReceived) { + Thread serverReader = new Thread(() -> { + try { + RubyFixnum len = RubyFixnum.newFixnum(runtime, 8192); + while (true) { + IRubyObject data = server.sysread(currentContext(), len); + serverReceived[0] += ((RubyString) data).getByteList().getRealSize(); + } + } catch (RaiseException e) { + String errorName = e.getException().getMetaClass().getName(); + if ("EOFError".equals(errorName) || "IOError".equals(errorName)) { // client closes connection + System.out.println("server-reader expected: " + e.getMessage()); + } else { + System.err.println("server-reader unexpected: " + e.getMessage()); + e.printStackTrace(System.err); + throw e; + } + } + }); + serverReader.start(); + return serverReader; + } + + /** + * After saturating the TCP send buffer with {@code syswrite_nonblock}, + * inspect {@code netWriteData} to verify the buffer is consistent. + */ + @Test + public void syswriteNonblockNetWriteDataConsistency() { + final RubyArray pair = (RubyArray) runtime.evalScriptlet(start_ssl_server_rb()); + SSLSocket client = (SSLSocket) pair.entry(0).toJava(SSLSocket.class); + + try { + assertNotNull("netWriteData initialized after handshake", client.netWriteData); + + // Saturate: server is not reading yet, so backpressure builds + byte[] chunk = new byte[16384]; + java.util.Arrays.fill(chunk, (byte) 'S'); + RubyString payload = RubyString.newString(runtime, chunk); + + int successWrites = 0; + for (int i = 0; i < 200; i++) { + try { + client.syswrite_nonblock(currentContext(), payload); + successWrites++; + } catch (RaiseException e) { + if ("OpenSSL::SSL::SSLErrorWaitWritable".equals(e.getException().getMetaClass().getName())) { + System.out.println("saturate-loop expected: " + e.getMessage()); + break; // buffer saturated — expected + } + System.err.println("saturate-loop unexpected: " + e.getMessage()); + throw e; + } + } + assertTrue("at least one write should succeed", successWrites > 0); + + ByteBuffer netWriteData = client.netWriteData; + assertTrue("position <= limit", netWriteData.position() <= netWriteData.limit()); + assertTrue("limit <= capacity", netWriteData.limit() <= netWriteData.capacity()); + + // If there are unflushed bytes, compact() preserved them + if (netWriteData.remaining() > 0) { + // The bytes should be valid TLS record data, not zeroed memory + byte b = netWriteData.get(netWriteData.position()); + assertNotEquals("preserved bytes should be TLS data, not zeroed", 0, b); + } + + } finally { + closeQuietly(pair); + } + } + + private void closeQuietly(final RubyArray sslPair) { + for (int i = 0; i < sslPair.getLength(); i++) { + final IRubyObject elem = sslPair.entry(i); + try { elem.callMethod(currentContext(), "close"); } + catch (RaiseException e) { // already closed? + System.err.println("close raised (" + elem.inspect() + ") : " + e.getMessage()); + } + } + } + + // ---------- + + /** + * MRI's ossl_ssl_read_internal returns :wait_writable (or raises SSLErrorWaitWritable / "write would block") + * when SSL_read hits SSL_ERROR_WANT_WRITE. Pending netWriteData is JRuby's equivalent state. + */ + @Test + public void sysreadNonblockReturnsWaitWritableWhenPendingEncryptedBytesRemain() { + final SSLSocket socket = newSSLSocket(runtime, partialWriteChannel(1)); + final SSLEngine engine = socket.ossl_ssl_setup(currentContext(), false); + engine.setUseClientMode(true); + + socket.netWriteData = ByteBuffer.wrap(new byte[] { 1, 2 }); + + final RubyHash opts = RubyHash.newKwargs(runtime, "exception", runtime.getFalse()); // exception: false + final IRubyObject result = socket.sysread_nonblock(currentContext(), runtime.newFixnum(1), opts); + + assertEquals("wait_writable", result.asJavaString()); + assertEquals(1, socket.netWriteData.remaining()); + } + + @Test + public void sysreadNonblockRaisesWaitWritableWhenPendingEncryptedBytesRemain() { + final SSLSocket socket = newSSLSocket(runtime, partialWriteChannel(1)); + final SSLEngine engine = socket.ossl_ssl_setup(currentContext(), false); + engine.setUseClientMode(true); + + socket.netWriteData = ByteBuffer.wrap(new byte[] { 1, 2 }); + + try { + socket.sysread_nonblock(currentContext(), runtime.newFixnum(1)); + fail("expected SSLErrorWaitWritable"); + } + catch (RaiseException ex) { + assertEquals("OpenSSL::SSL::SSLErrorWaitWritable", ex.getException().getMetaClass().getName()); + assertTrue(ex.getMessage().contains("write would block")); + assertEquals(1, socket.netWriteData.remaining()); + } + } + + private static SSLSocket newSSLSocket(final Ruby runtime, final SSLSocket.SocketChannelImpl socketChannel) { + final SSLContext sslContext = new SSLContext(runtime); + sslContext.doSetup(runtime.getCurrentContext()); + final SSLSocket sslSocket = new SSLSocket(runtime, runtime.getObject()); + sslSocket.sslContext = sslContext; + sslSocket.socketChannel = socketChannel; + return sslSocket; + } + + private static SSLSocket.SocketChannelImpl partialWriteChannel(final int bytesPerWrite) { + return new SSLSocket.SocketChannelImpl() { + public boolean isOpen() { return true; } + public int read(final ByteBuffer dst) { return 0; } + public int write(final ByteBuffer src) { + final int written = Math.min(bytesPerWrite, src.remaining()); + src.position(src.position() + written); + return written; + } + public int getRemotePort() { return 443; } + public boolean isSelectable() { return false; } + public boolean isBlocking() { return false; } + public void configureBlocking(final boolean block) { } + public SelectionKey register(final Selector selector, final int ops) throws IOException { + throw new UnsupportedOperationException(); + } + }; + } +} diff --git a/src/test/resources/start_ssl_server.rb b/src/test/resources/start_ssl_server.rb new file mode 100644 index 00000000..2e30eb96 --- /dev/null +++ b/src/test/resources/start_ssl_server.rb @@ -0,0 +1,37 @@ +# Creates a connected SSL socket pair for Java unit tests. +# Returns [client_ssl, server_ssl] +# +# OpenSSL extension is loaded by SSLSocketTest.setUp via OpenSSL.load(runtime). + +require 'socket' + +key = OpenSSL::PKey::RSA.new(2048) +cert = OpenSSL::X509::Certificate.new +cert.version = 2 +cert.serial = 1 +cert.subject = cert.issuer = OpenSSL::X509::Name.parse('/CN=Test') +cert.public_key = key.public_key +cert.not_before = Time.now +cert.not_after = Time.now + 3600 +cert.sign(key, OpenSSL::Digest::SHA256.new) + +tcp_server = TCPServer.new('127.0.0.1', 0) +port = tcp_server.local_address.ip_port +ctx = OpenSSL::SSL::SSLContext.new +ctx.cert = cert +ctx.key = key +ssl_server = OpenSSL::SSL::SSLServer.new(tcp_server, ctx) +ssl_server.start_immediately = true + +server_ssl = nil +server_thread = Thread.new { server_ssl = ssl_server.accept } + +sock = TCPSocket.new('127.0.0.1', port) +sock.setsockopt(Socket::SOL_SOCKET, Socket::SO_SNDBUF, 4096) +sock.setsockopt(Socket::SOL_SOCKET, Socket::SO_RCVBUF, 4096) +client_ssl = OpenSSL::SSL::SSLSocket.new(sock) +client_ssl.sync_close = true +client_ssl.connect +server_thread.join(5) + +[client_ssl, server_ssl] diff --git a/src/test/ruby/ssl/test_write_flush.rb b/src/test/ruby/ssl/test_write_flush.rb new file mode 100644 index 00000000..f3bd7bd1 --- /dev/null +++ b/src/test/ruby/ssl/test_write_flush.rb @@ -0,0 +1,166 @@ +# frozen_string_literal: false +require File.expand_path('test_helper', File.dirname(__FILE__)) + +class TestSSLWriteFlush < TestCase + + include SSLTestHelper + + # Exercises the write_nonblock -> read transition used by net/http for POST + # requests. The bug (clear() instead of compact()) loses encrypted bytes that + # remain in netWriteData after a partial flushData on the *last* write_nonblock. + # + # We run multiple request/response rounds on the same TLS connection with + # varying payload sizes to increase the probability that at least one round + # triggers a partial flush at the write->read boundary. + # + # NOTE: On localhost the loopback interface rarely causes partial socket writes, + # so this test may not reliably catch regressions to clear(). The definitive + # coverage is in the Java-level SSLSocketTest which can control buffer state + # directly. This test serves as an integration smoke test for the write->read + # data path. + # + # https://github.com/jruby/jruby-openssl/issues/242 + def test_write_nonblock_data_integrity + # Payload sizes chosen to exercise different alignments with the TLS record + # layer (~16 KB records) and socket send buffer. Primes avoid lucky alignment. + payload_sizes = [ + 8_191, # just under 8 KB — fits in one TLS record + 16_381, # just under 16 KB — nearly one full TLS record + 65_521, # ~64 KB — several TLS records, common chunk size + 262_139, # ~256 KB — large payload, many partial flushes likely + ] + + # The server reads a 4-byte big-endian length prefix, then that many bytes + # of payload, and responds with "OK:" where hex_digest is the + # SHA-256 of the received payload. This is repeated for each payload size. + server_proc = proc { |ctx, ssl| + begin + payload_sizes.length.times do + # read 4-byte length prefix + header = read_exactly(ssl, 4) + break unless header && header.bytesize == 4 + expected_len = header.unpack('N')[0] + + # read payload + payload = read_exactly(ssl, expected_len) + break unless payload && payload.bytesize == expected_len + + digest = OpenSSL::Digest::SHA256.hexdigest(payload) + response = "OK:#{digest}" + ssl.write(response) + end + ensure + ssl.close rescue nil + end + } + + start_server0(PORT, OpenSSL::SSL::VERIFY_NONE, true, + server_proc: server_proc) do |server, port| + sock = TCPSocket.new("127.0.0.1", port) + # Constrain the send buffer to make partial flushes more likely. + # The kernel may round this up, but even a modest reduction helps. + sock.setsockopt(Socket::SOL_SOCKET, Socket::SO_SNDBUF, 2048) + sock.setsockopt(Socket::IPPROTO_TCP, Socket::TCP_NODELAY, 1) + + ssl = OpenSSL::SSL::SSLSocket.new(sock) + ssl.connect + ssl.sync_close = true + + payload_sizes.each do |size| + data = generate_test_data(size) + expected_digest = OpenSSL::Digest::SHA256.hexdigest(data) + + # Send length-prefixed payload via write_nonblock + message = [size].pack('N') + data + write_nonblock_all(ssl, message) + + # Immediately switch to reading — this is where the bug manifests: + # if compact() was replaced with clear(), residual encrypted bytes + # from the last write_nonblock are lost and the server never + # receives the complete payload. + response = read_with_timeout(ssl, 5) + + assert_equal "OK:#{expected_digest}", response, + "Data integrity failure for #{size}-byte payload: " \ + "server did not receive the complete payload or it was corrupted" + end + + ssl.close + end + end + + private + + # Generate non-trivial test data that won't compress well in TLS. + # Uses a seeded PRNG so failures are reproducible, and avoids + # OpenSSL::Random which has a per-call size limit in some BC versions. + def generate_test_data(size) + rng = Random.new(size) # seeded for reproducibility + (0...size).map { rng.rand(256).chr }.join.b + end + + # Write all of +data+ via write_nonblock, retrying on WaitWritable. + # Does NOT do any extra flushing after the last write — this is critical + # for exercising the bug where clear() loses the tail of encrypted data. + def write_nonblock_all(ssl, data) + remaining = data + while remaining.bytesize > 0 + begin + written = ssl.write_nonblock(remaining) + remaining = remaining.byteslice(written..-1) + rescue IO::WaitWritable + IO.select(nil, [ssl]) + retry + end + end + end + + # Read a complete response from the SSL socket with a timeout. + # Returns the accumulated data, or fails the test on timeout. + def read_with_timeout(ssl, timeout_sec) + response = "" + deadline = Time.now + timeout_sec + loop do + remaining = deadline - Time.now + if remaining <= 0 + flunk "Timed out after #{timeout_sec}s waiting for server response " \ + "(got #{response.bytesize} bytes so far: #{response.inspect[0, 80]})" + end + if IO.select([ssl], nil, nil, [remaining, 0.5].min) + begin + chunk = ssl.read_nonblock(16384, exception: false) + case chunk + when :wait_readable then next + when nil then break # EOF + else + response << chunk + # Our protocol responses are short ("OK:<64 hex chars>"), so if + # we've received a plausible amount we can stop. + break if response.include?("OK:") && response.bytesize >= 67 + end + rescue EOFError + break + end + end + end + response + end + + # Read exactly +n+ bytes from an SSL socket, retrying partial reads. + def self.read_exactly(ssl, n) + buf = "" + while buf.bytesize < n + chunk = ssl.readpartial(n - buf.bytesize) + buf << chunk + end + buf + rescue EOFError + buf + end + + # Instance method wrapper for use in server_proc + def read_exactly(ssl, n) + self.class.read_exactly(ssl, n) + end + +end