Published on

Building a PostgreSQL Wire Protocol Server using Vanilla, Modern Java 21

Authors
Table of Contents

This tutorial is meant to serve as a dual purpose guide:

  • Demonstrate how to implement the Simple Query Protocol, where Java is an implementation detail
  • Show practical examples of most of the new features since JDK 17, including:
    • Records (JEP 395)
    • Sealed Types (JEP 360/JEP 409)
    • Pattern Matching for Switch (JEP 406)
    • Virtual Threads aka Project Loom (JEP 425)
    • Foreign-Function & Memory API (FMM) aka Project Panama (JEP 424)
    • (Also give a practical example of java.nio's AsynchronousChannelGroup and AsynchronousServerSocketChannel, for which there are few examples online)

Outline

This tutorial will be broken up into a series of 2 parts:

  1. Because we are diligent coders who know to Do The Simplest Thing That Could Possibly Work™, our first step will be an horrifically ugly, but functional, implementation of the PostgreSQL wire protocol server.

  2. Then, we will refactor our code to be more readable and maintainable.

Introduction to the PostgreSQL Wire Protocol

The PostgreSQL wire protocol is a binary protocol that is used to communicate between a PostgreSQL client and server.

The protocol is documented in the PostgreSQL Protocol Documentation.

In my opinion, this documentation is not the most understandable. If you want to learn more about the protocol, I recommend the following presentation:

What we're concerned with today, is primarily the following pieces of information:

  • Postgres clients send two types of message to the server: Startup Messages and Commands
  • Optionally, the Startup message can be preceded by an SSL Negotiation message, where the client asks the server if it supports SSL

Visualized as a Mermaid diagram:

There are many types of commands, but today we'll only be concerned with the Query command, which is used to execute a SQL query.

Knowing this, we can now start implementing our server.

Step 1: Doing The Simplest Thing That Could Possibly Work (TM)

We'll begin by implementing a basic java.nio.channels.AsynchronousServerSocketChannel server, which will accept connections and print out the messages it receives:

Step 1.1 - Initial Server Skeleton

Below is the initial skeleton of our server.

  • We create a java.nio.channels.AsynchronousServerSocketChannel and bind it to localhost and the default Postgres port (5432).
  • An ExecutorService is created, which will be used to create a java.nio.channels.AsynchronousChannelGroup for our server.
  • We use the newVirtualThreadPerTaskExecutor method, which will create a new Loom Virtual Thread for the server thread pool.
  • Then, we accept connections and print out the messages we receive.
package postgres.wire.protocol;

import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.AsynchronousChannelGroup;
import java.nio.channels.AsynchronousServerSocketChannel;
import java.nio.channels.AsynchronousSocketChannel;
import java.nio.channels.CompletionHandler;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

class AsynchronousSocketServer {
    private static final String HOST = "localhost";
    private static final int PORT = 5432;

    public static void main(String[] args) throws Exception {
        ExecutorService executor = Executors.newVirtualThreadPerTaskExecutor();
        AsynchronousChannelGroup group = AsynchronousChannelGroup.withThreadPool(executor);

        try (AsynchronousServerSocketChannel server = AsynchronousServerSocketChannel.open(group)) {
            server.bind(new InetSocketAddress(HOST, PORT));
            System.out.println("[SERVER] Listening on " + HOST + ":" + PORT);

            for (;;) {
                Future<AsynchronousSocketChannel> future = server.accept();
                AsynchronousSocketChannel client = future.get();
                System.out.println("[SERVER] Accepted connection from " + client.getRemoteAddress());
                ByteBuffer buffer = ByteBuffer.allocate(1024);
                client.read(buffer, buffer, new CompletionHandler<>() {
                    @Override
                    public void completed(Integer result, ByteBuffer attachment) {
                        attachment.flip();
                        if (result != -1) {
                            onMessageReceived(client, attachment);
                        }
                        attachment.clear();
                        client.read(attachment, attachment, this);
                    }

                    @Override
                    public void failed(Throwable exc, ByteBuffer attachment) {
                        System.err.println("[SERVER] Failed to read from client: " + exc);
                        exc.printStackTrace();
                    }
                });
            }
        }
    }

    private static void onMessageReceived(AsynchronousSocketChannel client, ByteBuffer buffer) {
        System.out.println("[SERVER] Received message from client: " + client);
        System.out.println("[SERVER] Buffer: " + buffer);
    }
}

class MainSimplest {
    public static void main(String[] args) throws Exception {
        AsynchronousSocketServer.main(args);
    }
}

If we start this up, we should see:

Running Gradle on WSL...

> Task :app:compileJava

> Task :app:processResources NO-SOURCE
Note: Some input files use preview features of Java SE 21.
Note: Recompile with -Xlint:preview for details.
> Task :app:classes

> Task :app:MainSimplestWalkThrough.main()
[SERVER] Listening on localhost:5432

Connecting to the Server with psql

Now, we can connect to our server using psql:

$ psql -h localhost -p 5432 -U postgres

We should see psql hang at the prompt, and the server should print out the following:

[SERVER] Accepted connection from /127.0.0.1:41826
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:41826]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=8 cap=1024]

