Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions drizzle-orm/src/mysql-core/columns/custom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ export class MySqlCustomColumn<T extends ColumnBaseConfig<'custom', 'MySqlCustom
private sqlName: string;
private mapTo?: (value: T['data']) => T['driverParam'];
private mapFrom?: (value: T['driverParam']) => T['data'];
private selectFn?: (columnName: string, decoder: any) => SQL<T['data']>;

constructor(
table: AnyMySqlTable<{ name: T['tableName'] }>,
Expand All @@ -72,6 +73,7 @@ export class MySqlCustomColumn<T extends ColumnBaseConfig<'custom', 'MySqlCustom
this.sqlName = config.customTypeParams.dataType(config.fieldConfig);
this.mapTo = config.customTypeParams.toDriver;
this.mapFrom = config.customTypeParams.fromDriver;
this.selectFn = config.customTypeParams.selectFromDb;
}

getSQLType(): string {
Expand All @@ -85,6 +87,16 @@ export class MySqlCustomColumn<T extends ColumnBaseConfig<'custom', 'MySqlCustom
override mapToDriverValue(value: T['data']): T['driverParam'] {
return typeof this.mapTo === 'function' ? this.mapTo(value) : value as T['data'];
}

/**
* Returns custom SQL expression for selecting this column, if defined.
*/
getSelectSQL(columnName?: string): SQL<T['data']> | this {
if (this.selectFn) {
return this.selectFn(columnName || this.name, this);
}
return this;
}
}

