Skip to content

Commit ccb15dd

Browse files
committed
refacotr: Simplify Expr GADT.
Move name and function to ADT and collapse all aggregations to one Agg case.
1 parent c54d57f commit ccb15dd

File tree

5 files changed

+485
-426
lines changed

5 files changed

+485
-426
lines changed

src/DataFrame/DecisionTree.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,8 +466,8 @@ pruneExpr (If cond trueBranch falseBranch) =
466466
(If condInner tInner _, _) | cond == condInner -> If cond tInner f
467467
(_, If condInner _ fInner) | cond == condInner -> If cond t fInner
468468
_ -> If cond t f
469-
pruneExpr (UnaryOp name op e) = UnaryOp name op (pruneExpr e)
470-
pruneExpr (BinaryOp name op l r) = BinaryOp name op (pruneExpr l) (pruneExpr r)
469+
pruneExpr (Unary op e) = Unary op (pruneExpr e)
470+
pruneExpr (Binary op l r) = Binary op (pruneExpr l) (pruneExpr r)
471471
pruneExpr e = e
472472

473473
buildGreedyTree ::

src/DataFrame/Functions.hs

Lines changed: 178 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,7 @@ import DataFrame.Internal.DataFrame (
1919
DataFrame (..),
2020
unsafeGetColumn,
2121
)
22-
import DataFrame.Internal.Expression (
23-
Expr (..),
24-
NamedExpr,
25-
UExpr (..),
26-
)
22+
import DataFrame.Internal.Expression hiding (normalize)
2723
import DataFrame.Internal.Statistics
2824

2925
import Control.Applicative
@@ -75,120 +71,237 @@ lit :: (Columnable a) => a -> Expr a
7571
lit = Lit
7672

7773
lift :: (Columnable a, Columnable b) => (a -> b) -> Expr a -> Expr b
78-
lift = UnaryOp "udf"
74+
lift f =
75+
Unary (MkUnaryOp{unaryFn = f, unaryName = "unaryUdf", unarySymbol = Nothing})
7976

8077
lift2 ::
8178
(Columnable c, Columnable b, Columnable a) =>
8279
(c -> b -> a) -> Expr c -> Expr b -> Expr a
83-
lift2 = BinaryOp "udf"
80+
lift2 f =
81+
Binary
82+
( MkBinaryOp
83+
{ binaryFn = f
84+
, binaryName = "binaryUdf"
85+
, binarySymbol = Nothing
86+
, binaryCommutative = False
87+
, binaryPrecedence = 0
88+
}
89+
)
90+
91+
liftDecorated ::
92+
(Columnable a, Columnable b) =>
93+
(a -> b) -> T.Text -> Maybe T.Text -> Expr a -> Expr b
94+
liftDecorated f name rep = Unary (MkUnaryOp{unaryFn = f, unaryName = name, unarySymbol = rep})
95+
96+
lift2Decorated ::
97+
(Columnable c, Columnable b, Columnable a) =>
98+
(c -> b -> a) ->
99+
T.Text ->
100+
Maybe T.Text ->
101+
Bool ->
102+
Int ->
103+
Expr c ->
104+
Expr b ->
105+
Expr a
106+
lift2Decorated f name rep comm prec =
107+
Binary
108+
( MkBinaryOp
109+
{ binaryFn = f
110+
, binaryName = name
111+
, binarySymbol = rep
112+
, binaryCommutative = comm
113+
, binaryPrecedence = prec
114+
}
115+
)
84116

85117
toDouble :: (Columnable a, Real a) => Expr a -> Expr Double
86-
toDouble = UnaryOp "toDouble" realToFrac
118+
toDouble =
119+
Unary
120+
( MkUnaryOp
121+
{ unaryFn = realToFrac
122+
, unaryName = "toDouble"
123+
, unarySymbol = Nothing
124+
}
125+
)
87126

88127
div :: (Integral a, Columnable a) => Expr a -> Expr a -> Expr a
89-
div = BinaryOp "div" Prelude.div
128+
div = lift2Decorated Prelude.div "div" (Just "//") False 2
90129

91130
mod :: (Integral a, Columnable a) => Expr a -> Expr a -> Expr a
92-
mod = BinaryOp "mod" Prelude.mod
131+
mod = lift2Decorated Prelude.mod "mod" Nothing False 2
93132