Great! We're now able to receive messages from the client.

Step 1.2 - Responding to the SSL Negotiation message and the Startup Message

What we want to do is just make sure that we're able to receive messages from the client, and respond to the:

  • Initial SSL Negotiation message with a 'N' byte (for No)
  • Startup Message with an AuthenticationOk message

Finally, we'll write a:

  • BackendKeyData message, which is used to identify the connection to the client
  • ReadyForQuery message, which indicates that the server is ready to accept commands.

Below is the updated code:

    private static void onMessageReceived(AsynchronousSocketChannel client, ByteBuffer buffer) {
        System.out.println("[SERVER] Received message from client: " + client);
        System.out.println("[SERVER] Buffer: " + buffer);

        // First, write 'N' for SSL negotiation
        ByteBuffer response = ByteBuffer.allocate(1);
        response.put((byte) 'N');
        response.flip();
        Future<Integer> writeResult = client.write(response);

        // Then, write AuthenticationOk
        ByteBuffer authOk = ByteBuffer.allocate(9);
        authOk.put((byte) 'R'); // 'R' for AuthenticationOk
        authOk.putInt(8); // Length
        authOk.putInt(0); // AuthenticationOk
        authOk.flip();
        writeResult = client.write(authOk);

        // Then, write BackendKeyData
        ByteBuffer backendKeyData = ByteBuffer.allocate(17);
        backendKeyData.put((byte) 'K'); // Message type
        backendKeyData.putInt(12); // Message length
        backendKeyData.putInt(1234); // Process ID
        backendKeyData.putInt(5678); // Secret key
        backendKeyData.flip();
        writeResult = client.write(backendKeyData);

        // Then, write ReadyForQuery
        ByteBuffer readyForQuery = ByteBuffer.allocate(6);
        readyForQuery.put((byte) 'Z'); // 'Z' for ReadyForQuery
        readyForQuery.putInt(5); // Length
        readyForQuery.put((byte) 'I'); // Transaction status indicator, 'I' for idle
        readyForQuery.flip();
        writeResult = client.write(readyForQuery);

        try {
            writeResult.get();
        } catch (Exception e) {
            System.err.println("[SERVER] Failed to write to client: " + e);
        }
    }

From this point on, it'll be useful to be able to visualize the messages that we're sending and receiving.

There are a few tools that you can use to do this:

I recommend using Wireshark's GUI, it's the easiest to use. For this tutorial I'll be using pgs-debug, for two reasons:

  • Wireshark doesn't work on WSL
  • I want to be able to paste the ASCII output into the tutorial

NOTE: If you want a video tutorial on how to capture Postgres traffic with Wireshark, I have a short demo on my pgprotokt repo here:

To capture the output with pgs-debug, I'll be using the command:

# Capture on loopback interface
$ sudo pgs-debug --interface lo

If we start up the server and connect with psql, we should see the following:

  • psql client:
[user@MSI ~]$ psql -h localhost -p 5432 -U postgres
psql (15.0, server 0.0.0)
WARNING: psql major version 15, server major version 0.0.
         Some psql features might not work.
Type "help" for help.

postgres=>
  • pgs-debug output:
