diff --git a/pgtype/pgtype.go b/pgtype/pgtype.go index b3ef32078..9758575ad 100644 --- a/pgtype/pgtype.go +++ b/pgtype/pgtype.go @@ -1497,7 +1497,18 @@ func TryWrapBuiltinTypeEncodePlan(value any) (plan WrappedEncodePlanNextSetter, case []byte: return &wrapByteSliceEncodePlan{}, byteSliceWrapper(value), true case fmt.Stringer: - return &wrapFmtStringerEncodePlan{}, fmtStringerWrapper{value}, true + // Check if the value is a renamed basic type. If it is, prefer the basic type encoding. + rv := reflect.ValueOf(value) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, + reflect.Float32, reflect.Float64, reflect.Bool, reflect.String: + // For renamed basic types, don't use Stringer interface automatically + // Let the specific type match above handle it + default: + // For structs and other complex types that implement Stringer, use the Stringer interface + return &wrapFmtStringerEncodePlan{}, fmtStringerWrapper{value}, true + } } return nil, nil, false diff --git a/values_test.go b/values_test.go index 116577d42..5a0b4ef46 100644 --- a/values_test.go +++ b/values_test.go @@ -1004,6 +1004,57 @@ func TestEncodeTypeRename(t *testing.T) { }) } +// Define custom types that are aliases of basic types but also implement fmt.Stringer +type StringerInt32 int32 +type StringerFloat64 float64 + +// Implement the String() method for these types +func (s StringerInt32) String() string { + return fmt.Sprintf("StringerInt32(%d)", int32(s)) +} + +func (s StringerFloat64) String() string { + return fmt.Sprintf("StringerFloat64(%f)", float64(s)) +} + +// TestStringerTypes tests custom type aliases that implement the fmt.Stringer interface +func TestStringerTypes(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + pgxtest.RunWithQueryExecModes(ctx, t, defaultConnTestRunner, nil, func(ctx context.Context, t testing.TB, conn *pgx.Conn) { + // Test values + inInt := StringerInt32(42) + var outInt StringerInt32 + + inFloat := StringerFloat64(553.36) + var outFloat StringerFloat64 + + // Register types with the connection + conn.TypeMap().RegisterDefaultPgType(inInt, "int4") + conn.TypeMap().RegisterDefaultPgType(inFloat, "float8") + + // Test that the underlying values are properly encoded/decoded, + // not the String() representation + err := conn.QueryRow(context.Background(), "select $1::int4, $2::float8", inInt, inFloat). + Scan(&outInt, &outFloat) + if err != nil { + t.Fatalf("Failed with Stringer types: %v", err) + } + + // Check that the values are correctly preserved (not converted to their String() representation) + if inInt != outInt { + t.Errorf("StringerInt32: expected %v, got %v", inInt, outInt) + } + + if inFloat != outFloat { + t.Errorf("StringerFloat64: expected %v, got %v", inFloat, outFloat) + } + }) +} + // func TestRowDecodeBinary(t *testing.T) { // t.Parallel()