diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 2a3c2c470..59910f824 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -384,11 +384,7 @@ func New(options *types.Options) (*Runner, error) { if options.RateLimit > 0 && options.RateLimitDuration == 0 { options.RateLimitDuration = time.Second } - if options.RateLimit == 0 && options.RateLimitDuration == 0 { - runner.rateLimiter = ratelimit.NewUnlimited(context.Background()) - } else { - runner.rateLimiter = ratelimit.New(context.Background(), uint(options.RateLimit), options.RateLimitDuration) - } + runner.rateLimiter = utils.GetRateLimiter(context.Background(), options.RateLimit, options.RateLimitDuration) if tmpDir, err := os.MkdirTemp("", "nuclei-tmp-*"); err == nil { runner.tmpDir = tmpDir diff --git a/lib/config.go b/lib/config.go index 2c2a585d9..cdc56ce06 100644 --- a/lib/config.go +++ b/lib/config.go @@ -7,7 +7,7 @@ import ( "github.com/projectdiscovery/goflags" "github.com/projectdiscovery/gologger" - "github.com/projectdiscovery/ratelimit" + "github.com/projectdiscovery/nuclei/v3/pkg/utils" "github.com/projectdiscovery/utils/errkit" "github.com/projectdiscovery/nuclei/v3/pkg/authprovider" @@ -181,7 +181,7 @@ func WithGlobalRateLimitCtx(ctx context.Context, maxTokens int, duration time.Du return func(e *NucleiEngine) error { e.opts.RateLimit = maxTokens e.opts.RateLimitDuration = duration - e.rateLimiter = ratelimit.New(ctx, uint(e.opts.RateLimit), e.opts.RateLimitDuration) + e.rateLimiter = utils.GetRateLimiter(ctx, e.opts.RateLimit, e.opts.RateLimitDuration) return nil } } diff --git a/lib/multi.go b/lib/multi.go index 5c542513c..b6c577587 100644 --- a/lib/multi.go +++ b/lib/multi.go @@ -12,7 +12,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/output" "github.com/projectdiscovery/nuclei/v3/pkg/protocols" "github.com/projectdiscovery/nuclei/v3/pkg/types" - "github.com/projectdiscovery/ratelimit" + "github.com/projectdiscovery/nuclei/v3/pkg/utils" "github.com/projectdiscovery/utils/errkit" "github.com/rs/xid" ) @@ -53,11 +53,7 @@ func createEphemeralObjects(ctx context.Context, base *NucleiEngine, opts *types if opts.RateLimit > 0 && opts.RateLimitDuration == 0 { opts.RateLimitDuration = time.Second } - if opts.RateLimit == 0 && opts.RateLimitDuration == 0 { - u.executerOpts.RateLimiter = ratelimit.NewUnlimited(ctx) - } else { - u.executerOpts.RateLimiter = ratelimit.New(ctx, uint(opts.RateLimit), opts.RateLimitDuration) - } + u.executerOpts.RateLimiter = utils.GetRateLimiter(ctx, opts.RateLimit, opts.RateLimitDuration) u.engine = core.New(opts) u.engine.SetExecuterOptions(u.executerOpts) return u, nil diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index 9a8c669f9..471c0e73a 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -1,14 +1,17 @@ package utils import ( + "context" "errors" "fmt" "io" "net/url" "strings" + "time" "github.com/cespare/xxhash" "github.com/projectdiscovery/nuclei/v3/pkg/catalog" + "github.com/projectdiscovery/ratelimit" "github.com/projectdiscovery/retryablehttp-go" mapsutil "github.com/projectdiscovery/utils/maps" "golang.org/x/exp/constraints" @@ -71,3 +74,12 @@ func MapHash[K constraints.Ordered, V any](m map[K]V) uint64 { } return xxhash.Sum64([]byte(sb.String())) } + +// GetRateLimiter returns a rate limiter with the given max tokens and duration +// if maxTokens is 0 or duration is 0, it returns an unlimited rate limiter +func GetRateLimiter(ctx context.Context, maxTokens int, duration time.Duration) *ratelimit.Limiter { + if maxTokens == 0 || duration == 0 { + return ratelimit.NewUnlimited(ctx) + } + return ratelimit.New(ctx, uint(maxTokens), duration) +}