Fix ExecuteCallbackWithCtx to use the context that was provided (#5236)

* Fix `ExecuteCallbackWithCtx` to use the context that was provided

This updates `ExecuteCallbackWithCtx` to use the context that was
provided.

* remove more hardcoded context

---------

Co-authored-by: Tarun Koyalwar <tarun@projectdiscovery.io>
This commit is contained in:
Douglas Danger Manley 2024-05-30 06:34:15 -04:00 committed by GitHub
parent 4ae0b39f53
commit 8011012c42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 40 additions and 21 deletions

View File

@ -164,11 +164,17 @@ func WithConcurrency(opts Concurrency) NucleiSDKOptions {
} }
// WithGlobalRateLimit sets global rate (i.e all hosts combined) limit options // WithGlobalRateLimit sets global rate (i.e all hosts combined) limit options
// Deprecated: will be removed in favour of WithGlobalRateLimitCtx in next release
func WithGlobalRateLimit(maxTokens int, duration time.Duration) NucleiSDKOptions { func WithGlobalRateLimit(maxTokens int, duration time.Duration) NucleiSDKOptions {
return WithGlobalRateLimitCtx(context.Background(), maxTokens, duration)
}
// WithGlobalRateLimitCtx allows setting a global rate limit for the entire engine
func WithGlobalRateLimitCtx(ctx context.Context, maxTokens int, duration time.Duration) NucleiSDKOptions {
return func(e *NucleiEngine) error { return func(e *NucleiEngine) error {
e.opts.RateLimit = maxTokens e.opts.RateLimit = maxTokens
e.opts.RateLimitDuration = duration e.opts.RateLimitDuration = duration
e.rateLimiter = ratelimit.New(context.Background(), uint(e.opts.RateLimit), e.opts.RateLimitDuration) e.rateLimiter = ratelimit.New(ctx, uint(e.opts.RateLimit), e.opts.RateLimitDuration)
return nil return nil
} }
} }

View File

