diff --git a/README.md b/README.md index 8b6b6fc..ad68ce7 100644 --- a/README.md +++ b/README.md @@ -1,45 +1,269 @@ -# Assert (c) Blake Mizerany and Keith Rarick -- MIT LICENCE +## Simple assertions for Go tests -## Assertions for Go tests +This is a small collection of assertions to make Go tests more concise and readable. -## Install +## Repository contains two packages: - $ go get github.com/bmizerany/assert +* **assert** - collection of assertion functions. +* **assertmysql** - collection of assertion functions to test programs that are using *github.com/ziutek/mymysql/mysql* mysql client but not only. -## Use +If you don't use *github.com/ziutek/mymysql/mysql* in your program you can still use just **assert** package. -**point.go** +## Assertions - package point +**assert.Equal** - type Point struct { - x, y int - } +Asserts parameters are equal. -**point_test.go** +```go +func Test_add(t *testing.T) { + exp := 3 + got := add(1, 3) + assert.Equal(t, exp, got) - package point + // Test with optional comment + assert.Equal(t, exp, got, "Add function is broken.") +} +``` - import ( - "testing" - "github.com/bmizerany/assert" - ) +Parameters **exp** and **got** could be of any type. - func TestAsserts(t *testing.T) { - p1 := Point{1, 1} - p2 := Point{2, 1} +**assert.NotEqual** - assert.Equal(t, p1, p2) - } +Works the same way as *assert.Equal()* but asserts parameters are not equal. -**output** - $ go test - --- FAIL: TestAsserts (0.00 seconds) - assert.go:15: /Users/flavio.barbosa/dev/stewie/src/point_test.go:12 - assert.go:24: ! X: 1 != 2 - FAIL +**assert.Error** -## Docs +Asserts parameter is an instance of *error*. - http://github.com/bmizerany/assert +```go +func Test_Error(t *testing.T) { + file, err := getFile("/file/path") + + assert.Error(err) + + // Test with optional comment + assert.Error(t, err, "Expected file error.") +} +``` + +**assert.NotError** + +Works the same way as *assert.Error()* but asserts parameter is not an *error*. + +**assert.Nil** + +Asserts parameter is nil. + +```go +func Test_nil(t *testing.T) { + nilValue := doWork() + + assert.Nil(nilValue) + + // Test with optional comment + assert.Nil(t, nilValue, "Some comment") +} +``` + +**assert.ErrorMsgContains** + +```go +func Test_WTF(t *testing.T) { + err := errors.New("WTF?") + assert.ErrorMsgContains(t, err, "WTF") +} +``` + +**assert.NotNil** + +Works the same way as *assert.Nil()* but asserts parameter is not a *nil* value. + +**assert.T** + +Asserts parameter is *true*. + +**assert.F** + +Asserts parameter is *false*. + +**assert.FileExists** + +Asserts file exists and is accessible. + +```go +func Test_WTF(t *testing.T) { + testLog := "/some/path/testlog.log" + createLogFile(testLog) + assert.FileExists(t, testLog) +} +``` + +**assert.EqualFuncPtr** + +Asserts two function pointers are pointing to the same function. + +```go +func A() string { + return "A" +} + +type T struct{ + fn1: func() string + fn2: func() string +} + +func Test_Ptr(t *testing.T) { + t := &T{A, A} + assert.EqualFuncPtr(t, t.fn1, t.fn2) + + assert.FuncPtr(t, t.fn1) + assert.FuncPtr(t, t.fn2) +} +``` + +**assert.FuncPtr** + +Asserts pointer is a function pointer. See previous example. + +**assert.Panic** + +Asserts function panics. + +```go +func A() string { + panic("Not implemented") +} + + +func Test_Panic(t *testing.T) { + + fn := func() { + A() + } + + assert.Panic(t, fn, "Not implemented") +} +``` + +## Assertion comments + +You may pass as many comments to assertions as you want. You can even use formating: + +```go +assert.Equal(t, exp, got, "Add function is broken for input (%d,%d).", 1, 2) +``` + +## Failing fast / slow + +By default all the tests are failing slow. Meaning all the assertions in the test case are run. But sometimes you need to fail the test case on first error. This is how you are doing it: + +```go +gotValue, err := doWork() +assert.Equal(t, someTestStruct, gotValue, assert.FAIL_FAST, "Did not expect this.") +assert.Equal(t, "expected", gotValue.field1) +``` + +In this case if first assertion fails the the whole test case fails and second assertion will never be run. + +## Example output + + --- FAIL: Test_add (0.00 seconds) + assert.go:199: /golang/src/a/a_test.go:9 + assert.go:212: ! 5 != 4 + --- FAIL: Test_notEqual (0.00 seconds) + assert.go:199: /golang/src/a/a_test.go:14 + assert.go:66: ! Values not supposed to be equal. + assert.go:67: ! 1 + --- FAIL: Test_nil (0.00 seconds) + assert.go:199: /golang/src/a/a_test.go:35 + assert.go:166: Expected nil got "asdads" + assert.go:168: ! - Test comment + assert.go:199: /golang/src/a/a_test.go:36 + assert.go:179: Did not expect nil + assert.go:181: ! - Some other test comment + +## Writing your own tests helpers or assertions + +See the code for **assertmysql** package which uses the **assert** package to implement its assertions. + +## Documentation + +* http://godoc.org/github.com/rzajac/assert/assert +* http://godoc.org/github.com/rzajac/assert/assertmysql + +## Installation + +To install assertions run: + + $ go get github.com/rzajac/assert + +## Import + +```go +package my_package + +import ( + "github.com/rzajac/assert/assert" + "github.com/rzajac/assert/assertmysql" +) +``` + +This imports both packages but you can use only one of them. + +## MySQL assertions + +* **assertmysql.Error** - asserts MySQL error. +* **assertmysql.NotError** - asserts error is not MySQL error. +* **assertmysql.ErrorCode** - asserts error is MySQL error and has specific code. +* **assertmysql.RowExists** - asserts row exists in a table. +* **assertmysql.RowDoesNotExist** - asserts row does not exists in a table. +* **assertmysql.TableExists** - asserts table exists in database. +* **assertmysql.TableDoesNotExists** - asserts table does not exist in database. +* **assertmysql.TableCount** - asserts database has x tables. +* **assertmysql.TableRowCount** - asserts database table has x rows. +* **assertmysql.TableNotEmpty** - asserts table is not empty (has at least one row). + +## MySQL helper functions + +* **assertmysql.GetMySqlError** - returns an error and returns *mysql.Error. +* **assertmysql.GetTableNames** - returns an array of database table names. +* **assertmysql.GetTableRowCount** - returns a number of rows in a table. +* **assertmysql.DropTable** - drops table from database. +* **assertmysql.TruncateTable** - truncates table. +* **assertmysql.DropAllTables** - drops all tables from database. + +## Setup MySQL assertions package + +**mydb_test.go** + +```go +package my_package + +import ( + "github.com/rzajac/assert/assert" + "github.com/rzajac/assert/assertmysql" + "testing" +) + +func init() { + assertmysql.InitMySqlAssertions("tcp", "", "localhost:3306", "test_user", "test_user_password", "test_database") +} + +func Test_createMasterTable(t *testing.T) { + dropAllSqlTables(t) + err := createMasterTable() + + assert.NotError(t, err) + assertmysql.TableExists(t, "master_table_name") + assertmysql.TableCount(t, 0) +} +``` + +**You need only one *init* function per package.** + +## License + +Released under the MIT License. +Assert (c) Blake Mizerany, Keith Rarick and Rafal Zajac diff --git a/assert.go b/assert.go deleted file mode 100644 index 7409f98..0000000 --- a/assert.go +++ /dev/null @@ -1,76 +0,0 @@ -package assert -// Testing helpers for doozer. - -import ( - "github.com/kr/pretty" - "reflect" - "testing" - "runtime" - "fmt" -) - -func assert(t *testing.T, result bool, f func(), cd int) { - if !result { - _, file, line, _ := runtime.Caller(cd + 1) - t.Errorf("%s:%d", file, line) - f() - t.FailNow() - } -} - -func equal(t *testing.T, exp, got interface{}, cd int, args ...interface{}) { - fn := func() { - for _, desc := range pretty.Diff(exp, got) { - t.Error("!", desc) - } - if len(args) > 0 { - t.Error("!", " -", fmt.Sprint(args...)) - } - } - result := reflect.DeepEqual(exp, got) - assert(t, result, fn, cd+1) -} - -func tt(t *testing.T, result bool, cd int, args ...interface{}) { - fn := func() { - t.Errorf("! Failure") - if len(args) > 0 { - t.Error("!", " -", fmt.Sprint(args...)) - } - } - assert(t, result, fn, cd+1) -} - -func T(t *testing.T, result bool, args ...interface{}) { - tt(t, result, 1, args...) -} - -func Tf(t *testing.T, result bool, format string, args ...interface{}) { - tt(t, result, 1, fmt.Sprintf(format, args...)) -} - -func Equal(t *testing.T, exp, got interface{}, args ...interface{}) { - equal(t, exp, got, 1, args...) -} - -func Equalf(t *testing.T, exp, got interface{}, format string, args ...interface{}) { - equal(t, exp, got, 1, fmt.Sprintf(format, args...)) -} - -func NotEqual(t *testing.T, exp, got interface{}, args ...interface{}) { - fn := func() { - t.Errorf("! Unexpected: <%#v>", exp) - if len(args) > 0 { - t.Error("!", " -", fmt.Sprint(args...)) - } - } - result := !reflect.DeepEqual(exp, got) - assert(t, result, fn, 1) -} - -func Panic(t *testing.T, err interface{}, fn func()) { - defer func() { - equal(t, err, recover(), 3) - }() - fn() -} diff --git a/assert/assert.go b/assert/assert.go new file mode 100644 index 0000000..90782eb --- /dev/null +++ b/assert/assert.go @@ -0,0 +1,428 @@ +// Assertions for Go tests +// +// Assert (c) Blake Mizerany, Keith Rarick and Rafal Zajac +// http://github.com/rzajac/assert +// +// Licensed under the MIT license + +// Package provides assertions for testing +package assert + +import ( + "fmt" + "github.com/rzajac/pretty" + "math" + "os" + "reflect" + "runtime" + "strings" + "testing" +) + +const ( + FAIL_FAST = true + FAIL_SLOW = false +) + +const ( + NESTING_0 = iota + NESTING_1 + NESTING_2 + NESTING_3 + NESTING_4 + NESTING_5 + NESTING_6 + NESTING_7 + NESTING_8 + NESTING_9 +) + +var DEF_FAIL_STRATEGY bool = false + +// SetFailStrategy sets test failing strategy. Set it to true for tests to +// fail fast. Failing fast means that test case will run s only till the first +// failing assertion. By default the strategy is to fail slow. +func SetFailStrategy(failFast bool) { + DEF_FAIL_STRATEGY = failFast +} + +// Equal asserts equality. The exp and got can be of any type. +// +// Examples: +// +// assert.Equal(m1, m2) +// assert.Equal(m1, m2, "Some comment about the test") +// assert.Equal(m1, m2, "Some comment %s", myVar.someString) +// assert.Equal(m1, m2, assert.NESTING_1, "Some comment %s", myVar.someString) +// assert.Equal(m1, m2, assert.NESTING_1, assert.FAIL_FAST, "Some comment %s", myVar.someString) +// assert.Equal(m1, m2, assert.FAIL_FAST, assert.NESTING_1, "Some comment %s", myVar.someString) +// +func Equal(t *testing.T, exp, got interface{}, args ...interface{}) { + equal(t, exp, got, args...) +} + +// NotEqual asserts exp is not equal to got. It works the same way as assert.Equal() +func NotEqual(t *testing.T, exp, got interface{}, args ...interface{}) { + nesting, failFast, errorMsg := DecodeArgs(args...) + fn := func() { + t.Error("! Values not supposed to be equal.") + t.Errorf("! %#v", exp) + if len(errorMsg) > 0 { + t.Error("!", " -", errorMsg) + } + } + result := !reflect.DeepEqual(exp, got) + Assert(t, result, nesting, failFast, fn) +} + +// T asserts result is true. +func T(t *testing.T, result bool, args ...interface{}) { + equal(t, true, result, args...) +} + +// F asserts result is false. +func F(t *testing.T, result bool, args ...interface{}) { + equal(t, false, result, args...) +} + +// Error asserts that err is of type error. +func Error(t *testing.T, err interface{}, args ...interface{}) { + nesting, failFast, errorMsg := DecodeArgs(args...) + fn := func() { + t.Error("Expected error.") + t.Errorf("Got: %v", err) + if len(errorMsg) > 0 { + t.Error("!", " -", errorMsg) + } + } + _, ok := err.(error) + Assert(t, ok, nesting, failFast, fn) +} + +// ErrorMsgContains asserts that err is of type error +// and error message contains specific string. +// +// Example: +// +// func Test_WTF(t *testing.T) { +// err := errors.New("WTF?") +// assert.ErrorMsgContains(t, err, "WTF") +// } +// +func ErrorMsgContains(t *testing.T, err interface{}, contains string, args ...interface{}) { + var ok bool + var e error + var wrongErrorMessage string + nesting, failFast, errorMsg := DecodeArgs(args...) + + fn := func() { + t.Errorf("Expected an error containing '%s'.", contains) + if wrongErrorMessage != "" { + t.Errorf("But got error: %s", wrongErrorMessage) + } + if len(errorMsg) > 0 { + t.Error("!", " -", errorMsg) + } + } + + e, ok = err.(error) + if ok { + if !strings.Contains(e.Error(), contains) { + wrongErrorMessage = e.Error() + Assert(t, false, nesting, failFast, fn) + } + } else { + Assert(t, false, nesting, failFast, fn) + } +} + +// NotError asserts err is not of type error. +// +// Example: +// +// var tests = []struct { +// expectedValue bool +// funcArg string +// }{ +// {false, "hello"}, +// {true, "HELLO"}, +// } +// +// func Test_doWork(t *testing.T) { +// for idx, test := range tests { +// gotValue, err := isUpperCase(test.funcArg) +// assert.NotError(t, err, "idx:%d", idx) +// assert.Equal(t, test.expectedValue, gotValue, "Idx:%d", idx) +// } +// } +// +func NotError(t *testing.T, err interface{}, args ...interface{}) { + var e error + nesting, failFast, errorMsg := DecodeArgs(args...) + fn := func() { + t.Errorf("Did not expect error: %s", e.Error()) + if len(errorMsg) > 0 { + t.Error("!", " -", errorMsg) + } + } + e, ok := err.(error) + Assert(t, !ok, nesting, failFast, fn) +} + +// Nil asserts v is nil. +func Nil(t *testing.T, v interface{}, args ...interface{}) { + nesting, failFast, errorMsg := DecodeArgs(args...) + fn := func() { + t.Errorf("Expected nil got %#v", v) + if len(errorMsg) > 0 { + t.Error("!", " -", errorMsg) + } + } + result := "" == fmt.Sprintf("%v", v) + Assert(t, result, nesting, failFast, fn) +} + +// Nil asserts v is not nil. +func NotNil(t *testing.T, v interface{}, args ...interface{}) { + nesting, failFast, errorMsg := DecodeArgs(args...) + fn := func() { + t.Errorf("Did not expect nil") + if len(errorMsg) > 0 { + t.Error("!", " -", errorMsg) + } + } + result := "" != fmt.Sprintf("%v", v) + Assert(t, result, nesting, failFast, fn) +} + +// FileExists asserts file exists. +func FileExists(t *testing.T, filePath string, args ...interface{}) { + var err error + exists := true + otherError := false + nesting, failFast, errorMsg := DecodeArgs(args...) + + if _, err = os.Stat(filePath); err != nil { + if os.IsNotExist(err) { + // file does not exist + exists = false + } else { + // other error + exists = false + otherError = true + } + } + + fn := func() { + if otherError { + t.Errorf("File %s exists but we got error: %s", filePath, err.Error()) + } else { + t.Errorf("File %s does not exist.", filePath) + } + if len(errorMsg) > 0 { + t.Error("!", " -", errorMsg) + } + } + + Assert(t, exists, nesting, failFast, fn) +} + +// EqualFuncPtr asserts exp and got point to the same function. +func EqualFuncPtr(t *testing.T, exp, got interface{}, args ...interface{}) { + var nameExp, nameGot string + expV := reflect.ValueOf(exp) + expKind := expV.Kind() + gotV := reflect.ValueOf(got) + gotKind := gotV.Kind() + + if !(expKind == reflect.Func || exp == nil || gotKind == reflect.Func || got == nil) { + nesting, failFast, _ := DecodeArgs(args...) + fn := func() { + if expKind != reflect.Func { + t.Error("exp must be a function or nil") + } + if gotKind != reflect.Func { + t.Error("got must be a function or nil") + } + } + Assert(t, false, nesting, failFast, fn) + return + } + + if expV.IsValid() { + fExp := runtime.FuncForPC(expV.Pointer()) + if fExp != nil { + nameExp = fExp.Name() + } + } + + if gotV.IsValid() { + fGot := runtime.FuncForPC(gotV.Pointer()) + if fGot != nil { + nameGot = fGot.Name() + } + } + + equal(t, nameExp, nameGot, args...) +} + +// FuncPtr asserts exp is a pointer to function. +func FuncPtr(t *testing.T, exp interface{}, args ...interface{}) { + expV := reflect.ValueOf(exp) + expKind := expV.Kind() + + nesting, failFast, errorMsg := DecodeArgs(args...) + fn := func() { + t.Error("Expected function.") + if len(errorMsg) > 0 { + t.Error("!", " -", errorMsg) + } + } + + Assert(t, expKind == reflect.Func, nesting, failFast, fn) +} + +// Panic asserts that calling fn will panic. +func Panic(t *testing.T, fn func(), errorMsg string, args ...interface{}) { + nesting, failFast, tstMsg := DecodeArgs(args...) + defer func() { + eMsg := recover() + var msg string + msg, ok := eMsg.(string) + if ok && !strings.Contains(msg, errorMsg) { + Equal(t, errorMsg, msg, nesting+1, failFast, errorMsg) + } + }() + fn() + T(t, false, nesting+1, failFast, "Expected panic: "+tstMsg) +} + +func ApproximatelyEqual(t *testing.T, exp, got, epsilon float64, args ...interface{}) { + var relativeError float64 + + if math.Abs(exp-got) < epsilon { + return + } + + expAbs := math.Abs(exp) + gotAbs := math.Abs(got) + + if gotAbs > expAbs { + relativeError = math.Abs((exp - got) / got) + } else { + relativeError = math.Abs((exp - got) / exp) + } + + nesting, failFast, errorMsg := DecodeArgs(args...) + fn := func() { + t.Errorf("Expected %f to be approximately equal to %f (relative error: %f, epsilon %f)", exp, got, relativeError, epsilon) + if len(errorMsg) > 0 { + t.Error(errorMsg) + } + } + + result := relativeError <= epsilon + Assert(t, result, nesting+1, failFast, fn) +} + +// Considered Internal + +// Assert is internal function but exported because it's used in other test packages. +func Assert(t *testing.T, result bool, nesting int, failFast bool, f func()) { + if !result { + nesting++ + for n := 0; n <= nesting; n++ { + _, file, line, _ := runtime.Caller(n) + if strings.Contains(file, "rzajac/assert/") { + continue + } + t.Errorf("%s:%d", file, line) + } + f() + if failFast { + t.FailNow() + } + } +} + +func equal(t *testing.T, exp, got interface{}, args ...interface{}) { + nesting, failFast, errorMsg := DecodeArgs(args...) + fn := func() { + for _, desc := range pretty.Diff(exp, got) { + t.Error("!", desc) + } + if len(errorMsg) > 0 { + t.Error("!", " -", errorMsg) + } + } + result := reflect.DeepEqual(exp, got) + Assert(t, result, nesting+1, failFast, fn) +} + +func tt(t *testing.T, result bool, nesting int, args ...interface{}) { + nesting, failFast, errorMsg := DecodeArgs(args...) + fn := func() { + t.Errorf("! Failure") + if len(errorMsg) > 0 { + t.Error("!", " -", errorMsg) + } + } + Assert(t, result, nesting+1, failFast, fn) +} + +// DecodeArgs decodes additional test arguments. +// Returns nesting, fail strategy and error message. +func DecodeArgs(args ...interface{}) (int, bool, string) { + var str string + nesting := 1 // Default nesting + failFast := DEF_FAIL_STRATEGY // Fail fast by default + message := "" // Default error message + parsingMsg := false // Set to true when nesting and failFast already parsed + format := "" + msgArgs := make([]interface{}, 0) + + for _, arg := range args { + switch arg.(type) { + case string: + if parsingMsg { + msgArgs = append(msgArgs, arg) + } else { + str, _ = arg.(string) + if isFormat(str) { + format = str + } else { + msgArgs = append(msgArgs, arg) + } + } + parsingMsg = true + break + case int: + if parsingMsg { + msgArgs = append(msgArgs, arg) + } else { + nesting, _ = arg.(int) + } + break + case bool: + if parsingMsg { + msgArgs = append(msgArgs, arg) + } else { + failFast, _ = arg.(bool) + } + break + } + } + + if format != "" { + message = fmt.Sprintf(format, msgArgs...) + } else { + message = fmt.Sprint(msgArgs...) + } + + return nesting, failFast, message +} + +// isFormat returns true if sting is format string. +func isFormat(format string) bool { + return strings.Contains(format, "%") +} diff --git a/assert/assert_test.go b/assert/assert_test.go new file mode 100644 index 0000000..844fc63 --- /dev/null +++ b/assert/assert_test.go @@ -0,0 +1,112 @@ +package assert + +import ( + "errors" + // "errors" + "testing" +) + +func Test_decodeArgs(t *testing.T) { + + var tests = []struct { + nesting int + failFast bool + errorMsg string + args []interface{} + }{ + {1, false, "", []interface{}{}}, + {1, false, "", []interface{}{1}}, + {1, false, "", []interface{}{false}}, + {1, true, "", []interface{}{true}}, + {2, false, "", []interface{}{2}}, + {2, false, "", []interface{}{2, false}}, + {2, false, "", []interface{}{false, 2}}, + {1, false, "msg", []interface{}{"msg"}}, + {1, false, "msg a", []interface{}{"msg %s", "a"}}, + {1, false, "msg1", []interface{}{"msg%d", 1}}, + {1, false, "msg", []interface{}{false, "msg"}}, + {3, false, "msg", []interface{}{false, 3, "msg"}}, + {3, false, "msg", []interface{}{3, false, "msg"}}, + {3, true, "msg true", []interface{}{3, true, "msg %t", true}}, + } + + for idx, test := range tests { + nesting, failFast, errorMsg := DecodeArgs(test.args...) + + if test.nesting != nesting { + t.Fatalf("Test: %d. Expected nesting %d got %d", idx, test.nesting, nesting) + } + + if test.failFast != failFast { + t.Fatalf("Test: %d. Expected failFast %t got %t", idx, test.failFast, failFast) + } + + if test.errorMsg != errorMsg { + t.Fatalf("Test: %d. Expected errorMsg '%s' got '%s'", idx, test.errorMsg, errorMsg) + } + } +} + +func Test_isFormat(t *testing.T) { + + var tests = []struct { + exp bool + format string + }{ + {false, "abc"}, + {true, "abc %s"}, + {true, "abc %d"}, + {true, "abc %t"}, + } + + for idx, test := range tests { + got := isFormat(test.format) + if test.exp != got { + t.Fatalf("Expected %t got %t for test %d", test.exp, got, idx) + } + } +} + +func Test_Error(t *testing.T) { + err := errors.New("Test Error") + Error(t, err) +} + +func Test_NotError(t *testing.T) { + NotError(t, 1) + NotError(t, nil) + NotError(t, "aaa") +} + +func Test_ErrorMsgContains(t *testing.T) { + err := errors.New("WTF") + ErrorMsgContains(t, err, "WTF") +} + +func Test_F(t *testing.T) { + F(t, false) +} + +func Test_T(t *testing.T) { + T(t, true) +} + +func Test_Nil(t *testing.T) { + Nil(t, nil) +} + +func Test_NotNil(t *testing.T) { + NotNil(t, 1) +} + +func Test_Equal(t *testing.T) { + Equal(t, 1, 1) + Equal(t, nil, nil) + Equal(t, "aaa", "aaa") +} + +func Test_NotEqual(t *testing.T) { + NotEqual(t, 1, 2) + NotEqual(t, nil, 1) + NotEqual(t, "aaa", 1) +} diff --git a/assert_test.go b/assert_test.go deleted file mode 100644 index 162a590..0000000 --- a/assert_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package assert - -import ( - "testing" -) - -func TestLineNumbers(t *testing.T) { - Equal(t, "foo", "foo", "msg!") - //Equal(t, "foo", "bar", "this should blow up") -} - -func TestNotEqual(t *testing.T) { - NotEqual(t, "foo", "bar", "msg!") - //NotEqual(t, "foo", "foo", "this should blow up") -} diff --git a/assertmysql/assertmysql.go b/assertmysql/assertmysql.go new file mode 100644 index 0000000..422e63a --- /dev/null +++ b/assertmysql/assertmysql.go @@ -0,0 +1,222 @@ +// Assertions and helper functions for testing applications using MySQL +// +// Assert (c) Rafal Zajac +// http://github.com/rzajac/assert +// +// Licensed under the MIT license + +// Assertions and helper functions for testing applications using MySQL +// MySQL driver used: http://github.com/ziutek/mymysql +package assertmysql + +import ( + "fmt" + "github.com/rzajac/assert/assert" + "github.com/ziutek/mymysql/mysql" + _ "github.com/ziutek/mymysql/thrsafe" + "testing" +) + +const ( + DUMMY_MYSQL_ERROR_CODE = 9999 +) + +// The database connection to use. +var dbcon mysql.Conn + +// InitMySqlAssertions initializes assertmysql package. +func InitMySqlAssertions(proto, laddr, raddr, user, passwd, db string) { + dbcon = mysql.New(proto, laddr, raddr, user, passwd, db) + err := dbcon.Connect() + if err != nil { + panic(err) + } +} + +// Error asserts err is MySql error. +func Error(t *testing.T, err error, args ...interface{}) { + nesting, failFast, errorMsg := assert.DecodeArgs(args...) + fn := func() { + t.Error("Expected MySQL error.") + if len(errorMsg) > 0 { + t.Error("!", " -", errorMsg) + } + } + _, ok := err.(*mysql.Error) + assert.Assert(t, ok, nesting+1, failFast, fn) +} + +// NotError asserts err is not MySql error. +func NotError(t *testing.T, err error, args ...interface{}) { + nesting, failFast, errorMsg := assert.DecodeArgs(args...) + fn := func() { + t.Errorf("Did not expect MySQL error: %s", err.Error()) + if len(errorMsg) > 0 { + t.Error("!", " -", errorMsg) + } + } + _, ok := err.(*mysql.Error) + assert.Assert(t, !ok, nesting, failFast, fn) +} + +// ErrorCode asserts err is *mysql.Error and has code. +func ErrorCode(t *testing.T, err error, errorCode uint16) { + assert.Error(t, err, assert.NESTING_2, assert.FAIL_SLOW, "Expected error but got nil.") + + merr := GetMySqlErorr(err) + fn := func() { + t.Error("Expected MySQL error but got some other error.") + if err == nil { + t.Errorf("Got error: nil") + } else { + t.Errorf("Got error: %s", err.Error()) + } + } + _, ok := err.(*mysql.Error) + assert.Assert(t, ok, assert.NESTING_2, assert.FAIL_FAST, fn) + assert.Equal(t, errorCode, merr.Code, assert.NESTING_2, "Expected error code %d got %d.", errorCode, merr.Code) +} + +// RowExists asserts row exists in a table. +func RowExists(t *testing.T, tableName, pkName, selectValue interface{}) ([]mysql.Row, mysql.Result) { + value := fmt.Sprintf("%v", selectValue) + if _, ok := selectValue.(string); ok { + value = "'" + dbcon.Escape(value) + "'" + } + rows, res, err := dbcon.Query("SELECT * FROM %s WHERE %s = %s", tableName, pkName, value) + assert.NotError(t, err, assert.NESTING_2, assert.FAIL_FAST) + assert.T(t, len(rows) > 0, assert.NESTING_2, assert.FAIL_FAST, "Expected row with %s = %s to exist in the database.", pkName, value) + return rows, res +} + +// RowDoesNotExists asserts row exists in a table. +func RowDoesNotExists(t *testing.T, tableName, pkName, selectValue interface{}) ([]mysql.Row, mysql.Result) { + value := fmt.Sprintf("%v", selectValue) + if _, ok := selectValue.(string); ok { + value = "'" + dbcon.Escape(value) + "'" + } + rows, res, err := dbcon.Query("SELECT * FROM %s WHERE %s = %s", tableName, pkName, value) + assert.NotError(t, err, assert.NESTING_2, assert.FAIL_FAST) + assert.T(t, len(rows) == 0, assert.NESTING_2, assert.FAIL_FAST, "Did not expect row with %s = %s to exist in the database.", pkName, value) + return rows, res +} + +// TableExists asserts MySQL table exists in the database. +func TableExists(t *testing.T, tableName string) { + foundTable := tableExists(tableName) + assert.T(t, foundTable, assert.NESTING_2, "Table '%s' is not present in the database.", tableName) +} + +// TableDoesNotExists asserts MySQL table does not exists in the database. +func TableDoesNotExists(t *testing.T, tableName string) { + foundTable := tableExists(tableName) + assert.T(t, !foundTable, assert.NESTING_2, "Table '%s' is present in the database.", tableName) +} + +// TableCount asserts number of tables in the database. +func TableCount(t *testing.T, expectedCount int) { + sqlTables, _ := GetTableNames() + tablesCount := len(sqlTables) + assert.Equal(t, expectedCount, tablesCount, assert.NESTING_2, "Expected %d tables but got %d.", expectedCount, tablesCount) +} + +// TableRowCount asserts tableName has expectedRowCount rows. +func TableRowCount(t *testing.T, tableName string, expectedRowCount int) { + rowCount, err := GetTableRowCount(tableName) + assert.NotError(t, err, assert.NESTING_2) + assert.Equal(t, rowCount, expectedRowCount, assert.NESTING_2, "Expected %d rows in %s table but got %d.", expectedRowCount, tableName, rowCount) +} + +// TableNotEmpty asserts tableName is not empty. +func TableNotEmpty(t *testing.T, tableName string) { + rowCount, err := GetTableRowCount(tableName) + assert.NotError(t, err, assert.NESTING_2) + assert.T(t, rowCount > 0, assert.NESTING_2, "Expected table to have data.") +} + +// tableExists returns true if table exists in database. +func tableExists(tableName string) bool { + var foundTable bool + sqlTables, _ := GetTableNames() + for _, dbTableName := range sqlTables { + if dbTableName == tableName { + foundTable = true + break + } + } + return foundTable +} + +// helper functions + +// GetMySqlErorr helper function casts error to *mysql.Error. +// If err is not *mysql.Error it still returns *mysql.Error but +// with invalid (not used) MySQL error code 9999. +func GetMySqlErorr(err error) *mysql.Error { + var ok bool + var mysqle *mysql.Error + + mysqle, ok = err.(*mysql.Error) + + if !ok { + mysqle = new(mysql.Error) + if err == nil { + mysqle.Msg = []byte("") + } else { + mysqle.Msg = []byte(err.Error()) + } + mysqle.Code = DUMMY_MYSQL_ERROR_CODE + } + return mysqle +} + +// GetTableNames gets all MySQL table names. +func GetTableNames() ([]string, error) { + tables := make([]string, 0, 10) + rows, _, err := dbcon.Query("SHOW TABLES") + if err != nil { + return tables, err + } + for _, row := range rows { + tables = append(tables, row.Str(0)) + } + return tables, err +} + +// GetTableRowCount returns number of rows in a table. +func GetTableRowCount(tableName string) (int, error) { + row, _, err := dbcon.QueryFirst("SELECT count(1) FROM %s", tableName) + if err != nil { + return 0, err + } + return row.Int(0), err +} + +// DropTable drops MySQL table by name. +func DropTable(tableName string) error { + _, _, err := dbcon.Query("DROP TABLE IF EXISTS %s", tableName) + return err +} + +// TruncateTable truncates MySQL table by name. +func TruncateTable(tableName string) error { + _, _, err := dbcon.Query("TRUNCATE %s", tableName) + return err +} + +// DropAllTables drops all MySQL tables in selected database. +func DropAllTables() error { + var err error + var tableNames []string + tableNames, err = GetTableNames() + if err != nil { + return err + } + for _, tableName := range tableNames { + err = DropTable(tableName) + if err != nil { + return err + } + } + return err +} diff --git a/assertmysql/assertmysql_test.go b/assertmysql/assertmysql_test.go new file mode 100644 index 0000000..5a5dd48 --- /dev/null +++ b/assertmysql/assertmysql_test.go @@ -0,0 +1,27 @@ +package assertmysql + +import ( + "errors" + "fmt" + "github.com/ziutek/mymysql/mysql" + "testing" +) + +func Test_NotError(t *testing.T) { + err1 := errors.New("Test1") + NotError(t, err1) +} + +func Test_Error(t *testing.T) { + err1 := new(mysql.Error) + err1.Msg = []byte("Test1") + Error(t, err1) +} + +func Test_ErrorCode(t *testing.T) { + err1 := new(mysql.Error) + err1.Code = 123 + err1.Msg = []byte("Test1") + fmt.Println(err1) + ErrorCode(t, err1, 123) +} diff --git a/example/point.go b/example/point.go deleted file mode 100644 index 15789fe..0000000 --- a/example/point.go +++ /dev/null @@ -1,5 +0,0 @@ -package point - -type Point struct { - X, Y int -} diff --git a/example/point_test.go b/example/point_test.go deleted file mode 100644 index 34e791a..0000000 --- a/example/point_test.go +++ /dev/null @@ -1,13 +0,0 @@ -package point - -import ( - "testing" - "assert" -) - -func TestAsserts(t *testing.T) { - p1 := Point{1, 1} - p2 := Point{2, 1} - - assert.Equal(t, p1, p2) -} diff --git a/examples/assert_examples.go b/examples/assert_examples.go new file mode 100644 index 0000000..b68e584 --- /dev/null +++ b/examples/assert_examples.go @@ -0,0 +1,9 @@ +package assert_example + +type Point struct { + X, Y int +} + +func AddNumbers(x, y int) int { + return x + y +} diff --git a/examples/assert_examples_test.go b/examples/assert_examples_test.go new file mode 100644 index 0000000..9195d4b --- /dev/null +++ b/examples/assert_examples_test.go @@ -0,0 +1,29 @@ +package assert_example + +import ( + "github.com/rzajac/assert/assert" + "testing" +) + +func Test_AssertEqual(t *testing.T) { + p1 := Point{1, 1} + p2 := Point{2, 1} + assert.Equal(t, p1, p2) +} + +func Test_AddNumbers(t *testing.T) { + var tests = []struct { + x, y int + expResult int + }{ + {1, 2, 3}, + {0, 0, 0}, + {2, 3, 6}, // Intentional error + } + var gotValue int + + for idx, test := range tests { + gotValue = addNumbers(test.x, test.y) + assert.Equalf(t, test.expResult, gotValue, "For idx: %s", idx) + } +}