Skip to content

Commit ea9b35a

Browse files
committed
feat: Add fixity and precedence to binary operations.
1 parent ba5c3df commit ea9b35a

File tree

2 files changed

+45
-29
lines changed

2 files changed

+45
-29
lines changed

src/DataFrame/Functions.hs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ import Text.Regex.TDFA
4444
import Prelude hiding (maximum, minimum)
4545
import Prelude as P
4646

47+
infix 8 .^^, `div`
4748
infix 4 .==, .<, .<=, .>=, .>, ./=
4849
infixr 3 .&&
4950
infixr 2 .||
@@ -125,10 +126,10 @@ toDouble =
125126
)
126127

127128
div :: (Integral a, Columnable a) => Expr a -> Expr a -> Expr a
128-
div = lift2Decorated Prelude.div "div" (Just "//") False 2
129+
div = lift2Decorated Prelude.div "div" (Just "//") False 7
129130

130131
mod :: (Integral a, Columnable a) => Expr a -> Expr a -> Expr a
131-
mod = lift2Decorated Prelude.mod "mod" Nothing False 2
132+
mod = lift2Decorated Prelude.mod "mod" Nothing False 7
132133

133134
(.==) :: (Columnable a, Eq a) => Expr a -> Expr a -> Expr Bool
134135
(.==) =
@@ -138,7 +139,7 @@ mod = lift2Decorated Prelude.mod "mod" Nothing False 2
138139
, binaryName = "eq"
139140
, binarySymbol = Just "=="
140141
, binaryCommutative = True
141-
, binaryPrecedence = 1
142+
, binaryPrecedence = 4
142143
}
143144
)
144145

@@ -150,7 +151,7 @@ mod = lift2Decorated Prelude.mod "mod" Nothing False 2
150151
, binaryName = "neq"
151152
, binarySymbol = Just "/="
152153
, binaryCommutative = True
153-
, binaryPrecedence = 1
154+
, binaryPrecedence = 4
154155
}
155156
)
156157

@@ -165,7 +166,7 @@ eq = (.==)
165166
, binaryName = "lt"
166167
, binarySymbol = Just "<"
167168
, binaryCommutative = False
168-
, binaryPrecedence = 1
169+
, binaryPrecedence = 4
169170
}
170171
)
171172

@@ -180,7 +181,7 @@ lt = (.<)
180181
, binaryName = "gt"
181182
, binarySymbol = Just ">"
182183
, binaryCommutative = False
183-
, binaryPrecedence = 1
184+
, binaryPrecedence = 4
184185
}
185186
)
186187

@@ -195,7 +196,7 @@ gt = (.>)
195196
, binaryName = "leq"
196197
, binarySymbol = Just "<="
197198
, binaryCommutative = False
198-
, binaryPrecedence = 1
199+
, binaryPrecedence = 4
199200
}
200201
)
201202

@@ -210,7 +211,7 @@ leq = (.<=)
210211
, binaryName = "geq"
211212
, binarySymbol = Just ">="
212213
, binaryCommutative = False
213-
, binaryPrecedence = 1
214+
, binaryPrecedence = 4
214215
}
215216
)
216217

@@ -228,7 +229,7 @@ and = (.&&)
228229
, binaryName = "and"
229230
, binarySymbol = Just "&&"
230231
, binaryCommutative = True
231-
, binaryPrecedence = 1
232+
, binaryPrecedence = 3
232233
}
233234
)
234235

@@ -243,7 +244,7 @@ or = (.||)
243244
, binaryName = "or"
244245
, binarySymbol = Just "||"
245246
, binaryCommutative = True
246-
, binaryPrecedence = 1
247+
, binaryPrecedence = 2
247248
}
248249
)
249250

@@ -328,10 +329,10 @@ zScore :: Expr Double -> Expr Double
328329
zScore c = (c - mean c) / stddev c
329330

330331
pow :: (Columnable a, Num a) => Expr a -> Int -> Expr a
331-
pow _ 0 = Lit 1
332-
pow (Lit n) i = Lit (n ^ i)
333-
pow expr 1 = expr
334-
pow expr i = lift2Decorated (^) "pow" (Just "^") False 3 expr (lit i)
332+
pow expr i = lift2Decorated (^) "pow" (Just "^") False 8 expr (lit i)
333+
334+
(.^^) :: (Columnable a, Num a) => Expr a -> Int -> Expr a
335+
(.^^) = pow
335336

336337
relu :: (Columnable a, Num a, Ord a) => Expr a -> Expr a
337338
relu = lift (Prelude.max 0)

