diff --git a/dataframe.cabal b/dataframe.cabal index c10646b..38c0d69 100644 --- a/dataframe.cabal +++ b/dataframe.cabal @@ -91,7 +91,17 @@ library DataFrame.Lazy.IO.CSV, DataFrame.Lazy.Internal.DataFrame, DataFrame.Monad, - DataFrame.DecisionTree + DataFrame.DecisionTree, + DataFrame.Typed.Types, + DataFrame.Typed.Schema, + DataFrame.Typed.Freeze, + DataFrame.Typed.Access, + DataFrame.Typed.Operations, + DataFrame.Typed.Join, + DataFrame.Typed.Aggregate, + DataFrame.Typed.TH, + DataFrame.Typed.Expr, + DataFrame.Typed build-depends: base >= 4 && <5, aeson >= 0.11.0.0 && < 3, array >= 0.5.4.0 && < 0.6, diff --git a/src/DataFrame/Internal/Column.hs b/src/DataFrame/Internal/Column.hs index f3c3b33..650907e 100644 --- a/src/DataFrame/Internal/Column.hs +++ b/src/DataFrame/Internal/Column.hs @@ -620,83 +620,40 @@ zipColumns (OptionalColumn optcolumn) (OptionalColumn optother) = BoxedColumn (V -- | Merge two columns using `These`. mergeColumns :: Column -> Column -> Column -mergeColumns (BoxedColumn column) (BoxedColumn other) = BoxedColumn (VG.zipWith These column other) -mergeColumns (BoxedColumn column) (UnboxedColumn other) = - BoxedColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - (\i -> These (column VG.! i) (other VG.! i)) - ) -mergeColumns (BoxedColumn column) (OptionalColumn other) = - BoxedColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - ( \i -> - if isNothing (other VG.! i) - then This (column VG.! i) - else These (column VG.! i) (fromJust $ other VG.! i) - ) - ) -mergeColumns (UnboxedColumn column) (BoxedColumn other) = - BoxedColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - (\i -> These (column VG.! i) (other VG.! i)) - ) -mergeColumns (UnboxedColumn column) (UnboxedColumn other) = - BoxedColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - (\i -> These (column VG.! i) (other VG.! i)) - ) -mergeColumns (UnboxedColumn column) (OptionalColumn other) = - BoxedColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - ( \i -> - if isNothing (other VG.! i) - then This (column VG.! i) - else These (column VG.! i) (fromJust $ other VG.! i) - ) - ) -mergeColumns (OptionalColumn column) (BoxedColumn other) = - BoxedColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - ( \i -> - if isNothing (column VG.! i) - then That (other VG.! i) - else These (fromJust $ column VG.! i) (other VG.! i) - ) - ) -mergeColumns (OptionalColumn column) (UnboxedColumn other) = - BoxedColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - ( \i -> - if isNothing (column VG.! i) - then That (other VG.! i) - else These (fromJust $ column VG.! i) (other VG.! i) - ) - ) -mergeColumns (OptionalColumn column) (OptionalColumn other) = - OptionalColumn - ( VB.generate - (min (VG.length column) (VG.length other)) - ( \i -> - if isNothing (column VG.! i) && isNothing (other VG.! i) - then Nothing - else - ( if isNothing (column VG.! i) - then Just (That (fromJust $ other VG.! i)) - else - ( if isNothing (other VG.! i) - then Just (This (fromJust $ column VG.! i)) - else Just (These (fromJust $ column VG.! i) (fromJust $ other VG.! i)) - ) - ) - ) - ) +mergeColumns colA colB = case (colA, colB) of + (OptionalColumn c1, OptionalColumn c2) -> + OptionalColumn $ mkVec c1 c2 $ \v1 v2 -> + case (v1, v2) of + (Nothing, Nothing) -> Nothing + (Just x, Nothing) -> Just (This x) + (Nothing, Just y) -> Just (That y) + (Just x, Just y) -> Just (These x y) + (OptionalColumn c1, BoxedColumn c2) -> optReq c1 c2 + (OptionalColumn c1, UnboxedColumn c2) -> optReq c1 c2 + (BoxedColumn c1, OptionalColumn c2) -> reqOpt c1 c2 + (UnboxedColumn c1, OptionalColumn c2) -> reqOpt c1 c2 + (BoxedColumn c1, BoxedColumn c2) -> reqReq c1 c2 + (BoxedColumn c1, UnboxedColumn c2) -> reqReq c1 c2 + (UnboxedColumn c1, BoxedColumn c2) -> reqReq c1 c2 + (UnboxedColumn c1, UnboxedColumn c2) -> reqReq c1 c2 + where + mkVec c1 c2 combineElements = + VB.generate + (min (VG.length c1) (VG.length c2)) + (\i -> combineElements (c1 VG.! i) (c2 VG.! i)) + {-# INLINE mkVec #-} + + reqReq c1 c2 = BoxedColumn $ mkVec c1 c2 These + + reqOpt c1 c2 = BoxedColumn $ mkVec c1 c2 $ \v1 v2 -> + case v2 of + Nothing -> This v1 + Just y -> These v1 y + + optReq c1 c2 = BoxedColumn $ mkVec c1 c2 $ \v1 v2 -> + case v1 of + Nothing -> That v2 + Just x -> These x v2 {-# INLINE mergeColumns #-} -- | An internal, column version of zipWith. diff --git a/src/DataFrame/Typed.hs b/src/DataFrame/Typed.hs new file mode 100644 index 0000000..2af5917 --- /dev/null +++ b/src/DataFrame/Typed.hs @@ -0,0 +1,221 @@ +{-# LANGUAGE DataKinds #-} + +{- | +Module : DataFrame.Typed +Copyright : (c) 2025 +License : MIT +Maintainer : mschavinda@gmail.com +Stability : experimental + +A type-safe layer over the @dataframe@ library. + +This module provides 'TypedDataFrame', a phantom-typed wrapper around +the untyped 'DataFrame' that tracks column names and types at compile time. +All operations delegate to the untyped core at runtime; the phantom type +is updated at compile time to reflect schema changes. + +== Key difference from untyped API: TExpr + +All expression-taking operations use 'TExpr' (typed expressions) instead +of raw @Expr@. Column references are validated at compile time: + +@ +{\-\# LANGUAGE DataKinds, TypeApplications, TypeOperators \#-\} +import qualified DataFrame.Typed as T + +type People = '[T.Column \"name\" Text, T.Column \"age\" Int] + +main = do + raw <- D.readCsv \"people.csv\" + case T.freeze \@People raw of + Nothing -> putStrLn \"Schema mismatch!\" + Just df -> do + let adults = T.filterWhere (T.col \@\"age\" T..>=. T.lit 18) df + let names = T.columnAsList \@\"name\" adults -- :: [Text] + print names +@ + +Column references like @T.col \@\"age\"@ are checked at compile time — if the +column doesn't exist or has the wrong type, you get a type error, not a +runtime exception. + +== filterAllJust tracks Maybe-stripping + +@ +df :: TypedDataFrame '[Column \"x\" (Maybe Double), Column \"y\" Int] +T.filterAllJust df :: TypedDataFrame '[Column \"x\" Double, Column \"y\" Int] +@ + +== Typed aggregation (Option B) + +@ +result = T.aggregate + (T.agg \@\"total\" (T.tsum (T.col \@\"salary\")) + $ T.agg \@\"count\" (T.tcount (T.col \@\"salary\")) + $ T.aggNil) + (T.groupBy \@'[\"dept\"] employees) +@ +-} +module DataFrame.Typed ( + -- * Core types + TypedDataFrame, + Column, + TypedGrouped, + These (..), + + -- * Typed expressions + TExpr (..), + col, + lit, + ifThenElse, + lift, + lift2, + + -- * Comparison operators + (.==.), + (./=.), + (.<.), + (.<=.), + (.>=.), + (.>.), + + -- * Logical operators + (.&&.), + (.||.), + DataFrame.Typed.Expr.not, + + -- * Aggregation expression combinators + DataFrame.Typed.Expr.sum, + mean, + count, + DataFrame.Typed.Expr.minimum, + DataFrame.Typed.Expr.maximum, + collect, + + -- * Typed sort orders + TSortOrder (..), + asc, + desc, + + -- * Named expression helper + DataFrame.Typed.Expr.as, + + -- * Freeze / thaw boundary + freeze, + freezeWithError, + thaw, + unsafeFreeze, + + -- * Typed column access + columnAsVector, + columnAsList, + + -- * Schema-preserving operations + filterWhere, + filter, + filterBy, + filterAllJust, + filterJust, + filterNothing, + sortBy, + take, + takeLast, + drop, + dropLast, + range, + cube, + distinct, + sample, + shuffle, + + -- * Schema-modifying operations + derive, + impute, + select, + exclude, + rename, + renameMany, + insert, + insertColumn, + insertVector, + cloneColumn, + dropColumn, + replaceColumn, + + -- * Metadata + dimensions, + nRows, + nColumns, + columnNames, + + -- * Vertical merge + append, + + -- * Joins + innerJoin, + leftJoin, + rightJoin, + fullOuterJoin, + + -- * GroupBy and Aggregation (Option B) + groupBy, + agg, + aggNil, + aggregate, + aggregateUntyped, + + -- * Template Haskell + deriveSchema, + deriveSchemaFromCsvFile, + + -- * Schema type families (for advanced use) + Lookup, + HasName, + SubsetSchema, + ExcludeSchema, + RenameInSchema, + RemoveColumn, + Impute, + Append, + Reverse, + StripAllMaybe, + StripMaybeAt, + GroupKeyColumns, + InnerJoinSchema, + LeftJoinSchema, + RightJoinSchema, + FullOuterJoinSchema, + AssertAbsent, + AssertPresent, + + -- * Constraints + KnownSchema (..), + AllKnownSymbol (..), + + -- * Pipe operator + (|>), +) where + +import Prelude hiding (drop, filter, take) + +import DataFrame.Typed.Access (columnAsList, columnAsVector) +import DataFrame.Typed.Aggregate ( + agg, + aggNil, + aggregate, + aggregateUntyped, + groupBy, + ) +import DataFrame.Typed.Expr +import DataFrame.Typed.Freeze (freeze, freezeWithError, thaw, unsafeFreeze) +import DataFrame.Typed.Join (fullOuterJoin, innerJoin, leftJoin, rightJoin) +import DataFrame.Typed.Operations +import DataFrame.Typed.Schema +import DataFrame.Typed.TH (deriveSchema, deriveSchemaFromCsvFile) +import DataFrame.Typed.Types ( + Column, + TSortOrder (..), + These (..), + TypedDataFrame, + TypedGrouped, + ) diff --git a/src/DataFrame/Typed/Access.hs b/src/DataFrame/Typed/Access.hs new file mode 100644 index 0000000..268fa91 --- /dev/null +++ b/src/DataFrame/Typed/Access.hs @@ -0,0 +1,55 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +module DataFrame.Typed.Access ( + -- * Typed column access + columnAsVector, + columnAsList, +) where + +import Control.Exception (throw) +import Data.Proxy (Proxy (..)) +import qualified Data.Text as T +import qualified Data.Vector as V +import GHC.TypeLits (KnownSymbol, symbolVal) + +import DataFrame.Internal.Column (Columnable) +import DataFrame.Internal.Expression (Expr (Col)) +import qualified DataFrame.Operations.Core as D +import DataFrame.Typed.Schema (AssertPresent, Lookup) +import DataFrame.Typed.Types (TypedDataFrame (..)) + +{- | Retrieve a column as a boxed 'Vector', with the type determined by +the schema. The column must exist (enforced at compile time). +-} +columnAsVector :: + forall name cols a. + ( KnownSymbol name + , a ~ Lookup name cols + , Columnable a + , AssertPresent name cols + ) => + TypedDataFrame cols -> V.Vector a +columnAsVector (TDF df) = + either throw id $ D.columnAsVector (Col @a colName) df + where + colName = T.pack (symbolVal (Proxy @name)) + +-- | Retrieve a column as a list, with the type determined by the schema. +columnAsList :: + forall name cols a. + ( KnownSymbol name + , a ~ Lookup name cols + , Columnable a + , AssertPresent name cols + ) => + TypedDataFrame cols -> [a] +columnAsList (TDF df) = + D.columnAsList (Col @a colName) df + where + colName = T.pack (symbolVal (Proxy @name)) diff --git a/src/DataFrame/Typed/Aggregate.hs b/src/DataFrame/Typed/Aggregate.hs new file mode 100644 index 0000000..ac538c1 --- /dev/null +++ b/src/DataFrame/Typed/Aggregate.hs @@ -0,0 +1,97 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +module DataFrame.Typed.Aggregate ( + -- * Typed groupBy + groupBy, + + -- * Typed aggregation builder (Option B) + agg, + aggNil, + + -- * Running aggregations + aggregate, + + -- * Escape hatch + aggregateUntyped, +) where + +import Data.Proxy (Proxy (..)) +import qualified Data.Text as T +import GHC.TypeLits (KnownSymbol, Symbol, symbolVal) + +import DataFrame.Internal.Column (Columnable) +import qualified DataFrame.Internal.DataFrame as D +import DataFrame.Internal.Expression (NamedExpr) +import qualified DataFrame.Operations.Aggregation as DA + +import DataFrame.Typed.Freeze (unsafeFreeze) +import DataFrame.Typed.Schema +import DataFrame.Typed.Types + +{- | Group a typed DataFrame by one or more key columns. + +@ +grouped = groupBy \@'[\"department\"] employees +@ +-} +groupBy :: + forall (keys :: [Symbol]) cols. + (AllKnownSymbol keys) => + TypedDataFrame cols -> TypedGrouped keys cols +groupBy (TDF df) = TGD (DA.groupBy (symbolVals @keys) df) + +-- | The empty aggregation — no output columns beyond the group keys. +aggNil :: TAgg keys cols '[] +aggNil = TAggNil + +{- | Add one aggregation to the builder. + +Each call prepends a @Column name a@ to the result schema and records +the runtime 'NamedExpr'. The expression is validated against the +source schema @cols@ at compile time. + +@ +agg \@\"total_sales\" (tsum (col \@\"salary\")) + $ agg \@\"avg_price\" (tmean (col \@\"price\")) + $ aggNil +@ +-} +agg :: + forall name a keys cols aggs. + ( KnownSymbol name + , Columnable a + ) => + TExpr cols a -> TAgg keys cols aggs -> TAgg keys cols (Column name a ': aggs) +agg = TAggCons colName + where + colName = T.pack (symbolVal (Proxy @name)) + +{- | Run a typed aggregation. + +Result schema = grouping key columns ++ aggregated columns (in declaration order). + +@ +result = aggregate + (agg \@\"total\" (tsum (col @"salary")) $ agg \@\"count\" (tcount (col @"salary") $ aggNil) + (groupBy \@'[\"dept\"] employees) +-- result :: TDF '[Column \"dept\" Text, Column \"total\" Double, Column \"count\" Int] +@ +-} +aggregate :: + forall keys cols aggs. + TAgg keys cols aggs -> + TypedGrouped keys cols -> + TypedDataFrame (Append (GroupKeyColumns keys cols) (Reverse aggs)) +aggregate tagg (TGD gdf) = + unsafeFreeze (DA.aggregate (taggToNamedExprs tagg) gdf) + +-- | Escape hatch: run an untyped aggregation and return a raw 'DataFrame'. +aggregateUntyped :: [NamedExpr] -> TypedGrouped keys cols -> D.DataFrame +aggregateUntyped exprs (TGD gdf) = DA.aggregate exprs gdf diff --git a/src/DataFrame/Typed/Expr.hs b/src/DataFrame/Typed/Expr.hs new file mode 100644 index 0000000..61c1900 --- /dev/null +++ b/src/DataFrame/Typed/Expr.hs @@ -0,0 +1,265 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +{- | Type-safe expression construction for typed DataFrames. + +Unlike the untyped @Expr a@ where column references are unchecked strings, +'TExpr' ensures at compile time that: + +* Referenced columns exist in the schema +* Column types match the expression type + +== Example + +@ +type Schema = '[Column \"age\" Int, Column \"salary\" Double] + +-- This compiles: +goodExpr :: TExpr Schema Double +goodExpr = col \@\"salary\" + +-- This gives a compile-time error (column not found): +badExpr :: TExpr Schema Double +badExpr = col \@\"nonexistent\" + +-- This gives a compile-time error (type mismatch): +wrongType :: TExpr Schema Int +wrongType = col \@\"salary\" -- salary is Double, not Int +@ +-} +module DataFrame.Typed.Expr ( + -- * Core typed expression type (re-exported from Types) + TExpr (..), + + -- * Column reference (schema-checked) + col, + + -- * Literals + lit, + + -- * Conditional + ifThenElse, + + -- * Unary / binary lifting + lift, + lift2, + + -- * Comparison operators + (.==.), + (./=.), + (.<.), + (.<=.), + (.>=.), + (.>.), + + -- * Logical operators + (.&&.), + (.||.), + DataFrame.Typed.Expr.not, + + -- * Aggregation combinators + sum, + mean, + count, + minimum, + maximum, + collect, + + -- * Named expression helper + as, + + -- * Sort helpers + asc, + desc, +) where + +import Data.Proxy (Proxy (..)) +import Data.String (IsString (..)) +import qualified Data.Text as T +import qualified Data.Vector.Unboxed as VU +import GHC.TypeLits (KnownSymbol, Symbol, symbolVal) + +import DataFrame.Internal.Column (Columnable) +import DataFrame.Internal.Expression ( + AggStrategy (..), + BinaryOp (..), + Expr (..), + NamedExpr, + UExpr (..), + UnaryOp (..), + ) +import DataFrame.Internal.Statistics +import DataFrame.Typed.Schema (AssertPresent, Lookup) +import DataFrame.Typed.Types (TExpr (..), TSortOrder (..)) +import Prelude hiding (maximum, minimum, sum) + +{- | Create a typed column reference. This is the key type-safety entry point. + +The column name must exist in @cols@ and its type must match @a@. +Both checks happen at compile time via type families. + +@ +salary :: TExpr '[Column \"salary\" Double] Double +salary = col \@\"salary\" +@ +-} +col :: + forall (name :: Symbol) cols a. + ( KnownSymbol name + , a ~ Lookup name cols + , Columnable a + , AssertPresent name cols + ) => + TExpr cols a +col = TExpr (Col (T.pack (symbolVal (Proxy @name)))) + +{- | Create a literal expression. Valid for any schema since it +references no columns. +-} +lit :: (Columnable a) => a -> TExpr cols a +lit = TExpr . Lit + +-- | Conditional expression. +ifThenElse :: + (Columnable a) => + TExpr cols Bool -> TExpr cols a -> TExpr cols a -> TExpr cols a +ifThenElse (TExpr c) (TExpr t) (TExpr e) = TExpr (If c t e) + +------------------------------------------------------------------------------- +-- Numeric instances (mirror Expr's instances) +------------------------------------------------------------------------------- + +instance (Num a, Columnable a) => Num (TExpr cols a) where + (TExpr a) + (TExpr b) = TExpr (a + b) + (TExpr a) - (TExpr b) = TExpr (a - b) + (TExpr a) * (TExpr b) = TExpr (a * b) + negate (TExpr a) = TExpr (negate a) + abs (TExpr a) = TExpr (abs a) + signum (TExpr a) = TExpr (signum a) + fromInteger = TExpr . fromInteger + +instance (Fractional a, Columnable a) => Fractional (TExpr cols a) where + fromRational = TExpr . fromRational + (TExpr a) / (TExpr b) = TExpr (a / b) + +instance (Floating a, Columnable a) => Floating (TExpr cols a) where + pi = TExpr pi + exp (TExpr a) = TExpr (exp a) + sqrt (TExpr a) = TExpr (sqrt a) + log (TExpr a) = TExpr (log a) + (TExpr a) ** (TExpr b) = TExpr (a ** b) + logBase (TExpr a) (TExpr b) = TExpr (logBase a b) + sin (TExpr a) = TExpr (sin a) + cos (TExpr a) = TExpr (cos a) + tan (TExpr a) = TExpr (tan a) + asin (TExpr a) = TExpr (asin a) + acos (TExpr a) = TExpr (acos a) + atan (TExpr a) = TExpr (atan a) + sinh (TExpr a) = TExpr (sinh a) + cosh (TExpr a) = TExpr (cosh a) + asinh (TExpr a) = TExpr (asinh a) + acosh (TExpr a) = TExpr (acosh a) + atanh (TExpr a) = TExpr (atanh a) + +instance (IsString a, Columnable a) => IsString (TExpr cols a) where + fromString = TExpr . fromString + +------------------------------------------------------------------------------- +-- Lifting arbitrary functions +------------------------------------------------------------------------------- + +-- | Lift a unary function into a typed expression. +lift :: + (Columnable a, Columnable b) => (a -> b) -> TExpr cols a -> TExpr cols b +lift f (TExpr e) = TExpr (Unary (MkUnaryOp f "unaryUdf" Nothing) e) + +-- | Lift a binary function into typed expressions. +lift2 :: + (Columnable a, Columnable b, Columnable c) => + (a -> b -> c) -> TExpr cols a -> TExpr cols b -> TExpr cols c +lift2 f (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp f "binaryUdf" Nothing False 0) a b) + +infixl 4 .==., ./=., .<., .<=., .>=., .>. +infixr 3 .&&. +infixr 2 .||. + +(.==.) :: + (Columnable a, Eq a) => TExpr cols a -> TExpr cols a -> TExpr cols Bool +(.==.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (==) "eq" (Just "==") True 4) a b) + +(./=.) :: + (Columnable a, Eq a) => TExpr cols a -> TExpr cols a -> TExpr cols Bool +(./=.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (/=) "neq" (Just "/=") True 4) a b) + +(.<.) :: + (Columnable a, Ord a) => TExpr cols a -> TExpr cols a -> TExpr cols Bool +(.<.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (<) "lt" (Just "<") False 4) a b) + +(.<=.) :: + (Columnable a, Ord a) => TExpr cols a -> TExpr cols a -> TExpr cols Bool +(.<=.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (<=) "leq" (Just "<=") False 4) a b) + +(.>=.) :: + (Columnable a, Ord a) => TExpr cols a -> TExpr cols a -> TExpr cols Bool +(.>=.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (>=) "geq" (Just ">=") False 4) a b) + +(.>.) :: + (Columnable a, Ord a) => TExpr cols a -> TExpr cols a -> TExpr cols Bool +(.>.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (>) "gt" (Just ">") False 4) a b) + +(.&&.) :: TExpr cols Bool -> TExpr cols Bool -> TExpr cols Bool +(.&&.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (&&) "and" (Just "&&") True 3) a b) + +(.||.) :: TExpr cols Bool -> TExpr cols Bool -> TExpr cols Bool +(.||.) (TExpr a) (TExpr b) = TExpr (Binary (MkBinaryOp (||) "or" (Just "||") True 2) a b) + +not :: TExpr cols Bool -> TExpr cols Bool +not (TExpr e) = TExpr (Unary (MkUnaryOp Prelude.not "not" (Just "!")) e) + +------------------------------------------------------------------------------- +-- Aggregation combinators +------------------------------------------------------------------------------- + +sum :: (Columnable a, Num a) => TExpr cols a -> TExpr cols a +sum (TExpr e) = TExpr (Agg (FoldAgg "sum" Nothing (+)) e) + +mean :: (Columnable a, Real a, VU.Unbox a) => TExpr cols a -> TExpr cols Double +mean (TExpr e) = TExpr (Agg (CollectAgg "mean" mean') e) + +count :: (Columnable a) => TExpr cols a -> TExpr cols Int +count (TExpr e) = TExpr (Agg (FoldAgg "count" (Just 0) (\acc _ -> acc + 1)) e) + +minimum :: (Columnable a, Ord a) => TExpr cols a -> TExpr cols a +minimum (TExpr e) = TExpr (Agg (FoldAgg "minimum" Nothing min) e) + +maximum :: (Columnable a, Ord a) => TExpr cols a -> TExpr cols a +maximum (TExpr e) = TExpr (Agg (FoldAgg "maximum" Nothing max) e) + +collect :: (Columnable a) => TExpr cols a -> TExpr cols [a] +collect (TExpr e) = TExpr (Agg (FoldAgg "collect" (Just []) (flip (:))) e) + +------------------------------------------------------------------------------- +-- Named expression helper +------------------------------------------------------------------------------- + +-- | Create a 'NamedExpr' for use with 'aggregateUntyped'. +as :: (Columnable a) => TExpr cols a -> T.Text -> NamedExpr +as (TExpr e) name = (name, UExpr e) + +-- | Create an ascending sort order from a typed expression. +asc :: (Columnable a) => TExpr cols a -> TSortOrder cols +asc = Asc + +-- | Create a descending sort order from a typed expression. +desc :: (Columnable a) => TExpr cols a -> TSortOrder cols +desc = Desc diff --git a/src/DataFrame/Typed/Freeze.hs b/src/DataFrame/Typed/Freeze.hs new file mode 100644 index 0000000..fb00de2 --- /dev/null +++ b/src/DataFrame/Typed/Freeze.hs @@ -0,0 +1,86 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} + +module DataFrame.Typed.Freeze ( + -- * Safe boundary + freeze, + freezeWithError, + + -- * Escape hatches + thaw, + unsafeFreeze, +) where + +import qualified Data.Text as T +import Type.Reflection (SomeTypeRep) + +import qualified DataFrame.Internal.Column as C +import qualified DataFrame.Internal.DataFrame as D +import DataFrame.Operations.Core (columnNames) +import DataFrame.Typed.Schema (KnownSchema (..)) +import DataFrame.Typed.Types (TypedDataFrame (..)) + +{- | Validate that an untyped 'DataFrame' matches the expected schema @cols@, +then wrap it. Returns 'Nothing' on mismatch. +-} +freeze :: + forall cols. (KnownSchema cols) => D.DataFrame -> Maybe (TypedDataFrame cols) +freeze df = case validateSchema @cols df of + Left _ -> Nothing + Right _ -> Just (TDF df) + +-- | Like 'freeze' but returns a descriptive error message on failure. +freezeWithError :: + forall cols. + (KnownSchema cols) => + D.DataFrame -> Either T.Text (TypedDataFrame cols) +freezeWithError df = case validateSchema @cols df of + Left err -> Left err + Right _ -> Right (TDF df) + +{- | Unwrap a typed DataFrame back to the untyped representation. +Always safe; discards type information. +-} +thaw :: TypedDataFrame cols -> D.DataFrame +thaw (TDF df) = df + +{- | Wrap an untyped DataFrame without any validation. +Used internally after delegation where the library guarantees schema correctness. +-} +unsafeFreeze :: D.DataFrame -> TypedDataFrame cols +unsafeFreeze = TDF + +validateSchema :: + forall cols. + (KnownSchema cols) => + D.DataFrame -> Either T.Text () +validateSchema df = mapM_ checkCol (schemaEvidence @cols) + where + checkCol :: (T.Text, SomeTypeRep) -> Either T.Text () + checkCol (name, expectedRep) = case D.getColumn name df of + Nothing -> + Left $ + "Column '" + <> name + <> "' not found in DataFrame. " + <> "Available columns: " + <> T.pack (show (columnNames df)) + Just col -> + if matchesType expectedRep col + then Right () + else + Left $ + "Type mismatch on column '" + <> name + <> "': expected " + <> T.pack (show expectedRep) + <> ", got " + <> T.pack (C.columnTypeString col) + +-- | Check if a Column's element type matches the expected SomeTypeRep. +matchesType :: SomeTypeRep -> C.Column -> Bool +matchesType expected col = T.pack (show expected) == T.pack (C.columnTypeString col) diff --git a/src/DataFrame/Typed/Join.hs b/src/DataFrame/Typed/Join.hs new file mode 100644 index 0000000..fdb4928 --- /dev/null +++ b/src/DataFrame/Typed/Join.hs @@ -0,0 +1,72 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} + +module DataFrame.Typed.Join ( + -- * Typed joins + innerJoin, + leftJoin, + rightJoin, + fullOuterJoin, +) where + +import GHC.TypeLits (Symbol) + +import qualified DataFrame.Operations.Join as DJ + +import DataFrame.Typed.Freeze (unsafeFreeze) +import DataFrame.Typed.Schema +import DataFrame.Typed.Types (TypedDataFrame (..)) + +-- | Typed inner join on one or more key columns. +innerJoin :: + forall (keys :: [Symbol]) left right. + (AllKnownSymbol keys) => + TypedDataFrame left -> + TypedDataFrame right -> + TypedDataFrame (InnerJoinSchema keys left right) +innerJoin (TDF l) (TDF r) = + unsafeFreeze (DJ.innerJoin keyNames r l) + where + keyNames = symbolVals @keys + +-- | Typed left join. +leftJoin :: + forall (keys :: [Symbol]) left right. + (AllKnownSymbol keys) => + TypedDataFrame left -> + TypedDataFrame right -> + TypedDataFrame (LeftJoinSchema keys left right) +leftJoin (TDF l) (TDF r) = + unsafeFreeze (DJ.leftJoin keyNames r l) + where + keyNames = symbolVals @keys + +-- | Typed right join. +rightJoin :: + forall (keys :: [Symbol]) left right. + (AllKnownSymbol keys) => + TypedDataFrame left -> + TypedDataFrame right -> + TypedDataFrame (RightJoinSchema keys left right) +rightJoin (TDF l) (TDF r) = + unsafeFreeze (DJ.rightJoin keyNames r l) + where + keyNames = symbolVals @keys + +-- | Typed full outer join. +fullOuterJoin :: + forall (keys :: [Symbol]) left right. + (AllKnownSymbol keys) => + TypedDataFrame left -> + TypedDataFrame right -> + TypedDataFrame (FullOuterJoinSchema keys left right) +fullOuterJoin (TDF l) (TDF r) = + unsafeFreeze (DJ.fullOuterJoin keyNames r l) + where + keyNames = symbolVals @keys diff --git a/src/DataFrame/Typed/Operations.hs b/src/DataFrame/Typed/Operations.hs new file mode 100644 index 0000000..52854aa --- /dev/null +++ b/src/DataFrame/Typed/Operations.hs @@ -0,0 +1,389 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} + +module DataFrame.Typed.Operations ( + -- * Schema-preserving operations + filterWhere, + filter, + filterBy, + filterAllJust, + filterJust, + filterNothing, + sortBy, + take, + takeLast, + drop, + dropLast, + range, + cube, + distinct, + sample, + shuffle, + + -- * Schema-modifying operations + derive, + impute, + select, + exclude, + rename, + renameMany, + insert, + insertColumn, + insertVector, + cloneColumn, + dropColumn, + replaceColumn, + + -- * Metadata + dimensions, + nRows, + nColumns, + columnNames, + + -- * Vertical merge + append, + + -- * Pipe operator + (|>), +) where + +import Data.Function ((&)) +import Data.Proxy (Proxy (..)) +import qualified Data.Text as T +import qualified Data.Vector as V +import GHC.TypeLits (KnownSymbol, Symbol, symbolVal) +import System.Random (RandomGen) +import Prelude hiding (drop, filter, take) + +import qualified DataFrame.Functions as DF +import DataFrame.Internal.Column (Columnable) +import qualified DataFrame.Internal.Column as C +import qualified DataFrame.Operations.Aggregation as DA +import qualified DataFrame.Operations.Core as D +import DataFrame.Operations.Merge () +import qualified DataFrame.Operations.Permutation as D +import qualified DataFrame.Operations.Subset as D +import qualified DataFrame.Operations.Transformations as D + +-- Semigroup instance + +import DataFrame.Typed.Freeze (unsafeFreeze) +import DataFrame.Typed.Schema +import DataFrame.Typed.Types (TExpr (..), TSortOrder (..), TypedDataFrame (..)) +import qualified DataFrame.Typed.Types as T + +-- | Pipe operator, re-exported for convenience. +(|>) :: a -> (a -> b) -> b +(|>) = (&) + +infixl 1 |> + +------------------------------------------------------------------------------- +-- Schema-preserving operations +------------------------------------------------------------------------------- + +{- | Filter rows where a boolean expression evaluates to True. +The expression is validated against the schema at compile time. +-} +filterWhere :: TExpr cols Bool -> TypedDataFrame cols -> TypedDataFrame cols +filterWhere (TExpr expr) (TDF df) = TDF (D.filterWhere expr df) + +-- | Filter rows by applying a predicate to a typed expression. +filter :: + (Columnable a) => + TExpr cols a -> (a -> Bool) -> TypedDataFrame cols -> TypedDataFrame cols +filter (TExpr expr) pred' (TDF df) = TDF (D.filter expr pred' df) + +-- | Filter rows by a predicate on a column expression (flipped argument order). +filterBy :: + (Columnable a) => + (a -> Bool) -> TExpr cols a -> TypedDataFrame cols -> TypedDataFrame cols +filterBy pred' (TExpr expr) (TDF df) = TDF (D.filterBy pred' expr df) + +{- | Keep only rows where ALL Optional columns have Just values. +Strips 'Maybe' from all column types in the result schema. + +@ +df :: TDF '[Column \"x\" (Maybe Double), Column \"y\" Int] +filterAllJust df :: TDF '[Column \"x\" Double, Column \"y\" Int] +@ +-} +filterAllJust :: TypedDataFrame cols -> TypedDataFrame (StripAllMaybe cols) +filterAllJust (TDF df) = unsafeFreeze (D.filterAllJust df) + +{- | Keep only rows where the named column has Just values. +Strips 'Maybe' from that column's type in the result schema. + +@ +filterJust \@\"x\" df +@ +-} +filterJust :: + forall name cols. + ( KnownSymbol name + , AssertPresent name cols + ) => + TypedDataFrame cols -> TypedDataFrame (StripMaybeAt name cols) +filterJust (TDF df) = unsafeFreeze (D.filterJust colName df) + where + colName = T.pack (symbolVal (Proxy @name)) + +{- | Keep only rows where the named column has Nothing. +Schema is preserved (column types unchanged, just fewer rows). +-} +filterNothing :: + forall name cols. + ( KnownSymbol name + , AssertPresent name cols + ) => + TypedDataFrame cols -> TypedDataFrame cols +filterNothing (TDF df) = TDF (D.filterNothing colName df) + where + colName = T.pack (symbolVal (Proxy @name)) + +{- | Sort by the given typed sort orders. +Sort orders reference columns that are validated against the schema. +-} +sortBy :: [TSortOrder cols] -> TypedDataFrame cols -> TypedDataFrame cols +sortBy ords (TDF df) = TDF (D.sortBy (map toUntypedSort ords) df) + where + toUntypedSort :: TSortOrder cols -> D.SortOrder + toUntypedSort (Asc (TExpr e)) = D.Asc e + toUntypedSort (Desc (TExpr e)) = D.Desc e + +-- | Take the first @n@ rows. +take :: Int -> TypedDataFrame cols -> TypedDataFrame cols +take n (TDF df) = TDF (D.take n df) + +-- | Take the last @n@ rows. +takeLast :: Int -> TypedDataFrame cols -> TypedDataFrame cols +takeLast n (TDF df) = TDF (D.takeLast n df) + +-- | Drop the first @n@ rows. +drop :: Int -> TypedDataFrame cols -> TypedDataFrame cols +drop n (TDF df) = TDF (D.drop n df) + +-- | Drop the last @n@ rows. +dropLast :: Int -> TypedDataFrame cols -> TypedDataFrame cols +dropLast n (TDF df) = TDF (D.dropLast n df) + +-- | Take rows in the given range (start, end). +range :: (Int, Int) -> TypedDataFrame cols -> TypedDataFrame cols +range r (TDF df) = TDF (D.range r df) + +-- | Take a sub-cube of the DataFrame. +cube :: (Int, Int) -> TypedDataFrame cols -> TypedDataFrame cols +cube c (TDF df) = TDF (D.cube c df) + +-- | Remove duplicate rows. +distinct :: TypedDataFrame cols -> TypedDataFrame cols +distinct (TDF df) = TDF (DA.distinct df) + +-- | Randomly sample a fraction of rows. +sample :: + (RandomGen g) => g -> Double -> TypedDataFrame cols -> TypedDataFrame cols +sample g frac (TDF df) = TDF (D.sample g frac df) + +-- | Shuffle all rows randomly. +shuffle :: (RandomGen g) => g -> TypedDataFrame cols -> TypedDataFrame cols +shuffle g (TDF df) = TDF (D.shuffle g df) + +------------------------------------------------------------------------------- +-- Schema-modifying operations +------------------------------------------------------------------------------- + +{- | Derive a new column from a typed expression. The column name must NOT +already exist in the schema (enforced at compile time via 'AssertAbsent'). +The expression is validated against the current schema. + +@ +df' = derive \@\"total\" (col \@\"price\" * col \@\"qty\") df +-- df' :: TDF (Column \"total\" Double ': originalCols) +@ +-} +derive :: + forall name a cols. + ( KnownSymbol name + , Columnable a + , AssertAbsent name cols + ) => + TExpr cols a -> + TypedDataFrame cols -> + TypedDataFrame (Snoc cols (T.Column name a)) +derive (TExpr expr) (TDF df) = unsafeFreeze (D.derive colName expr df) + where + colName = T.pack (symbolVal (Proxy @name)) + +impute :: + forall name a cols. + ( KnownSymbol name + , Columnable a + ) => + a -> + TypedDataFrame cols -> + TypedDataFrame (Impute name cols) +impute value (TDF df) = + unsafeFreeze + (D.derive colName (DF.fromMaybe value (DF.col @(Maybe a) colName)) df) + where + colName = T.pack (symbolVal (Proxy @name)) + +-- | Select a subset of columns by name. +select :: + forall (names :: [Symbol]) cols. + (AllKnownSymbol names) => + TypedDataFrame cols -> TypedDataFrame (SubsetSchema names cols) +select (TDF df) = unsafeFreeze (D.select (symbolVals @names) df) + +-- | Exclude columns by name. +exclude :: + forall (names :: [Symbol]) cols. + (AllKnownSymbol names) => + TypedDataFrame cols -> TypedDataFrame (ExcludeSchema names cols) +exclude (TDF df) = unsafeFreeze (D.exclude (symbolVals @names) df) + +-- | Rename a column. +rename :: + forall old new cols. + (KnownSymbol old, KnownSymbol new) => + TypedDataFrame cols -> TypedDataFrame (RenameInSchema old new cols) +rename (TDF df) = unsafeFreeze (D.rename oldName newName df) + where + oldName = T.pack (symbolVal (Proxy @old)) + newName = T.pack (symbolVal (Proxy @new)) + +-- | Rename multiple columns from a type-level list of pairs. +renameMany :: + forall (pairs :: [(Symbol, Symbol)]) cols. + (AllKnownPairs pairs) => + TypedDataFrame cols -> TypedDataFrame (RenameManyInSchema pairs cols) +renameMany (TDF df) = unsafeFreeze (foldRenames (pairVals @pairs) df) + where + foldRenames [] df' = df' + foldRenames ((old, new) : rest) df' = foldRenames rest (D.rename old new df') + +-- | Insert a new column from a Foldable container. +insert :: + forall name a cols t. + ( KnownSymbol name + , Columnable a + , Foldable t + , AssertAbsent name cols + ) => + t a -> TypedDataFrame cols -> TypedDataFrame (T.Column name a ': cols) +insert xs (TDF df) = unsafeFreeze (D.insert colName xs df) + where + colName = T.pack (symbolVal (Proxy @name)) + +-- | Insert a raw 'Column' value. +insertColumn :: + forall name a cols. + ( KnownSymbol name + , Columnable a + , AssertAbsent name cols + ) => + C.Column -> TypedDataFrame cols -> TypedDataFrame (T.Column name a ': cols) +insertColumn col (TDF df) = unsafeFreeze (D.insertColumn colName col df) + where + colName = T.pack (symbolVal (Proxy @name)) + +-- | Insert a boxed 'Vector'. +insertVector :: + forall name a cols. + ( KnownSymbol name + , Columnable a + , AssertAbsent name cols + ) => + V.Vector a -> TypedDataFrame cols -> TypedDataFrame (T.Column name a ': cols) +insertVector vec (TDF df) = unsafeFreeze (D.insertVector colName vec df) + where + colName = T.pack (symbolVal (Proxy @name)) + +-- | Clone an existing column under a new name. +cloneColumn :: + forall old new cols. + ( KnownSymbol old + , KnownSymbol new + , AssertPresent old cols + , AssertAbsent new cols + ) => + TypedDataFrame cols -> TypedDataFrame (T.Column new (Lookup old cols) ': cols) +cloneColumn (TDF df) = unsafeFreeze (D.cloneColumn oldName newName df) + where + oldName = T.pack (symbolVal (Proxy @old)) + newName = T.pack (symbolVal (Proxy @new)) + +-- | Drop a column by name. +dropColumn :: + forall name cols. + ( KnownSymbol name + , AssertPresent name cols + ) => + TypedDataFrame cols -> TypedDataFrame (RemoveColumn name cols) +dropColumn (TDF df) = unsafeFreeze (D.exclude [colName] df) + where + colName = T.pack (symbolVal (Proxy @name)) + +{- | Replace an existing column with new values derived from a typed expression. +The column must already exist and the new type must match. +-} +replaceColumn :: + forall name a cols. + ( KnownSymbol name + , Columnable a + , a ~ Lookup name cols + , AssertPresent name cols + ) => + TExpr cols a -> TypedDataFrame cols -> TypedDataFrame cols +replaceColumn (TExpr expr) (TDF df) = unsafeFreeze (D.derive colName expr df) + where + colName = T.pack (symbolVal (Proxy @name)) + +-- | Vertically merge two DataFrames with the same schema. +append :: TypedDataFrame cols -> TypedDataFrame cols -> TypedDataFrame cols +append (TDF a) (TDF b) = TDF (a <> b) + +------------------------------------------------------------------------------- +-- Metadata (pass-through) +------------------------------------------------------------------------------- + +dimensions :: TypedDataFrame cols -> (Int, Int) +dimensions (TDF df) = D.dimensions df + +nRows :: TypedDataFrame cols -> Int +nRows (TDF df) = D.nRows df + +nColumns :: TypedDataFrame cols -> Int +nColumns (TDF df) = D.nColumns df + +columnNames :: TypedDataFrame cols -> [T.Text] +columnNames (TDF df) = D.columnNames df + +------------------------------------------------------------------------------- +-- Internal helpers +------------------------------------------------------------------------------- + +-- | Helper class for extracting [(Text, Text)] from type-level pairs. +class AllKnownPairs (pairs :: [(Symbol, Symbol)]) where + pairVals :: [(T.Text, T.Text)] + +instance AllKnownPairs '[] where + pairVals = [] + +instance + (KnownSymbol a, KnownSymbol b, AllKnownPairs rest) => + AllKnownPairs ('(a, b) ': rest) + where + pairVals = + ( T.pack (symbolVal (Proxy @a)) + , T.pack (symbolVal (Proxy @b)) + ) + : pairVals @rest diff --git a/src/DataFrame/Typed/Schema.hs b/src/DataFrame/Typed/Schema.hs new file mode 100644 index 0000000..7510b7a --- /dev/null +++ b/src/DataFrame/Typed/Schema.hs @@ -0,0 +1,339 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + +module DataFrame.Typed.Schema ( + -- * Type families for schema manipulation + Lookup, + HasName, + RemoveColumn, + Impute, + SubsetSchema, + ExcludeSchema, + RenameInSchema, + RenameManyInSchema, + Append, + Snoc, + Reverse, + ColumnNames, + AssertAbsent, + AssertPresent, + IsElem, + + -- * Maybe-stripping families + StripAllMaybe, + StripMaybeAt, + + -- * Join schema families + SharedNames, + UniqueLeft, + InnerJoinSchema, + LeftJoinSchema, + RightJoinSchema, + FullOuterJoinSchema, + WrapMaybe, + WrapMaybeColumns, + CollidingColumns, + + -- * GroupBy helpers + GroupKeyColumns, + + -- * KnownSchema class + KnownSchema (..), + + -- * Helpers + AllKnownSymbol (..), +) where + +import Data.Kind (Constraint, Type) +import Data.Proxy (Proxy (..)) +import qualified Data.Text as T +import Data.These (These) +import GHC.TypeLits +import Type.Reflection (SomeTypeRep, Typeable, someTypeRep) + +import DataFrame.Internal.Column (Columnable) +import DataFrame.Internal.Types (If) +import DataFrame.Typed.Types (Column) + +-- | Look up the element type of a column by name. +type family Lookup (name :: Symbol) (cols :: [Type]) :: Type where + Lookup name (Column name a ': _) = a + Lookup name (Column _ _ ': rest) = Lookup name rest + Lookup name '[] = + TypeError + ('Text "Column '" ':<>: 'Text name ':<>: 'Text "' not found in schema") + +-- | Unwrap a Maybe from a type after we impute values. +type family Impute (name :: Symbol) (cols :: [Type]) :: [Type] where + Impute name (Column name (Maybe a) ': rest) = Column name a ': rest + Impute name (Column name _ ': rest) = + TypeError + ('Text "Column '" ':<>: 'Text name ':<>: 'Text "' is not of kind Maybe *") + Impute name (col ': rest) = col ': Impute name rest + Impute name '[] = '[] + +-- | Add type to the end of a list. +type family Snoc (xs :: [k]) (x :: k) :: [k] where + Snoc '[] x = '[x] + Snoc (y ': ys) x = y ': Snoc ys x + +-- | Check whether a column name exists in a schema (type-level Bool). +type family HasName (name :: Symbol) (cols :: [Type]) :: Bool where + HasName name (Column name _ ': _) = 'True + HasName name (Column _ _ ': rest) = HasName name rest + HasName name '[] = 'False + +-- | Remove a column by name from a schema. +type family RemoveColumn (name :: Symbol) (cols :: [Type]) :: [Type] where + RemoveColumn name (Column name _ ': rest) = rest + RemoveColumn name (col ': rest) = col ': RemoveColumn name rest + RemoveColumn name '[] = '[] + +-- | Select a subset of columns by a list of names. +type family SubsetSchema (names :: [Symbol]) (cols :: [Type]) :: [Type] where + SubsetSchema '[] cols = '[] + SubsetSchema (n ': ns) cols = Column n (Lookup n cols) ': SubsetSchema ns cols + +-- | Exclude columns by a list of names. +type family ExcludeSchema (names :: [Symbol]) (cols :: [Type]) :: [Type] where + ExcludeSchema names '[] = '[] + ExcludeSchema names (Column n a ': rest) = + If + (IsElem n names) + (ExcludeSchema names rest) + (Column n a ': ExcludeSchema names rest) + +-- | Type-level elem for Symbols +type family IsElem (x :: Symbol) (xs :: [Symbol]) :: Bool where + IsElem x '[] = 'False + IsElem x (x ': _) = 'True + IsElem x (_ ': xs) = IsElem x xs + +-- | Rename a column in the schema. +type family RenameInSchema (old :: Symbol) (new :: Symbol) (cols :: [Type]) :: [Type] where + RenameInSchema old new (Column old a ': rest) = Column new a ': rest + RenameInSchema old new (col ': rest) = col ': RenameInSchema old new rest + RenameInSchema old new '[] = + TypeError + ('Text "Cannot rename: column '" ':<>: 'Text old ':<>: 'Text "' not found") + +-- | Rename multiple columns. +type family RenameManyInSchema (pairs :: [(Symbol, Symbol)]) (cols :: [Type]) :: [Type] where + RenameManyInSchema '[] cols = cols + RenameManyInSchema ('(old, new) ': rest) cols = + RenameManyInSchema rest (RenameInSchema old new cols) + +-- | Append two type-level lists. +type family Append (xs :: [k]) (ys :: [k]) :: [k] where + Append '[] ys = ys + Append (x ': xs) ys = x ': Append xs ys + +-- | Reverse a type-level list. +type family Reverse (xs :: [Type]) :: [Type] where + Reverse xs = ReverseAcc xs '[] + +type family ReverseAcc (xs :: [Type]) (acc :: [Type]) :: [Type] where + ReverseAcc '[] acc = acc + ReverseAcc (x ': xs) acc = ReverseAcc xs (x ': acc) + +-- | Extract column names as a type-level list of Symbols. +type family ColumnNames (cols :: [Type]) :: [Symbol] where + ColumnNames '[] = '[] + ColumnNames (Column n _ ': rest) = n ': ColumnNames rest + +-- | Assert that a column name is absent from the schema (for derive/insert). +type family AssertAbsent (name :: Symbol) (cols :: [Type]) :: Constraint where + AssertAbsent name cols = AssertAbsentHelper name (HasName name cols) cols + +type family + AssertAbsentHelper (name :: Symbol) (found :: Bool) (cols :: [Type]) :: + Constraint + where + AssertAbsentHelper name 'False cols = () + AssertAbsentHelper name 'True cols = + TypeError + ( 'Text "Column '" + ':<>: 'Text name + ':<>: 'Text "' already exists in schema. " + ':<>: 'Text "Use replaceColumn to overwrite." + ) + +-- | Assert that a column name is present in the schema. +type family AssertPresent (name :: Symbol) (cols :: [Type]) :: Constraint where + AssertPresent name cols = AssertPresentHelper name (HasName name cols) cols + +type family + AssertPresentHelper (name :: Symbol) (found :: Bool) (cols :: [Type]) :: + Constraint + where + AssertPresentHelper name 'True cols = () + AssertPresentHelper name 'False cols = + TypeError + ('Text "Column '" ':<>: 'Text name ':<>: 'Text "' not found in schema") + +{- | Strip 'Maybe' from all columns. Used by 'filterAllJust'. + +@Column "x" (Maybe Double)@ becomes @Column "x" Double@. +@Column "y" Int@ stays @Column "y" Int@. +-} +type family StripAllMaybe (cols :: [Type]) :: [Type] where + StripAllMaybe '[] = '[] + StripAllMaybe (Column n (Maybe a) ': rest) = Column n a ': StripAllMaybe rest + StripAllMaybe (Column n a ': rest) = Column n a ': StripAllMaybe rest + +{- | Strip 'Maybe' from a single named column. Used by 'filterJust'. + +@StripMaybeAt "x" '[Column "x" (Maybe Double), Column "y" Int]@ + = @'[Column "x" Double, Column "y" Int]@ +-} +type family StripMaybeAt (name :: Symbol) (cols :: [Type]) :: [Type] where + StripMaybeAt name (Column name (Maybe a) ': rest) = Column name a ': rest + StripMaybeAt name (Column name a ': rest) = Column name a ': rest + StripMaybeAt name (col ': rest) = col ': StripMaybeAt name rest + StripMaybeAt name '[] = + TypeError + ('Text "Column '" ':<>: 'Text name ':<>: 'Text "' not found in schema") + +-- | Extract column names that appear in both schemas. +type family SharedNames (left :: [Type]) (right :: [Type]) :: [Symbol] where + SharedNames '[] right = '[] + SharedNames (Column n _ ': rest) right = + If (HasName n right) (n ': SharedNames rest right) (SharedNames rest right) + +-- | Columns from @left@ whose names do NOT appear in @right@. +type family UniqueLeft (left :: [Type]) (rightNames :: [Symbol]) :: [Type] where + UniqueLeft '[] _ = '[] + UniqueLeft (Column n a ': rest) rn = + If (IsElem n rn) (UniqueLeft rest rn) (Column n a ': UniqueLeft rest rn) + +-- | Wrap column types in Maybe. +type family WrapMaybe (cols :: [Type]) :: [Type] where + WrapMaybe '[] = '[] + WrapMaybe (Column n a ': rest) = Column n (Maybe a) ': WrapMaybe rest + +-- | Wrap selected columns in Maybe by name list. +type family WrapMaybeColumns (names :: [Symbol]) (cols :: [Type]) :: [Type] where + WrapMaybeColumns names '[] = '[] + WrapMaybeColumns names (Column n a ': rest) = + If + (IsElem n names) + (Column n (Maybe a) ': WrapMaybeColumns names rest) + (Column n a ': WrapMaybeColumns names rest) + +-- | Columns in left whose names collide with right (excluding keys). +type family CollidingColumns (left :: [Type]) (right :: [Type]) (keys :: [Symbol]) :: [Type] where + CollidingColumns '[] _ _ = '[] + CollidingColumns (Column n a ': rest) right keys = + If + (IsElem n keys) + (CollidingColumns rest right keys) + ( If + (HasName n right) + (Column n (These a (Lookup n right)) ': CollidingColumns rest right keys) + (CollidingColumns rest right keys) + ) + +-- | Inner join result schema. +type family InnerJoinSchema (keys :: [Symbol]) (left :: [Type]) (right :: [Type]) :: [Type] where + InnerJoinSchema keys left right = + Append + (SubsetSchema keys left) + ( Append + (UniqueLeft left (Append keys (ColumnNames right))) + ( Append + (UniqueLeft right (Append keys (ColumnNames left))) + (CollidingColumns left right keys) + ) + ) + +-- | Left join result schema. +type family LeftJoinSchema (keys :: [Symbol]) (left :: [Type]) (right :: [Type]) :: [Type] where + LeftJoinSchema keys left right = + Append + (SubsetSchema keys left) + ( Append + (UniqueLeft left (Append keys (ColumnNames right))) + ( Append + (WrapMaybe (UniqueLeft right (Append keys (ColumnNames left)))) + (CollidingColumns left right keys) + ) + ) + +-- | Right join result schema. +type family RightJoinSchema (keys :: [Symbol]) (left :: [Type]) (right :: [Type]) :: [Type] where + RightJoinSchema keys left right = + Append + (SubsetSchema keys right) + ( Append + (WrapMaybe (UniqueLeft left (Append keys (ColumnNames right)))) + ( Append + (UniqueLeft right (Append keys (ColumnNames left))) + (CollidingColumns left right keys) + ) + ) + +-- | Full outer join result schema. +type family + FullOuterJoinSchema (keys :: [Symbol]) (left :: [Type]) (right :: [Type]) :: + [Type] + where + FullOuterJoinSchema keys left right = + Append + (WrapMaybe (SubsetSchema keys left)) + ( Append + (WrapMaybe (UniqueLeft left (Append keys (ColumnNames right)))) + ( Append + (WrapMaybe (UniqueLeft right (Append keys (ColumnNames left)))) + (CollidingColumns left right keys) + ) + ) + +------------------------------------------------------------------------------- +-- GroupBy helpers +------------------------------------------------------------------------------- + +-- | Extract Column entries from a schema whose names appear in @keys@. +type family GroupKeyColumns (keys :: [Symbol]) (cols :: [Type]) :: [Type] where + GroupKeyColumns keys '[] = '[] + GroupKeyColumns keys (Column n a ': rest) = + If + (IsElem n keys) + (Column n a ': GroupKeyColumns keys rest) + (GroupKeyColumns keys rest) + +-- | Provides runtime evidence of a schema: a list of (name, TypeRep) pairs. +class KnownSchema (cols :: [Type]) where + schemaEvidence :: [(T.Text, SomeTypeRep)] + +instance KnownSchema '[] where + schemaEvidence = [] + +instance + (KnownSymbol name, Typeable a, Columnable a, KnownSchema rest) => + KnownSchema (Column name a ': rest) + where + schemaEvidence = + (T.pack (symbolVal (Proxy @name)), someTypeRep (Proxy @a)) + : schemaEvidence @rest + +-- | A class that provides a list of 'Text' values for a type-level list of Symbols. +class AllKnownSymbol (names :: [Symbol]) where + symbolVals :: [T.Text] + +instance AllKnownSymbol '[] where + symbolVals = [] + +instance (KnownSymbol n, AllKnownSymbol ns) => AllKnownSymbol (n ': ns) where + symbolVals = T.pack (symbolVal (Proxy @n)) : symbolVals @ns diff --git a/src/DataFrame/Typed/TH.hs b/src/DataFrame/Typed/TH.hs new file mode 100644 index 0000000..c1d460c --- /dev/null +++ b/src/DataFrame/Typed/TH.hs @@ -0,0 +1,94 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TemplateHaskellQuotes #-} +{-# LANGUAGE TypeApplications #-} + +module DataFrame.Typed.TH ( + -- * Schema inference + deriveSchema, + deriveSchemaFromCsvFile, + + -- * Re-export for TH splices + TypedDataFrame, + Column, +) where + +import Control.Monad.IO.Class +import qualified Data.List as L +import qualified Data.Map as M +import qualified Data.Text as T + +import Language.Haskell.TH + +import qualified DataFrame.IO.CSV as D +import qualified DataFrame.Internal.Column as C +import qualified DataFrame.Internal.DataFrame as D +import DataFrame.Typed.Types (Column, TypedDataFrame) + +{- | Generate a type synonym for a schema based on an existing 'DataFrame'. + +@ +-} + +{- $(deriveSchema \"IrisSchema\" irisDF) +-- Generates: type IrisSchema = '[Column \"sepal_length\" Double, ...] +@ +-} + +deriveSchema :: String -> D.DataFrame -> DecsQ +deriveSchema typeName df = do + let cols = getSchemaInfo df + let names = map fst cols + case findDuplicate names of + Just dup -> fail $ "Duplicate column name in DataFrame: " ++ T.unpack dup + Nothing -> pure () + colTypes <- mapM mkColumnType cols + let schemaType = foldr (\t acc -> PromotedConsT `AppT` t `AppT` acc) PromotedNilT colTypes + let synName = mkName typeName + pure [TySynD synName [] schemaType] + +deriveSchemaFromCsvFile :: String -> String -> DecsQ +deriveSchemaFromCsvFile typeName path = do + df <- liftIO (D.readCsv path) + deriveSchema typeName df + +getSchemaInfo :: D.DataFrame -> [(T.Text, String)] +getSchemaInfo df = + let orderedNames = + map fst $ + L.sortBy (\(_, a) (_, b) -> compare a b) $ + M.toList (D.columnIndices df) + in map (\name -> (name, getColumnTypeStr name df)) orderedNames + +getColumnTypeStr :: T.Text -> D.DataFrame -> String +getColumnTypeStr name df = case D.getColumn name df of + Just col -> C.columnTypeString col + Nothing -> error $ "Column not found: " ++ T.unpack name + +mkColumnType :: (T.Text, String) -> Q Type +mkColumnType (name, tyStr) = do + ty <- parseTypeString tyStr + let nameLit = LitT (StrTyLit (T.unpack name)) + pure $ ConT ''Column `AppT` nameLit `AppT` ty + +parseTypeString :: String -> Q Type +parseTypeString "Int" = pure $ ConT ''Int +parseTypeString "Double" = pure $ ConT ''Double +parseTypeString "Float" = pure $ ConT ''Float +parseTypeString "Bool" = pure $ ConT ''Bool +parseTypeString "Char" = pure $ ConT ''Char +parseTypeString "String" = pure $ ConT ''String +parseTypeString "Text" = pure $ ConT ''T.Text +parseTypeString "Integer" = pure $ ConT ''Integer +parseTypeString s + | "Maybe " `L.isPrefixOf` s = do + inner <- parseTypeString (L.drop 6 s) + pure $ ConT ''Maybe `AppT` inner +parseTypeString s = fail $ "Unsupported column type in schema inference: " ++ s + +findDuplicate :: (Eq a) => [a] -> Maybe a +findDuplicate [] = Nothing +findDuplicate (x : xs) + | x `elem` xs = Just x + | otherwise = findDuplicate xs diff --git a/src/DataFrame/Typed/Types.hs b/src/DataFrame/Typed/Types.hs new file mode 100644 index 0000000..286a3bf --- /dev/null +++ b/src/DataFrame/Typed/Types.hs @@ -0,0 +1,117 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ExistentialQuantification #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeOperators #-} + +module DataFrame.Typed.Types ( + -- * Core phantom-typed wrapper + TypedDataFrame (..), + + -- * Column phantom type (no constructors) + Column, + + -- * Typed expressions (schema-validated) + TExpr (..), + + -- * Typed sort orders + TSortOrder (..), + + -- * Grouped typed dataframe + TypedGrouped (..), + + -- * Typed aggregation builder (Option B) + TAgg (..), + taggToNamedExprs, + + -- * Re-export These + These (..), +) where + +import Data.Kind (Type) +import Data.These (These (..)) +import GHC.TypeLits (Symbol) + +import qualified Data.Text as T +import DataFrame.Internal.Column (Columnable) +import qualified DataFrame.Internal.DataFrame as D +import DataFrame.Internal.Expression (Expr, NamedExpr, UExpr (..)) + +{- | A phantom-typed wrapper over the untyped 'DataFrame'. + +The type parameter @cols@ is a type-level list of @Column name ty@ entries +that tracks the schema at compile time. All operations delegate to the +untyped core at runtime and update the phantom type at compile time. +-} +newtype TypedDataFrame (cols :: [Type]) = TDF {unTDF :: D.DataFrame} + +instance Show (TypedDataFrame cols) where + show (TDF df) = show df + +instance Eq (TypedDataFrame cols) where + (TDF a) == (TDF b) = a == b + +{- | A phantom type that pairs a type-level column name ('Symbol') +with its element type. Has no value-level constructors — used +purely at the type level to describe schemas. +-} +data Column (name :: Symbol) (a :: Type) + +{- | A typed expression validated against schema @cols@, producing values of type @a@. + +Unlike the untyped 'Expr a', a 'TExpr' can only be constructed through +type-safe combinators ('col', 'lit', arithmetic operations) that verify +column references exist in the schema with the correct type. + +Use 'unTExpr' to extract the underlying 'Expr' for delegation to the untyped API. +-} +newtype TExpr (cols :: [Type]) a = TExpr {unTExpr :: Expr a} + +-- | A typed sort order validated against schema @cols@. +data TSortOrder (cols :: [Type]) where + Asc :: (Columnable a) => TExpr cols a -> TSortOrder cols + Desc :: (Columnable a) => TExpr cols a -> TSortOrder cols + +-- | A phantom-typed wrapper over 'GroupedDataFrame'. +newtype TypedGrouped (keys :: [Symbol]) (cols :: [Type]) + = TGD {unTGD :: D.GroupedDataFrame} + +{- | A typed aggregation builder (Option B). + +Accumulates 'NamedExpr' values at the term level while building +the result schema at the type level. Each @agg@ call prepends a +'Column' to the @aggs@ phantom list. + +Usage: + +@ +agg \@\"total\" (F.sum salary) + $ agg \@\"avg_age\" (F.mean age) + $ aggNil +@ +-} +data TAgg (keys :: [Symbol]) (cols :: [Type]) (aggs :: [Type]) where + TAggNil :: TAgg keys cols '[] + TAggCons :: + (Columnable a) => + -- | column name + T.Text -> + -- | typed aggregation expression + TExpr cols a -> + -- | rest + TAgg keys cols aggs -> + TAgg keys cols (Column name a ': aggs) + +{- | Extract the runtime 'NamedExpr' list from a 'TAgg', in +declaration order (reversed from the cons-built order). +-} +taggToNamedExprs :: TAgg keys cols aggs -> [NamedExpr] +taggToNamedExprs = reverse . go + where + go :: TAgg keys cols aggs -> [NamedExpr] + go TAggNil = [] + go (TAggCons name (TExpr expr) rest) = (name, UExpr expr) : go rest diff --git a/tests/Main.hs b/tests/Main.hs index 37820ed..6569e1b 100644 --- a/tests/Main.hs +++ b/tests/Main.hs @@ -5134,15 +5134,16 @@ isSuccessful _ = False main :: IO () main = do result <- runTestTT tests - -- Property tests - propRes <- - mapM - (quickCheckWithResult stdArgs) - Operations.Subset.tests - monadRes <- mapM (quickCheckWithResult stdArgs) Monad.tests - if failures result > 0 - || errors result > 0 - || not (all isSuccessful propRes) - || not (all isSuccessful monadRes) + if failures result > 0 || errors result > 0 then Exit.exitFailure - else Exit.exitSuccess + else do + -- Property tests + propRes <- + mapM + (quickCheckWithResult stdArgs) + Operations.Subset.tests + monadRes <- mapM (quickCheckWithResult stdArgs) Monad.tests + if not (all isSuccessful propRes) + || not (all isSuccessful monadRes) + then Exit.exitFailure + else Exit.exitSuccess diff --git a/tests/Operations/Aggregations.hs b/tests/Operations/Aggregations.hs index d43b0ae..7bb2307 100644 --- a/tests/Operations/Aggregations.hs +++ b/tests/Operations/Aggregations.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeApplications #-} @@ -7,6 +8,7 @@ import qualified Data.Text as T import qualified DataFrame as D import qualified DataFrame.Functions as F import qualified DataFrame.Internal.Column as DI +import qualified DataFrame.Typed as DT import Data.Function import DataFrame.Operators @@ -18,7 +20,7 @@ values = , ("test2", DI.fromList ([12, 11 .. 1] :: [Int])) , ("test3", DI.fromList ([1 .. 12] :: [Int])) , ("test4", DI.fromList ['a' .. 'l']) - , ("test4", DI.fromList (map show ['a' .. 'l'])) + , ("test5", DI.fromList (map show ['a' .. 'l'])) , ("test6", DI.fromList ([1 .. 12] :: [Integer])) ] @@ -42,6 +44,33 @@ foldAggregation = ) ) +foldAggregationTyped :: Test +foldAggregationTyped = + TestCase + ( assertEqual + "Typed counting elements after grouping gives correct numbers" + ( D.fromNamedColumns + [ ("test1", DI.fromList [1 :: Int, 2, 3]) + , ("test2_count", DI.fromList [6 :: Int, 3, 3]) + ] + ) + ( testData + & either (error . show) id + . DT.freezeWithError + @[ DT.Column "test1" Int + , DT.Column "test2" Int + , DT.Column "test3" Int + , DT.Column "test4" Char + , DT.Column "test5" String + , DT.Column "test6" Integer + ] + & DT.groupBy @'["test1"] + & DT.aggregate (DT.agg @"test2_count" (DT.count (DT.col @"test2")) DT.aggNil) + & DT.sortBy [DT.asc (DT.col @"test1")] + & DT.thaw + ) + ) + numericAggregation :: Test numericAggregation = TestCase @@ -59,6 +88,33 @@ numericAggregation = ) ) +numericAggregationTyped :: Test +numericAggregationTyped = + TestCase + ( assertEqual + "Typed ean works for ints" + ( D.fromNamedColumns + [ ("test1", DI.fromList [1 :: Int, 2, 3]) + , ("test2_mean", DI.fromList [6.5 :: Double, 8.0, 5.0]) + ] + ) + ( testData + & either (error . show) id + . DT.freezeWithError + @[ DT.Column "test1" Int + , DT.Column "test2" Int + , DT.Column "test3" Int + , DT.Column "test4" Char + , DT.Column "test5" String + , DT.Column "test6" Integer + ] + & DT.groupBy @'["test1"] + & DT.aggregate (DT.agg @"test2_mean" (DT.mean (DT.col @"test2")) DT.aggNil) + & DT.sortBy [DT.asc (DT.col @"test1")] + & DT.thaw + ) + ) + numericAggregationOfUnaggregatedUnaryOp :: Test numericAggregationOfUnaggregatedUnaryOp = TestCase @@ -154,7 +210,9 @@ aggregationOnNoRows = tests :: [Test] tests = [ TestLabel "foldAggregation" foldAggregation + , TestLabel "foldAggregationTyped" foldAggregationTyped , TestLabel "numericAggregation" numericAggregation + , TestLabel "numericAggregationTyped" numericAggregationTyped , TestLabel "numericAggregationOfUnaggregatedUnaryOp" numericAggregationOfUnaggregatedUnaryOp diff --git a/tests/Operations/Derive.hs b/tests/Operations/Derive.hs index 88692cf..0675ea3 100644 --- a/tests/Operations/Derive.hs +++ b/tests/Operations/Derive.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeApplications #-} @@ -10,6 +11,7 @@ import qualified DataFrame as D import qualified DataFrame.Functions as F import qualified DataFrame.Internal.Column as DI import qualified DataFrame.Internal.DataFrame as DI +import qualified DataFrame.Typed as DT import Test.HUnit @@ -44,7 +46,33 @@ deriveWAI = ) ) +deriveWAITyped :: Test +deriveWAITyped = + TestCase + ( assertEqual + "typed derive works with column expression" + (zipWith (\n c -> show n ++ [c]) [1 .. 26] ['a' .. 'z']) + ( DT.columnAsList @"test4" $ + DT.derive + @"test4" + ( DT.lift2 + (++) + (DT.lift show (DT.col @"test1")) + (DT.lift (: ([] :: [Char])) (DT.col @"test3")) + ) + ( either + (error . show) + id + ( DT.freezeWithError + @[DT.Column "test1" Int, DT.Column "test2" String, DT.Column "test3" Char] + testData + ) + ) + ) + ) + tests :: [Test] tests = [ TestLabel "deriveWAI" deriveWAI + , TestLabel "deriveWAITyped" deriveWAITyped ] diff --git a/tests/Operations/Join.hs b/tests/Operations/Join.hs index e34757d..24cde38 100644 --- a/tests/Operations/Join.hs +++ b/tests/Operations/Join.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE DataKinds #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE TypeApplications #-} @@ -8,6 +9,7 @@ import Data.These import qualified DataFrame as D import qualified DataFrame.Functions as F import DataFrame.Operations.Join +import qualified DataFrame.Typed as DT import Test.HUnit df1 :: D.DataFrame @@ -66,6 +68,54 @@ testRightJoin = (D.sortBy [D.Asc (F.col @Text "key")] (rightJoin ["key"] df2 df1)) ) +tdf1 :: DT.TypedDataFrame [DT.Column "key" Text, DT.Column "A" Text] +tdf1 = either (error . show) id (DT.freezeWithError df1) + +tdf2 :: DT.TypedDataFrame [DT.Column "key" Text, DT.Column "B" Text] +tdf2 = either (error . show) id (DT.freezeWithError df2) + +testInnerJoinTyped :: Test +testInnerJoinTyped = + TestCase + ( assertEqual + "Test typed inner join with single key" + ( D.fromNamedColumns + [ ("key", D.fromList ["K0" :: Text, "K1", "K2"]) + , ("A", D.fromList ["A0" :: Text, "A1", "A2"]) + , ("B", D.fromList ["B0" :: Text, "B1", "B2"]) + ] + ) + (DT.thaw $ DT.sortBy [DT.asc (DT.col @"key")] (DT.innerJoin @'["key"] tdf1 tdf2)) + ) + +testLeftJoinTyped :: Test +testLeftJoinTyped = + TestCase + ( assertEqual + "Test typed left join with single key" + ( D.fromNamedColumns + [ ("key", D.fromList ["K0" :: Text, "K1", "K2", "K3", "K4", "K5"]) + , ("A", D.fromList ["A0" :: Text, "A1", "A2", "A3", "A4", "A5"]) + , ("B", D.fromList [Just "B0", Just "B1" :: Maybe Text, Just "B2"]) + ] + ) + (DT.thaw $ DT.sortBy [DT.asc (DT.col @"key")] (DT.leftJoin @'["key"] tdf1 tdf2)) + ) + +testRightJoinTyped :: Test +testRightJoinTyped = + TestCase + ( assertEqual + "Test typed right join with single key" + ( D.fromNamedColumns + [ ("key", D.fromList ["K0" :: Text, "K1", "K2"]) + , ("A", D.fromList [Just "A0" :: Maybe Text, Just "A1", Just "A2"]) + , ("B", D.fromList ["B0" :: Text, "B1", "B2"]) + ] + ) + (DT.thaw $ DT.sortBy [DT.asc (DT.col @"key")] (DT.rightJoin @'["key"] tdf1 tdf2)) + ) + staffDf :: D.DataFrame staffDf = D.fromRows @@ -206,8 +256,11 @@ testOuterJoinWithCollisions = tests :: [Test] tests = [ TestLabel "innerJoin" testInnerJoin + , TestLabel "testInnerJoinTyped" testInnerJoinTyped , TestLabel "leftJoin" testLeftJoin + , TestLabel "testLeftJoinTyped" testLeftJoinTyped , TestLabel "rightJoin" testRightJoin + , TestLabel "testRightJoinTyped" testRightJoinTyped , TestLabel "fullOuterJoin" testFullOuterJoin , TestLabel "innerJoinWithCollisions" testInnerJoinWithCollisions , TestLabel "leftJoinWithCollisions" testLeftJoinWithCollisions