Skip to content

Commit fcb7e90

Browse files
committed
Add sortby tests and cleanup test structure.
1 parent 292752d commit fcb7e90

File tree

9 files changed

+71
-27
lines changed

9 files changed

+71
-27
lines changed

dataframe.cabal

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,14 @@ benchmark dataframe-benchmark
106106

107107
test-suite tests
108108
type: exitcode-stdio-1.0
109-
main-is: DataFrameTests.hs
109+
main-is: Main.hs
110110
other-modules: Assertions,
111-
AddColumn,
112-
Apply,
113-
Filter,
114-
Sort,
115-
Take
111+
Operations.AddColumn,
112+
Operations.Apply,
113+
Operations.Filter,
114+
Operations.GroupBy,
115+
Operations.Sort,
116+
Operations.Take
116117
build-depends: base >= 4.17.2.0 && < 4.21,
117118
HUnit ^>= 1.6,
118119
random >= 1,

src/Data/DataFrame/Internal.hs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
module Data.DataFrame.Internal where
1313

14+
import qualified Data.List as L
1415
import qualified Data.Map as M
1516
import qualified Data.Map.Strict as MS
1617
import qualified Data.Set as S
@@ -183,14 +184,16 @@ instance Eq Column where
183184
case testEquality (typeRep @t1) (typeRep @t2) of
184185
Nothing -> False
185186
Just Refl -> a == b
187+
-- Note: comparing grouped columns is expensive. We do this for stable tests
188+
-- but also you should probably aggregate grouped columns soon after creating them.
186189
(==) (GroupedBoxedColumn (a :: V.Vector t1)) (GroupedBoxedColumn (b :: V.Vector t2)) =
187190
case testEquality (typeRep @t1) (typeRep @t2) of
188191
Nothing -> False
189-
Just Refl -> a == b
192+
Just Refl -> V.map (L.sort . VG.toList) a == V.map (L.sort . VG.toList) b
190193
(==) (GroupedUnboxedColumn (a :: V.Vector t1)) (GroupedUnboxedColumn (b :: V.Vector t2)) =
191194
case testEquality (typeRep @t1) (typeRep @t2) of
192195
Nothing -> False
193-
Just Refl -> a == b
196+
Just Refl -> V.map (L.sort . VG.toList) a == V.map (L.sort . VG.toList) b
194197
(==) _ _ = False
195198

196199
-- Traversing columns.
Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ import Test.HUnit
1717

1818
import Assertions
1919

20-
import qualified AddColumn
21-
import qualified Apply
22-
import qualified Filter
23-
import qualified Sort
24-
import qualified Take
20+
import qualified Operations.AddColumn
21+
import qualified Operations.Apply
22+
import qualified Operations.Filter
23+
import qualified Operations.GroupBy
24+
import qualified Operations.Sort
25+
import qualified Operations.Take
2526