@ -26,7 +26,7 @@ type unsafeOptions struct {
} }
// createEphemeralObjects creates ephemeral nuclei objects/instances/types // createEphemeralObjects creates ephemeral nuclei objects/instances/types
func createEphemeralObjects(base *NucleiEngine, opts *types.Options) (*unsafeOptions, error) { func createEphemeralObjects(ctx context.Context, base *NucleiEngine, opts *types.Options) (*unsafeOptions, error) {
u := &unsafeOptions{} u := &unsafeOptions{}
u.executerOpts = protocols.ExecutorOptions{ u.executerOpts = protocols.ExecutorOptions{
Output: base.customWriter, Output: base.customWriter,
@ -49,9 +49,9 @@ func createEphemeralObjects(base *NucleiEngine, opts *types.Options) (*unsafeOpt
opts.RateLimitDuration = time.Second opts.RateLimitDuration = time.Second
} }
if opts.RateLimit == 0 && opts.RateLimitDuration == 0 { if opts.RateLimit == 0 && opts.RateLimitDuration == 0 {
u.executerOpts.RateLimiter = ratelimit.NewUnlimited(context.Background()) u.executerOpts.RateLimiter = ratelimit.NewUnlimited(ctx)
} else { } else {
u.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(opts.RateLimit), opts.RateLimitDuration) u.executerOpts.RateLimiter = ratelimit.New(ctx, uint(opts.RateLimit), opts.RateLimitDuration)
} }
u.engine = core.New(opts) u.engine = core.New(opts)
u.engine.SetExecuterOptions(u.executerOpts) u.engine.SetExecuterOptions(u.executerOpts)
@ -83,7 +83,7 @@ type ThreadSafeNucleiEngine struct {
// NewThreadSafeNucleiEngine creates a new nuclei engine with given options // NewThreadSafeNucleiEngine creates a new nuclei engine with given options
// whose methods are thread-safe and can be used concurrently // whose methods are thread-safe and can be used concurrently
// Note: Non-thread-safe methods start with Global prefix // Note: Non-thread-safe methods start with Global prefix
func NewThreadSafeNucleiEngine(opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngine, error) { func NewThreadSafeNucleiEngineCtx(ctx context.Context, opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngine, error) {
// default options // default options
e := &NucleiEngine{ e := &NucleiEngine{
opts: types.DefaultOptions(), opts: types.DefaultOptions(),
@ -94,12 +94,17 @@ func NewThreadSafeNucleiEngine(opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngin
return nil, err return nil, err
} }
} }
if err := e.init(); err != nil { if err := e.init(ctx); err != nil {
return nil, err return nil, err
} }
return &ThreadSafeNucleiEngine{eng: e}, nil return &ThreadSafeNucleiEngine{eng: e}, nil
} }
// Deprecated: use NewThreadSafeNucleiEngineCtx instead
func NewThreadSafeNucleiEngine(opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngine, error) {
return NewThreadSafeNucleiEngineCtx(context.Background(), opts...)
}
// GlobalLoadAllTemplates loads all templates from nuclei-templates repo // GlobalLoadAllTemplates loads all templates from nuclei-templates repo
// This method will load all templates based on filters given at the time of nuclei engine creation in opts // This method will load all templates based on filters given at the time of nuclei engine creation in opts
func (e *ThreadSafeNucleiEngine) GlobalLoadAllTemplates() error { func (e *ThreadSafeNucleiEngine) GlobalLoadAllTemplates() error {
@ -124,7 +129,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOptsCtx(ctx context.Context, t
} }
} }
// create ephemeral nuclei objects/instances/types using base nuclei engine // create ephemeral nuclei objects/instances/types using base nuclei engine
unsafeOpts, err := createEphemeralObjects(e.eng, tmpEngine.opts) unsafeOpts, err := createEphemeralObjects(ctx, e.eng, tmpEngine.opts)
if err != nil { if err != nil {
return err return err
} }
@ -156,7 +161,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOptsCtx(ctx context.Context, t
engine := core.New(tmpEngine.opts) engine := core.New(tmpEngine.opts)
engine.SetExecuterOptions(unsafeOpts.executerOpts) engine.SetExecuterOptions(unsafeOpts.executerOpts)
_ = engine.ExecuteScanWithOpts(context.Background(), store.Templates(), inputProvider, false) _ = engine.ExecuteScanWithOpts(ctx, store.Templates(), inputProvider, false)
engine.WorkPool().Wait() engine.WorkPool().Wait()
return nil return nil

View File

@ -240,7 +240,7 @@ func (e *NucleiEngine) ExecuteCallbackWithCtx(ctx context.Context, callback ...f
} }
e.resultCallbacks = append(e.resultCallbacks, filtered...) e.resultCallbacks = append(e.resultCallbacks, filtered...)
_ = e.engine.ExecuteScanWithOpts(context.Background(), e.store.Templates(), e.inputProvider, false) _ = e.engine.ExecuteScanWithOpts(ctx, e.store.Templates(), e.inputProvider, false)
defer e.engine.WorkPool().Wait() defer e.engine.WorkPool().Wait()
return nil return nil
} }
@ -261,8 +261,8 @@ func (e *NucleiEngine) Engine() *core.Engine {
return e.engine return e.engine
} }
// NewNucleiEngine creates a new nuclei engine instance // NewNucleiEngineCtx creates a new nuclei engine instance with given context
func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) { func NewNucleiEngineCtx(ctx context.Context, options ...NucleiSDKOptions) (*NucleiEngine, error) {
// default options // default options
e := &NucleiEngine{ e := &NucleiEngine{
opts: types.DefaultOptions(), opts: types.DefaultOptions(),
@ -273,8 +273,13 @@ func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) {
return nil, err return nil, err
} }
} }
if err := e.init(); err != nil { if err := e.init(ctx); err != nil {
return nil, err return nil, err
} }
return e, nil return e, nil
} }
// Deprecated: use NewNucleiEngineCtx instead
func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) {
return NewNucleiEngineCtx(context.Background(), options...)
}

View File

