plugai_updsrv/deps/github.com/anacrolix/torrent/tracker/udp/client.go

226 lines
5.4 KiB
Go

package udp
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
"net"
"sync"
"time"
"github.com/anacrolix/dht/v2/krpc"
)
// Client interacts with UDP trackers via its Writer and Dispatcher. It has no knowledge of
// connection specifics.
type Client struct {
mu sync.Mutex
connId ConnectionId
connIdIssued time.Time
shouldReconnectOverride func() bool
Dispatcher *Dispatcher
Writer io.Writer
}
func (cl *Client) Announce(
ctx context.Context, req AnnounceRequest, opts Options,
// Decides whether the response body is IPv6 or IPv4, see BEP 15.
ipv6 func(net.Addr) bool,
) (
respHdr AnnounceResponseHeader,
// A slice of krpc.NodeAddr, likely wrapped in an appropriate unmarshalling wrapper.
peers AnnounceResponsePeers,
err error,
) {
respBody, addr, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
if err != nil {
return
}
r := bytes.NewBuffer(respBody)
err = Read(r, &respHdr)
if err != nil {
err = fmt.Errorf("reading response header: %w", err)
return
}
if ipv6(addr) {
peers = &krpc.CompactIPv6NodeAddrs{}
} else {
peers = &krpc.CompactIPv4NodeAddrs{}
}
err = peers.UnmarshalBinary(r.Bytes())
if err != nil {
err = fmt.Errorf("reading response peers: %w", err)
}
return
}
// There's no way to pass options in a scrape, since we don't when the request body ends.
func (cl *Client) Scrape(
ctx context.Context, ihs []InfoHash,
) (
out ScrapeResponse, err error,
) {
respBody, _, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
if err != nil {
return
}
r := bytes.NewBuffer(respBody)
for r.Len() != 0 {
var item ScrapeInfohashResult
err = Read(r, &item)
if err != nil {
return
}
out = append(out, item)
}
if len(out) > len(ihs) {
err = fmt.Errorf("got %v results but expected %v", len(out), len(ihs))
return
}
return
}
func (cl *Client) shouldReconnectDefault() bool {
return cl.connIdIssued.IsZero() || time.Since(cl.connIdIssued) >= time.Minute
}
func (cl *Client) shouldReconnect() bool {
if cl.shouldReconnectOverride != nil {
return cl.shouldReconnectOverride()
}
return cl.shouldReconnectDefault()
}
func (cl *Client) connect(ctx context.Context) (err error) {
if !cl.shouldReconnect() {
return nil
}
return cl.doConnectRoundTrip(ctx)
}
// This just does the connect request and updates local state if it succeeds.
func (cl *Client) doConnectRoundTrip(ctx context.Context) (err error) {
respBody, _, err := cl.request(ctx, ActionConnect, nil)
if err != nil {
return err
}
var connResp ConnectionResponse
err = binary.Read(bytes.NewReader(respBody), binary.BigEndian, &connResp)
if err != nil {
return
}
cl.connId = connResp.ConnectionId
cl.connIdIssued = time.Now()
//log.Printf("conn id set to %x", cl.connId)
return
}
func (cl *Client) connIdForRequest(ctx context.Context, action Action) (id ConnectionId, err error) {
if action == ActionConnect {
id = ConnectRequestConnectionId
return
}
err = cl.connect(ctx)
if err != nil {
return
}
id = cl.connId
return
}
func (cl *Client) writeRequest(
ctx context.Context, action Action, body []byte, tId TransactionId, buf *bytes.Buffer,
) (
err error,
) {
var connId ConnectionId
if action == ActionConnect {
connId = ConnectRequestConnectionId
} else {
// We lock here while establishing a connection ID, and then ensuring that the request is
// written before allowing the connection ID to change again. This is to ensure the server
// doesn't assign us another ID before we've sent this request. Note that this doesn't allow
// for us to return if the context is cancelled while we wait to obtain a new ID.
cl.mu.Lock()
defer cl.mu.Unlock()
connId, err = cl.connIdForRequest(ctx, action)
if err != nil {
return
}
}
buf.Reset()
err = Write(buf, RequestHeader{
ConnectionId: connId,
Action: action,
TransactionId: tId,
})
if err != nil {
panic(err)
}
buf.Write(body)
_, err = cl.Writer.Write(buf.Bytes())
//log.Printf("sent request with conn id %x", connId)
return
}
func (cl *Client) requestWriter(ctx context.Context, action Action, body []byte, tId TransactionId) (err error) {
var buf bytes.Buffer
for n := 0; ; n++ {
err = cl.writeRequest(ctx, action, body, tId, &buf)
if err != nil {
return
}
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(timeout(n)):
}
}
}
const ConnectionIdMissmatchNul = "Connection ID missmatch.\x00"
type ErrorResponse struct {
Message string
}
func (me ErrorResponse) Error() string {
return fmt.Sprintf("error response: %#q", me.Message)
}
func (cl *Client) request(ctx context.Context, action Action, body []byte) (respBody []byte, addr net.Addr, err error) {
respChan := make(chan DispatchedResponse, 1)
t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
respChan <- dr
})
defer t.End()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
writeErr := make(chan error, 1)
go func() {
writeErr <- cl.requestWriter(ctx, action, body, t.Id())
}()
select {
case dr := <-respChan:
if dr.Header.Action == action {
respBody = dr.Body
addr = dr.Addr
} else if dr.Header.Action == ActionError {
// udp://tracker.torrent.eu.org:451/announce frequently returns "Connection ID
// missmatch.\x00"
err = ErrorResponse{Message: string(dr.Body)}
} else {
err = fmt.Errorf("unexpected response action %v", dr.Header.Action)
}
case err = <-writeErr:
err = fmt.Errorf("write error: %w", err)
case <-ctx.Done():
err = ctx.Err()
}
return
}