[user@MSI ~]$ sudo pgs-debug --interface lo
Packet: t=1673886702.924458, session=213070643347544
PGSQL: type=SSLRequest, F -> B
SSL REQUEST

Packet: t=1673886702.928187, session=213070643347544
PGSQL: type=SSLAnswer, B -> F
SSL BACKEND ANSWER: N

Packet: t=1673886702.928222, session=213070643347544
PGSQL: type=StartupMessage, F -> B
STARTUP MESSAGE version: 3
  application_name=psql
  database=postgres
  client_encoding=UTF8
  user=postgres


Packet: t=1673886702.928318, session=213070643347544
PGSQL: type=AuthenticationOk, B -> F
AUTHENTIFICATION REQUEST code=0 (SUCCESS)

Packet: t=1673886702.970239, session=213070643347544
PGSQL: type=BackendKeyData, B -> F
BACKEND KEY DATA pid=1234, key=5678

Packet: t=1673886702.970239, session=213070643347544
PGSQL: type=ReadyForQuery, B -> F
READY FOR QUERY type=<IDLE>
  • Server output:
[SERVER] Listening on localhost:5432
[SERVER] Accepted connection from /127.0.0.1:47544
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:47544]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=8 cap=1024]
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:47544]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=84 cap=1024]

Step 1.3 - Differentiating between the SSL/Authentication request, and Command messages

We need to be able to differentiate between the SSL Negotiation message, the Authentication request, and the standard command messages.

This is so that we can properly route the messages to the appropriate handlers. Otherwise we wouldn't be able to serve more than one client at a time.

To do this, we can create some predicate helpers for testing the message type to determine whether it's the SSL Request or Startup Message.

    static Predicate<ByteBuffer> isSSLRequest = (ByteBuffer b) -> {
        return b.get(4) == 0x04
                && b.get(5) == (byte) 0xd2
                && b.get(6) == 0x16
                && b.get(7) == 0x2f;
    };

    static Predicate<ByteBuffer> isStartupMessage = (ByteBuffer b) -> {
        return b.remaining() > 8
                && b.get(4) == 0x00
                && b.get(5) == 0x03 // Protocol version 3
                && b.get(6) == 0x00
                && b.get(7) == 0x00;
    };


    private static void onMessageReceived(AsynchronousSocketChannel client, ByteBuffer buffer) {
        System.out.println("[SERVER] Received message from client: " + client);
        System.out.println("[SERVER] Buffer: " + buffer);

        Future<Integer> writeResult = null;

        if (isSSLRequest.test(buffer)) {
            System.out.println("[SERVER] SSL Request");
            ByteBuffer sslResponse = ByteBuffer.allocate(1);
            sslResponse.put((byte) 'N');
            sslResponse.flip();
            writeResult = client.write(sslResponse);
        } else if (isStartupMessage.test(buffer)) {
            System.out.println("[SERVER] Startup Message");

            // Then, write AuthenticationOk
            ByteBuffer authOk = ByteBuffer.allocate(9);
            authOk.put((byte) 'R'); // 'R' for AuthenticationOk
            authOk.putInt(8); // Length
            authOk.putInt(0); // AuthenticationOk
            authOk.flip();
            writeResult = client.write(authOk);

            // Then, write BackendKeyData
            ByteBuffer backendKeyData = ByteBuffer.allocate(17);
            backendKeyData.put((byte) 'K'); // Message type
            backendKeyData.putInt(12); // Message length
            backendKeyData.putInt(1234); // Process ID
            backendKeyData.putInt(5678); // Secret key
            backendKeyData.flip();
            writeResult = client.write(backendKeyData);

            // Then, write ReadyForQuery
            ByteBuffer readyForQuery = ByteBuffer.allocate(6);
            readyForQuery.put((byte) 'Z'); // 'Z' for ReadyForQuery
            readyForQuery.putInt(5); // Length
            readyForQuery.put((byte) 'I'); // Transaction status indicator, 'I' for idle
            readyForQuery.flip();
            writeResult = client.write(readyForQuery);
        } else {
            System.out.println("[SERVER] Unknown message");
        }

        try {
            System.out.println("[SERVER] Write result: " + writeResult.get());
        } catch (Exception e) {
            System.err.println("[SERVER] Failed to write to client: " + e);
            e.printStackTrace();
        }
    }