94133
(.==) :: (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
95-
(.==) = BinaryOp "eq" (==)
134+
(.==) =
135+
Binary
136+
( MkBinaryOp
137+
{ binaryFn = (==)
138+
, binaryName = "eq"
139+
, binarySymbol = Just "=="
140+
, binaryCommutative = True
141+
, binaryPrecedence = 1
142+
}
143+
)
96144

97145
(./=) :: (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
98-
(./=) = BinaryOp "neq" (/=)
146+
(./=) =
147+
Binary
148+
( MkBinaryOp
149+
{ binaryFn = (/=)
150+
, binaryName = "neq"
151+
, binarySymbol = Just "/="
152+
, binaryCommutative = True
153+
, binaryPrecedence = 1
154+
}
155+
)
99156

100157
eq :: (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
101-
eq = BinaryOp "eq" (==)
158+
eq = (.==)
102159

103160
(.<) :: (Columnable a, Ord a) => Expr a -> Expr a -> Expr Bool
104-
(.<) = BinaryOp "lt" (<)
161+
(.<) =
162+
Binary
163+
( MkBinaryOp
164+
{ binaryFn = (<)
165+
, binaryName = "lt"
166+
, binarySymbol = Just "<"
167+
, binaryCommutative = False
168+
, binaryPrecedence = 1
169+
}
170+
)
105171

106172
lt :: (Columnable a, Ord a) => Expr a -> Expr a -> Expr Bool
107-
lt = BinaryOp "lt" (<)
173+
lt = (.<)
108174

109-
-- TODO: Generalize this pattern for other equality functions.
110175
(.>) :: (Columnable a, Ord a) => Expr a -> Expr a -> Expr Bool
111-
(.>) = BinaryOp "gt" (>)
176+
(.>) =
177+
Binary
178+
( MkBinaryOp
179+
{ binaryFn = (>)
180+
, binaryName = "gt"
181+
, binarySymbol = Just ">"
182+
, binaryCommutative = False
183+
, binaryPrecedence = 1
184+
}
185+
)
112186

113187
gt :: (Columnable a, Ord a) => Expr a -> Expr a -> Expr Bool
114188
gt = (.>)
115189

116190
(.<=) :: (Columnable a, Ord a, Eq a) => Expr a -> Expr a -> Expr Bool
117-
(.<=) = BinaryOp "leq" (<=)
191+
(.<=) =
192+
Binary
193+
( MkBinaryOp
194+
{ binaryFn = (<=)
195+
, binaryName = "leq"
196+
, binarySymbol = Just "<="
197+
, binaryCommutative = False
198+
, binaryPrecedence = 1
199+
}
200+
)
118201

119202
leq :: (Columnable a, Ord a, Eq a) => Expr a -> Expr a -> Expr Bool
120203
leq = (.<=)
121204

122205
(.>=) :: (Columnable a, Ord a, Eq a) => Expr a -> Expr a -> Expr Bool
123-
(.>=) = BinaryOp "geq" (>=)
206+
(.>=) =
207+
Binary
208+
( MkBinaryOp
209+
{ binaryFn = (>=)
210+
, binaryName = "geq"
211+
, binarySymbol = Just ">="
212+
, binaryCommutative = False
213+
, binaryPrecedence = 1
214+
}
215+
)
124216

125217
geq :: (Columnable a, Ord a, Eq a) => Expr a -> Expr a -> Expr Bool
126-
geq = BinaryOp "geq" (>=)
218+
geq = (.>=)
127219

128220
and :: Expr Bool -> Expr Bool -> Expr Bool
129-
and = BinaryOp "and" (&&)
221+
and = (.&&)
130222

131223
(.&&) :: Expr Bool -> Expr Bool -> Expr Bool
132-
(.&&) = BinaryOp "and" (&&)
224+
(.&&) =
225+
Binary
226+
( MkBinaryOp
227+
{ binaryFn = (&&)
228+
, binaryName = "and"
229+
, binarySymbol = Just "&&"
230+
, binaryCommutative = True
231+
, binaryPrecedence = 1
232+
}
233+
)
133234

134235
or :: Expr Bool -> Expr Bool -> Expr Bool
135-
or = BinaryOp "or" (||)
236+
or = (.||)
136237

137238
(.||) :: Expr Bool -> Expr Bool -> Expr Bool
138-
(.||) = BinaryOp "or" (||)
239+
(.||) =
240+
Binary
241+
( MkBinaryOp
242+
{ binaryFn = (||)
243+
, binaryName = "or"
244+
, binarySymbol = Just "||"
245+
, binaryCommutative = True
246+
, binaryPrecedence = 1
247+
}
248+
)
139249

140250
not :: Expr Bool -> Expr Bool
141-
not = UnaryOp "not" Prelude.not
251+
not =
252+
Unary
253+
(MkUnaryOp{unaryFn = Prelude.not, unaryName = "not", unarySymbol = Just "~"})
142254

143255
count :: (Columnable a) => Expr a -> Expr Int
144-
count expr = AggFold expr "count" 0 (\acc _ -> acc + 1)
256+
count = Agg (FoldAgg "count" (Just 0) (\acc _ -> acc + 1))
145257

146258
collect :: (Columnable a) => Expr a -> Expr [a]
147-
collect expr = AggFold expr "collect" [] (flip (:))
259+
collect = Agg (FoldAgg "collect" (Just []) (flip (:)))
148260

149261
mode :: (Ord a, Columnable a, Eq a) => Expr a -> Expr a
150-
mode expr =
151-
AggVector
152-
expr
153-
"mode"
154-
( fst
155-
. L.maximumBy (compare `on` snd)
156-
. M.toList
157-
. V.foldl' (\m e -> M.insertWith (+) e 1 m) M.empty
262+
mode =
263+
Agg
264+
( CollectAgg
265+
"mode"
266+
( fst
267+
. L.maximumBy (compare `on` snd)
268+
. M.toList
269+
. V.foldl' (\m e -> M.insertWith (+) e 1 m) M.empty
270+
)
158271
)
159272

160273
minimum :: (Columnable a, Ord a) => Expr a -> Expr a
161-
minimum expr = AggReduce expr "minimum" Prelude.min
274+
minimum = Agg (FoldAgg "minimum" Nothing Prelude.min)
162275

163276
maximum :: (Columnable a, Ord a) => Expr a -> Expr a
164-
maximum expr = AggReduce expr "maximum" Prelude.max
277+
maximum = Agg (FoldAgg "maximum" Nothing Prelude.max)
165278

166279
sum :: forall a. (Columnable a, Num a) => Expr a -> Expr a
167-
sum expr = AggReduce expr "sum" (+)
280+
sum = Agg (FoldAgg "sum" Nothing (+))
168281
{-# SPECIALIZE DataFrame.Functions.sum :: Expr Double -> Expr Double #-}
169282
{-# SPECIALIZE DataFrame.Functions.sum :: Expr Int -> Expr Int #-}
170283
{-# INLINEABLE DataFrame.Functions.sum #-}
171284

172285
sumMaybe :: forall a. (Columnable a, Num a) => Expr (Maybe a) -> Expr a
173-
sumMaybe expr = AggVector expr "sumMaybe" (P.sum . Maybe.catMaybes . V.toList)
286+
sumMaybe = Agg (CollectAgg "sumMaybe" (P.sum . Maybe.catMaybes . V.toList))
174287

175288
mean :: (Columnable a, Real a, VU.Unbox a) => Expr a -> Expr Double
176-
mean expr = AggNumericVector expr "mean" mean'
289+
mean = Agg (CollectAgg "mean" mean')
177290
{-# SPECIALIZE DataFrame.Functions.mean :: Expr Double -> Expr Double #-}
178291
{-# SPECIALIZE DataFrame.Functions.mean :: Expr Int -> Expr Double #-}
179292
{-# INLINEABLE DataFrame.Functions.mean #-}
180293

181294
meanMaybe :: forall a. (Columnable a, Real a) => Expr (Maybe a) -> Expr Double
182-
meanMaybe expr = AggVector expr "meanMaybe" (mean' . optionalToDoubleVector)
295+
meanMaybe = Agg (CollectAgg "meanMaybe" (mean' . optionalToDoubleVector))
183296

184297
variance :: (Columnable a, Real a, VU.Unbox a) => Expr a -> Expr Double
185-
variance expr = AggNumericVector expr "variance" variance'
298+
variance = Agg (CollectAgg "variance" variance')
186299

187300
median :: (Columnable a, Real a, VU.Unbox a) => Expr a -> Expr Double
188-
median expr = AggNumericVector expr "median" median'
301+
median = Agg (CollectAgg "median" median')
189302

190303
medianMaybe :: (Columnable a, Real a) => Expr (Maybe a) -> Expr Double
191-
medianMaybe expr = AggVector expr "meanMaybe" (median' . optionalToDoubleVector)
304+
medianMaybe = Agg (CollectAgg "meanMaybe" (median' . optionalToDoubleVector))
192305

193306
optionalToDoubleVector :: (Real a) => V.Vector (Maybe a) -> VU.Vector Double
194307
optionalToDoubleVector =
@@ -198,17 +311,18 @@ optionalToDoubleVector =
198311
[]
199312

200313
percentile :: Int -> Expr Double -> Expr Double
201-
percentile n expr =
202-
AggNumericVector
203-
expr
204-
(T.pack $ "percentile " ++ show n)
205-
(percentile' n)
314+
percentile n =
315+
Agg
316+
( CollectAgg
317+
(T.pack $ "percentile " ++ show n)
318+
(percentile' n)
319+
)
206320

207321
stddev :: (Columnable a, Real a, VU.Unbox a) => Expr a -> Expr Double
208-
stddev expr = AggNumericVector expr "stddev" (sqrt . variance')
322+
stddev = Agg (CollectAgg "stddev" (sqrt . variance'))
209323

210324
stddevMaybe :: forall a. (Columnable a, Real a) => Expr (Maybe a) -> Expr Double
211-
stddevMaybe expr = AggVector expr "stddevMaybe" (sqrt . variance' . optionalToDoubleVector)
325+
stddevMaybe = Agg (CollectAgg "stddevMaybe" (sqrt . variance' . optionalToDoubleVector))
212326

213327
zScore :: Expr Double -> Expr Double
214328
zScore c = (c - mean c) / stddev c
@@ -217,36 +331,36 @@ pow :: (Columnable a, Num a) => Expr a -> Int -> Expr a
217331
pow _ 0 = Lit 1
218332
pow (Lit n) i = Lit (n ^ i)
219333
pow expr 1 = expr
220-
pow expr i = BinaryOp "pow" (^) expr (lit i)
334+
pow expr i = lift2Decorated (^) "pow" (Just "^") False 3 expr (lit i)
221335

222336
relu :: (Columnable a, Num a, Ord a) => Expr a -> Expr a
223-
relu = UnaryOp "relu" (Prelude.max 0)
337+
relu = lift (Prelude.max 0)
224338

225339
min :: (Columnable a, Ord a) => Expr a -> Expr a -> Expr a
226-
min = BinaryOp "min" Prelude.min
340+
min = lift2Decorated Prelude.min "max" Nothing True 1
227341

228342
max :: (Columnable a, Ord a) => Expr a -> Expr a -> Expr a
229-
max = BinaryOp "max" Prelude.max
343+
max = lift2Decorated Prelude.max "max" Nothing True 1
230344

231345
reduce ::
232346
forall a b.
233347
(Columnable a, Columnable b) => Expr b -> a -> (a -> b -> a) -> Expr a
234-
reduce expr = AggFold expr "foldUdf"
348+
reduce expr start f = Agg (FoldAgg "foldUdf" (Just start) f) expr
235349

236350
toMaybe :: (Columnable a) => Expr a -> Expr (Maybe a)
237-
toMaybe = UnaryOp "toMaybe" Just
351+
toMaybe = lift Just
238352

239353
fromMaybe :: (Columnable a) => a -> Expr (Maybe a) -> Expr a
240-
fromMaybe d = UnaryOp ("fromMaybe " <> T.pack (show d)) (Maybe.fromMaybe d)
354+
fromMaybe d = lift (Maybe.fromMaybe d)
241355

242356
isJust :: (Columnable a) => Expr (Maybe a) -> Expr Bool
243-
isJust = UnaryOp "isJust" Maybe.isJust
357+
isJust = lift Maybe.isJust
244358

245359
isNothing :: (Columnable a) => Expr (Maybe a) -> Expr Bool
246-
isNothing = UnaryOp "isNothing" Maybe.isNothing
360+
isNothing = lift Maybe.isNothing
247361

248362
fromJust :: (Columnable a) => Expr (Maybe a) -> Expr a
249-
fromJust = UnaryOp "fromJust" Maybe.fromJust
363+
fromJust = lift Maybe.fromJust
250364

251365
whenPresent ::
252366
forall a b.
@@ -262,7 +376,7 @@ whenBothPresent f = lift2 (\l r -> f <$> l <*> r)
262376
recode ::
263377
forall a b.
264378
(Columnable a, Columnable b) => [(a, b)] -> Expr a -> Expr (Maybe b)
265-
recode mapping = UnaryOp (T.pack ("recode " ++ show mapping)) (`lookup` mapping)
379+
recode mapping = lift (`lookup` mapping)
266380

267381
recodeWithCondition ::
268382
forall a b.
@@ -274,13 +388,10 @@ recodeWithCondition fallback ((cond, value) : rest) expr = ifThenElse (cond expr
274388
recodeWithDefault ::
275389
forall a b.
276390
(Columnable a, Columnable b) => b -> [(a, b)] -> Expr a -> Expr b
277-
recodeWithDefault d mapping =
278-
UnaryOp
279-
(T.pack ("recodeWithDefault " ++ show d ++ " " ++ show mapping))
280-
(Maybe.fromMaybe d . (`lookup` mapping))
391+
recodeWithDefault d mapping = lift (Maybe.fromMaybe d . (`lookup` mapping))
281392

282393
firstOrNothing :: (Columnable a) => Expr [a] -> Expr (Maybe a)
283-
firstOrNothing = UnaryOp "firstOrNothing" Maybe.listToMaybe
394+
firstOrNothing = lift Maybe.listToMaybe
284395

285396
lastOrNothing :: (Columnable a) => Expr [a] -> Expr (Maybe a)
286397
lastOrNothing = lift (Maybe.listToMaybe . reverse)

0 commit comments

Comments
 (0)