diff --git a/dht/dht.go b/dht/dht.go index 26fbd86..53a747b 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -3,6 +3,7 @@ package dht import ( "bytes" "context" + "log" "sync" "github.com/decanus/bureka/dht/state" @@ -129,7 +130,7 @@ func (d *DHT) AddPeer(id state.Peer) { } // RemovePeer removes a peer from the dht. -func (d *DHT) RemovePeer(id state.Peer) { +func (d *DHT) RemovePeer(id state.Peer) bool { d.Lock() defer d.Unlock() @@ -137,7 +138,7 @@ func (d *DHT) RemovePeer(id state.Peer) { d.NeighborhoodSet = ns d.RoutingTable = d.RoutingTable.Remove(d.ID, id) - d.LeafSet.Remove(id) + return d.LeafSet.Remove(id) } func (d *DHT) Heartbeat(id state.Peer) { @@ -149,6 +150,20 @@ func (d *DHT) Heartbeat(id state.Peer) { } } +// Close sends a message to all neighbors that a peer is exiting the DHT. +func (d *DHT) Close() { + d.MapNeighbors(func(peer state.Peer) { + err := d.Send( + context.Background(), + &pb.Message{Key: peer, Type: pb.Message_NODE_EXIT, Sender: d.ID}, + ) + + if err != nil { + log.Println(err) + } + }) +} + // MapNeighbors iterates over the NeighborhoodSet and calls the process for every peer. func (d *DHT) MapNeighbors(process MapFunc) { d.RLock() diff --git a/node/node.go b/node/node.go index f47d7e0..66bb189 100644 --- a/node/node.go +++ b/node/node.go @@ -2,6 +2,7 @@ package node import ( "context" + "errors" logging "github.com/ipfs/go-log" "github.com/libp2p/go-libp2p-core/event" @@ -75,7 +76,7 @@ func (n *Node) FindPeer(ctx context.Context, id peer.ID) (peer.AddrInfo, error) b := []byte(id) p := n.dht.Find(b) if p == nil { - return peer.AddrInfo{}, nil // @todo error + return peer.AddrInfo{}, errors.New("failed to find peer") } id, err := peer.IDFromBytes(p) @@ -90,6 +91,12 @@ func (n *Node) Send(ctx context.Context, msg *pb.Message) error { return n.dht.Send(ctx, msg) } +// Close closes the node and stops the DHT. +func (n *Node) Close() { + n.dht.Close() + n.host.Close() +} + func (n *Node) handleOutgoingMessages() { c := make(chan dht.Packet) n.dht.Feed().Subscribe(c) diff --git a/node/node_test.go b/node/node_test.go index ab00a57..0b9c877 100644 --- a/node/node_test.go +++ b/node/node_test.go @@ -105,7 +105,7 @@ func TestFindPeer(t *testing.T) { dhts := setupNodes(t, ctx, 4) defer func() { for i := 0; i < 4; i++ { - dhts[i].host.Close() + dhts[i].Close() } }()