export type CustomTypeValues = {
Expand Down Expand Up @@ -195,6 +207,14 @@ export interface CustomTypeParams<T extends CustomTypeValues> {
* ```
*/
fromDriver?: (value: T['driverData']) => T['data'];

/**
* Optional function to wrap the column in custom SQL when selecting from database.
* @param columnName - The column name to be selected
* @param decoder - The value decoder for this column
* @returns SQL expression for selecting this column
*/
selectFromDb?: (columnName: string, decoder: any) => SQL<T['data']>;
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/**
* Test for custom type selectFromDb feature
* Issue: #1083 - Default select for custom types
*/

import { describe, it, expect } from 'vitest';
import { customType } from '../custom.ts';

describe('Custom Type selectFromDb', () => {
it('should allow defining selectFromDb for custom types', () => {
// Define a custom point type with selectFromDb
const pointType = customType<{
data: { lat: number; lng: number };
driverData: string;
}>({
dataType() {
return 'geometry(Point,4326)';
},
fromDriver(value: string) {
const matches = value.match(/POINT\((?<lng>[\d.-]+) (?<lat>[\d.-]+)\)/);
const { lat, lng } = matches?.groups ?? {};
return { lat: parseFloat(String(lat)), lng: parseFloat(String(lng)) };
},
});

expect(pointType).toBeDefined();
});

it('should accept selectFromDb callback', () => {
let selectFromDbCalled = false;

const customType_ = customType<{
data: string;
driverData: string;
}>({
dataType() {
return 'text';
},
selectFromDb() {
selectFromDbCalled = true;
return {} as any;
},
});

// Create column to trigger the callback storage
const builder = customType_();
expect(builder).toBeDefined();
});

it('should work without selectFromDb (optional)', () => {
const simpleCustom = customType<{
data: string;
driverData: string;
}>({
dataType() {
return 'text';
},
});

expect(simpleCustom).toBeDefined();
});
});
31 changes: 31 additions & 0 deletions drizzle-orm/src/pg-core/columns/custom.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ export class PgCustomColumn<T extends ColumnBaseConfig<'custom', 'PgCustomColumn
private sqlName: string;
private mapTo?: (value: T['data']) => T['driverParam'];
private mapFrom?: (value: T['driverParam']) => T['data'];
private selectFn?: (columnName: string, decoder: any) => SQL<T['data']>;

constructor(
table: AnyPgTable<{ name: T['tableName'] }>,
Expand All @@ -72,6 +73,7 @@ export class PgCustomColumn<T extends ColumnBaseConfig<'custom', 'PgCustomColumn
this.sqlName = config.customTypeParams.dataType(config.fieldConfig);
this.mapTo = config.customTypeParams.toDriver;
this.mapFrom = config.customTypeParams.fromDriver;
this.selectFn = config.customTypeParams.selectFromDb;
}

getSQLType(): string {
Expand All @@ -85,6 +87,19 @@ export class PgCustomColumn<T extends ColumnBaseConfig<'custom', 'PgCustomColumn
override mapToDriverValue(value: T['data']): T['driverParam'] {
return typeof this.mapTo === 'function' ? this.mapTo(value) : value as T['data'];
}

/**
* Returns custom SQL expression for selecting this column, if defined.
* Otherwise returns the column itself.
* @param columnName - The column name to use in the SQL expression
* @returns SQL expression or the column itself
*/
getSelectSQL(columnName?: string): SQL<T['data']> | this {
if (this.selectFn) {
return this.selectFn(columnName || this.name, this);
}
return this;
}
}

export type CustomTypeValues = {
Expand Down Expand Up @@ -195,6 +210,22 @@ export interface CustomTypeParams<T extends CustomTypeValues> {
* ```
*/
fromDriver?: (value: T['driverData']) => T['data'];

/**
* Optional function to wrap the column in custom SQL when selecting from database.
* Useful for types like PostGIS geometry that need special handling.
* @example
* For PostGIS geometry, we need to convert from WKT format:
* ```
* selectFromDb(column, decoder) {
* return sql<Point>`ST_AsText(${sql.identifier(column)})`.mapWith(decoder).as(column);
* }
* ```
* @param columnName - The column name to be selected
* @param decoder - The value decoder for this column
* @returns SQL expression for selecting this column
*/
selectFromDb?: (columnName: string, decoder: any) => SQL<T['data']>;
}

/**
Expand Down
31 changes: 26 additions & 5 deletions drizzle-orm/src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ export function mapResultRow<TResult>(
joinsNotNullableMap: Record<string, boolean> | undefined,
): TResult {
// Key -> nested object key, value -> table name if all fields in the nested object are from the same table, false otherwise
const nullifyMap: Record<string, string | false> = {};
// New format: { tableName: string, hasNonNullValue: boolean } | string | false
const nullifyMap: Record<string, { tableName: string; hasNonNullValue: boolean } | string | false> = {};

const result = columns.reduce<Record<string, any>>(
(result, { path, field }, columnIndex) => {
Expand All @@ -45,11 +46,24 @@ export function mapResultRow<TResult>(

if (joinsNotNullableMap && is(field, Column) && path.length === 2) {
const objectName = path[0]!;
const tableName = getTableName(field.table);

// Initialize tracking for this object if not exists
if (!(objectName in nullifyMap)) {
nullifyMap[objectName] = value === null ? getTableName(field.table) : false;
// Track all columns in this nested object
nullifyMap[objectName] = {
tableName,
hasNonNullValue: value !== null,
};
} else if (typeof nullifyMap[objectName] === 'object' && nullifyMap[objectName] !== null) {
// Update if we find any non-null value
if (value !== null) {
nullifyMap[objectName].hasNonNullValue = true;
}
} else if (
typeof nullifyMap[objectName] === 'string' && nullifyMap[objectName] !== getTableName(field.table)
typeof nullifyMap[objectName] === 'string' && nullifyMap[objectName] !== tableName
) {
// Legacy: different table with same object name
nullifyMap[objectName] = false;
}
}
Expand All @@ -62,8 +76,15 @@ export function mapResultRow<TResult>(

// Nullify all nested objects from nullifyMap that are nullable
if (joinsNotNullableMap && Object.keys(nullifyMap).length > 0) {
for (const [objectName, tableName] of Object.entries(nullifyMap)) {
if (typeof tableName === 'string' && !joinsNotNullableMap[tableName]) {
for (const [objectName, tracking] of Object.entries(nullifyMap)) {
// Handle new object tracking format
if (typeof tracking === 'object' && tracking !== null && 'hasNonNullValue' in tracking) {
// Only nullify if ALL values were null AND the join is nullable
if (!tracking.hasNonNullValue && !joinsNotNullableMap[tracking.tableName]) {
result[objectName] = null;
}
} else if (typeof tracking === 'string' && !joinsNotNullableMap[tracking]) {
// Legacy format handling
result[objectName] = null;
}
}
Expand Down