Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 36 additions & 19 deletions dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
//
// See examples for details.
type Dispatcher struct {
mu sync.RWMutex
serviceMap map[string]*serviceData
}

Expand Down Expand Up @@ -60,6 +61,8 @@ func NewDispatcher() *Dispatcher {
//
// See examples for details.
func (d *Dispatcher) AddFunc(funcName string, f interface{}) {
d.mu.Lock()
defer d.mu.Unlock()
sd, ok := d.serviceMap[""]
if !ok {
sd = &serviceData{
Expand Down Expand Up @@ -90,6 +93,8 @@ func (d *Dispatcher) AddFunc(funcName string, f interface{}) {
//
// All public methods must conform requirements described in AddFunc().
func (d *Dispatcher) AddService(serviceName string, service interface{}) {
d.mu.Lock()
defer d.mu.Unlock()
if serviceName == "" {
logPanic("gorpc.Dispatcher: serviceName cannot be empty")
}
Expand Down Expand Up @@ -242,7 +247,7 @@ func validateType(t reflect.Type) (err error) {
})

switch t.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.UnsafePointer:
case reflect.Chan, reflect.Func, reflect.UnsafePointer:
err = fmt.Errorf("%s. Found [%s]", t.Kind(), t)
return err
case reflect.Array, reflect.Slice:
Expand Down Expand Up @@ -308,18 +313,14 @@ func init() {
// The returned HandlerFunc must be assigned to Server.Handler or
// passed to New*Server().
func (d *Dispatcher) NewHandlerFunc() HandlerFunc {
if len(d.serviceMap) == 0 {
logPanic("gorpc.Dispatcher: register at least one service before calling HandlerFunc()")
}

serviceMap := copyServiceMap(d.serviceMap)

return func(clientAddr string, request interface{}) interface{} {
req, ok := request.(*dispatcherRequest)
if !ok {
logPanic("gorpc.Dispatcher: unsupported request type received from the client: %T", request)
}
return dispatchRequest(serviceMap, clientAddr, req)
d.mu.RLock()
defer d.mu.RUnlock()
return dispatchRequest(d.serviceMap, clientAddr, req)
}
}

Expand Down Expand Up @@ -446,13 +447,9 @@ type DispatcherClient struct {
serviceName string
}

// NewFuncClient returns a client suitable for calling functions registered
// via AddFunc().
func (d *Dispatcher) NewFuncClient(c *Client) *DispatcherClient {
if len(d.serviceMap) == 0 || d.serviceMap[""] == nil {
logPanic("gorpc.Dispatcher: register at least one function with AddFunc() before calling NewFuncClient()")
}

// NewDispatcherFuncClient returns a client suitable for calling functions
// registered via AddFunc().
func NewDispatcherFuncClient(c *Client) *DispatcherClient {
return &DispatcherClient{
c: c,
}
Expand All @@ -462,15 +459,35 @@ func (d *Dispatcher) NewFuncClient(c *Client) *DispatcherClient {
// of the service with name serviceName registered via AddService().
//
// It is safe creating multiple service clients over a single underlying client.
func NewDispatcherServiceClient(serviceName string, c *Client) *DispatcherClient {
return &DispatcherClient{
c: c,
serviceName: serviceName,
}
}

// NewFuncClient checks and returns a client suitable for calling functions
// registered via AddFunc().
func (d *Dispatcher) NewFuncClient(c *Client) *DispatcherClient {
d.mu.RLock()
defer d.mu.RUnlock()
if len(d.serviceMap) == 0 || d.serviceMap[""] == nil {
logPanic("gorpc.Dispatcher: register at least one function with AddFunc() before calling NewFuncClient()")
}

return NewDispatcherFuncClient(c)
}

// NewServiceClient checks and returns a client suitable for calling methods
// of the service with name serviceName registered via AddService().
func (d *Dispatcher) NewServiceClient(serviceName string, c *Client) *DispatcherClient {
d.mu.RLock()
defer d.mu.RUnlock()
if len(d.serviceMap) == 0 || d.serviceMap[serviceName] == nil {
logPanic("gorpc.Dispatcher: service [%s] must be registered with AddService() before calling NewServiceClient()", serviceName)
}

return &DispatcherClient{
c: c,
serviceName: serviceName,
}
return NewDispatcherServiceClient(serviceName, c)
}

// Call calls the given function with the given request.
Expand Down
80 changes: 44 additions & 36 deletions dispatcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,13 @@ package gorpc
import (
"bytes"
"fmt"
"io"
"reflect"
"sync/atomic"
"testing"
"time"
"unsafe"
)

func TestDispatcherNewHandlerNoFuncs(t *testing.T) {
d := NewDispatcher()
testPanic(t, func() {
d.NewHandlerFunc()
})
}

func TestDispatcherNewFuncClientNoFuncs(t *testing.T) {
c := NewTCPClient(getRandomAddr())

Expand Down Expand Up @@ -92,16 +84,6 @@ func TestDispatcherChanArg(t *testing.T) {
})
}

func TestDispatcherInterfaceArg(t *testing.T) {
d := NewDispatcher()
testPanic(t, func() {
d.AddFunc("foo", func(req io.Reader) {})
})
testPanic(t, func() {
d.AddFunc("foo", func(req interface{}) {})
})
}

func TestDispatcherUnsafePointerArg(t *testing.T) {
d := NewDispatcher()
testPanic(t, func() {
Expand All @@ -123,16 +105,6 @@ func TestDispatcherChanRes(t *testing.T) {
})
}

func TestDispatcherInterfaceRes(t *testing.T) {
d := NewDispatcher()
testPanic(t, func() {
d.AddFunc("foo", func() (res io.Reader) { return })
})
testPanic(t, func() {
d.AddFunc("foo", func() (res interface{}) { return })
})
}

func TestDispatcherUnsafePointerRes(t *testing.T) {
d := NewDispatcher()
testPanic(t, func() {
Expand All @@ -143,7 +115,7 @@ func TestDispatcherUnsafePointerRes(t *testing.T) {
func TestDispatcherStructWithInvalidFields(t *testing.T) {
type InvalidMsg struct {
B int
A io.Reader
A chan bool
}

d := NewDispatcher()
Expand All @@ -152,13 +124,6 @@ func TestDispatcherStructWithInvalidFields(t *testing.T) {
})
}

func TestDispatcherInvalidMap(t *testing.T) {
d := NewDispatcher()
testPanic(t, func() {
d.AddFunc("foo", func(req map[string]interface{}) {})
})
}

func TestDispatcherPassStructArgByValue(t *testing.T) {
type RequestType struct {
a int
Expand Down Expand Up @@ -289,6 +254,26 @@ func TestDispatcherInvalidArgType(t *testing.T) {
})
}

func TestDispatcherFuncLater(t *testing.T) {
d := NewDispatcher()
d.AddFunc("foo", func(request string) {})
testDispatcherFunc(t, d, func(dc *DispatcherClient) {
res, err := dc.Call("foo", nil)
if err == nil {
t.Fatalf("Expected non-nil error")
}
if res != nil {
t.Fatalf("Expected nil response. Got %+v", res)
}

d.AddFunc("foo0", func() {})
_, err = dc.Call("foo0", nil)
if err != nil {
t.Fatalf("Expected nil error")
}
})
}

func TestDispatcherUnknownFuncCall(t *testing.T) {
d := NewDispatcher()
d.AddFunc("foo", func(request string) {})
Expand Down Expand Up @@ -987,6 +972,29 @@ func TestDispatcherServiceUnknownMethodCall(t *testing.T) {
testDispatcherService(t, d, "qwerty", func(dc *DispatcherClient) { testUnknownFuncs(t, dc) })
}

func TestDispatcherServiceLater(t *testing.T) {
d := NewDispatcher()
c, s := getClientServer(t, d)
defer s.Stop()
defer c.Stop()

dc := NewDispatcherServiceClient("qwerty", c)

res, err := dc.Call("Get", nil)
if err == nil {
t.Fatalf("Error expected")
}
if res != nil {
t.Fatalf("Expected nil response. Got %+v", res)
}

d.AddService("qwerty", &testService{})
_, err = dc.Call("Get", nil)
if err != nil {
t.Fatalf("Expected nil error")
}
}

func TestDispatcherServicePrivateMethodCall(t *testing.T) {
service := &testService{}

Expand Down