If we re-start and reconnect with psql, we should now see:

[SERVER] Listening on localhost:5432
[SERVER] Accepted connection from /127.0.0.1:35090
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:35090]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=8 cap=1024]
[SERVER] SSL Request
[SERVER] Write result: 1
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:35090]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=84 cap=1024]
[SERVER] Startup Message
[SERVER] Write result: 6

Step 1.4 - Handle a query and return data rows

Now, the moment you've likely all been waiting for. Let's handle a query and return some data rows.

"Handle" in this case means that we'll just return a hard-coded set of rows, and not actually query a database (sorry to disappoint!). I did say "no libraries" =(

To do this, we'll need to handle the Query message, and then send a RowDescription message, followed by a DataRow message for each row, and finally a CommandComplete message.

  • For our RowDescription message, we'll send two columns, with names "id" and "name"
  • For our DataRow message, we'll send two rows, with values (1, "one") and (2, "two")

To complete the cycle, we lastly need to follow up with a ReadyForQuery message.

This is the hairiest part of the protocol, so the annotated code below should hopefully help you understand what's going on:

} else {
    System.out.println("[SERVER] Unknown message");
    // Let's assume it's a query message, and just send a simple response
    // First we send a RowDescription. We'll send two columns, with names "id" and "name"
    ByteBuffer rowDescription = ByteBuffer.allocate(51);
    rowDescription.put((byte) 'T'); // 'T' for RowDescription
    rowDescription.putInt(50); // Length
    rowDescription.putShort((short) 2); // Number of fields/columns
    // For each field/column:
    rowDescription.put("id".getBytes()).put((byte) 0); // Column name of column 1 (null-terminated)
    rowDescription.putInt(0); // Object ID of column 1
    rowDescription.putShort((short) 0); // Attribute number of column 1
    rowDescription.putInt(23); // Data type OID of column 1
    rowDescription.putShort((short) 4); // Data type size of column 1
    rowDescription.putInt(-1); // Type modifier of column 1
    rowDescription.putShort((short) 0); // Format code of column 1

    rowDescription.put("name".getBytes()).put((byte) 0); // Column name of column 2 (null-terminated)
    rowDescription.putInt(0); // Object ID of column 2
    rowDescription.putShort((short) 0); // Attribute number of column 2
    rowDescription.putInt(25); // Data type OID of column 2
    rowDescription.putShort((short) -1); // Data type size of column 2
    rowDescription.putInt(-1); // Type modifier of column 2
    rowDescription.putShort((short) 0); // Format code of column 2
    rowDescription.flip();
    writeResult = client.write(rowDescription);

    // Then we send a DataRow for each row. We'll send two rows, with values (1, "one") and (2, "two")
    ByteBuffer dataRow1 = ByteBuffer.allocate(19);
    dataRow1.put((byte) 'D'); // 'D' for DataRow
    dataRow1.putInt(18); // Length (4)
    dataRow1.putShort((short) 2); // Number of columns (5-6)
    dataRow1.putInt(1); // Length of column 1 (7-10)
    dataRow1.put((byte) '1'); // Value of column 1 (11-11)
    dataRow1.putInt(3); // Length of column 2 (12-15)
    dataRow1.put("one".getBytes()); // Value of column 2 (16-18)
    dataRow1.flip();
    writeResult = client.write(dataRow1);

    ByteBuffer dataRow2 = ByteBuffer.allocate(19);
    dataRow2.put((byte) 'D'); // 'D' for DataRow
    dataRow2.putInt(18); // Length
    dataRow2.putShort((short) 2); // Number of columns
    dataRow2.putInt(1); // Length of column 2
    dataRow2.put((byte) '2'); // Value of column 2
    dataRow2.putInt(3); // Length of column 2
    dataRow2.put("two".getBytes()); // Value of column 2
    dataRow2.flip();
    writeResult = client.write(dataRow2);

    // We send a CommandComplete
    ByteBuffer commandComplete = ByteBuffer.allocate(14);
    commandComplete.put((byte) 'C'); // 'C' for CommandComplete
    commandComplete.putInt(13); // Length
    commandComplete.put("SELECT 2".getBytes()); // Command tag
    commandComplete.put((byte) 0); // Null terminator
    commandComplete.flip();
    writeResult = client.write(commandComplete);

    // Finally, write ReadyForQuery
    ByteBuffer readyForQuery = ByteBuffer.allocate(6);
    readyForQuery.put((byte) 'Z'); // 'Z' for ReadyForQuery
    readyForQuery.putInt(5); // Length
    readyForQuery.put((byte) 'I'); // Transaction status indicator, 'I' for idle
    readyForQuery.flip();
    writeResult = client.write(readyForQuery);
}

