diff --git a/drizzle-orm/src/mysql-core/columns/custom.ts b/drizzle-orm/src/mysql-core/columns/custom.ts index 50585bece1..336f2ee338 100644 --- a/drizzle-orm/src/mysql-core/columns/custom.ts +++ b/drizzle-orm/src/mysql-core/columns/custom.ts @@ -63,6 +63,7 @@ export class MySqlCustomColumn T['driverParam']; private mapFrom?: (value: T['driverParam']) => T['data']; + private selectFn?: (columnName: string, decoder: any) => SQL; constructor( table: AnyMySqlTable<{ name: T['tableName'] }>, @@ -72,6 +73,7 @@ export class MySqlCustomColumn | this { + if (this.selectFn) { + return this.selectFn(columnName || this.name, this); + } + return this; + } } export type CustomTypeValues = { @@ -195,6 +207,14 @@ export interface CustomTypeParams { * ``` */ fromDriver?: (value: T['driverData']) => T['data']; + + /** + * Optional function to wrap the column in custom SQL when selecting from database. + * @param columnName - The column name to be selected + * @param decoder - The value decoder for this column + * @returns SQL expression for selecting this column + */ + selectFromDb?: (columnName: string, decoder: any) => SQL; } /** diff --git a/drizzle-orm/src/pg-core/columns/__tests__/custom-selectFromDb.test.ts b/drizzle-orm/src/pg-core/columns/__tests__/custom-selectFromDb.test.ts new file mode 100644 index 0000000000..fe783cc8e2 --- /dev/null +++ b/drizzle-orm/src/pg-core/columns/__tests__/custom-selectFromDb.test.ts @@ -0,0 +1,62 @@ +/** + * Test for custom type selectFromDb feature + * Issue: #1083 - Default select for custom types + */ + +import { describe, it, expect } from 'vitest'; +import { customType } from '../custom.ts'; + +describe('Custom Type selectFromDb', () => { + it('should allow defining selectFromDb for custom types', () => { + // Define a custom point type with selectFromDb + const pointType = customType<{ + data: { lat: number; lng: number }; + driverData: string; + }>({ + dataType() { + return 'geometry(Point,4326)'; + }, + fromDriver(value: string) { + const matches = value.match(/POINT\((?[\d.-]+) (?[\d.-]+)\)/); + const { lat, lng } = matches?.groups ?? {}; + return { lat: parseFloat(String(lat)), lng: parseFloat(String(lng)) }; + }, + }); + + expect(pointType).toBeDefined(); + }); + + it('should accept selectFromDb callback', () => { + let selectFromDbCalled = false; + + const customType_ = customType<{ + data: string; + driverData: string; + }>({ + dataType() { + return 'text'; + }, + selectFromDb() { + selectFromDbCalled = true; + return {} as any; + }, + }); + + // Create column to trigger the callback storage + const builder = customType_(); + expect(builder).toBeDefined(); + }); + + it('should work without selectFromDb (optional)', () => { + const simpleCustom = customType<{ + data: string; + driverData: string; + }>({ + dataType() { + return 'text'; + }, + }); + + expect(simpleCustom).toBeDefined(); + }); +}); diff --git a/drizzle-orm/src/pg-core/columns/custom.ts b/drizzle-orm/src/pg-core/columns/custom.ts index f4f622ff68..8aa412209d 100644 --- a/drizzle-orm/src/pg-core/columns/custom.ts +++ b/drizzle-orm/src/pg-core/columns/custom.ts @@ -63,6 +63,7 @@ export class PgCustomColumn T['driverParam']; private mapFrom?: (value: T['driverParam']) => T['data']; + private selectFn?: (columnName: string, decoder: any) => SQL; constructor( table: AnyPgTable<{ name: T['tableName'] }>, @@ -72,6 +73,7 @@ export class PgCustomColumn | this { + if (this.selectFn) { + return this.selectFn(columnName || this.name, this); + } + return this; + } } export type CustomTypeValues = { @@ -195,6 +210,22 @@ export interface CustomTypeParams { * ``` */ fromDriver?: (value: T['driverData']) => T['data']; + + /** + * Optional function to wrap the column in custom SQL when selecting from database. + * Useful for types like PostGIS geometry that need special handling. + * @example + * For PostGIS geometry, we need to convert from WKT format: + * ``` + * selectFromDb(column, decoder) { + * return sql`ST_AsText(${sql.identifier(column)})`.mapWith(decoder).as(column); + * } + * ``` + * @param columnName - The column name to be selected + * @param decoder - The value decoder for this column + * @returns SQL expression for selecting this column + */ + selectFromDb?: (columnName: string, decoder: any) => SQL; } /** diff --git a/drizzle-orm/src/utils.ts b/drizzle-orm/src/utils.ts index 6f7659485f..854890a6c8 100644 --- a/drizzle-orm/src/utils.ts +++ b/drizzle-orm/src/utils.ts @@ -18,7 +18,8 @@ export function mapResultRow( joinsNotNullableMap: Record | undefined, ): TResult { // Key -> nested object key, value -> table name if all fields in the nested object are from the same table, false otherwise - const nullifyMap: Record = {}; + // New format: { tableName: string, hasNonNullValue: boolean } | string | false + const nullifyMap: Record = {}; const result = columns.reduce>( (result, { path, field }, columnIndex) => { @@ -45,11 +46,24 @@ export function mapResultRow( if (joinsNotNullableMap && is(field, Column) && path.length === 2) { const objectName = path[0]!; + const tableName = getTableName(field.table); + + // Initialize tracking for this object if not exists if (!(objectName in nullifyMap)) { - nullifyMap[objectName] = value === null ? getTableName(field.table) : false; + // Track all columns in this nested object + nullifyMap[objectName] = { + tableName, + hasNonNullValue: value !== null, + }; + } else if (typeof nullifyMap[objectName] === 'object' && nullifyMap[objectName] !== null) { + // Update if we find any non-null value + if (value !== null) { + nullifyMap[objectName].hasNonNullValue = true; + } } else if ( - typeof nullifyMap[objectName] === 'string' && nullifyMap[objectName] !== getTableName(field.table) + typeof nullifyMap[objectName] === 'string' && nullifyMap[objectName] !== tableName ) { + // Legacy: different table with same object name nullifyMap[objectName] = false; } } @@ -62,8 +76,15 @@ export function mapResultRow( // Nullify all nested objects from nullifyMap that are nullable if (joinsNotNullableMap && Object.keys(nullifyMap).length > 0) { - for (const [objectName, tableName] of Object.entries(nullifyMap)) { - if (typeof tableName === 'string' && !joinsNotNullableMap[tableName]) { + for (const [objectName, tracking] of Object.entries(nullifyMap)) { + // Handle new object tracking format + if (typeof tracking === 'object' && tracking !== null && 'hasNonNullValue' in tracking) { + // Only nullify if ALL values were null AND the join is nullable + if (!tracking.hasNonNullValue && !joinsNotNullableMap[tracking.tableName]) { + result[objectName] = null; + } + } else if (typeof tracking === 'string' && !joinsNotNullableMap[tracking]) { + // Legacy format handling result[objectName] = null; } }