package ratelimit

import (
	"context"
	"errors"
	"fmt"
	"strings"
	"testing"
)

func TestTokenRateLimit(t *testing.T) {
	type usage struct {
		Cost      uint
		Release   bool
		Err       string
		AddTokens uint
	}

	cases := map[string]struct {
		Tokens uint
		Usages []usage
	}{
		"retrieve": {
			Tokens: 10,
			Usages: []usage{
				{Cost: 5, Release: true},
				{Cost: 5},
				{Cost: 5},
				{Cost: 5, Err: "retry quota exceeded"},
				{AddTokens: 5},
				{Cost: 5},
			},
		},
	}

	for name, c := range cases {
		t.Run(name, func(t *testing.T) {
			rl := NewTokenRateLimit(c.Tokens)

			for i, u := range c.Usages {
				t.Run(fmt.Sprintf("usage_%d", i), func(t *testing.T) {
					if u.Cost != 0 {
						f, err := rl.GetToken(context.Background(), u.Cost)
						if len(u.Err) != 0 {
							if err == nil {
								t.Fatalf("expect error, got none")
							}
							if e, a := u.Err, err.Error(); !strings.Contains(a, e) {
								t.Fatalf("expect %q error, got %q", e, a)
							}
						} else if err != nil {
							t.Fatalf("expect no error, got %v", err)
						}

						if u.Release {
							if err := f(); err != nil {
								t.Fatalf("expect no error, got %v", err)
							}
						}
					}

					if u.AddTokens != 0 {
						rl.AddTokens(u.AddTokens)
					}
				})
			}
		})
	}
}

func TestTokenRateLimit_canceled(t *testing.T) {
	rl := NewTokenRateLimit(10)

	ctx, cancel := context.WithCancel(context.Background())
	cancel()

	fn, err := rl.GetToken(ctx, 1)
	if err == nil {
		t.Fatalf("expect error, got none")
	}
	if fn != nil {
		t.Errorf("expect no release func returned")
	}

	var v interface{ CanceledError() bool }
	if !errors.As(err, &v) {
		t.Fatalf("expect %T error, got %v", v, err)
	}
}