diff --git a/CHANGELOG b/CHANGELOG index 6e6efe1..15b25bc 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,5 +1,6 @@ 0.2.7: * Add support for DATETIMEOFFSET + * Add exec variant returning number of affected rows 0.2.6: * Add support for SQLSTATE * Fix copying issues for error messages diff --git a/app/Main.hs b/app/Main.hs index 9ee1221..f18fc7e 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -5,12 +5,10 @@ module Main (main) where -import Data.List import Data.Time.LocalTime (ZonedTime(..)) import qualified Data.Text as T import qualified Data.Text.IO as T import Control.Exception -import qualified Data.Text as T import qualified Database.ODBC.Internal as ODBC import System.Environment import System.IO diff --git a/cbits/odbc.c b/cbits/odbc.c index 9ac39bd..ecaff61 100644 --- a/cbits/odbc.c +++ b/cbits/odbc.c @@ -269,6 +269,13 @@ RETCODE odbc_SQLNumResultCols(SQLHSTMT *hstmt, SQLSMALLINT *cols){ return SQLNumResultCols(*hstmt, cols); } +//////////////////////////////////////////////////////////////////////////////// +// Get rows + +RETCODE odbc_SQLRowCount(SQLHSTMT *hstmt, SQLLEN *rows){ + return SQLRowCount(*hstmt, rows); +} + //////////////////////////////////////////////////////////////////////////////// // Logs @@ -402,4 +409,4 @@ SQLSMALLINT TIMESTAMPOFFSET_STRUCT_timezone_hour(TIMESTAMPOFFSET_STRUCT *t){ SQLSMALLINT TIMESTAMPOFFSET_STRUCT_timezone_minute(TIMESTAMPOFFSET_STRUCT *t){ return t->timezone_minute; -} \ No newline at end of file +} diff --git a/odbc.cabal b/odbc.cabal index 3e05d82..f8ad94f 100644 --- a/odbc.cabal +++ b/odbc.cabal @@ -5,7 +5,7 @@ description: Haskell binding to the ODBC API. This has been tested suite runs on OS X, Windows and Linux. copyright: FP Complete 2018 maintainer: chrisdone@fpcomplete.com -version: 0.2.6 +version: 0.2.7 license: BSD3 license-file: LICENSE build-type: Simple diff --git a/src/Database/ODBC/Internal.hs b/src/Database/ODBC/Internal.hs index 5c6f357..3de7b74 100644 --- a/src/Database/ODBC/Internal.hs +++ b/src/Database/ODBC/Internal.hs @@ -29,6 +29,7 @@ module Database.ODBC.Internal , Connection -- * Executing queries , exec + , execAffectedRows , query , Value(..) , Binary(..) @@ -38,6 +39,7 @@ module Database.ODBC.Internal , Step(..) -- * Parameters , execWithParams + , execAffectedRowsWithParams , queryWithParams , streamWithParams , Param(..) @@ -280,7 +282,18 @@ exec :: -> m () exec conn string = execWithParams conn string mempty --- | Same as 'exec' but with parameters. +-- | Execute a statement on the database and returns number of affected rows. +-- +-- @since 0.2.7 +execAffectedRows :: + MonadIO m + => Connection -- ^ A connection to the database. + -> Text -- ^ SQL statement. + -> m Int +execAffectedRows conn string = execAffectedRowsWithParams conn string mempty +{-# INLINE execAffectedRows #-} + +-- | Same as 'execAffectedRows but with parameters. -- -- @since 0.2.4 execWithParams :: @@ -296,6 +309,22 @@ execWithParams conn string params = "exec" (\dbc -> withExecDirect dbc string params (fetchAllResults dbc))) +-- | Same as 'execAffectedRowsWithParams but returns number of affected rows. +-- +-- @since 0.2.7 +execAffectedRowsWithParams :: + MonadIO m + => Connection -- ^ A connection to the database. + -> Text -- ^ SQL query with ? inside. + -> [Param] -- ^ Params matching the ? in the query string. + -> m Int +execAffectedRowsWithParams conn string params = + withBound + (withHDBC + conn + "exec" + (\dbc -> withExecDirect dbc string params (fetchAllResults' dbc))) + -- | Query and return a list of rows. query :: MonadIO m @@ -549,6 +578,21 @@ fetchAllResults dbc stmt = do (retcode == sql_success || retcode == sql_success_with_info) (fetchAllResults dbc stmt) +-- | Fetch all results from possible multiple statements. +fetchAllResults' :: Ptr EnvAndDbc -> SQLHSTMT s -> IO Int +fetchAllResults' dbc stmt = countRows <* fetchAllResults dbc stmt + where + countRows = do + SQLLEN rows <- + withMalloc + (\sizep -> do + assertSuccess + dbc + "odbc_SQLRowCount" + (odbc_SQLRowCount stmt sizep) + peek sizep) + pure $! fromIntegral (max 0 rows) + -- | Fetch all rows from a statement. fetchStatementRows :: Ptr EnvAndDbc -> SQLHSTMT s -> IO [[(Column,Value)]] fetchStatementRows dbc stmt = do @@ -1089,7 +1133,7 @@ newtype SQLCHAR = SQLCHAR CChar deriving (Show, Eq, Storable) -- https://github.com/Microsoft/ODBC-Specification/blob/753d7e714b7eab9eaab4ad6105fdf4267d6ad6f6/Windows/inc/sqltypes.h#L88 newtype SQLSMALLINT = SQLSMALLINT Int16 deriving (Show, Eq, Storable, Num, Integral, Enum, Ord, Real) --- https://github.com/Microsoft/ODBC-Specification/blob/753d7e714b7eab9eaab4ad6105fdf4267d6ad6f6/Windows/inc/sqltypes.h#L64 +-- https://github.com/Microsoft/ODBC-Specification/blob/753d7e714b7eab9eaab4ad6105fdf4267d6ad6f6/Windows/inc/sqltypes.h#L641 newtype SQLLEN = SQLLEN Int64 deriving (Show, Eq, Storable, Num) -- https://github.com/Microsoft/ODBC-Specification/blob/753d7e714b7eab9eaab4ad6105fdf4267d6ad6f6/Windows/inc/sqltypes.h#L65..L65 @@ -1168,6 +1212,9 @@ foreign import ccall "odbc odbc_SQLMoreResults" foreign import ccall "odbc odbc_SQLNumResultCols" odbc_SQLNumResultCols :: SQLHSTMT s -> Ptr SQLSMALLINT -> IO RETCODE +foreign import ccall "odbc odbc_SQLRowCount" + odbc_SQLRowCount :: SQLHSTMT s -> Ptr SQLLEN -> IO RETCODE + foreign import ccall "odbc odbc_SQLGetData" odbc_SQLGetData :: Ptr EnvAndDbc diff --git a/src/Database/ODBC/SQLServer.hs b/src/Database/ODBC/SQLServer.hs index a76941d..0cb7001 100644 --- a/src/Database/ODBC/SQLServer.hs +++ b/src/Database/ODBC/SQLServer.hs @@ -25,6 +25,7 @@ module Database.ODBC.SQLServer -- * Executing queries , exec + , execAffectedRows , query , Value(..) , Query @@ -79,8 +80,6 @@ import Data.Fixed import Data.Foldable import Data.Int import Data.Maybe -import Data.Monoid (Monoid, (<>)) -import Data.Semigroup (Semigroup) import Data.Sequence (Seq) import qualified Data.Sequence as Seq import Data.String @@ -482,6 +481,17 @@ exec c q = Internal.execWithParams c rendered params where (rendered, params) = renderedAndParams q +-- | Execute a statement on the database and return number of affected rows. +execAffectedRows :: + MonadIO m + => Connection -- ^ A connection to the database. + -> Query -- ^ SQL statement. + -> m Int +execAffectedRows c q = Internal.execAffectedRowsWithParams c rendered params + where + (rendered, params) = renderedAndParams q +{-# INLINE execAffectedRows #-} + -------------------------------------------------------------------------------- -- Query building @@ -496,7 +506,7 @@ renderedAndParams q = (renderParts parts', params) ValuePart v | Just {} <- valueToParam v -> case v of - TextValue t -> TextPart "CAST(? AS NVARCHAR(MAX))" + TextValue _ -> TextPart "CAST(? AS NVARCHAR(MAX))" _ -> TextPart "?" p -> p) parts diff --git a/src/Database/ODBC/TH.hs b/src/Database/ODBC/TH.hs index fd4501f..310b9cc 100644 --- a/src/Database/ODBC/TH.hs +++ b/src/Database/ODBC/TH.hs @@ -10,7 +10,6 @@ module Database.ODBC.TH import Control.DeepSeq import Data.Char import Data.List (foldl1') -import Data.Monoid ((<>)) import Language.Haskell.TH (Q, Exp) import qualified Language.Haskell.TH as TH import Language.Haskell.TH.Quote (QuasiQuoter(..)) diff --git a/test/Main.hs b/test/Main.hs index f24dc3f..c13eb59 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -64,6 +64,7 @@ spec = do (do describe "Connectivity" connectivity describe "Regression tests" regressions describe "Data retrieval" dataRetrieval + describe "Data affected" dataAffected describe "Big data" bigData) describe "Database.ODBC.SQLServer" @@ -192,6 +193,25 @@ connectivity = do (do sequence_ [connectWithString >>= Internal.close | _ <- [1 :: Int .. 10]] shouldBe True True) +dataAffected :: Spec +dataAffected = do + it + "Basic sanity check" + (do c <- connectWithString + _ <- Internal.execAffectedRows c "DROP TABLE IF EXISTS test" + arOnCreate <- Internal.execAffectedRows + c + "CREATE TABLE test (int integer, text text, bool bit, nt ntext, fl float)" + _ <- Internal.execAffectedRows + c + "INSERT INTO test VALUES (123, 'abc', 1, 'wib', 2.415), (456, 'def', 0, 'wibble',0.9999999999999), (NULL, NULL, NULL, NULL, NULL)" + arOnDelete <- Internal.execAffectedRows c "delete from test" + arOnDelete' <- Internal.execAffectedRows c "delete from test" + Internal.close c + shouldBe + [("create", arOnCreate), ("delete", arOnDelete), ("delete'", arOnDelete')] + [("create", 0), ("delete", 3), ("delete'", 0)]) + dataRetrieval :: Spec dataRetrieval = do it