If we run this, we should see the following output:

  • psql Client:
$ psql -h localhost -p 5432 -U postgres
psql (15.0, server 0.0.0)
WARNING: psql major version 15, server major version 0.0.
         Some psql features might not work.
Type "help" for help.

postgres=> select 1;
 id | name
----+------
  1 | one
  2 | two
(2 rows)

postgres=>
  • pgs-debug output:
Packet: t=1673887134.207655, session=213070643346422
PGSQL: type=SSLRequest, F -> B
SSL REQUEST

Packet: t=1673887134.210943, session=213070643346422
PGSQL: type=SSLAnswer, B -> F
SSL BACKEND ANSWER: N

Packet: t=1673887134.210986, session=213070643346422
PGSQL: type=StartupMessage, F -> B
STARTUP MESSAGE version: 3
  client_encoding=UTF8
  user=postgres
  database=postgres
  application_name=psql


Packet: t=1673887134.211448, session=213070643346422
PGSQL: type=AuthenticationOk, B -> F
AUTHENTIFICATION REQUEST code=0 (SUCCESS)

Packet: t=1673887134.260990, session=213070643346422
PGSQL: type=BackendKeyData, B -> F
BACKEND KEY DATA pid=1234, key=5678

Packet: t=1673887134.260990, session=213070643346422
PGSQL: type=ReadyForQuery, B -> F
READY FOR QUERY type=<IDLE>

Packet: t=1673887136.771401, session=213070643346422
PGSQL: type=Query, F -> B
QUERY query=select 1;

Packet: t=1673887136.772593, session=213070643346422
PGSQL: type=RowDescription, B -> F
ROW DESCRIPTION: num_fields=2
  ---[Field 01]---
  name='id'
  type=23
  type_len=4
  type_mod=4294967295
  relid=0
  attnum=0
  format=0
  ---[Field 02]---
  name='name'
  type=25
  type_len=65535
  type_mod=4294967295
  relid=0
  attnum=0
  format=0


Packet: t=1673887136.772650, session=213070643346422
PGSQL: type=DataRow, B -> F
DATA ROW num_values=2
  ---[Value 0001]---
  length=1
  value='1'
  ---[Value 0002]---
  length=3
  value='one'


Packet: t=1673887136.772717, session=213070643346422
PGSQL: type=DataRow, B -> F
DATA ROW num_values=2
  ---[Value 0001]---
  length=1
  value='2'
  ---[Value 0002]---
  length=3
  value='two'


Packet: t=1673887136.772759, session=213070643346422
PGSQL: type=CommandComplete, B -> F
COMMAND COMPLETE command='SELECT 2'

Packet: t=1673887136.772792, session=213070643346422
PGSQL: type=ReadyForQuery, B -> F
READY FOR QUERY type=<IDLE>
  • Server output:
[SERVER] Listening on localhost:5432
[SERVER] Accepted connection from /127.0.0.1:46422
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:46422]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=8 cap=1024]
[SERVER] SSL Request
[SERVER] Write result: 1
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:46422]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=84 cap=1024]
[SERVER] Startup Message
[SERVER] Write result: 6
[SERVER] Received message from client: sun.nio.ch.UnixAsynchronousSocketChannelImpl[connected local=/127.0.0.1:5432 remote=/127.0.0.1:46422]
[SERVER] Buffer: java.nio.HeapByteBuffer[pos=0 lim=15 cap=1024]
[SERVER] Unknown message
[SERVER] Write result: 6

Step 2: Refactoring

Now that we have a working server, we can start refactoring the code to make it more readable and maintainable.

