/ client.go
client.go
  1  package rendezvous
  2  
  3  import (
  4  	"context"
  5  	"errors"
  6  	"fmt"
  7  	"math/rand"
  8  	"time"
  9  
 10  	pb "github.com/waku-org/go-waku-rendezvous/pb"
 11  
 12  	ggio "github.com/gogo/protobuf/io"
 13  
 14  	"github.com/libp2p/go-libp2p/core/host"
 15  	inet "github.com/libp2p/go-libp2p/core/network"
 16  	"github.com/libp2p/go-libp2p/core/peer"
 17  	"github.com/libp2p/go-libp2p/core/peerstore"
 18  	"github.com/libp2p/go-libp2p/core/record"
 19  )
 20  
 21  var (
 22  	DiscoverAsyncInterval = 2 * time.Minute
 23  )
 24  
 25  type RendezvousPoint interface {
 26  	Register(ctx context.Context, ns string, ttl int) (time.Duration, error)
 27  	Discover(ctx context.Context, ns string, limit int) ([]Registration, error)
 28  	DiscoverAsync(ctx context.Context, ns string) (<-chan Registration, error)
 29  }
 30  
 31  type Registration struct {
 32  	Peer peer.AddrInfo
 33  	Ns   string
 34  	Ttl  int
 35  }
 36  
 37  type RendezvousClient interface {
 38  	Register(ctx context.Context, ns string, ttl int) (time.Duration, error)
 39  	Discover(ctx context.Context, ns string, limit int) ([]peer.AddrInfo, error)
 40  	DiscoverAsync(ctx context.Context, ns string) (<-chan peer.AddrInfo, error)
 41  }
 42  
 43  func NewRendezvousPoint(host host.Host) RendezvousPoint {
 44  	return &rendezvousPoint{
 45  		host: host,
 46  	}
 47  }
 48  
 49  type rendezvousPoint struct {
 50  	host host.Host
 51  }
 52  
 53  func NewRendezvousClient(host host.Host) RendezvousClient {
 54  	return NewRendezvousClientWithPoint(NewRendezvousPoint(host))
 55  }
 56  
 57  func NewRendezvousClientWithPoint(rp RendezvousPoint) RendezvousClient {
 58  	return &rendezvousClient{rp: rp}
 59  }
 60  
 61  type rendezvousClient struct {
 62  	rp RendezvousPoint
 63  }
 64  
 65  func (r *rendezvousPoint) getRandomPeer() (peer.ID, error) {
 66  	var peerIDs []peer.ID
 67  	for _, peer := range r.host.Peerstore().Peers() {
 68  		protocols, err := r.host.Peerstore().SupportsProtocols(peer, string(RendezvousID_v001))
 69  		if err != nil {
 70  			log.Error("error obtaining the protocols supported by peers", err)
 71  			return "", err
 72  		}
 73  		if len(protocols) > 0 {
 74  			peerIDs = append(peerIDs, peer)
 75  		}
 76  	}
 77  
 78  	if len(peerIDs) == 0 {
 79  		return "", errors.New("no peers available")
 80  	}
 81  
 82  	return peerIDs[rand.Intn(len(peerIDs))], nil // nolint: gosec
 83  }
 84  
 85  func (rp *rendezvousPoint) Register(ctx context.Context, ns string, ttl int) (time.Duration, error) {
 86  	randomPeer, err := rp.getRandomPeer()
 87  	if err != nil {
 88  		return 0, err
 89  	}
 90  
 91  	s, err := rp.host.NewStream(ctx, randomPeer, RendezvousID_v001)
 92  	if err != nil {
 93  		return 0, err
 94  	}
 95  	defer s.Reset()
 96  
 97  	r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
 98  	w := ggio.NewDelimitedWriter(s)
 99  
100  	privKey := rp.host.Peerstore().PrivKey(rp.host.ID())
101  	req, err := newRegisterMessage(privKey, ns, peer.AddrInfo{ID: rp.host.ID(), Addrs: rp.host.Addrs()}, ttl)
102  	if err != nil {
103  		return 0, err
104  	}
105  
106  	err = w.WriteMsg(req)
107  	if err != nil {
108  		return 0, err
109  	}
110  
111  	var res pb.Message
112  	err = r.ReadMsg(&res)
113  	if err != nil {
114  		return 0, err
115  	}
116  
117  	if res.GetType() != pb.Message_REGISTER_RESPONSE {
118  		return 0, fmt.Errorf("Unexpected response: %s", res.GetType().String())
119  	}
120  
121  	response := res.GetRegisterResponse()
122  	status := response.GetStatus()
123  	if status != pb.Message_OK {
124  		return 0, RendezvousError{Status: status, Text: res.GetRegisterResponse().GetStatusText()}
125  	}
126  
127  	return time.Duration(response.Ttl) * time.Second, nil
128  }
129  
130  func (rc *rendezvousClient) Register(ctx context.Context, ns string, ttl int) (time.Duration, error) {
131  	if ttl < 120 {
132  		return 0, fmt.Errorf("registration TTL is too short")
133  	}
134  
135  	returnedTTL, err := rc.rp.Register(ctx, ns, ttl)
136  	if err != nil {
137  		return 0, err
138  	}
139  
140  	go registerRefresh(ctx, rc.rp, ns, ttl)
141  	return returnedTTL, nil
142  }
143  
144  func registerRefresh(ctx context.Context, rz RendezvousPoint, ns string, ttl int) {
145  	var refresh time.Duration
146  	errcount := 0
147  
148  	for {
149  		if errcount > 0 {
150  			// do randomized exponential backoff, up to ~4 hours
151  			if errcount > 7 {
152  				errcount = 7
153  			}
154  			backoff := 2 << uint(errcount)
155  			refresh = 5*time.Minute + time.Duration(rand.Intn(backoff*60000))*time.Millisecond
156  		} else {
157  			refresh = time.Duration(ttl-30) * time.Second
158  		}
159  
160  		select {
161  		case <-time.After(refresh):
162  		case <-ctx.Done():
163  			return
164  		}
165  
166  		_, err := rz.Register(ctx, ns, ttl)
167  		if err != nil {
168  			log.Errorf("Error registering [%s]: %s", ns, err.Error())
169  			errcount++
170  		} else {
171  			errcount = 0
172  		}
173  	}
174  }
175  
176  func (rp *rendezvousPoint) Discover(ctx context.Context, ns string, limit int) ([]Registration, error) {
177  	randomPeer, err := rp.getRandomPeer()
178  	if err != nil {
179  		return nil, err
180  	}
181  
182  	s, err := rp.host.NewStream(ctx, randomPeer, RendezvousID_v001)
183  	if err != nil {
184  		return nil, err
185  	}
186  	defer s.Reset()
187  
188  	r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
189  	w := ggio.NewDelimitedWriter(s)
190  
191  	return rp.discoverQuery(ns, limit, r, w)
192  }
193  
194  func (rp *rendezvousPoint) discoverQuery(ns string, limit int, r ggio.Reader, w ggio.Writer) ([]Registration, error) {
195  	req := newDiscoverMessage(ns, limit)
196  	err := w.WriteMsg(req)
197  	if err != nil {
198  		return nil, err
199  	}
200  
201  	var res pb.Message
202  	err = r.ReadMsg(&res)
203  	if err != nil {
204  		return nil, err
205  	}
206  
207  	if res.GetType() != pb.Message_DISCOVER_RESPONSE {
208  		return nil, fmt.Errorf("unexpected response: %s", res.GetType().String())
209  	}
210  
211  	status := res.GetDiscoverResponse().GetStatus()
212  	if status != pb.Message_OK {
213  		return nil, RendezvousError{Status: status, Text: res.GetDiscoverResponse().GetStatusText()}
214  	}
215  
216  	regs := res.GetDiscoverResponse().GetRegistrations()
217  	result := make([]Registration, 0, len(regs))
218  	for _, reg := range regs {
219  
220  		reg.GetSignedPeerRecord()
221  		envelope, err := record.UnmarshalEnvelope(reg.GetSignedPeerRecord())
222  		if err != nil {
223  			log.Errorf("Invalid peer info: %s", err.Error())
224  			continue
225  		}
226  
227  		cab, ok := peerstore.GetCertifiedAddrBook(rp.host.Peerstore())
228  		if !ok {
229  			return nil, errors.New("a certified addr book is required")
230  		}
231  
232  		_, err = cab.ConsumePeerRecord(envelope, time.Duration(reg.Ttl))
233  		if err != nil {
234  			log.Errorf("Invalid peer info: %s", err.Error())
235  			continue
236  		}
237  
238  		var record peer.PeerRecord
239  		err = envelope.TypedRecord(&record)
240  		if err != nil {
241  			log.Errorf("Invalid peer record: %s", err.Error())
242  			continue
243  		}
244  
245  		result = append(result, Registration{Peer: peer.AddrInfo{ID: record.PeerID, Addrs: record.Addrs}, Ns: reg.GetNs(), Ttl: int(reg.GetTtl())})
246  	}
247  
248  	return result, nil
249  }
250  
251  func (rp *rendezvousPoint) DiscoverAsync(ctx context.Context, ns string) (<-chan Registration, error) {
252  	randomPeer, err := rp.getRandomPeer()
253  	if err != nil {
254  		return nil, err
255  	}
256  
257  	s, err := rp.host.NewStream(ctx, randomPeer, RendezvousID_v001)
258  	if err != nil {
259  		return nil, err
260  	}
261  
262  	ch := make(chan Registration)
263  	go rp.discoverAsync(ctx, ns, s, ch)
264  	return ch, nil
265  }
266  
267  func (rp *rendezvousPoint) discoverAsync(ctx context.Context, ns string, s inet.Stream, ch chan Registration) {
268  	defer s.Reset()
269  	defer close(ch)
270  
271  	r := ggio.NewDelimitedReader(s, inet.MessageSizeMax)
272  	w := ggio.NewDelimitedWriter(s)
273  
274  	const batch = 200
275  
276  	var (
277  		regs []Registration
278  		err  error
279  	)
280  
281  	for {
282  		regs, err = rp.discoverQuery(ns, batch, r, w)
283  		if err != nil {
284  			// TODO robust error recovery
285  			//      - handle closed streams with backoff + new stream
286  			log.Errorf("Error in discovery [%s]: %s", ns, err.Error())
287  			return
288  		}
289  
290  		for _, reg := range regs {
291  			select {
292  			case ch <- reg:
293  			case <-ctx.Done():
294  				return
295  			}
296  		}
297  
298  		if len(regs) < batch {
299  			// TODO adaptive backoff for heavily loaded rendezvous points
300  			select {
301  			case <-time.After(DiscoverAsyncInterval):
302  			case <-ctx.Done():
303  				return
304  			}
305  		}
306  	}
307  }
308  
309  func (rc *rendezvousClient) Discover(ctx context.Context, ns string, limit int) ([]peer.AddrInfo, error) {
310  	regs, err := rc.rp.Discover(ctx, ns, limit)
311  	if err != nil {
312  		return nil, err
313  	}
314  
315  	pinfos := make([]peer.AddrInfo, len(regs))
316  	for i, reg := range regs {
317  		pinfos[i] = reg.Peer
318  	}
319  
320  	return pinfos, nil
321  }
322  
323  func (rc *rendezvousClient) DiscoverAsync(ctx context.Context, ns string) (<-chan peer.AddrInfo, error) {
324  	rch, err := rc.rp.DiscoverAsync(ctx, ns)
325  	if err != nil {
326  		return nil, err
327  	}
328  
329  	ch := make(chan peer.AddrInfo)
330  	go discoverPeersAsync(ctx, rch, ch)
331  	return ch, nil
332  }
333  
334  func discoverPeersAsync(ctx context.Context, rch <-chan Registration, ch chan peer.AddrInfo) {
335  	defer close(ch)
336  	for {
337  		select {
338  		case reg, ok := <-rch:
339  			if !ok {
340  				return
341  			}
342  
343  			select {
344  			case ch <- reg.Peer:
345  			case <-ctx.Done():
346  				return
347  			}
348  		case <-ctx.Done():
349  			return
350  		}
351  	}
352  }