diff --git a/relayer/chainreader/loop/loop_reader.go b/relayer/chainreader/loop/loop_reader.go index 4266b31f..5f410cdc 100644 --- a/relayer/chainreader/loop/loop_reader.go +++ b/relayer/chainreader/loop/loop_reader.go @@ -7,6 +7,7 @@ import ( "fmt" "reflect" "strings" + "sync" "github.com/smartcontractkit/chainlink-aptos/relayer/codec" "github.com/smartcontractkit/chainlink-common/pkg/logger" @@ -34,9 +35,10 @@ func NewLoopChainReader(logger logger.Logger, cr types.ContractReader) types.Con type loopChainReader struct { services.Service types.UnimplementedContractReader - logger logger.Logger - cr types.ContractReader - moduleAddresses map[string]string + logger logger.Logger + cr types.ContractReader + moduleAddressesMu sync.RWMutex + moduleAddresses map[string]string } func (a *loopChainReader) Name() string { @@ -67,7 +69,7 @@ func (a *loopChainReader) GetLatestValue(ctx context.Context, readIdentifier str _, contractName, _ := readComponents[0], readComponents[1], readComponents[2] - _, ok := a.moduleAddresses[contractName] + ok := a.hasModuleAddress(contractName) if !ok { return fmt.Errorf("no such contract: %s", contractName) } @@ -103,7 +105,7 @@ func (a *loopChainReader) BatchGetLatestValues(ctx context.Context, request type for contract, requestBatch := range request { convertedBatch := []types.BatchRead{} for _, read := range requestBatch { - _, ok := a.moduleAddresses[contract.Name] + ok := a.hasModuleAddress(contract.Name) if !ok { return nil, fmt.Errorf("no such contract: %s", contract.Name) } @@ -210,22 +212,27 @@ func (a *loopChainReader) QueryKey(ctx context.Context, contract types.BoundCont } func (a *loopChainReader) Bind(ctx context.Context, bindings []types.BoundContract) error { + a.moduleAddressesMu.Lock() for _, binding := range bindings { a.moduleAddresses[binding.Name] = binding.Address } + a.moduleAddressesMu.Unlock() return a.cr.Bind(ctx, bindings) } func (a *loopChainReader) Unbind(ctx context.Context, bindings []types.BoundContract) error { + a.moduleAddressesMu.Lock() for _, binding := range bindings { key := binding.Name if _, ok := a.moduleAddresses[key]; ok { delete(a.moduleAddresses, key) } else { + a.moduleAddressesMu.Unlock() return fmt.Errorf("no such binding: %s", key) } } + a.moduleAddressesMu.Unlock() // we ignore unbind errors, because if the LOOP plugin restarted, the binding would not exist. err := a.cr.Unbind(ctx, bindings) @@ -237,6 +244,9 @@ func (a *loopChainReader) Unbind(ctx context.Context, bindings []types.BoundCont } func (a *loopChainReader) getBindings() []types.BoundContract { + a.moduleAddressesMu.RLock() + defer a.moduleAddressesMu.RUnlock() + bindings := make([]types.BoundContract, 0, len(a.moduleAddresses)) for name, address := range a.moduleAddresses { @@ -249,6 +259,14 @@ func (a *loopChainReader) getBindings() []types.BoundContract { return bindings } +func (a *loopChainReader) hasModuleAddress(name string) bool { + a.moduleAddressesMu.RLock() + defer a.moduleAddressesMu.RUnlock() + + _, ok := a.moduleAddresses[name] + return ok +} + func (a *loopChainReader) decodeGLVReturnValue(label string, jsonBytes []byte, returnVal any) error { if len(jsonBytes) > maxResponseSize { return fmt.Errorf("getLatestValue response size (%d bytes) exceeds maximum allowed size (%d bytes)", len(jsonBytes), maxResponseSize) diff --git a/relayer/chainreader/loop/loop_reader_test.go b/relayer/chainreader/loop/loop_reader_test.go new file mode 100644 index 00000000..bfba0ba0 --- /dev/null +++ b/relayer/chainreader/loop/loop_reader_test.go @@ -0,0 +1,71 @@ +package loop + +import ( + "context" + "fmt" + "sync" + "testing" + + "github.com/smartcontractkit/chainlink-common/pkg/logger" + "github.com/smartcontractkit/chainlink-common/pkg/types" +) + +type fakeContractReader struct { + types.UnimplementedContractReader + mu sync.Mutex + bindCalls int + unbindCalls int +} + +func (f *fakeContractReader) Bind(_ context.Context, bindings []types.BoundContract) error { + f.mu.Lock() + defer f.mu.Unlock() + f.bindCalls++ + return nil +} + +func (f *fakeContractReader) Unbind(_ context.Context, bindings []types.BoundContract) error { + f.mu.Lock() + defer f.mu.Unlock() + f.unbindCalls++ + return nil +} + +func TestLoopChainReaderConcurrentMapAccess(t *testing.T) { + t.Parallel() + + cr := &fakeContractReader{} + reader := NewLoopChainReader(logger.Test(t), cr).(*loopChainReader) + ctx := context.Background() + + err := reader.Bind(ctx, []types.BoundContract{ + {Name: "router", Address: "0x1"}, + {Name: "offramp", Address: "0x2"}, + }) + if err != nil { + t.Fatalf("bind seed contracts: %v", err) + } + + var wg sync.WaitGroup + for i := 0; i < 24; i++ { + i := i + wg.Add(1) + go func() { + defer wg.Done() + + name := fmt.Sprintf("contract-%d", i%8) + for j := 0; j < 250; j++ { + _ = reader.Bind(ctx, []types.BoundContract{ + {Name: name, Address: fmt.Sprintf("0x%x", j)}, + }) + + _ = reader.getBindings() + _ = reader.hasModuleAddress(name) + + _ = reader.Unbind(ctx, []types.BoundContract{{Name: name}}) + } + }() + } + + wg.Wait() +}