This is where we'll pull in some of the more advanced features of recent Java versions, like Sealed Interfaces, Records, and Pattern Matching.

Our plan for refactoring is going to be to:

  • Pull both the Client -> Server, and Server -> Client messages into their own sealed interfaces, where each message type is represented by a record.
  • We'll then use encoders/decoders to convert between the raw bytes and the message types
    • And do this in a succinct way using pattern matching in a return switch + yield statement.
  • Along the way, we'll see how we can use some of the API methods from the Foreign Function and Memory JEP to help us work with C-Strings from raw bytes.

In a Mermaid diagram, the hierarchy of sealed interface types for client and server messages looks like this:

Refactoring into Sealed Interfaces with Records and Message Decoders (Pt 1. Client Side)

We'll start with the client side, since it's a bit simpler.

Most of this code is pretty straightforward, though there's one interesting bit that I'll point out.

When we decode the StartupMessage, what we have are a series of null-terminated key/value string pairs.

The Foreign Function and Memory JEP provides a MemorySegment class that we can use to work with raw bytes. One of the convenience methods on MemorySegment is getUtf8String(long offset), which will read a null-terminated string from the given offset.

This method is incredibly useful, previously it required a lot of boilerplate when working with ByteBuffer to read a null-terminated C-String. Unfortunately this method does not directly return a Java string, so we need to do some extra conversion to convert the MemorySegment bytes further to a String.

With these records and encoders in place, it allows our dispatch logic to become a succinct pattern match statement:

private static void onMessageReceived(AsynchronousSocketChannel client, ByteBuffer buffer) {
    PGClientMessage message = PGClientMessage.decode(buffer);
    switch (message) {
        case PGClientMessage.SSLNegotation ssl -> handleSSLRequest(ssl, client);
        case PGClientMessage.StartupMessage startup -> handleStartupMessage(startup, client);
        case PGClientMessage.QueryMessage query -> handleQueryMessage(query, client);
    }
}

The full code for the client message types and decoders is below:

sealed interface PGClientMessage permits
    PGClientMessage.SSLNegotation,
    PGClientMessage.StartupMessage,
    PGClientMessage.QueryMessage {

    record SSLNegotation() implements PGClientMessage {
    }

    record StartupMessage(Map<String, String> parameters) implements PGClientMessage {
    }

    record QueryMessage(String query) implements PGClientMessage {
    }

    Predicate<ByteBuffer> isSSLRequest = (ByteBuffer b) -> {
        return b.get(4) == 0x04
                && b.get(5) == (byte) 0xd2
                && b.get(6) == 0x16
                && b.get(7) == 0x2f;
    };

    Predicate<ByteBuffer> isStartupMessage = (ByteBuffer b) -> {
        return b.remaining() > 8
                && b.get(4) == 0x00
                && b.get(5) == 0x03 // Protocol version 3
                && b.get(6) == 0x00
                && b.get(7) == 0x00;
    };

    static PGClientMessage decode(ByteBuffer buffer) {
        if (isSSLRequest.test(buffer)) {
            return new SSLNegotation();
        } else if (isStartupMessage.test(buffer)) {
            var segment = MemorySegment.ofBuffer(buffer);
            var length = buffer.getInt(0);
            var parameters = new HashMap<String, String>();
            var offset = 8;
            while (offset < length - 1) {
                var name = segment.getUtf8String(offset);
                offset += name.length() + 1;
                var value = segment.getUtf8String(offset);
                offset += value.length() + 1;
                parameters.put(name, value);
            }
            return new StartupMessage(parameters);
        } else {
            // Assume it's a query message
            var query = MemorySegment.ofBuffer(buffer).getUtf8String(5);
            return new QueryMessage(query);
        }
    }
}

// In "AsynchronousSocketServer"
private static void onMessageReceived(AsynchronousSocketChannel client, ByteBuffer buffer) {
    System.out.println("[SERVER] Received message from client: " + client);
    System.out.println("[SERVER] Buffer: " + buffer);

    PGClientMessage message = PGClientMessage.decode(buffer);
    switch (message) {
        case PGClientMessage.SSLNegotation ssl -> handleSSLRequest(ssl, client);
        case PGClientMessage.StartupMessage startup -> handleStartupMessage(startup, client);
        case PGClientMessage.QueryMessage query -> handleQueryMessage(query, client);
    }
}

