Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,31 @@ func invalidateCachesOnBatchResultsError(conn *Conn, b *Batch, err error) {
}
}
}

// ErrPreprocessingBatch occurs when an error is encountered while preprocessing a batch.
// The two preprocessing steps are "prepare" (server-side SQL parse/plan) and
// "build" (client-side argument encoding).
type ErrPreprocessingBatch struct {
step string // "prepare" or "build"
sql string
err error
}

func newErrPreprocessingBatch(step, sql string, err error) ErrPreprocessingBatch {
return ErrPreprocessingBatch{step: step, sql: sql, err: err}
}

func (e ErrPreprocessingBatch) Error() string {
// intentionally not including the SQL query in the error message
// to avoid leaking potentially sensitive information into logs.
// If the user wants the SQL, they can call SQL().
return fmt.Sprintf("error preprocessing batch (%s): %v", e.step, e.err)
}

func (e ErrPreprocessingBatch) Unwrap() error {
return e.err
}

func (e ErrPreprocessingBatch) SQL() string {
return e.sql
}
47 changes: 45 additions & 2 deletions batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,8 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) {

var n int32
err := br.QueryRow().Scan(&n)
if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") {
var pgErr *pgconn.PgError
if !(errors.As(err, &pgErr) && pgErr.Code == "42601") {
t.Errorf("rows.Err() => %v, want error code %v", err, 42601)
}

Expand All @@ -609,6 +610,48 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) {
})
}

func TestConnSendBatchErrorReturnsErrPreprocessingBatch(t *testing.T) {
t.Parallel()

ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()

// Only test exec modes that go through sendBatchExtendedWithDescription which wraps errors with ErrPreprocessingBatch.
modes := []pgx.QueryExecMode{
pgx.QueryExecModeCacheStatement,
pgx.QueryExecModeCacheDescribe,
pgx.QueryExecModeDescribeExec,
}

pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, modes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
var preprocessingErr pgx.ErrPreprocessingBatch

// Test prepare step failure: syntax error in a non-first statement.
batch := &pgx.Batch{}
batch.Queue("select 1")
batch.Queue("select 1 1") // syntax error triggers prepare failure

err := conn.SendBatch(ctx, batch).Close()
require.Error(t, err)
require.ErrorAs(t, err, &preprocessingErr)
assert.Equal(t, "select 1 1", preprocessingErr.SQL())
assert.NotContains(t, preprocessingErr.Error(), "select 1 1") // we don't want to leak the SQL query in the error message
assert.Contains(t, preprocessingErr.Error(), "error preprocessing batch (prepare)")

// Test build step failure: wrong number of arguments in a statement.
batch = &pgx.Batch{}
batch.Queue("select 1")
batch.Queue("select $1::int", 1, 2) // mismatched argument count triggers build failure

err = conn.SendBatch(ctx, batch).Close()
require.Error(t, err)
require.ErrorAs(t, err, &preprocessingErr)
assert.Equal(t, "select $1::int", preprocessingErr.SQL())
assert.NotContains(t, preprocessingErr.Error(), "select $1::int") // we don't want to leak the SQL query in the error message
assert.Contains(t, preprocessingErr.Error(), "error preprocessing batch (build)")
})
}

func TestConnSendBatchQueryRowInsert(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -1012,7 +1055,7 @@ func TestConnSendBatchErrorDoesNotLeaveOrphanedPreparedStatement(t *testing.T) {
batch.Queue("select col1 from foo")
batch.Queue("select col1 from baz")
err := conn.SendBatch(ctx, batch).Close()
require.EqualError(t, err, `ERROR: relation "baz" does not exist (SQLSTATE 42P01)`)
require.ErrorContains(t, err, `relation "baz" does not exist (SQLSTATE 42P01)`)

mustExec(t, conn, `create temporary table baz(col1 text primary key);`)

Expand Down
5 changes: 2 additions & 3 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1194,7 +1194,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
for _, sd := range distinctNewQueries {
results, err := pipeline.GetResults()
if err != nil {
return err
return newErrPreprocessingBatch("prepare", sd.SQL, err)
}

resultSD, ok := results.(*pgconn.StatementDescription)
Expand Down Expand Up @@ -1228,8 +1228,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d
for _, bi := range b.QueuedQueries {
err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments)
if err != nil {
// we wrap the error so we the user can understand which query failed inside the batch
err = fmt.Errorf("error building query %s: %w", bi.SQL, err)
err = newErrPreprocessingBatch("build", bi.SQL, err)
return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true}
}

Expand Down