2627
testData :: D.DataFrame
2728
testData = D.fromList [ ("test1", DI.toColumn ([1..26] :: [Int]))
@@ -68,11 +69,12 @@ parseTests = [
6869

6970
tests :: Test
7071
tests = TestList $ dimensionsTest
71-
++ AddColumn.tests
72-
++ Apply.tests
73-
++ Filter.tests
74-
++ Sort.tests
75-
++ Take.tests
72+
++ Operations.AddColumn.tests
73+
++ Operations.Apply.tests
74+
++ Operations.Filter.tests
75+
++ Operations.GroupBy.tests
76+
++ Operations.Sort.tests
77+
++ Operations.Take.tests
7678
++ parseTests
7779

7880
main :: IO ()
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{-# LANGUAGE TypeApplications #-}
22
{-# LANGUAGE OverloadedStrings #-}
3-
module AddColumn where
3+
module Operations.AddColumn where
44

55
import qualified Data.DataFrame as D
66
import qualified Data.DataFrame.Internal as DI

tests/Apply.hs renamed to tests/Operations/Apply.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{-# LANGUAGE TypeApplications #-}
22
{-# LANGUAGE OverloadedStrings #-}
33
{-# LANGUAGE TupleSections #-}
4-
module Apply where
4+
module Operations.Apply where
55

66
import qualified Data.DataFrame as D
77
import qualified Data.DataFrame.Internal as DI
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{-# LANGUAGE TypeApplications #-}
22
{-# LANGUAGE OverloadedStrings #-}
3-
module Filter where
3+
module Operations.Filter where
44

55
import qualified Data.DataFrame as D
66
import qualified Data.DataFrame.Internal as DI

tests/Operations/GroupBy.hs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
{-# LANGUAGE OverloadedStrings #-}
2+
module Operations.GroupBy where
3+
4+
import qualified Data.DataFrame as D
5+
import qualified Data.DataFrame.Internal as DI
6+
import qualified Data.DataFrame.Errors as DE
7+
import qualified Data.DataFrame.Operations as DO
8+
import qualified Data.DataFrame.Util as DU
9+
import qualified Data.Text as T
10+
import qualified Data.Vector as V
11+
import qualified Data.Vector.Unboxed as VU
12+
13+
import Assertions
14+
import Test.HUnit
15+
16+
values :: [(T.Text, DI.Column)]
17+
values = [ ("test1", DI.toColumn (concatMap (replicate 10) [1 :: Int, 2, 3, 4]))
18+
, ("test2", DI.toColumn (take 40 $ cycle [1 :: Int,2]))
19+
, ("test3", DI.toColumn [(1 :: Int)..40])
20+
, ("test4", DI.toColumn (reverse [(1 :: Int)..40]))
21+
]
22+
23+
testData :: D.DataFrame
24+
testData = D.fromList values
25+
26+
groupBySingleRowWAI :: Test
27+
groupBySingleRowWAI = TestCase (assertEqual "Groups by single column"
28+
(D.fromList [("test1", DI.toColumn [(1::Int)..4]),
29+
-- This just makes rows with [1, 2] for every unique test1 row
30+
("test2", DI.GroupedUnboxedColumn (V.replicate 4 $ VU.fromList (take 10 $ cycle [1 :: Int, 2]))),
31+
("test3", DI.GroupedUnboxedColumn (V.generate 4 (\i -> VU.fromList [(i * 10 + 1)..((i + 1) * 10)]))),
32+
("test4", DI.GroupedUnboxedColumn (V.generate 4 (\i -> VU.fromList [(((3 - i) + 1) * 10),(((3 - i) + 1) * 10 - 1)..((3 - i) * 10 + 1)])))
33+
])
34+
(D.groupBy ["test1"] testData D.|> D.sortBy D.Ascending "test1"))
35+
36+
tests :: [Test]
37+
tests = [ TestLabel "groupBySingleRowWAI" groupBySingleRowWAI
38+
, TestLabel "groupBySingleRowWAI" groupBySingleRowWAI
39+
]
40+

tests/Sort.hs renamed to tests/Operations/Sort.hs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
{-# LANGUAGE TypeApplications #-}
21
{-# LANGUAGE OverloadedStrings #-}
3-
module Sort where
2+
module Operations.Sort where
43

54
import qualified Data.DataFrame as D
65
import qualified Data.DataFrame.Internal as DI
@@ -11,7 +10,6 @@ import qualified Data.Vector as V
1110
import qualified Data.Vector.Unboxed as VU
1211

1312
import Assertions
14-
1513
import Control.Monad
1614
import Data.Char
1715
import System.Random
@@ -29,13 +27,13 @@ testData :: D.DataFrame
2927
testData = D.fromList values
3028

3129
sortByAscendingWAI :: Test
32-
sortByAscendingWAI = TestCase (assertEqual "Non existent filter value returns no rows"
30+
sortByAscendingWAI = TestCase (assertEqual "Sorting rows by ascending works as intended"
3331
(D.fromList [("test1", DI.toColumn [(1::Int)..26]),
3432
("test2", DI.toColumn ['a'..'z'])])
3533
(D.sortBy D.Ascending "test1" testData))
3634

3735
sortByDescendingWAI :: Test
38-
sortByDescendingWAI = TestCase (assertEqual "Non existent filter value returns no rows"
36+
sortByDescendingWAI = TestCase (assertEqual "Sorting rows by descending works as intended"
3937
(D.fromList [("test1", DI.toColumn $ reverse [(1::Int)..26]),
4038
("test2", DI.toColumn $ reverse ['a'..'z'])])
4139
(D.sortBy D.Descending "test1" testData))

tests/Take.hs renamed to tests/Operations/Take.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{-# LANGUAGE OverloadedStrings #-}
2-
module Take where
2+
module Operations.Take where
33

44
import qualified Data.DataFrame as D
55
import qualified Data.DataFrame.Internal as DI

0 commit comments

Comments
 (0)