// Where each of those methods contains the previous logic it held in the "if/else" statement, for example:
private static void handleSSLRequest(PGClientMessage.SSLNegotation sslRequest, AsynchronousSocketChannel client) {
    System.out.println("[SERVER] SSL Request: " + sslRequest);
    ByteBuffer sslResponse = ByteBuffer.allocate(1);
    sslResponse.put((byte) 'N');
    sslResponse.flip();
    try {
        client.write(sslResponse).get();
    } catch (Exception e) {
        e.printStackTrace();
    }
}

Refactoring into Sealed Interfaces with Records and Message Decoders (Pt 2. Server Side)

The final section of our journey, is just more of the same, so I will leave this section a bit more brief.

It does what it says on the tin, we refactor the server side to use sealed interfaces with records and message decoders.

sealed interface PGServerMessage permits
        PGServerMessage.AuthenticationRequest, PGServerMessage.BackendKeyData,
        PGServerMessage.ReadyForQuery, PGServerMessage.RowDescription,
        PGServerMessage.DataRow, PGServerMessage.CommandComplete {

    record AuthenticationRequest() implements PGServerMessage {
    }

    record BackendKeyData(int processId, int secretKey) implements PGServerMessage {
    }

    record ReadyForQuery() implements PGServerMessage {
    }

    record RowDescription(List<RowDescription.Field> fields) implements PGServerMessage {

        record Field(
                String name,
                int tableObjectId,
                int columnAttributeNumber,
                int dataTypeObjectId,
                int dataTypeSize,
                int typeModifier,
                int formatCode) {

            int length() {
                // 4 (int tableObjectId) + 2 (short columnAttributeNumber) + 4 (int dataTypeObjectId) + 2 (short dataTypeSize) + 4 (int typeModifier) + 2 (short formatCode)
                // 4 + 2 + 4 + 2 + 4 + 2 = 18
                // Add name length, plus 1 for null terminator '\0'
                return 18 + name.length() + 1;
            }
        }
    }

    record DataRow(List<ByteBuffer> values) implements PGServerMessage {
    }

    record CommandComplete(String commandTag) implements PGServerMessage {
    }

    static ByteBuffer encode(PGServerMessage message) {
        return switch (message) {
            case AuthenticationRequest auth -> {
                var buffer = ByteBuffer.allocate(9);
                buffer.put((byte) 'R'); // 'R' for AuthenticationRequest
                buffer.putInt(8); // Length
                buffer.putInt(0); // Authentication type, 0 for OK
                buffer.flip();
                yield buffer;
            }
            case BackendKeyData keyData -> {
                var buffer = ByteBuffer.allocate(13);
                buffer.put((byte) 'K'); // 'K' for BackendKeyData
                buffer.putInt(12); // Length
                buffer.putInt(keyData.processId()); // Process ID
                buffer.putInt(keyData.secretKey()); // Secret key
                buffer.flip();
                yield buffer;
            }
            case ReadyForQuery ready -> {
                var buffer = ByteBuffer.allocate(6);
                buffer.put((byte) 'Z'); // 'Z' for ReadyForQuery
                buffer.putInt(5); // Length
                buffer.put((byte) 'I'); // Transaction status indicator, 'I' for idle
                buffer.flip();
                yield buffer;
            }
            case RowDescription rowDesc -> {
                var fields = rowDesc.fields();
                var length = 4 + 2 + fields.stream().mapToInt(RowDescription.Field::length).sum();
                var buffer = ByteBuffer.allocate(length + 1)
                        .put((byte) 'T')
                        .putInt(length)
                        .putShort((short) fields.size());
                fields.forEach(field -> buffer
                        .put(field.name().getBytes(StandardCharsets.UTF_8))
                        .put((byte) 0) // null-terminated
                        .putInt(field.tableObjectId())
                        .putShort((short) field.columnAttributeNumber())
                        .putInt(field.dataTypeObjectId())
                        .putShort((short) field.dataTypeSize())
                        .putInt(field.typeModifier())
                        .putShort((short) field.formatCode()));
                buffer.flip();
                yield buffer;
            }
            case DataRow dataRow -> {
                var values = dataRow.values();
                // For each value, we need to add 4 bytes for the length, plus the length of the value
                var length = 4 + 2 + values.stream().map(it -> it.remaining() + 4).reduce(0, Integer::sum);
                var buffer = ByteBuffer.allocate(length + 1) // +1 for msg type
                        .put((byte) 'D')
                        .putInt(length) // +4 for length
                        .putShort((short) values.size()); // +2 for number of columns
                for (var value : values) {
                    buffer.putInt(value.remaining());
                    buffer.put(value);
                }
                buffer.flip();
                yield buffer;
            }
            case CommandComplete cmdComplete -> {
                var commandTag = cmdComplete.commandTag();
                var length = 4 + commandTag.length() + 1;
                yield ByteBuffer.allocate(length + 1) // +1 for msg type
                        .put((byte) 'C')
                        .putInt(length) // +4 for length
                        .put(commandTag.getBytes(StandardCharsets.UTF_8))
                        .put((byte) 0) // null terminator
                        .flip();
            }
        };
    }
}

