From d275e53487cf26c8d11a994be9fa21e0bf1e894b Mon Sep 17 00:00:00 2001 From: Achille Roussel Date: Sat, 27 Dec 2025 12:17:40 -0800 Subject: [PATCH] pgconn: support configuring the pgproto3 tracer before opening a connection Signed-off-by: Achille Roussel --- pgconn/config.go | 9 + pgconn/pgconn.go | 4 + pgproto3/backend.go | 2 - pgproto3/frontend.go | 2 - pgproto3/trace.go | 322 ++++++++++++++++++++++++++++++------ pgproto3/trace_test.go | 362 +++++++++++++++++++++++++++++++++++++++++ 6 files changed, 651 insertions(+), 50 deletions(-) diff --git a/pgconn/config.go b/pgconn/config.go index d5914aad9..383134167 100644 --- a/pgconn/config.go +++ b/pgconn/config.go @@ -83,6 +83,15 @@ type Config struct { // that you close on FATAL errors by returning false. OnPgError PgErrorHandler + // Tracer is an io.Writer to which the PostgreSQL frontend/backend protocol messages will be logged. + // The format roughly mimics the format produced by the libpq C function PQtrace. Messages are logged + // from the connection handshake onwards, providing visibility into authentication, parameter status + // messages, backend key data, and all subsequent protocol traffic. + Tracer io.Writer + + // TracerOptions controls tracing behavior. Only relevant when Tracer is set. + TracerOptions pgproto3.TracerOptions + createdByParseConfig bool // Used to enforce created by ParseConfig rule. } diff --git a/pgconn/pgconn.go b/pgconn/pgconn.go index ab4fc514b..ecd493c41 100644 --- a/pgconn/pgconn.go +++ b/pgconn/pgconn.go @@ -369,6 +369,10 @@ func connectOne(ctx context.Context, config *Config, connectConfig *connectOneCo pgConn.bgReaderStarted = make(chan struct{}) pgConn.frontend = config.BuildFrontend(pgConn.bgReader, pgConn.conn) + if config.Tracer != nil { + pgConn.frontend.Trace(config.Tracer, config.TracerOptions) + } + startupMsg := pgproto3.StartupMessage{ ProtocolVersion: pgproto3.ProtocolVersionNumber, Parameters: make(map[string]string), diff --git a/pgproto3/backend.go b/pgproto3/backend.go index d9d0f370c..ba89058b8 100644 --- a/pgproto3/backend.go +++ b/pgproto3/backend.go @@ -1,7 +1,6 @@ package pgproto3 import ( - "bytes" "encoding/binary" "fmt" "io" @@ -105,7 +104,6 @@ func (b *Backend) Flush() error { func (b *Backend) Trace(w io.Writer, options TracerOptions) { b.tracer = &tracer{ w: w, - buf: &bytes.Buffer{}, TracerOptions: options, } } diff --git a/pgproto3/frontend.go b/pgproto3/frontend.go index 056e547cd..8939d4dc3 100644 --- a/pgproto3/frontend.go +++ b/pgproto3/frontend.go @@ -1,7 +1,6 @@ package pgproto3 import ( - "bytes" "encoding/binary" "errors" "fmt" @@ -124,7 +123,6 @@ func (f *Frontend) Flush() error { func (f *Frontend) Trace(w io.Writer, options TracerOptions) { f.tracer = &tracer{ w: w, - buf: &bytes.Buffer{}, TracerOptions: options, } } diff --git a/pgproto3/trace.go b/pgproto3/trace.go index 6cc7d3e36..ed8f90b50 100644 --- a/pgproto3/trace.go +++ b/pgproto3/trace.go @@ -2,10 +2,12 @@ package pgproto3 import ( "bytes" + "encoding/hex" + "errors" "fmt" "io" + "iter" "strconv" - "strings" "sync" "time" ) @@ -17,7 +19,7 @@ type tracer struct { mux sync.Mutex w io.Writer - buf *bytes.Buffer + buf bytes.Buffer } // TracerOptions controls tracing behavior. It is roughly equivalent to the libpq function PQsetTraceFlags. @@ -29,6 +31,212 @@ type TracerOptions struct { RegressMode bool } +const timestampFormat = "2006-01-02 15:04:05.000000" + +var ( + errUnclosedDoubleQuote = errors.New("unclosed double quote") + errUnclosedSingleQuote = errors.New("unclosed single quote") + errExpectedSingleQuote = errors.New("expected single quote") +) + +// Parse parses a single trace line into its components. +// Returns the timestamp (zero if SuppressTimestamps was true), actor ('F' or 'B'), +// message type name, encoded message size, and the args portion (may be empty). +func (opts TracerOptions) Parse(line []byte) (timestamp time.Time, actor byte, msgType string, size int32, args []byte, err error) { + // Parse fields by scanning for tabs manually to avoid allocation from bytes.Split + data := line + + // Parse timestamp if present + if !opts.SuppressTimestamps { + tabIdx := indexByte(data, '\t') + if tabIdx < 0 { + return time.Time{}, 0, "", 0, nil, errors.New("invalid trace line: not enough fields") + } + timestamp, err = time.Parse(timestampFormat, string(data[:tabIdx])) + if err != nil { + return time.Time{}, 0, "", 0, nil, fmt.Errorf("invalid timestamp: %w", err) + } + data = data[tabIdx+1:] + } + + // Parse actor + if len(data) < 1 { + return time.Time{}, 0, "", 0, nil, errors.New("invalid trace line: not enough fields") + } + actor = data[0] + if actor != 'F' && actor != 'B' { + return time.Time{}, 0, "", 0, nil, fmt.Errorf("invalid actor: expected 'F' or 'B', got '%c'", actor) + } + data = data[1:] + + // Expect tab after actor + if len(data) == 0 || data[0] != '\t' { + return time.Time{}, 0, "", 0, nil, errors.New("invalid actor: expected single character") + } + data = data[1:] + + // Parse message type + tabIdx := indexByte(data, '\t') + if tabIdx < 0 { + return time.Time{}, 0, "", 0, nil, errors.New("invalid trace line: not enough fields") + } + msgType = string(data[:tabIdx]) + data = data[tabIdx+1:] + + // Parse size + tabIdx = indexByte(data, '\t') + var sizeBytes []byte + if tabIdx < 0 { + sizeBytes = data + data = nil + } else { + sizeBytes = data[:tabIdx] + data = data[tabIdx+1:] + } + sizeVal, err := strconv.ParseInt(string(sizeBytes), 10, 32) + if err != nil { + return time.Time{}, 0, "", 0, nil, fmt.Errorf("invalid size: %w", err) + } + size = int32(sizeVal) + + // Remaining data is args + args = data + + return timestamp, actor, msgType, size, args, nil +} + +// indexByte returns the index of the first instance of c in s, or -1 if c is not present. +func indexByte(s []byte, c byte) int { + return bytes.IndexByte(s, c) +} + +// ParseArgs returns an iterator over space-separated arguments in the args portion. +// Each value is unquoted and unescaped: +// - Double-quoted strings: "value" → value (quotes removed) +// - Single-quoted strings with hex escapes: 'hello\x0aworld' → hello\nworld (unescaped) +// - Unquoted values returned as-is +func (opts TracerOptions) ParseArgs(args []byte) iter.Seq2[[]byte, error] { + return func(yield func([]byte, error) bool) { + data := args + + for len(data) > 0 { + // Skip leading spaces + for len(data) > 0 && data[0] == ' ' { + data = data[1:] + } + if len(data) == 0 { + break + } + + var value []byte + var err error + + switch data[0] { + case '"': + // Double-quoted string: find closing quote + end := bytes.IndexByte(data[1:], '"') + if end < 0 { + if !yield(nil, errUnclosedDoubleQuote) { + return + } + return + } + value = data[1 : end+1] + data = data[end+2:] + + case '\'': + // Single-quoted string with hex escapes + value, data, err = parseSingleQuoted(data) + if err != nil { + if !yield(nil, err) { + return + } + return + } + + default: + // Unquoted value: read until space + end := bytes.IndexByte(data, ' ') + if end < 0 { + value = data + data = nil + } else { + value = data[:end] + data = data[end:] + } + } + + if !yield(value, nil) { + return + } + } + } +} + +// parseSingleQuoted parses a single-quoted string with hex escapes. +// Returns the unescaped value, remaining data, and any error. +// Optimized to avoid allocations when there are no escape sequences. +func parseSingleQuoted(data []byte) (value []byte, remaining []byte, err error) { + if len(data) == 0 || data[0] != '\'' { + return nil, data, errExpectedSingleQuote + } + data = data[1:] + + // First, scan to find the closing quote and check if any escapes exist + hasEscape := false + closeIdx := -1 + for i := range data { + if data[i] == '\'' { + closeIdx = i + break + } + if data[i] == '\\' && i+1 < len(data) && data[i+1] == 'x' { + hasEscape = true + } + } + + if closeIdx < 0 { + return nil, nil, errUnclosedSingleQuote + } + + content := data[:closeIdx] + remaining = data[closeIdx+1:] + + // Fast path: no escapes, return a subslice directly + if !hasEscape { + return content, remaining, nil + } + + // Slow path: need to unescape + // Pre-calculate the result size to avoid reallocations + resultLen := 0 + for i := 0; i < len(content); i++ { + if len(content) >= i+4 && content[i] == '\\' && content[i+1] == 'x' { + resultLen++ + i += 3 // skip \xNN, loop will add 1 more + } else { + resultLen++ + } + } + + result := make([]byte, 0, resultLen) + var decoded [1]byte + for i := 0; i < len(content); i++ { + if len(content) >= i+4 && content[i] == '\\' && content[i+1] == 'x' { + _, err := hex.Decode(decoded[:], content[i+2:i+4]) + if err != nil { + return nil, data, fmt.Errorf("invalid hex escape: %w", err) + } + result = append(result, decoded[0]) + i += 3 + } else { + result = append(result, content[i]) + } + } + + return result, remaining, nil +} + func (t *tracer) traceMessage(sender byte, encodedLen int32, msg Message) { switch msg := msg.(type) { case *AuthenticationCleartextPassword: @@ -163,24 +371,24 @@ func (t *tracer) traceBackendKeyData(sender byte, encodedLen int32, msg *Backend if t.RegressMode { t.buf.WriteString("\t NNNN NNNN") } else { - fmt.Fprintf(t.buf, "\t %d %d", msg.ProcessID, msg.SecretKey) + fmt.Fprintf(&t.buf, "\t %d %d", msg.ProcessID, msg.SecretKey) } }) } func (t *tracer) traceBind(sender byte, encodedLen int32, msg *Bind) { t.writeTrace(sender, encodedLen, "Bind", func() { - fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.DestinationPortal)), traceDoubleQuotedString([]byte(msg.PreparedStatement)), len(msg.ParameterFormatCodes)) + fmt.Fprintf(&t.buf, "\t %s %s %d", doubleQuotedString{&msg.DestinationPortal}, doubleQuotedString{&msg.PreparedStatement}, len(msg.ParameterFormatCodes)) for _, fc := range msg.ParameterFormatCodes { - fmt.Fprintf(t.buf, " %d", fc) + fmt.Fprintf(&t.buf, " %d", fc) } - fmt.Fprintf(t.buf, " %d", len(msg.Parameters)) - for _, p := range msg.Parameters { - fmt.Fprintf(t.buf, " %s", traceSingleQuotedString(p)) + fmt.Fprintf(&t.buf, " %d", len(msg.Parameters)) + for i := range msg.Parameters { + fmt.Fprintf(&t.buf, " %s", singleQuotedEscaped{&msg.Parameters[i]}) } - fmt.Fprintf(t.buf, " %d", len(msg.ResultFormatCodes)) + fmt.Fprintf(&t.buf, " %d", len(msg.ResultFormatCodes)) for _, fc := range msg.ResultFormatCodes { - fmt.Fprintf(t.buf, " %d", fc) + fmt.Fprintf(&t.buf, " %d", fc) } }) } @@ -203,7 +411,7 @@ func (t *tracer) traceCloseComplete(sender byte, encodedLen int32, msg *CloseCom func (t *tracer) traceCommandComplete(sender byte, encodedLen int32, msg *CommandComplete) { t.writeTrace(sender, encodedLen, "CommandComplete", func() { - fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString(msg.CommandTag)) + fmt.Fprintf(&t.buf, "\t %s", doubleQuotedBytes{&msg.CommandTag}) }) } @@ -221,7 +429,7 @@ func (t *tracer) traceCopyDone(sender byte, encodedLen int32, msg *CopyDone) { func (t *tracer) traceCopyFail(sender byte, encodedLen int32, msg *CopyFail) { t.writeTrace(sender, encodedLen, "CopyFail", func() { - fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.Message))) + fmt.Fprintf(&t.buf, "\t %s", doubleQuotedString{&msg.Message}) }) } @@ -235,12 +443,12 @@ func (t *tracer) traceCopyOutResponse(sender byte, encodedLen int32, msg *CopyOu func (t *tracer) traceDataRow(sender byte, encodedLen int32, msg *DataRow) { t.writeTrace(sender, encodedLen, "DataRow", func() { - fmt.Fprintf(t.buf, "\t %d", len(msg.Values)) - for _, v := range msg.Values { - if v == nil { + fmt.Fprintf(&t.buf, "\t %d", len(msg.Values)) + for i := range msg.Values { + if msg.Values[i] == nil { t.buf.WriteString(" -1") } else { - fmt.Fprintf(t.buf, " %d %s", len(v), traceSingleQuotedString(v)) + fmt.Fprintf(&t.buf, " %d %s", len(msg.Values[i]), singleQuotedEscaped{&msg.Values[i]}) } } }) @@ -248,7 +456,7 @@ func (t *tracer) traceDataRow(sender byte, encodedLen int32, msg *DataRow) { func (t *tracer) traceDescribe(sender byte, encodedLen int32, msg *Describe) { t.writeTrace(sender, encodedLen, "Describe", func() { - fmt.Fprintf(t.buf, "\t %c %s", msg.ObjectType, traceDoubleQuotedString([]byte(msg.Name))) + fmt.Fprintf(&t.buf, "\t %c %s", msg.ObjectType, doubleQuotedString{&msg.Name}) }) } @@ -262,7 +470,7 @@ func (t *tracer) traceErrorResponse(sender byte, encodedLen int32, msg *ErrorRes func (t *tracer) TraceQueryute(sender byte, encodedLen int32, msg *Execute) { t.writeTrace(sender, encodedLen, "Execute", func() { - fmt.Fprintf(t.buf, "\t %s %d", traceDoubleQuotedString([]byte(msg.Portal)), msg.MaxRows) + fmt.Fprintf(&t.buf, "\t %s %d", doubleQuotedString{&msg.Portal}, msg.MaxRows) }) } @@ -292,7 +500,7 @@ func (t *tracer) traceNoticeResponse(sender byte, encodedLen int32, msg *NoticeR func (t *tracer) traceNotificationResponse(sender byte, encodedLen int32, msg *NotificationResponse) { t.writeTrace(sender, encodedLen, "NotificationResponse", func() { - fmt.Fprintf(t.buf, "\t %d %s %s", msg.PID, traceDoubleQuotedString([]byte(msg.Channel)), traceDoubleQuotedString([]byte(msg.Payload))) + fmt.Fprintf(&t.buf, "\t %d %s %s", msg.PID, doubleQuotedString{&msg.Channel}, doubleQuotedString{&msg.Payload}) }) } @@ -302,15 +510,15 @@ func (t *tracer) traceParameterDescription(sender byte, encodedLen int32, msg *P func (t *tracer) traceParameterStatus(sender byte, encodedLen int32, msg *ParameterStatus) { t.writeTrace(sender, encodedLen, "ParameterStatus", func() { - fmt.Fprintf(t.buf, "\t %s %s", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Value))) + fmt.Fprintf(&t.buf, "\t %s %s", doubleQuotedString{&msg.Name}, doubleQuotedString{&msg.Value}) }) } func (t *tracer) traceParse(sender byte, encodedLen int32, msg *Parse) { t.writeTrace(sender, encodedLen, "Parse", func() { - fmt.Fprintf(t.buf, "\t %s %s %d", traceDoubleQuotedString([]byte(msg.Name)), traceDoubleQuotedString([]byte(msg.Query)), len(msg.ParameterOIDs)) + fmt.Fprintf(&t.buf, "\t %s %s %d", doubleQuotedString{&msg.Name}, doubleQuotedString{&msg.Query}, len(msg.ParameterOIDs)) for _, oid := range msg.ParameterOIDs { - fmt.Fprintf(t.buf, " %d", oid) + fmt.Fprintf(&t.buf, " %d", oid) } }) } @@ -325,21 +533,21 @@ func (t *tracer) tracePortalSuspended(sender byte, encodedLen int32, msg *Portal func (t *tracer) traceQuery(sender byte, encodedLen int32, msg *Query) { t.writeTrace(sender, encodedLen, "Query", func() { - fmt.Fprintf(t.buf, "\t %s", traceDoubleQuotedString([]byte(msg.String))) + fmt.Fprintf(&t.buf, "\t %s", doubleQuotedString{&msg.String}) }) } func (t *tracer) traceReadyForQuery(sender byte, encodedLen int32, msg *ReadyForQuery) { t.writeTrace(sender, encodedLen, "ReadyForQuery", func() { - fmt.Fprintf(t.buf, "\t %c", msg.TxStatus) + fmt.Fprintf(&t.buf, "\t %c", msg.TxStatus) }) } func (t *tracer) traceRowDescription(sender byte, encodedLen int32, msg *RowDescription) { t.writeTrace(sender, encodedLen, "RowDescription", func() { - fmt.Fprintf(t.buf, "\t %d", len(msg.Fields)) - for _, fd := range msg.Fields { - fmt.Fprintf(t.buf, ` %s %d %d %d %d %d %d`, traceDoubleQuotedString(fd.Name), fd.TableOID, fd.TableAttributeNumber, fd.DataTypeOID, fd.DataTypeSize, fd.TypeModifier, fd.Format) + fmt.Fprintf(&t.buf, "\t %d", len(msg.Fields)) + for i := range msg.Fields { + fmt.Fprintf(&t.buf, ` %s %d %d %d %d %d %d`, doubleQuotedBytes{&msg.Fields[i].Name}, msg.Fields[i].TableOID, msg.Fields[i].TableAttributeNumber, msg.Fields[i].DataTypeOID, msg.Fields[i].DataTypeSize, msg.Fields[i].TypeModifier, msg.Fields[i].Format) } }) } @@ -365,7 +573,7 @@ func (t *tracer) writeTrace(sender byte, encodedLen int32, msgType string, write defer t.mux.Unlock() defer func() { if t.buf.Cap() > 1024 { - t.buf = &bytes.Buffer{} + t.buf = bytes.Buffer{} } else { t.buf.Reset() } @@ -373,7 +581,7 @@ func (t *tracer) writeTrace(sender byte, encodedLen int32, msgType string, write if !t.SuppressTimestamps { now := time.Now() - t.buf.WriteString(now.Format("2006-01-02 15:04:05.000000")) + t.buf.Write(now.AppendFormat(t.buf.AvailableBuffer(), timestampFormat)) t.buf.WriteByte('\t') } @@ -381,7 +589,7 @@ func (t *tracer) writeTrace(sender byte, encodedLen int32, msgType string, write t.buf.WriteByte('\t') t.buf.WriteString(msgType) t.buf.WriteByte('\t') - t.buf.WriteString(strconv.FormatInt(int64(encodedLen), 10)) + t.buf.Write(strconv.AppendInt(t.buf.AvailableBuffer(), int64(encodedLen), 10)) if writeDetails != nil { writeDetails() @@ -391,26 +599,48 @@ func (t *tracer) writeTrace(sender byte, encodedLen int32, msgType string, write t.buf.WriteTo(t.w) } -// traceDoubleQuotedString returns t.buf as a double-quoted string without any escaping. It is roughly equivalent to -// pqTraceOutputString in libpq. -func traceDoubleQuotedString(buf []byte) string { - return `"` + string(buf) + `"` +// doubleQuotedString wraps a pointer to string for zero-copy double-quoted formatting. +// Using a pointer avoids copying the string header when passed to fmt.Fprintf. +type doubleQuotedString struct{ data *string } + +func (dq doubleQuotedString) Format(f fmt.State, verb rune) { + io.WriteString(f, `"`) + io.WriteString(f, *dq.data) + io.WriteString(f, `"`) } -// traceSingleQuotedString returns buf as a single-quoted string with non-printable characters hex-escaped. It is -// roughly equivalent to pqTraceOutputNchar in libpq. -func traceSingleQuotedString(buf []byte) string { - sb := &strings.Builder{} +// doubleQuotedBytes wraps a pointer to []byte for zero-copy double-quoted formatting. +// Using a pointer avoids copying the slice header when passed to fmt.Fprintf. +type doubleQuotedBytes struct{ data *[]byte } - sb.WriteByte('\'') - for _, b := range buf { - if b < 32 || b > 126 { - fmt.Fprintf(sb, `\x%x`, b) - } else { - sb.WriteByte(b) +func (dq doubleQuotedBytes) Format(f fmt.State, verb rune) { + io.WriteString(f, `"`) + f.Write(*dq.data) + io.WriteString(f, `"`) +} + +// singleQuotedEscaped wraps a pointer to []byte for zero-copy single-quoted formatting +// with hex escaping for non-printable characters (libpq style). +type singleQuotedEscaped struct{ data *[]byte } + +func (sq singleQuotedEscaped) Format(f fmt.State, verb rune) { + io.WriteString(f, `'`) + + data := *sq.data + for len(data) > 0 { + i := 0 + for i < len(data) && data[i] >= 32 && data[i] <= 126 { + i++ + } + if i > 0 { + f.Write(data[:i]) + data = data[i:] + } + if len(data) > 0 && (data[0] < 32 || data[0] > 126) { + fmt.Fprintf(f, `\x%x`, data[0]) + data = data[1:] } } - sb.WriteByte('\'') - return sb.String() + io.WriteString(f, `'`) } diff --git a/pgproto3/trace_test.go b/pgproto3/trace_test.go index c56a49912..802723344 100644 --- a/pgproto3/trace_test.go +++ b/pgproto3/trace_test.go @@ -9,6 +9,7 @@ import ( "github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgproto3" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -54,3 +55,364 @@ B ReadyForQuery 6 I require.Equal(t, expected, traceOutput.String()) } + +func TestTracerOptionsParse(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + line string + suppressTimestamps bool + wantTimestamp time.Time + wantActor byte + wantMsgType string + wantSize int32 + wantArgs string + wantErr bool + }{ + { + name: "with timestamp and args", + line: "2024-01-15 10:30:45.123456\tB\tParameterStatus\t25\t\"server_version\" \"15.4\"", + suppressTimestamps: false, + wantTimestamp: time.Date(2024, 1, 15, 10, 30, 45, 123456000, time.UTC), + wantActor: 'B', + wantMsgType: "ParameterStatus", + wantSize: 25, + wantArgs: "\"server_version\" \"15.4\"", + }, + { + name: "with timestamp no args", + line: "2024-01-15 10:30:45.123456\tF\tSync\t5", + suppressTimestamps: false, + wantTimestamp: time.Date(2024, 1, 15, 10, 30, 45, 123456000, time.UTC), + wantActor: 'F', + wantMsgType: "Sync", + wantSize: 5, + wantArgs: "", + }, + { + name: "suppress timestamps with args", + line: "B\tDataRow\t12\t 1 1 '1'", + suppressTimestamps: true, + wantActor: 'B', + wantMsgType: "DataRow", + wantSize: 12, + wantArgs: " 1 1 '1'", + }, + { + name: "suppress timestamps no args", + line: "F\tSync\t5", + suppressTimestamps: true, + wantActor: 'F', + wantMsgType: "Sync", + wantSize: 5, + wantArgs: "", + }, + { + name: "invalid actor", + line: "X\tSync\t5", + suppressTimestamps: true, + wantErr: true, + }, + { + name: "invalid actor multi-char", + line: "FB\tSync\t5", + suppressTimestamps: true, + wantErr: true, + }, + { + name: "invalid timestamp", + line: "not-a-timestamp\tB\tSync\t5", + suppressTimestamps: false, + wantErr: true, + }, + { + name: "not enough fields", + line: "B\tSync", + suppressTimestamps: true, + wantErr: true, + }, + { + name: "invalid size", + line: "B\tSync\tnotanumber", + suppressTimestamps: true, + wantErr: true, + }, + { + name: "args with tabs", + line: "B\tTest\t10\tfirst\tsecond", + suppressTimestamps: true, + wantActor: 'B', + wantMsgType: "Test", + wantSize: 10, + wantArgs: "first\tsecond", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := pgproto3.TracerOptions{SuppressTimestamps: tt.suppressTimestamps} + timestamp, actor, msgType, size, args, err := opts.Parse([]byte(tt.line)) + + if tt.wantErr { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantTimestamp, timestamp) + assert.Equal(t, tt.wantActor, actor) + assert.Equal(t, tt.wantMsgType, msgType) + assert.Equal(t, tt.wantSize, size) + assert.Equal(t, tt.wantArgs, string(args)) + }) + } +} + +func TestTracerOptionsParseArgs(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + args string + want []string + wantErrs []bool // true if that iteration should yield an error + }{ + { + name: "double quoted strings", + args: `"hello" "world"`, + want: []string{"hello", "world"}, + }, + { + name: "single quoted strings", + args: `'hello' 'world'`, + want: []string{"hello", "world"}, + }, + { + name: "unquoted values", + args: `123 456 abc`, + want: []string{"123", "456", "abc"}, + }, + { + name: "mixed types", + args: `"name" 42 'value'`, + want: []string{"name", "42", "value"}, + }, + { + name: "single quoted with hex escape", + args: `'hello\x0aworld'`, + want: []string{"hello\nworld"}, + }, + { + name: "single quoted with multiple escapes", + args: `'\x00\x01\x02'`, + want: []string{"\x00\x01\x02"}, + }, + { + name: "single quoted with tab escape", + args: `'col1\x09col2'`, + want: []string{"col1\tcol2"}, + }, + { + name: "leading space in args", + args: ` "first" "second"`, + want: []string{"first", "second"}, + }, + { + name: "multiple spaces between args", + args: `"one" "two" "three"`, + want: []string{"one", "two", "three"}, + }, + { + name: "empty args", + args: ``, + want: nil, + }, + { + name: "only spaces", + args: ` `, + want: nil, + }, + { + name: "empty double quoted", + args: `""`, + want: []string{""}, + }, + { + name: "empty single quoted", + args: `''`, + want: []string{""}, + }, + { + name: "unclosed double quote", + args: `"hello`, + wantErrs: []bool{true}, + }, + { + name: "unclosed single quote", + args: `'hello`, + wantErrs: []bool{true}, + }, + { + name: "invalid hex escape", + args: `'hello\xgg'`, + wantErrs: []bool{true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := pgproto3.TracerOptions{} + var got []string + var gotErrs []bool + + for val, err := range opts.ParseArgs([]byte(tt.args)) { + if err != nil { + gotErrs = append(gotErrs, true) + } else { + got = append(got, string(val)) + gotErrs = append(gotErrs, false) + } + } + + if tt.wantErrs != nil { + assert.Equal(t, tt.wantErrs, gotErrs) + } else { + assert.Equal(t, tt.want, got) + for _, gotErr := range gotErrs { + assert.False(t, gotErr) + } + } + }) + } +} + +func TestParseRoundTrip(t *testing.T) { + t.Parallel() + + // Test that we can parse the output from the tracer + traceOutput := `F Parse 45 "" "select n from generate_series(1,5) n" 0 +F Bind 13 "" "" 0 0 0 +F Describe 7 P "" +F Execute 10 "" 0 +F Sync 5 +B ParseComplete 5 +B BindComplete 5 +B RowDescription 27 1 "n" 0 0 23 4 -1 0 +B DataRow 12 1 1 '1' +B DataRow 12 1 1 '2' +B CommandComplete 14 "SELECT 5" +B ReadyForQuery 6 I +` + + opts := pgproto3.TracerOptions{SuppressTimestamps: true} + lines := bytes.Split([]byte(traceOutput), []byte{'\n'}) + + for _, line := range lines { + if len(line) == 0 { + continue + } + + _, actor, msgType, size, _, err := opts.Parse(line) + require.NoError(t, err) + assert.True(t, actor == 'F' || actor == 'B') + assert.NotEmpty(t, msgType) + assert.Greater(t, size, int32(0)) + } +} + +func BenchmarkParse(b *testing.B) { + benchmarks := []struct { + name string + line []byte + suppressTimestamps bool + }{ + { + name: "with timestamp and args", + line: []byte("2024-01-15 10:30:45.123456\tB\tParameterStatus\t25\t\"server_version\" \"15.4\""), + suppressTimestamps: false, + }, + { + name: "with timestamp no args", + line: []byte("2024-01-15 10:30:45.123456\tF\tSync\t5"), + suppressTimestamps: false, + }, + { + name: "suppress timestamps with args", + line: []byte("B\tDataRow\t12\t 1 1 '1'"), + suppressTimestamps: true, + }, + { + name: "suppress timestamps no args", + line: []byte("F\tSync\t5"), + suppressTimestamps: true, + }, + { + name: "row description", + line: []byte("B\tRowDescription\t27\t 1 \"n\" 0 0 23 4 -1 0"), + suppressTimestamps: true, + }, + } + + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + opts := pgproto3.TracerOptions{SuppressTimestamps: bm.suppressTimestamps} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _, _, _, _, _ = opts.Parse(bm.line) + } + }) + } +} + +func BenchmarkParseArgs(b *testing.B) { + benchmarks := []struct { + name string + args []byte + }{ + { + name: "double quoted", + args: []byte(`"server_version" "15.4"`), + }, + { + name: "single quoted", + args: []byte(`'hello' 'world'`), + }, + { + name: "single quoted with escapes", + args: []byte(`'hello\x0aworld\x09tab'`), + }, + { + name: "mixed", + args: []byte(`"name" 42 'value'`), + }, + { + name: "data row", + args: []byte(` 1 1 '1'`), + }, + { + name: "row description", + args: []byte(` 1 "n" 0 0 23 4 -1 0`), + }, + { + name: "parse message", + args: []byte(` "" "select n from generate_series(1,5) n" 0`), + }, + } + + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + opts := pgproto3.TracerOptions{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, err := range opts.ParseArgs(bm.args) { + if err != nil { + b.Fatal(err) + } + } + } + }) + } +}