diff --git a/batch.go b/batch.go index 2307e3cc8..dabf87ea5 100644 --- a/batch.go +++ b/batch.go @@ -505,3 +505,31 @@ func invalidateCachesOnBatchResultsError(conn *Conn, b *Batch, err error) { } } } + +// ErrPreprocessingBatch occurs when an error is encountered while preprocessing a batch. +// The two preprocessing steps are "prepare" (server-side SQL parse/plan) and +// "build" (client-side argument encoding). +type ErrPreprocessingBatch struct { + step string // "prepare" or "build" + sql string + err error +} + +func newErrPreprocessingBatch(step, sql string, err error) ErrPreprocessingBatch { + return ErrPreprocessingBatch{step: step, sql: sql, err: err} +} + +func (e ErrPreprocessingBatch) Error() string { + // intentionally not including the SQL query in the error message + // to avoid leaking potentially sensitive information into logs. + // If the user wants the SQL, they can call SQL(). + return fmt.Sprintf("error preprocessing batch (%s): %v", e.step, e.err) +} + +func (e ErrPreprocessingBatch) Unwrap() error { + return e.err +} + +func (e ErrPreprocessingBatch) SQL() string { + return e.sql +} diff --git a/batch_test.go b/batch_test.go index d49b4e862..8afa501b3 100644 --- a/batch_test.go +++ b/batch_test.go @@ -598,7 +598,8 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) { var n int32 err := br.QueryRow().Scan(&n) - if pgErr, ok := err.(*pgconn.PgError); !(ok && pgErr.Code == "42601") { + var pgErr *pgconn.PgError + if !(errors.As(err, &pgErr) && pgErr.Code == "42601") { t.Errorf("rows.Err() => %v, want error code %v", err, 42601) } @@ -609,6 +610,48 @@ func TestConnSendBatchQuerySyntaxError(t *testing.T) { }) } +func TestConnSendBatchErrorReturnsErrPreprocessingBatch(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + // Only test exec modes that go through sendBatchExtendedWithDescription which wraps errors with ErrPreprocessingBatch. + modes := []pgx.QueryExecMode{ + pgx.QueryExecModeCacheStatement, + pgx.QueryExecModeCacheDescribe, + pgx.QueryExecModeDescribeExec, + } + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, modes, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + var preprocessingErr pgx.ErrPreprocessingBatch + + // Test prepare step failure: syntax error in a non-first statement. + batch := &pgx.Batch{} + batch.Queue("select 1") + batch.Queue("select 1 1") // syntax error triggers prepare failure + + err := conn.SendBatch(ctx, batch).Close() + require.Error(t, err) + require.ErrorAs(t, err, &preprocessingErr) + assert.Equal(t, "select 1 1", preprocessingErr.SQL()) + assert.NotContains(t, preprocessingErr.Error(), "select 1 1") // we don't want to leak the SQL query in the error message + assert.Contains(t, preprocessingErr.Error(), "error preprocessing batch (prepare)") + + // Test build step failure: wrong number of arguments in a statement. + batch = &pgx.Batch{} + batch.Queue("select 1") + batch.Queue("select $1::int", 1, 2) // mismatched argument count triggers build failure + + err = conn.SendBatch(ctx, batch).Close() + require.Error(t, err) + require.ErrorAs(t, err, &preprocessingErr) + assert.Equal(t, "select $1::int", preprocessingErr.SQL()) + assert.NotContains(t, preprocessingErr.Error(), "select $1::int") // we don't want to leak the SQL query in the error message + assert.Contains(t, preprocessingErr.Error(), "error preprocessing batch (build)") + }) +} + func TestConnSendBatchQueryRowInsert(t *testing.T) { t.Parallel() @@ -1012,7 +1055,7 @@ func TestConnSendBatchErrorDoesNotLeaveOrphanedPreparedStatement(t *testing.T) { batch.Queue("select col1 from foo") batch.Queue("select col1 from baz") err := conn.SendBatch(ctx, batch).Close() - require.EqualError(t, err, `ERROR: relation "baz" does not exist (SQLSTATE 42P01)`) + require.ErrorContains(t, err, `relation "baz" does not exist (SQLSTATE 42P01)`) mustExec(t, conn, `create temporary table baz(col1 text primary key);`) diff --git a/conn.go b/conn.go index 1fc0b5680..fb43225e4 100644 --- a/conn.go +++ b/conn.go @@ -1194,7 +1194,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d for _, sd := range distinctNewQueries { results, err := pipeline.GetResults() if err != nil { - return err + return newErrPreprocessingBatch("prepare", sd.SQL, err) } resultSD, ok := results.(*pgconn.StatementDescription) @@ -1228,8 +1228,7 @@ func (c *Conn) sendBatchExtendedWithDescription(ctx context.Context, b *Batch, d for _, bi := range b.QueuedQueries { err := c.eqb.Build(c.typeMap, bi.sd, bi.Arguments) if err != nil { - // we wrap the error so we the user can understand which query failed inside the batch - err = fmt.Errorf("error building query %s: %w", bi.SQL, err) + err = newErrPreprocessingBatch("build", bi.SQL, err) return &pipelineBatchResults{ctx: ctx, conn: c, err: err, closed: true} }