diff --git a/cmd/msak-client/client.go b/cmd/msak-client/client.go index 422bdc2..cc5f4e4 100644 --- a/cmd/msak-client/client.go +++ b/cmd/msak-client/client.go @@ -15,15 +15,16 @@ const clientName = "msak-client-go" var clientVersion = version.Version var ( - flagServer = flag.String("server", "", "Server address") - flagStreams = flag.Int("streams", client.DefaultStreams, "Number of streams") - flagCC = flag.String("cc", "bbr", "Congestion control algorithm to use") - flagDelay = flag.Duration("delay", 0, "Delay between each stream") - flagDuration = flag.Duration("duration", client.DefaultLength, "Length of the last stream") - flagScheme = flag.String("scheme", client.DefaultScheme, "Websocket scheme (wss or ws)") - flagMID = flag.String("mid", uuid.NewString(), "Measurement ID to use") - flagNoVerify = flag.Bool("no-verify", false, "Skip TLS certificate verification") - flagDebug = flag.Bool("debug", false, "Enable debug logging") + flagServer = flag.String("server", "", "Server address") + flagStreams = flag.Int("streams", client.DefaultStreams, "Number of streams") + flagCC = flag.String("cc", "bbr", "Congestion control algorithm to use") + flagDelay = flag.Duration("delay", 0, "Delay between each stream") + flagDuration = flag.Duration("duration", client.DefaultLength, "Length of the last stream") + flagScheme = flag.String("scheme", client.DefaultScheme, "Websocket scheme (wss or ws)") + flagMID = flag.String("mid", uuid.NewString(), "Measurement ID to use") + flagNoVerify = flag.Bool("no-verify", false, "Skip TLS certificate verification") + flagDebug = flag.Bool("debug", false, "Enable debug logging") + flagByteLimit = flag.Int("bytes", 0, "Byte limit to request to the server") ) func main() { @@ -46,7 +47,8 @@ func main() { Emitter: client.HumanReadable{ Debug: *flagDebug, }, - NoVerify: *flagNoVerify, + NoVerify: *flagNoVerify, + ByteLimit: *flagByteLimit, } cl := client.New(clientName, clientVersion, config) diff --git a/internal/handler/handler.go b/internal/handler/handler.go index 65e10ae..b90bde3 100644 --- a/internal/handler/handler.go +++ b/internal/handler/handler.go @@ -16,6 +16,7 @@ import ( "github.com/m-lab/msak/internal/persistence" "github.com/m-lab/msak/pkg/throughput1" "github.com/m-lab/msak/pkg/throughput1/model" + "github.com/m-lab/msak/pkg/throughput1/spec" "github.com/m-lab/msak/pkg/version" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -130,6 +131,20 @@ func (h *Handler) upgradeAndRunMeasurement(kind model.TestDirection, rw http.Res model.NameValue{Name: "delay", Value: requestDelay}) } + requestByteLimit := query.Get(spec.ByteLimitParameterName) + var byteLimit int + if requestByteLimit != "" { + if byteLimit, err = strconv.Atoi(requestByteLimit); err != nil { + ClientConnections.WithLabelValues(string(kind), "invalid-byte-limit").Inc() + log.Info("Received request with an invalid byte limit", "source", req.RemoteAddr, + "value", requestByteLimit) + writeBadRequest(rw) + return + } + clientOptions = append(clientOptions, + model.NameValue{Name: spec.ByteLimitParameterName, Value: requestByteLimit}) + } + // Read metadata (i.e. everything in the querystring that's not a known // option). metadata, err := getRequestMetadata(req) @@ -198,6 +213,7 @@ func (h *Handler) upgradeAndRunMeasurement(kind model.TestDirection, rw http.Res defer cancel() proto := throughput1.New(wsConn) + proto.SetByteLimit(byteLimit) var senderCh, receiverCh <-chan model.WireMeasurement var errCh <-chan error if kind == model.DirectionDownload { diff --git a/internal/handler/handler_test.go b/internal/handler/handler_test.go index 60a9ea6..f0ae4c0 100644 --- a/internal/handler/handler_test.go +++ b/internal/handler/handler_test.go @@ -215,6 +215,11 @@ func TestHandler_Validation(t *testing.T) { target: "/?mid=test&streams=2&duration=invalid", statusCode: http.StatusBadRequest, }, + { + name: "invalid byte limit", + target: "/?mid=test&streams=2&duration=1000&bytes=invalid", + statusCode: http.StatusBadRequest, + }, { name: "metadata key too long", target: "/?mid=test&streams=2&" + longKey, diff --git a/pkg/client/client.go b/pkg/client/client.go index 22585fb..edb8be5 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -136,6 +136,7 @@ func (c *Throughput1Client) connect(ctx context.Context, serviceURL *url.URL) (* q := serviceURL.Query() q.Set("streams", fmt.Sprint(c.config.NumStreams)) q.Set("cc", c.config.CongestionControl) + q.Set(spec.ByteLimitParameterName, fmt.Sprint(c.config.ByteLimit)) q.Set("duration", fmt.Sprintf("%d", c.config.Length.Milliseconds())) q.Set("client_arch", runtime.GOARCH) q.Set("client_library_name", libraryName) diff --git a/pkg/client/config.go b/pkg/client/config.go index 5a89ac2..e55416c 100644 --- a/pkg/client/config.go +++ b/pkg/client/config.go @@ -35,4 +35,8 @@ type Config struct { // NoVerify disables the TLS certificate verification. NoVerify bool + + // ByteLimit is the maximum number of bytes to download or upload. If set to 0, the + // limit is disabled. + ByteLimit int } diff --git a/pkg/throughput1/protocol.go b/pkg/throughput1/protocol.go index 1800819..43df87f 100644 --- a/pkg/throughput1/protocol.go +++ b/pkg/throughput1/protocol.go @@ -47,6 +47,8 @@ type Protocol struct { applicationBytesReceived atomic.Int64 applicationBytesSent atomic.Int64 + + byteLimit int } // New returns a new Protocol with the specified connection and every other @@ -61,6 +63,12 @@ func New(conn *websocket.Conn) *Protocol { } } +// SetByteLimit sets the number of bytes sent after which a test (either download or upload) will stop. +// Set the value to zero to disable the byte limit. +func (p *Protocol) SetByteLimit(value int) { + p.byteLimit = value +} + // Upgrade takes a HTTP request and upgrades the connection to WebSocket. // Returns a websocket Conn if the upgrade succeeded, and an error otherwise. func Upgrade(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { @@ -180,6 +188,7 @@ func (p *Protocol) receiver(ctx context.Context, func (p *Protocol) sendCounterflow(ctx context.Context, measurerCh <-chan model.Measurement, results chan<- model.WireMeasurement, errCh chan<- error) { + byteLimit := int64(p.byteLimit) for { select { case <-ctx.Done(): @@ -218,13 +227,19 @@ func (p *Protocol) sendCounterflow(ctx context.Context, case results <- wm: default: } + + // End the test once enough bytes have been received. + if byteLimit > 0 && m.TCPInfo != nil && m.TCPInfo.BytesReceived >= byteLimit { + p.close(ctx) + return + } } } } func (p *Protocol) sender(ctx context.Context, measurerCh <-chan model.Measurement, results chan<- model.WireMeasurement, errCh chan<- error) { - size := spec.MinMessageSize + size := p.ScaleMessage(spec.MinMessageSize, 0) message, err := p.makePreparedMessage(size) if err != nil { log.Printf("makePreparedMessage failed (ctx: %p)", ctx) @@ -283,27 +298,39 @@ func (p *Protocol) sender(ctx context.Context, measurerCh <-chan model.Measureme } p.applicationBytesSent.Add(int64(size)) - // Determine whether it's time to scale the message size. - if size >= spec.MaxScaledMessageSize { - continue + bytesSent := int(p.applicationBytesSent.Load()) + if p.byteLimit > 0 && bytesSent >= p.byteLimit { + p.close(ctx) + return } - if size > int(p.applicationBytesSent.Load())/spec.ScalingFraction { + // Determine whether it's time to scale the message size. + if size >= spec.MaxScaledMessageSize || size > bytesSent/spec.ScalingFraction { + size = p.ScaleMessage(size, bytesSent) continue } - size *= 2 + size = p.ScaleMessage(size*2, bytesSent) message, err = p.makePreparedMessage(size) if err != nil { log.Printf("failed to make prepared message (ctx: %p, err: %v)", ctx, err) errCh <- err return } - } } } +// ScaleMessage sets the binary message size taking into consideration byte limits. +func (p *Protocol) ScaleMessage(msgSize int, bytesSent int) int { + // Check if the next payload size will push the total number of bytes over the limit. + excess := bytesSent + msgSize - p.byteLimit + if p.byteLimit > 0 && excess > 0 { + msgSize -= excess + } + return msgSize +} + func (p *Protocol) close(ctx context.Context) { msg := websocket.FormatCloseMessage( websocket.CloseNormalClosure, "Done sending") diff --git a/pkg/throughput1/protocol_test.go b/pkg/throughput1/protocol_test.go index 6180fca..37c0e5d 100644 --- a/pkg/throughput1/protocol_test.go +++ b/pkg/throughput1/protocol_test.go @@ -137,3 +137,51 @@ func TestProtocol_Download(t *testing.T) { } } } + +func TestProtocol_ScaleMessage(t *testing.T) { + tests := []struct { + name string + byteLimit int + msgSize int + bytesSent int + want int + }{ + { + name: "no-limit", + byteLimit: 0, + msgSize: 10, + bytesSent: 100, + want: 10, + }, + { + name: "under-limit", + byteLimit: 200, + msgSize: 10, + bytesSent: 100, + want: 10, + }, + { + name: "at-limit", + byteLimit: 110, + msgSize: 10, + bytesSent: 100, + want: 10, + }, + { + name: "over-limit", + byteLimit: 110, + msgSize: 20, + bytesSent: 100, + want: 10, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &throughput1.Protocol{} + p.SetByteLimit(tt.byteLimit) + if got := p.ScaleMessage(tt.msgSize, tt.bytesSent); got != tt.want { + t.Errorf("Protocol.ScaleMessage() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/throughput1/spec/spec.go b/pkg/throughput1/spec/spec.go index cb966b8..4a00424 100644 --- a/pkg/throughput1/spec/spec.go +++ b/pkg/throughput1/spec/spec.go @@ -37,6 +37,11 @@ const ( // SecWebSocketProtocol is the value of the Sec-WebSocket-Protocol header. SecWebSocketProtocol = "net.measurementlab.throughput.v1" + + // ByteLimitParameterName is the name of the parameter that clients can use + // to terminate throughput1 download tests once the test has transferred + // the specified number of bytes. + ByteLimitParameterName = "bytes" ) // SubtestKind indicates the subtest kind