diff --git a/api/src/main/java/org/apache/iceberg/Schema.java b/api/src/main/java/org/apache/iceberg/Schema.java index fbc557f6979a..3e59998be476 100644 --- a/api/src/main/java/org/apache/iceberg/Schema.java +++ b/api/src/main/java/org/apache/iceberg/Schema.java @@ -105,6 +105,10 @@ public Schema(List columns, Set identifierFieldIds, TypeUt this(DEFAULT_SCHEMA_ID, columns, identifierFieldIds, getId); } + public Schema(List columns, TypeUtil.GetID getId) { + this(DEFAULT_SCHEMA_ID, columns, ImmutableSet.of(), getId); + } + public Schema(int schemaId, List columns) { this(schemaId, columns, ImmutableSet.of()); } diff --git a/api/src/main/java/org/apache/iceberg/types/TypeUtil.java b/api/src/main/java/org/apache/iceberg/types/TypeUtil.java index b1c556be0667..09478dc01285 100644 --- a/api/src/main/java/org/apache/iceberg/types/TypeUtil.java +++ b/api/src/main/java/org/apache/iceberg/types/TypeUtil.java @@ -601,6 +601,52 @@ public interface GetID { int get(int oldId); } + /** + * Creates a function that reassigns specified field IDs. + * + *

This is useful for merging schemas where some field IDs in one schema might conflict with + * IDs already in use by another schema. The function will reassign the provided IDs to new unused + * IDs, while preserving other IDs. + * + * @param conflictingIds the set of conflicting field IDs that should be reassigned + * @param allUsedIds the set of field IDs that are already in use and cannot be reused + * @return a function that reassigns conflicting field IDs while preserving others + */ + public static GetID reassignConflictingIds(Set conflictingIds, Set allUsedIds) { + return new ReassignConflictingIds(conflictingIds, allUsedIds); + } + + private static class ReassignConflictingIds implements GetID { + private final Set conflictingIds; + private final Set allUsedIds; + private final AtomicInteger nextId; + + private ReassignConflictingIds(Set conflictingIds, Set allUsedIds) { + this.conflictingIds = conflictingIds; + this.allUsedIds = allUsedIds; + this.nextId = new AtomicInteger(); + } + + @Override + public int get(int oldId) { + if (conflictingIds.contains(oldId)) { + return nextAvailableId(); + } else { + return oldId; + } + } + + private int nextAvailableId() { + int candidateId = nextId.incrementAndGet(); + + while (allUsedIds.contains(candidateId)) { + candidateId = nextId.incrementAndGet(); + } + + return candidateId; + } + } + public static class SchemaVisitor { public void beforeField(Types.NestedField field) {} diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java index fcf5fbeb2acb..4c3713d3fff3 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/SparkSchemaUtil.java @@ -42,6 +42,7 @@ import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalog.Column; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; /** Helper methods for working with Spark/Hive metadata. */ @@ -369,4 +370,8 @@ public static Map indexQuotedNameById(Schema schema) { Function quotingFunc = name -> String.format("`%s`", name.replace("`", "``")); return TypeUtil.indexQuotedNameById(schema.asStruct(), quotingFunc); } + + public static StructType toStructType(List fields) { + return new StructType(fields.toArray(new StructField[0])); + } } diff --git a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java index 744d9f28a985..5be1ba1b297c 100644 --- a/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java +++ b/spark/v4.1/spark/src/main/java/org/apache/iceberg/spark/source/SparkScanBuilder.java @@ -22,9 +22,7 @@ import java.util.List; import java.util.Optional; import java.util.Set; -import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; -import java.util.stream.Stream; import org.apache.iceberg.BaseTable; import org.apache.iceberg.BatchScan; import org.apache.iceberg.FileScanTask; @@ -49,7 +47,6 @@ import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.metrics.InMemoryMetricsReporter; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; -import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Sets; import org.apache.iceberg.spark.Spark3Util; @@ -60,6 +57,7 @@ import org.apache.iceberg.spark.SparkV2Filters; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.TypeUtil.GetID; import org.apache.iceberg.types.Types; import org.apache.iceberg.util.SnapshotUtil; import org.apache.spark.sql.SparkSession; @@ -97,10 +95,10 @@ public class SparkScanBuilder private final Table table; private final CaseInsensitiveStringMap options; private final SparkReadConf readConf; - private final List metaColumns = Lists.newArrayList(); + private final Set metaFieldNames = Sets.newLinkedHashSet(); private final InMemoryMetricsReporter metricsReporter; - private Schema schema; + private Schema projection; private boolean caseSensitive; private List filterExpressions = null; private Predicate[] pushedPredicates = NO_PREDICATES; @@ -114,7 +112,7 @@ public class SparkScanBuilder CaseInsensitiveStringMap options) { this.spark = spark; this.table = table; - this.schema = schema; + this.projection = schema; this.options = options; this.readConf = new SparkReadConf(spark, table, branch, options); this.caseSensitive = readConf.caseSensitive(); @@ -169,7 +167,7 @@ public Predicate[] pushPredicates(Predicate[] predicates) { if (expr != null) { // try binding the expression to ensure it can be pushed down - Binder.bind(schema.asStruct(), expr, caseSensitive); + Binder.bind(projection.asStruct(), expr, caseSensitive); expressions.add(expr); pushableFilters.add(predicate); } @@ -211,7 +209,7 @@ public boolean pushAggregation(Aggregation aggregation) { try { Expression expr = SparkAggregates.convert(aggregateFunc); if (expr != null) { - Expression bound = Binder.bind(schema.asStruct(), expr, caseSensitive); + Expression bound = Binder.bind(projection.asStruct(), expr, caseSensitive); expressions.add((BoundAggregate) bound); } else { LOG.info( @@ -232,7 +230,7 @@ public boolean pushAggregation(Aggregation aggregation) { } org.apache.iceberg.Scan scan = - buildIcebergBatchScan(true /* include Column Stats */, schemaWithMetadataColumns()); + buildIcebergBatchScan(true /* include Column Stats */, projectionWithMetadataColumns()); try (CloseableIterable fileScanTasks = scan.planFiles()) { for (FileScanTask task : fileScanTasks) { @@ -321,74 +319,63 @@ private boolean metricsModeSupportsAggregatePushDown(List> @Override public void pruneColumns(StructType requestedSchema) { - StructType requestedProjection = - new StructType( - Stream.of(requestedSchema.fields()) - .filter(field -> MetadataColumns.nonMetadataColumn(field.name())) - .toArray(StructField[]::new)); - - // the projection should include all columns that will be returned, including those only used in - // filters - this.schema = - SparkSchemaUtil.prune(schema, requestedProjection, filterExpression(), caseSensitive); - - Stream.of(requestedSchema.fields()) - .map(StructField::name) - .filter(MetadataColumns::isMetadataColumn) - .distinct() - .forEach(metaColumns::add); - } - - private Schema schemaWithMetadataColumns() { - // metadata columns - List metadataFields = - metaColumns.stream() - .distinct() - .map(name -> MetadataColumns.metadataColumn(table, name)) - .collect(Collectors.toList()); - Schema metadataSchema = calculateMetadataSchema(metadataFields); - - // schema or rows returned by readers - return TypeUtil.join(schema, metadataSchema); - } - - private Schema calculateMetadataSchema(List metaColumnFields) { - Optional partitionField = - metaColumnFields.stream() - .filter(f -> MetadataColumns.PARTITION_COLUMN_ID == f.fieldId()) - .findFirst(); - - // only calculate potential column id collision if partition metadata column was requested - if (!partitionField.isPresent()) { - return new Schema(metaColumnFields); - } - - Set idsToReassign = - TypeUtil.indexById(partitionField.get().type().asStructType()).keySet(); - - // Calculate used ids by union metadata columns with all base table schemas - Set currentlyUsedIds = - metaColumnFields.stream().map(Types.NestedField::fieldId).collect(Collectors.toSet()); - Set allUsedIds = - table.schemas().values().stream() - .map(currSchema -> TypeUtil.indexById(currSchema.asStruct()).keySet()) - .reduce(currentlyUsedIds, Sets::union); - - // Reassign selected ids to deduplicate with used ids. - AtomicInteger nextId = new AtomicInteger(); - return new Schema( - metaColumnFields, - ImmutableSet.of(), - oldId -> { - if (!idsToReassign.contains(oldId)) { - return oldId; - } - int candidate = nextId.incrementAndGet(); - while (allUsedIds.contains(candidate)) { - candidate = nextId.incrementAndGet(); - } - return candidate; - }); + List dataFields = Lists.newArrayList(); + + for (StructField field : requestedSchema.fields()) { + if (MetadataColumns.isMetadataColumn(field.name())) { + metaFieldNames.add(field.name()); + } else { + dataFields.add(field); + } + } + + StructType requestedProjection = SparkSchemaUtil.toStructType(dataFields); + this.projection = prune(projection, requestedProjection); + } + + // the projection should include all columns that will be returned, + // including those only used in filters + private Schema prune(Schema schema, StructType requestedSchema) { + return SparkSchemaUtil.prune(schema, requestedSchema, filterExpression(), caseSensitive); + } + + // schema of rows that must be returned by readers + protected Schema projectionWithMetadataColumns() { + return TypeUtil.join(projection, calculateMetadataSchema()); + } + + // computes metadata schema avoiding conflicts between partition and data field IDs + private Schema calculateMetadataSchema() { + List metaFields = metaFields(); + Optional partitionField = findPartitionField(metaFields); + + if (partitionField.isEmpty()) { + return new Schema(metaFields); + } + + Types.StructType partitionType = partitionField.get().type().asStructType(); + Set partitionFieldIds = TypeUtil.getProjectedIds(partitionType); + GetID getId = TypeUtil.reassignConflictingIds(partitionFieldIds, allUsedFieldIds()); + return new Schema(metaFields, getId); + } + + private List metaFields() { + return metaFieldNames.stream() + .map(name -> MetadataColumns.metadataColumn(table, name)) + .collect(Collectors.toList()); + } + + private Optional findPartitionField(List fields) { + return fields.stream() + .filter(field -> MetadataColumns.PARTITION_COLUMN_ID == field.fieldId()) + .findFirst(); + } + + // collects used data field IDs across all known table schemas + private Set allUsedFieldIds() { + return table.schemas().values().stream() + .flatMap(tableSchema -> TypeUtil.getProjectedIds(tableSchema.asStruct()).stream()) + .collect(Collectors.toSet()); } @Override @@ -401,7 +388,7 @@ public Scan build() { } private Scan buildBatchScan() { - Schema expectedSchema = schemaWithMetadataColumns(); + Schema expectedSchema = projectionWithMetadataColumns(); return new SparkBatchQueryScan( spark, table, @@ -573,7 +560,7 @@ public Scan buildChangelogScan() { } } - Schema expectedSchema = schemaWithMetadataColumns(); + Schema expectedSchema = projectionWithMetadataColumns(); IncrementalChangelogScan scan = table @@ -642,7 +629,7 @@ public Scan buildMergeOnReadScan() { table, null, readConf, - schemaWithMetadataColumns(), + projectionWithMetadataColumns(), filterExpressions, metricsReporter::scanReport); } @@ -655,7 +642,7 @@ public Scan buildMergeOnReadScan() { SparkReadConf adjustedReadConf = new SparkReadConf(spark, table, readConf.branch(), adjustedOptions); - Schema expectedSchema = schemaWithMetadataColumns(); + Schema expectedSchema = projectionWithMetadataColumns(); BatchScan scan = newBatchScan() @@ -685,12 +672,12 @@ public Scan buildCopyOnWriteScan() { spark, table, readConf, - schemaWithMetadataColumns(), + projectionWithMetadataColumns(), filterExpressions, metricsReporter::scanReport); } - Schema expectedSchema = schemaWithMetadataColumns(); + Schema expectedSchema = projectionWithMetadataColumns(); BatchScan scan = newBatchScan()