diff --git a/dispatcher.go b/dispatcher.go index 4eee2fc..f563f51 100644 --- a/dispatcher.go +++ b/dispatcher.go @@ -22,6 +22,7 @@ import ( // // See examples for details. type Dispatcher struct { + mu sync.RWMutex serviceMap map[string]*serviceData } @@ -60,6 +61,8 @@ func NewDispatcher() *Dispatcher { // // See examples for details. func (d *Dispatcher) AddFunc(funcName string, f interface{}) { + d.mu.Lock() + defer d.mu.Unlock() sd, ok := d.serviceMap[""] if !ok { sd = &serviceData{ @@ -90,6 +93,8 @@ func (d *Dispatcher) AddFunc(funcName string, f interface{}) { // // All public methods must conform requirements described in AddFunc(). func (d *Dispatcher) AddService(serviceName string, service interface{}) { + d.mu.Lock() + defer d.mu.Unlock() if serviceName == "" { logPanic("gorpc.Dispatcher: serviceName cannot be empty") } @@ -242,7 +247,7 @@ func validateType(t reflect.Type) (err error) { }) switch t.Kind() { - case reflect.Chan, reflect.Func, reflect.Interface, reflect.UnsafePointer: + case reflect.Chan, reflect.Func, reflect.UnsafePointer: err = fmt.Errorf("%s. Found [%s]", t.Kind(), t) return err case reflect.Array, reflect.Slice: @@ -308,18 +313,14 @@ func init() { // The returned HandlerFunc must be assigned to Server.Handler or // passed to New*Server(). func (d *Dispatcher) NewHandlerFunc() HandlerFunc { - if len(d.serviceMap) == 0 { - logPanic("gorpc.Dispatcher: register at least one service before calling HandlerFunc()") - } - - serviceMap := copyServiceMap(d.serviceMap) - return func(clientAddr string, request interface{}) interface{} { req, ok := request.(*dispatcherRequest) if !ok { logPanic("gorpc.Dispatcher: unsupported request type received from the client: %T", request) } - return dispatchRequest(serviceMap, clientAddr, req) + d.mu.RLock() + defer d.mu.RUnlock() + return dispatchRequest(d.serviceMap, clientAddr, req) } } @@ -446,13 +447,9 @@ type DispatcherClient struct { serviceName string } -// NewFuncClient returns a client suitable for calling functions registered -// via AddFunc(). -func (d *Dispatcher) NewFuncClient(c *Client) *DispatcherClient { - if len(d.serviceMap) == 0 || d.serviceMap[""] == nil { - logPanic("gorpc.Dispatcher: register at least one function with AddFunc() before calling NewFuncClient()") - } - +// NewDispatcherFuncClient returns a client suitable for calling functions +// registered via AddFunc(). +func NewDispatcherFuncClient(c *Client) *DispatcherClient { return &DispatcherClient{ c: c, } @@ -462,15 +459,35 @@ func (d *Dispatcher) NewFuncClient(c *Client) *DispatcherClient { // of the service with name serviceName registered via AddService(). // // It is safe creating multiple service clients over a single underlying client. +func NewDispatcherServiceClient(serviceName string, c *Client) *DispatcherClient { + return &DispatcherClient{ + c: c, + serviceName: serviceName, + } +} + +// NewFuncClient checks and returns a client suitable for calling functions +// registered via AddFunc(). +func (d *Dispatcher) NewFuncClient(c *Client) *DispatcherClient { + d.mu.RLock() + defer d.mu.RUnlock() + if len(d.serviceMap) == 0 || d.serviceMap[""] == nil { + logPanic("gorpc.Dispatcher: register at least one function with AddFunc() before calling NewFuncClient()") + } + + return NewDispatcherFuncClient(c) +} + +// NewServiceClient checks and returns a client suitable for calling methods +// of the service with name serviceName registered via AddService(). func (d *Dispatcher) NewServiceClient(serviceName string, c *Client) *DispatcherClient { + d.mu.RLock() + defer d.mu.RUnlock() if len(d.serviceMap) == 0 || d.serviceMap[serviceName] == nil { logPanic("gorpc.Dispatcher: service [%s] must be registered with AddService() before calling NewServiceClient()", serviceName) } - return &DispatcherClient{ - c: c, - serviceName: serviceName, - } + return NewDispatcherServiceClient(serviceName, c) } // Call calls the given function with the given request. diff --git a/dispatcher_test.go b/dispatcher_test.go index 1d59f05..b740246 100644 --- a/dispatcher_test.go +++ b/dispatcher_test.go @@ -3,7 +3,6 @@ package gorpc import ( "bytes" "fmt" - "io" "reflect" "sync/atomic" "testing" @@ -11,13 +10,6 @@ import ( "unsafe" ) -func TestDispatcherNewHandlerNoFuncs(t *testing.T) { - d := NewDispatcher() - testPanic(t, func() { - d.NewHandlerFunc() - }) -} - func TestDispatcherNewFuncClientNoFuncs(t *testing.T) { c := NewTCPClient(getRandomAddr()) @@ -92,16 +84,6 @@ func TestDispatcherChanArg(t *testing.T) { }) } -func TestDispatcherInterfaceArg(t *testing.T) { - d := NewDispatcher() - testPanic(t, func() { - d.AddFunc("foo", func(req io.Reader) {}) - }) - testPanic(t, func() { - d.AddFunc("foo", func(req interface{}) {}) - }) -} - func TestDispatcherUnsafePointerArg(t *testing.T) { d := NewDispatcher() testPanic(t, func() { @@ -123,16 +105,6 @@ func TestDispatcherChanRes(t *testing.T) { }) } -func TestDispatcherInterfaceRes(t *testing.T) { - d := NewDispatcher() - testPanic(t, func() { - d.AddFunc("foo", func() (res io.Reader) { return }) - }) - testPanic(t, func() { - d.AddFunc("foo", func() (res interface{}) { return }) - }) -} - func TestDispatcherUnsafePointerRes(t *testing.T) { d := NewDispatcher() testPanic(t, func() { @@ -143,7 +115,7 @@ func TestDispatcherUnsafePointerRes(t *testing.T) { func TestDispatcherStructWithInvalidFields(t *testing.T) { type InvalidMsg struct { B int - A io.Reader + A chan bool } d := NewDispatcher() @@ -152,13 +124,6 @@ func TestDispatcherStructWithInvalidFields(t *testing.T) { }) } -func TestDispatcherInvalidMap(t *testing.T) { - d := NewDispatcher() - testPanic(t, func() { - d.AddFunc("foo", func(req map[string]interface{}) {}) - }) -} - func TestDispatcherPassStructArgByValue(t *testing.T) { type RequestType struct { a int @@ -289,6 +254,26 @@ func TestDispatcherInvalidArgType(t *testing.T) { }) } +func TestDispatcherFuncLater(t *testing.T) { + d := NewDispatcher() + d.AddFunc("foo", func(request string) {}) + testDispatcherFunc(t, d, func(dc *DispatcherClient) { + res, err := dc.Call("foo", nil) + if err == nil { + t.Fatalf("Expected non-nil error") + } + if res != nil { + t.Fatalf("Expected nil response. Got %+v", res) + } + + d.AddFunc("foo0", func() {}) + _, err = dc.Call("foo0", nil) + if err != nil { + t.Fatalf("Expected nil error") + } + }) +} + func TestDispatcherUnknownFuncCall(t *testing.T) { d := NewDispatcher() d.AddFunc("foo", func(request string) {}) @@ -987,6 +972,29 @@ func TestDispatcherServiceUnknownMethodCall(t *testing.T) { testDispatcherService(t, d, "qwerty", func(dc *DispatcherClient) { testUnknownFuncs(t, dc) }) } +func TestDispatcherServiceLater(t *testing.T) { + d := NewDispatcher() + c, s := getClientServer(t, d) + defer s.Stop() + defer c.Stop() + + dc := NewDispatcherServiceClient("qwerty", c) + + res, err := dc.Call("Get", nil) + if err == nil { + t.Fatalf("Error expected") + } + if res != nil { + t.Fatalf("Expected nil response. Got %+v", res) + } + + d.AddService("qwerty", &testService{}) + _, err = dc.Call("Get", nil) + if err != nil { + t.Fatalf("Expected nil error") + } +} + func TestDispatcherServicePrivateMethodCall(t *testing.T) { service := &testService{}