// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 package bandwidthlimiter import ( "errors" "fmt" "io" "sync" "time" log "github.com/sirupsen/logrus" "go.amzn.com/lambda/interop" "go.amzn.com/lambda/metering" ) var ErrBufferSizeTooLarge = errors.New("buffer size cannot be greater than bucket size") func NewBucket(capacity int64, initialTokenCount int64, refillNumber int64, refillInterval time.Duration) (*Bucket, error) { if capacity <= 0 || initialTokenCount < 0 || refillNumber <= 0 || refillInterval <= 0 || capacity < initialTokenCount { errorMsg := fmt.Sprintf("invalid bucket parameters (capacity: %d, initialTokenCount: %d, refillNumber: %d,"+ "refillInterval: %d)", capacity, initialTokenCount, refillInterval, refillInterval) log.Error(errorMsg) return nil, errors.New(errorMsg) } return &Bucket{ capacity: capacity, tokenCount: initialTokenCount, refillNumber: refillNumber, refillInterval: refillInterval, mutex: sync.Mutex{}, }, nil } type Bucket struct { capacity int64 tokenCount int64 refillNumber int64 refillInterval time.Duration mutex sync.Mutex } func (b *Bucket) produceTokens() { b.mutex.Lock() defer b.mutex.Unlock() if b.tokenCount < b.capacity { b.tokenCount = min64(b.tokenCount+b.refillNumber, b.capacity) } } func (b *Bucket) consumeTokens(n int64) bool { b.mutex.Lock() defer b.mutex.Unlock() if n <= b.tokenCount { b.tokenCount -= n return true } return false } func (b *Bucket) getTokenCount() int64 { b.mutex.Lock() defer b.mutex.Unlock() return b.tokenCount } func NewThrottler(bucket *Bucket) (*Throttler, error) { if bucket == nil { errorMsg := "cannot create a throttler with nil bucket" log.Error(errorMsg) return nil, errors.New(errorMsg) } return &Throttler{ b: bucket, running: false, produced: make(chan int64), done: make(chan struct{}), // FIXME: // The runtime tells whether the function response mode is streaming or not. // Ideally, we would want to use that value here. Since I'm just rebasing, I will leave // as-is, but we should use that instead of relying on our memory to set this here // because we "know" it's a streaming code path. metrics: &interop.InvokeResponseMetrics{FunctionResponseMode: interop.FunctionResponseModeStreaming}, }, nil } type Throttler struct { b *Bucket running bool produced chan int64 done chan struct{} metrics *interop.InvokeResponseMetrics } func (th *Throttler) start() { if th.running { return } th.running = true th.metrics.StartReadingResponseMonoTimeMs = metering.Monotime() go func() { ticker := time.NewTicker(th.b.refillInterval) for { select { case <-ticker.C: th.b.produceTokens() select { case th.produced <- metering.Monotime(): default: } case <-th.done: ticker.Stop() return } } }() } func (th *Throttler) stop() { if !th.running { return } th.running = false th.metrics.FinishReadingResponseMonoTimeMs = metering.Monotime() durationMs := (th.metrics.FinishReadingResponseMonoTimeMs - th.metrics.StartReadingResponseMonoTimeMs) / int64(time.Millisecond) if durationMs > 0 { th.metrics.OutboundThroughputBps = (th.metrics.ProducedBytes / durationMs) * int64(time.Second/time.Millisecond) } else { th.metrics.OutboundThroughputBps = -1 } th.done <- struct{}{} } func (th *Throttler) bandwidthLimitingWrite(w io.Writer, p []byte) (written int, err error) { n := int64(len(p)) if n > th.b.capacity { return 0, ErrBufferSizeTooLarge } for { if th.b.consumeTokens(n) { written, err = w.Write(p) th.metrics.ProducedBytes += int64(written) return } waitStart := metering.Monotime() elapsed := <-th.produced - waitStart if elapsed > 0 { th.metrics.TimeShapedNs += elapsed } } }