// In "AsynchronousSocketServer"
private static void handleStartupMessage(PGClientMessage.StartupMessage startup, AsynchronousSocketChannel client) {
    System.out.println("[SERVER] Startup Message: " + startup);

    Future<Integer> writeResult;

    // Then, write AuthenticationOk
    PGServerMessage.AuthenticationRequest authRequest = new PGServerMessage.AuthenticationRequest();
    writeResult = client.write(PGServerMessage.encode(authRequest));


    // Then, write BackendKeyData
    PGServerMessage.BackendKeyData backendKeyData = new PGServerMessage.BackendKeyData(1234, 5678);
    writeResult = client.write(PGServerMessage.encode(backendKeyData));

    // Then, write ReadyForQuery
    PGServerMessage.ReadyForQuery readyForQuery = new PGServerMessage.ReadyForQuery();
    writeResult = client.write(PGServerMessage.encode(readyForQuery));

    try {
        writeResult.get();
    } catch (Exception e) {
        e.printStackTrace();
    }
}

private static void handleQueryMessage(PGClientMessage.QueryMessage query, AsynchronousSocketChannel client) {
    System.out.println("[SERVER] Query Message: " + query);

    Future<Integer> writeResult;

    // Let's assume it's a query message, and just send a simple response
    // First we send a RowDescription. We'll send two columns, with names "id" and "name"
    PGServerMessage.RowDescription rowDescription = new PGServerMessage.RowDescription(List.of(
            new PGServerMessage.RowDescription.Field("id", 0, 0, 23, 4, -1, 0),
            new PGServerMessage.RowDescription.Field("name", 0, 0, 25, -1, -1, 0)
    ));
    writeResult = client.write(PGServerMessage.encode(rowDescription));


    // Then we send a DataRow for each row. We'll send two rows, with values (1, "one") and (2, "two")
    PGServerMessage.DataRow dataRow1 = new PGServerMessage.DataRow(List.of(
            ByteBuffer.wrap("1".getBytes(StandardCharsets.UTF_8)),
            ByteBuffer.wrap("one".getBytes(StandardCharsets.UTF_8))
    ));
    writeResult = client.write(PGServerMessage.encode(dataRow1));

    PGServerMessage.DataRow dataRow2 = new PGServerMessage.DataRow(List.of(
            ByteBuffer.wrap("2".getBytes(StandardCharsets.UTF_8)),
            ByteBuffer.wrap("two".getBytes(StandardCharsets.UTF_8))
    ));

    writeResult = client.write(PGServerMessage.encode(dataRow2));

    // We send a CommandComplete
    PGServerMessage.CommandComplete commandComplete = new PGServerMessage.CommandComplete("SELECT 2");
    writeResult = client.write(PGServerMessage.encode(commandComplete));

    // Finally, write ReadyForQuery
    PGServerMessage.ReadyForQuery readyForQuery = new PGServerMessage.ReadyForQuery();
    writeResult = client.write(PGServerMessage.encode(readyForQuery));

    try {
        writeResult.get();
    } catch (Exception e) {
        e.printStackTrace();
    }
}