src/DataFrame/Internal/Expression.hs

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ instance (Num a, Columnable a) => Num (Expr a) where
6868
, binaryName = "add"
6969
, binarySymbol = Just "+"
7070
, binaryCommutative = True
71-
, binaryPrecedence = 1
71+
, binaryPrecedence = 6
7272
}
7373
)
7474

@@ -80,7 +80,7 @@ instance (Num a, Columnable a) => Num (Expr a) where
8080
, binaryName = "sub"
8181
, binarySymbol = Just "-"
8282
, binaryCommutative = False
83-
, binaryPrecedence = 1
83+
, binaryPrecedence = 6
8484
}
8585
)
8686

@@ -92,7 +92,7 @@ instance (Num a, Columnable a) => Num (Expr a) where
9292
, binaryName = "mult"
9393
, binarySymbol = Just "*"
9494
, binaryCommutative = True
95-
, binaryPrecedence = 2
95+
, binaryPrecedence = 7
9696
}
9797
)
9898

@@ -133,7 +133,7 @@ instance (Fractional a, Columnable a) => Fractional (Expr a) where
133133
, binaryName = "divide"
134134
, binarySymbol = Just "/"
135135
, binaryCommutative = True
136-
, binaryPrecedence = 2
136+
, binaryPrecedence = 7
137137
}
138138
)
139139

@@ -160,7 +160,7 @@ instance (Floating a, Columnable a) => Floating (Expr a) where
160160
, binaryName = "exponentiate"
161161
, binarySymbol = Just "**"
162162
, binaryCommutative = False
163-
, binaryPrecedence = 3
163+
, binaryPrecedence = 8
164164
}
165165
)
166166
log :: (Floating a, Columnable a) => Expr a -> Expr a
@@ -338,22 +338,37 @@ getColumns (Binary op l r) = getColumns l <> getColumns r
338338
getColumns (Agg strategy expr) = getColumns expr
339339

340340
prettyPrint :: Expr a -> String
341-
prettyPrint = go 0
341+
prettyPrint = go 0 0
342342
where
343-
go :: Int -> Expr a -> String
344-
go prec expr = case expr of
343+
indent :: Int -> String
344+
indent n = replicate (n * 2) ' '
345+
346+
go :: Int -> Int -> Expr a -> String
347+
go depth prec expr = case expr of
345348
Col name -> T.unpack name
346349
Lit value -> show value
347350
If cond t e ->
348-
"if (" ++ go 0 cond ++ ") then (" ++ go 0 t ++ ") else (" ++ go 0 e ++ ")"
351+
let inner =
352+
"if "
353+
++ go (depth + 1) 0 cond
354+
++ "\n"
355+
++ indent (depth + 1)
356+
++ "then "
357+
++ go (depth + 1) 0 t
358+
++ "\n"
359+
++ indent (depth + 1)
360+
++ "else "
361+
++ go (depth + 1) 0 e
362+
in if prec > 0 then "(" ++ inner ++ ")" else inner
349363
Unary op arg -> case unarySymbol op of
350-
Nothing -> T.unpack (unaryName op) ++ "(" ++ go 0 arg ++ ")"
351-
Just sym -> T.unpack sym ++ "(" ++ go 0 arg ++ ")"
364+
Nothing -> T.unpack (unaryName op) ++ "(" ++ go depth 0 arg ++ ")"
365+
Just sym -> T.unpack sym ++ "(" ++ go depth 0 arg ++ ")"
352366
Binary op l r ->
353367
let p = binaryPrecedence op
354368
inner = case binarySymbol op of
355-
Just name -> go p l ++ " " ++ T.unpack name ++ " " ++ go p r
356-
Nothing -> T.unpack (binaryName op) ++ "(" ++ go p l ++ ", " ++ go p r ++ ")"
369+
Just name -> go depth p l ++ " " ++ T.unpack name ++ " " ++ go depth p r
370+
Nothing ->
371+
T.unpack (binaryName op) ++ "(" ++ go depth p l ++ ", " ++ go depth p r ++ ")"
357372
in if prec > p then "(" ++ inner ++ ")" else inner
358-
Agg (CollectAgg op _) arg -> T.unpack op ++ "(" ++ go 0 arg ++ ")"
359-
Agg (FoldAgg op _ _) arg -> T.unpack op ++ "(" ++ go 0 arg ++ ")"
373+
Agg (CollectAgg op _) arg -> T.unpack op ++ "(" ++ go depth 0 arg ++ ")"
374+
Agg (FoldAgg op _ _) arg -> T.unpack op ++ "(" ++ go depth 0 arg ++ ")"

0 commit comments

Comments
 (0)