centralizing ratelimiter logic

This commit is contained in:
Mzack9999 2025-09-12 17:46:42 +02:00
parent 46555bcd1e
commit 089e2a4ee0
4 changed files with 17 additions and 13 deletions

View File

@ -384,11 +384,7 @@ func New(options *types.Options) (*Runner, error) {
if options.RateLimit > 0 && options.RateLimitDuration == 0 { if options.RateLimit > 0 && options.RateLimitDuration == 0 {
options.RateLimitDuration = time.Second options.RateLimitDuration = time.Second
} }
if options.RateLimit == 0 && options.RateLimitDuration == 0 { runner.rateLimiter = utils.GetRateLimiter(context.Background(), options.RateLimit, options.RateLimitDuration)
runner.rateLimiter = ratelimit.NewUnlimited(context.Background())
} else {
runner.rateLimiter = ratelimit.New(context.Background(), uint(options.RateLimit), options.RateLimitDuration)
}
if tmpDir, err := os.MkdirTemp("", "nuclei-tmp-*"); err == nil { if tmpDir, err := os.MkdirTemp("", "nuclei-tmp-*"); err == nil {
runner.tmpDir = tmpDir runner.tmpDir = tmpDir

View File

@ -7,7 +7,7 @@ import (
"github.com/projectdiscovery/goflags" "github.com/projectdiscovery/goflags"
"github.com/projectdiscovery/gologger" "github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/ratelimit" "github.com/projectdiscovery/nuclei/v3/pkg/utils"
"github.com/projectdiscovery/utils/errkit" "github.com/projectdiscovery/utils/errkit"
"github.com/projectdiscovery/nuclei/v3/pkg/authprovider" "github.com/projectdiscovery/nuclei/v3/pkg/authprovider"
@ -181,7 +181,7 @@ func WithGlobalRateLimitCtx(ctx context.Context, maxTokens int, duration time.Du
return func(e *NucleiEngine) error { return func(e *NucleiEngine) error {
e.opts.RateLimit = maxTokens e.opts.RateLimit = maxTokens
e.opts.RateLimitDuration = duration e.opts.RateLimitDuration = duration
e.rateLimiter = ratelimit.New(ctx, uint(e.opts.RateLimit), e.opts.RateLimitDuration) e.rateLimiter = utils.GetRateLimiter(ctx, e.opts.RateLimit, e.opts.RateLimitDuration)
return nil return nil
} }
} }

View File

@ -12,7 +12,7 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/output" "github.com/projectdiscovery/nuclei/v3/pkg/output"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols" "github.com/projectdiscovery/nuclei/v3/pkg/protocols"
"github.com/projectdiscovery/nuclei/v3/pkg/types" "github.com/projectdiscovery/nuclei/v3/pkg/types"
"github.com/projectdiscovery/ratelimit" "github.com/projectdiscovery/nuclei/v3/pkg/utils"
"github.com/projectdiscovery/utils/errkit" "github.com/projectdiscovery/utils/errkit"
"github.com/rs/xid" "github.com/rs/xid"
) )
@ -53,11 +53,7 @@ func createEphemeralObjects(ctx context.Context, base *NucleiEngine, opts *types
if opts.RateLimit > 0 && opts.RateLimitDuration == 0 { if opts.RateLimit > 0 && opts.RateLimitDuration == 0 {
opts.RateLimitDuration = time.Second opts.RateLimitDuration = time.Second
} }
if opts.RateLimit == 0 && opts.RateLimitDuration == 0 { u.executerOpts.RateLimiter = utils.GetRateLimiter(ctx, opts.RateLimit, opts.RateLimitDuration)
u.executerOpts.RateLimiter = ratelimit.NewUnlimited(ctx)
} else {
u.executerOpts.RateLimiter = ratelimit.New(ctx, uint(opts.RateLimit), opts.RateLimitDuration)
}
u.engine = core.New(opts) u.engine = core.New(opts)
u.engine.SetExecuterOptions(u.executerOpts) u.engine.SetExecuterOptions(u.executerOpts)
return u, nil return u, nil

View File

@ -1,14 +1,17 @@
package utils package utils
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/url" "net/url"
"strings" "strings"
"time"
"github.com/cespare/xxhash" "github.com/cespare/xxhash"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog" "github.com/projectdiscovery/nuclei/v3/pkg/catalog"
"github.com/projectdiscovery/ratelimit"
"github.com/projectdiscovery/retryablehttp-go" "github.com/projectdiscovery/retryablehttp-go"
mapsutil "github.com/projectdiscovery/utils/maps" mapsutil "github.com/projectdiscovery/utils/maps"
"golang.org/x/exp/constraints" "golang.org/x/exp/constraints"
@ -71,3 +74,12 @@ func MapHash[K constraints.Ordered, V any](m map[K]V) uint64 {
} }
return xxhash.Sum64([]byte(sb.String())) return xxhash.Sum64([]byte(sb.String()))
} }
// GetRateLimiter returns a rate limiter with the given max tokens and duration
// if maxTokens is 0 or duration is 0, it returns an unlimited rate limiter
func GetRateLimiter(ctx context.Context, maxTokens int, duration time.Duration) *ratelimit.Limiter {
if maxTokens == 0 || duration == 0 {
return ratelimit.NewUnlimited(ctx)
}
return ratelimit.New(ctx, uint(maxTokens), duration)
}