@ -37,7 +37,7 @@ import (
var sharedInit sync.Once = sync.Once{} var sharedInit sync.Once = sync.Once{}
// applyRequiredDefaults to options // applyRequiredDefaults to options
func (e *NucleiEngine) applyRequiredDefaults() { func (e *NucleiEngine) applyRequiredDefaults(ctx context.Context) {
mockoutput := testutils.NewMockOutputWriter(e.opts.OmitTemplate) mockoutput := testutils.NewMockOutputWriter(e.opts.OmitTemplate)
mockoutput.WriteCallback = func(event *output.ResultEvent) { mockoutput.WriteCallback = func(event *output.ResultEvent) {
if len(e.resultCallbacks) > 0 { if len(e.resultCallbacks) > 0 {
@ -81,7 +81,7 @@ func (e *NucleiEngine) applyRequiredDefaults() {
e.interactshOpts = interactsh.DefaultOptions(e.customWriter, e.rc, e.customProgress) e.interactshOpts = interactsh.DefaultOptions(e.customWriter, e.rc, e.customProgress)
} }
if e.rateLimiter == nil { if e.rateLimiter == nil {
e.rateLimiter = ratelimit.New(context.Background(), 150, time.Second) e.rateLimiter = ratelimit.New(ctx, 150, time.Second)
} }
if e.opts.ExcludeTags == nil { if e.opts.ExcludeTags == nil {
e.opts.ExcludeTags = []string{} e.opts.ExcludeTags = []string{}
@ -94,7 +94,7 @@ func (e *NucleiEngine) applyRequiredDefaults() {
} }
// init // init
func (e *NucleiEngine) init() error { func (e *NucleiEngine) init(ctx context.Context) error {
if e.opts.Verbose { if e.opts.Verbose {
gologger.DefaultLogger.SetMaxLevel(levels.LevelVerbose) gologger.DefaultLogger.SetMaxLevel(levels.LevelVerbose)
} else if e.opts.Debug { } else if e.opts.Debug {
@ -121,7 +121,7 @@ func (e *NucleiEngine) init() error {
_ = protocolinit.Init(e.opts) _ = protocolinit.Init(e.opts)
}) })
e.applyRequiredDefaults() e.applyRequiredDefaults(ctx)
var err error var err error
// setup progressbar // setup progressbar
@ -204,9 +204,9 @@ func (e *NucleiEngine) init() error {
e.opts.RateLimitDuration = time.Second e.opts.RateLimitDuration = time.Second
} }
if e.opts.RateLimit == 0 && e.opts.RateLimitDuration == 0 { if e.opts.RateLimit == 0 && e.opts.RateLimitDuration == 0 {
e.executerOpts.RateLimiter = ratelimit.NewUnlimited(context.Background()) e.executerOpts.RateLimiter = ratelimit.NewUnlimited(ctx)
} else { } else {
e.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(e.opts.RateLimit), e.opts.RateLimitDuration) e.executerOpts.RateLimiter = ratelimit.New(ctx, uint(e.opts.RateLimit), e.opts.RateLimitDuration)
} }
} }

View File

@ -1,6 +1,7 @@
package sdk_test package sdk_test
import ( import (
"context"
"os" "os"
"os/exec" "os/exec"
"testing" "testing"
@ -28,7 +29,8 @@ func TestSimpleNuclei(t *testing.T) {
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
goleak.VerifyNone(t, knownLeaks...) goleak.VerifyNone(t, knownLeaks...)
}() }()
ne, err := nuclei.NewNucleiEngine( ne, err := nuclei.NewNucleiEngineCtx(
context.TODO(),
nuclei.WithTemplateFilters(nuclei.TemplateFilters{ProtocolTypes: "dns"}), // filter dns templates nuclei.WithTemplateFilters(nuclei.TemplateFilters{ProtocolTypes: "dns"}), // filter dns templates
nuclei.EnableStatsWithOpts(nuclei.StatsOptions{JSON: true}), nuclei.EnableStatsWithOpts(nuclei.StatsOptions{JSON: true}),
) )
@ -62,7 +64,8 @@ func TestSimpleNucleiRemote(t *testing.T) {
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
goleak.VerifyNone(t, knownLeaks...) goleak.VerifyNone(t, knownLeaks...)
}() }()
ne, err := nuclei.NewNucleiEngine( ne, err := nuclei.NewNucleiEngineCtx(
context.TODO(),
nuclei.WithTemplatesOrWorkflows( nuclei.WithTemplatesOrWorkflows(
nuclei.TemplateSources{ nuclei.TemplateSources{
RemoteTemplates: []string{"https://cloud.projectdiscovery.io/public/nameserver-fingerprint.yaml"}, RemoteTemplates: []string{"https://cloud.projectdiscovery.io/public/nameserver-fingerprint.yaml"},
@ -100,7 +103,7 @@ func TestThreadSafeNuclei(t *testing.T) {
goleak.VerifyNone(t, knownLeaks...) goleak.VerifyNone(t, knownLeaks...)
}() }()
// create nuclei engine with options // create nuclei engine with options
ne, err := nuclei.NewThreadSafeNucleiEngine() ne, err := nuclei.NewThreadSafeNucleiEngineCtx(context.TODO())
require.Nil(t, err) require.Nil(t, err)
// scan 1 = run dns templates on scanme.sh // scan 1 = run dns templates on scanme.sh