Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion core/api/core.api
Original file line number Diff line number Diff line change
Expand Up @@ -5756,7 +5756,7 @@ public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/Aggregations
}

public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator : org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler, org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler, org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler {
public fun <init> (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler;Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler;Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler;Ljava/lang/String;)V
public fun <init> (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorAggregationHandler;Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler;Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler;Ljava/lang/String;Ljava/util/Map;)V
public fun aggregateMultipleColumns (Lkotlin/sequences/Sequence;)Ljava/lang/Object;
public fun aggregateSequence (Lkotlin/sequences/Sequence;Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ValueType;)Ljava/lang/Object;
public fun aggregateSingleColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Object;
Expand All @@ -5769,6 +5769,7 @@ public final class org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/
public final fun getInputHandler ()Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorInputHandler;
public final fun getMultipleColumnsHandler ()Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/AggregatorMultipleColumnsHandler;
public final fun getName ()Ljava/lang/String;
public final fun getStatisticsParameters ()Ljava/util/Map;
public fun indexOfAggregationResultSingleSequence (Lkotlin/sequences/Sequence;Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ValueType;)I
public fun init (Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/Aggregator;)V
public fun preprocessAggregation (Lkotlin/sequences/Sequence;Lorg/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/ValueType;)Lkotlin/Pair;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public class Aggregator<in Value : Any, out Return : Any?>(
public val inputHandler: AggregatorInputHandler<Value, Return>,
public val multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
public val name: String,
public val statisticsParameters: Map<String, Any>,
) : AggregatorInputHandler<Value, Return> by inputHandler,
AggregatorMultipleColumnsHandler<Value, Return> by multipleColumnsHandler,
AggregatorAggregationHandler<Value, Return> by aggregationHandler {
Expand Down Expand Up @@ -75,13 +76,30 @@ public class Aggregator<in Value : Any, out Return : Any?>(
aggregationHandler: AggregatorAggregationHandler<Value, Return>,
inputHandler: AggregatorInputHandler<Value, Return>,
multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
statisticsParameters: Map<String, Any>,
): AggregatorProvider<Value, Return> =
AggregatorProvider { name ->
Aggregator(
aggregationHandler = aggregationHandler,
inputHandler = inputHandler,
multipleColumnsHandler = multipleColumnsHandler,
name = name,
statisticsParameters = statisticsParameters,
)
}

internal operator fun <Value : Any, Return : Any?> invoke(
aggregationHandler: AggregatorAggregationHandler<Value, Return>,
inputHandler: AggregatorInputHandler<Value, Return>,
multipleColumnsHandler: AggregatorMultipleColumnsHandler<Value, Return>,
): AggregatorProvider<Value, Return> =
AggregatorProvider { name ->
Aggregator(
aggregationHandler = aggregationHandler,
inputHandler = inputHandler,
multipleColumnsHandler = multipleColumnsHandler,
name = name,
emptyMap(),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.api.asSequence
import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.aggregationHandlers.SelectingAggregationHandler
import org.jetbrains.kotlinx.dataframe.impl.columns.StatisticResult
import org.jetbrains.kotlinx.dataframe.impl.columns.ValueColumnInternal
import kotlin.reflect.KType

/**
Expand All @@ -26,13 +28,34 @@ public interface AggregatorAggregationHandler<in Value : Any, out Return : Any?>

/**
* Aggregates the data in the given column and computes a single resulting value.
* Calls [aggregateSequence].
* Calls [aggregateSequence]. It tries to exploit a cache for statistics which is proper of
* [ValueColumnInternal]
*/
public fun aggregateSingleColumn(column: DataColumn<Value?>): Return =
aggregateSequence(
public fun aggregateSingleColumn(column: DataColumn<Value?>): Return {
if (column is ValueColumnInternal<*>) {
// cache check, cache is dynamically created
val aggregator = this.aggregator ?: throw IllegalStateException("Aggregator is required")
val desiredStatisticNotConsideringParameters = column.statistics.getOrPut(aggregator.name) {
mutableMapOf<Map<String, Any>, StatisticResult>()
}
// can't compare maps whose Values are Any? -> ParameterValue instead
val desiredStatistic = desiredStatisticNotConsideringParameters[aggregator.statisticsParameters]
// if desiredStatistic is null, statistic was never calculated
if (desiredStatistic != null) {
return desiredStatistic.value as Return
}
val statistic = aggregateSequence(
values = column.asSequence(),
valueType = column.type().toValueType(),
)
desiredStatisticNotConsideringParameters[aggregator.statisticsParameters] = StatisticResult(statistic)
return aggregateSingleColumn(column)
}
return aggregateSequence(
values = column.asSequence(),
valueType = column.type().toValueType(),
)
}

/**
* Function that can give the return type of [aggregateSequence] as [KType], given the type of the input.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,24 @@ public object Aggregators {
getReturnType: CalculateReturnType,
indexOfResult: IndexOfResult<Value>,
stepOneSelector: Selector<Value, Return>,
statisticsParameters: Map<String, Any>,
) = Aggregator(
aggregationHandler = SelectingAggregationHandler(stepOneSelector, indexOfResult, getReturnType),
inputHandler = AnyInputHandler(),
multipleColumnsHandler = TwoStepMultipleColumnsHandler(),
statisticsParameters = statisticsParameters,
)

private fun <Value : Any, Return : Any?> flattenHybridForAny(
getReturnType: CalculateReturnType,
indexOfResult: IndexOfResult<Value>,
reducer: Reducer<Value, Return>,
statisticsParameters: Map<String, Any>,
) = Aggregator(
aggregationHandler = HybridAggregationHandler(reducer, indexOfResult, getReturnType),
inputHandler = AnyInputHandler(),
multipleColumnsHandler = FlatteningMultipleColumnsHandler(),
statisticsParameters = statisticsParameters,
)

private fun <Value : Any, Return : Any?> twoStepReducingForAny(
Expand Down Expand Up @@ -83,20 +87,24 @@ public object Aggregators {

private fun <Return : Number?> flattenReducingForNumbers(
getReturnType: CalculateReturnType,
statisticsParameters: Map<String, Any>,
reducer: Reducer<Number, Return>,
) = Aggregator(
aggregationHandler = ReducingAggregationHandler(reducer, getReturnType),
inputHandler = NumberInputHandler(),
multipleColumnsHandler = FlatteningMultipleColumnsHandler(),
statisticsParameters = statisticsParameters,
)

private fun <Return : Number?> twoStepReducingForNumbers(
getReturnType: CalculateReturnType,
statisticsParameters: Map<String, Any>,
reducer: Reducer<Number, Return>,
) = Aggregator(
aggregationHandler = ReducingAggregationHandler(reducer, getReturnType),
inputHandler = NumberInputHandler(),
multipleColumnsHandler = TwoStepMultipleColumnsHandler(),
statisticsParameters = statisticsParameters,
)

/** @include [AggregatorOptionSwitch1] */
Expand All @@ -117,8 +125,9 @@ public object Aggregators {
by withOneOption { skipNaN: Boolean ->
twoStepSelectingForAny<Comparable<Any>, Comparable<Any>?>(
getReturnType = minTypeConversion,
stepOneSelector = { type -> minOrNull(type, skipNaN) },
indexOfResult = { type -> indexOfMin(type, skipNaN) },
stepOneSelector = { type -> minOrNull(type, skipNaN) },
statisticsParameters = mapOf<String, Any>(Pair("skipNaN", skipNaN)),
)
}

Expand All @@ -132,6 +141,7 @@ public object Aggregators {
getReturnType = maxTypeConversion,
stepOneSelector = { type -> maxOrNull(type, skipNaN) },
indexOfResult = { type -> indexOfMax(type, skipNaN) },
statisticsParameters = mapOf<String, Any>(Pair("skipNaN", skipNaN)),
)
}

Expand All @@ -140,17 +150,30 @@ public object Aggregators {
skipNaN: Boolean,
ddof: Int,
->
flattenReducingForNumbers(stdTypeConversion) { type ->
std(type, skipNaN, ddof)
}
flattenReducingForNumbers(
getReturnType = stdTypeConversion,
statisticsParameters = mapOf<String, Any>(
Pair("skipNaN", skipNaN),
Pair("ddof", ddof),
),
reducer = { type ->
std(type, skipNaN, ddof)
},
)
}

// step one: T: Number? -> Double
// step two: Double -> Double
public val mean: AggregatorOptionSwitch1<Boolean, Number, Double> by withOneOption { skipNaN: Boolean ->
twoStepReducingForNumbers(meanTypeConversion) { type ->
mean(type, skipNaN)
}
twoStepReducingForNumbers(
getReturnType = meanTypeConversion,
statisticsParameters = mapOf<String, Any>(
Pair("skipNaN", skipNaN),
),
reducer = { type ->
mean(type, skipNaN)
},
)
}

// T: primitive Number? -> Double?
Expand Down Expand Up @@ -187,6 +210,10 @@ public object Aggregators {
getReturnType = percentileConversion,
reducer = { type -> percentileOrNull(percentile, type, skipNaN) as Comparable<Any>? },
indexOfResult = { type -> indexOfPercentile(percentile, type, skipNaN) },
statisticsParameters = mapOf<String, Any>(
Pair("skipNaN", skipNaN),
Pair("percentile", percentile),
),
)
}

Expand Down Expand Up @@ -215,6 +242,7 @@ public object Aggregators {
getReturnType = medianConversion,
reducer = { type -> medianOrNull(type, skipNaN) as Comparable<Any>? },
indexOfResult = { type -> indexOfMedian(type, skipNaN) },
statisticsParameters = mapOf<String, Any>(Pair("skipNaN", skipNaN)),
)
}

Expand All @@ -223,8 +251,12 @@ public object Aggregators {
// Short -> Int
// Nothing -> Double
public val sum: AggregatorOptionSwitch1<Boolean, Number, Number> by withOneOption { skipNaN: Boolean ->
twoStepReducingForNumbers(sumTypeConversion) { type ->
sum(type, skipNaN)
}
twoStepReducingForNumbers(
getReturnType = sumTypeConversion,
statisticsParameters = mapOf<String, Any>(Pair("skipNaN", skipNaN)),
reducer = { type ->
sum(type, skipNaN)
},
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,22 @@ import org.jetbrains.kotlinx.dataframe.columns.ValueColumn
import kotlin.reflect.KType
import kotlin.reflect.full.withNullability

@JvmInline
internal value class StatisticResult(val value: Any?)

internal interface ValueColumnInternal<T> : ValueColumn<T> {
val statistics: MutableMap<String, MutableMap<Map<String, Any>, StatisticResult>>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit dangerous to expose a mutable map, especially one as complicated as this one, for other parts of the library to modify.

I would move the logic of getting/storing statistics here and only call those functions in Aggregators.kt

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this way ValueColumnInternal could only expose functions like putStatisticCache(name, arguments, value), and getStatisticCacheOrNull(name, arguments) and make the MutableMap private inside ValueColumnImpl. It's less bug prone that way :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and it allows to you avoid names like desiredStatisticNotConsideringParameters which I have difficulties with comprehending ;P

}

internal open class ValueColumnImpl<T>(
values: List<T>,
name: String,
type: KType,
val defaultValue: T? = null,
distinct: Lazy<Set<T>>? = null,
) : DataColumnImpl<T>(values, name, type, distinct),
ValueColumn<T> {
ValueColumn<T>,
ValueColumnInternal<T> {

override fun distinct() = ValueColumnImpl(toSet().toList(), name, type, defaultValue, distinct)

Expand Down Expand Up @@ -48,10 +56,13 @@ internal open class ValueColumnImpl<T>(
override fun defaultValue() = defaultValue

override fun forceResolve() = ResolvingValueColumn(this)

override val statistics = mutableMapOf<String, MutableMap<Map<String, Any>, StatisticResult>>()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm maybe a name like statisticsCache is a bit more descriptive

}

internal class ResolvingValueColumn<T>(override val source: ValueColumn<T>) :
ValueColumn<T> by source,
ValueColumnInternal<T>,
ForceResolvedColumn<T> {

override fun resolve(context: ColumnResolutionContext) = super<ValueColumn>.resolve(context)
Expand All @@ -70,4 +81,6 @@ internal class ResolvingValueColumn<T>(override val source: ValueColumn<T>) :
override fun equals(other: Any?) = source.checkEquals(other)

override fun hashCode(): Int = source.hashCode()

override val statistics = mutableMapOf<String, MutableMap<Map<String, Any>, StatisticResult>>()
}
Loading