Fix ExecuteCallbackWithCtx to use the context that was provided (#5236)

* Fix `ExecuteCallbackWithCtx` to use the context that was provided

This updates `ExecuteCallbackWithCtx` to use the context that was
provided.

* remove more hardcoded context

---------

Co-authored-by: Tarun Koyalwar <tarun@projectdiscovery.io>
This commit is contained in:
Douglas Danger Manley 2024-05-30 06:34:15 -04:00 committed by GitHub
parent 4ae0b39f53
commit 8011012c42
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 40 additions and 21 deletions

View File

@ -164,11 +164,17 @@ func WithConcurrency(opts Concurrency) NucleiSDKOptions {
}
// WithGlobalRateLimit sets global rate (i.e all hosts combined) limit options
// Deprecated: will be removed in favour of WithGlobalRateLimitCtx in next release
func WithGlobalRateLimit(maxTokens int, duration time.Duration) NucleiSDKOptions {
return WithGlobalRateLimitCtx(context.Background(), maxTokens, duration)
}
// WithGlobalRateLimitCtx allows setting a global rate limit for the entire engine
func WithGlobalRateLimitCtx(ctx context.Context, maxTokens int, duration time.Duration) NucleiSDKOptions {
return func(e *NucleiEngine) error {
e.opts.RateLimit = maxTokens
e.opts.RateLimitDuration = duration
e.rateLimiter = ratelimit.New(context.Background(), uint(e.opts.RateLimit), e.opts.RateLimitDuration)
e.rateLimiter = ratelimit.New(ctx, uint(e.opts.RateLimit), e.opts.RateLimitDuration)
return nil
}
}

View File

@ -26,7 +26,7 @@ type unsafeOptions struct {
}
// createEphemeralObjects creates ephemeral nuclei objects/instances/types
func createEphemeralObjects(base *NucleiEngine, opts *types.Options) (*unsafeOptions, error) {
func createEphemeralObjects(ctx context.Context, base *NucleiEngine, opts *types.Options) (*unsafeOptions, error) {
u := &unsafeOptions{}
u.executerOpts = protocols.ExecutorOptions{
Output: base.customWriter,
@ -49,9 +49,9 @@ func createEphemeralObjects(base *NucleiEngine, opts *types.Options) (*unsafeOpt
opts.RateLimitDuration = time.Second
}
if opts.RateLimit == 0 && opts.RateLimitDuration == 0 {
u.executerOpts.RateLimiter = ratelimit.NewUnlimited(context.Background())
u.executerOpts.RateLimiter = ratelimit.NewUnlimited(ctx)
} else {
u.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(opts.RateLimit), opts.RateLimitDuration)
u.executerOpts.RateLimiter = ratelimit.New(ctx, uint(opts.RateLimit), opts.RateLimitDuration)
}
u.engine = core.New(opts)
u.engine.SetExecuterOptions(u.executerOpts)
@ -83,7 +83,7 @@ type ThreadSafeNucleiEngine struct {
// NewThreadSafeNucleiEngine creates a new nuclei engine with given options
// whose methods are thread-safe and can be used concurrently
// Note: Non-thread-safe methods start with Global prefix
func NewThreadSafeNucleiEngine(opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngine, error) {
func NewThreadSafeNucleiEngineCtx(ctx context.Context, opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngine, error) {
// default options
e := &NucleiEngine{
opts: types.DefaultOptions(),
@ -94,12 +94,17 @@ func NewThreadSafeNucleiEngine(opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngin
return nil, err
}
}
if err := e.init(); err != nil {
if err := e.init(ctx); err != nil {
return nil, err
}
return &ThreadSafeNucleiEngine{eng: e}, nil
}
// Deprecated: use NewThreadSafeNucleiEngineCtx instead
func NewThreadSafeNucleiEngine(opts ...NucleiSDKOptions) (*ThreadSafeNucleiEngine, error) {
return NewThreadSafeNucleiEngineCtx(context.Background(), opts...)
}
// GlobalLoadAllTemplates loads all templates from nuclei-templates repo
// This method will load all templates based on filters given at the time of nuclei engine creation in opts
func (e *ThreadSafeNucleiEngine) GlobalLoadAllTemplates() error {
@ -124,7 +129,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOptsCtx(ctx context.Context, t
}
}
// create ephemeral nuclei objects/instances/types using base nuclei engine
unsafeOpts, err := createEphemeralObjects(e.eng, tmpEngine.opts)
unsafeOpts, err := createEphemeralObjects(ctx, e.eng, tmpEngine.opts)
if err != nil {
return err
}
@ -156,7 +161,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOptsCtx(ctx context.Context, t
engine := core.New(tmpEngine.opts)
engine.SetExecuterOptions(unsafeOpts.executerOpts)
_ = engine.ExecuteScanWithOpts(context.Background(), store.Templates(), inputProvider, false)
_ = engine.ExecuteScanWithOpts(ctx, store.Templates(), inputProvider, false)
engine.WorkPool().Wait()
return nil

View File

@ -240,7 +240,7 @@ func (e *NucleiEngine) ExecuteCallbackWithCtx(ctx context.Context, callback ...f
}
e.resultCallbacks = append(e.resultCallbacks, filtered...)
_ = e.engine.ExecuteScanWithOpts(context.Background(), e.store.Templates(), e.inputProvider, false)
_ = e.engine.ExecuteScanWithOpts(ctx, e.store.Templates(), e.inputProvider, false)
defer e.engine.WorkPool().Wait()
return nil
}
@ -261,8 +261,8 @@ func (e *NucleiEngine) Engine() *core.Engine {
return e.engine
}
// NewNucleiEngine creates a new nuclei engine instance
func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) {
// NewNucleiEngineCtx creates a new nuclei engine instance with given context
func NewNucleiEngineCtx(ctx context.Context, options ...NucleiSDKOptions) (*NucleiEngine, error) {
// default options
e := &NucleiEngine{
opts: types.DefaultOptions(),
@ -273,8 +273,13 @@ func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) {
return nil, err
}
}
if err := e.init(); err != nil {
if err := e.init(ctx); err != nil {
return nil, err
}
return e, nil
}
// Deprecated: use NewNucleiEngineCtx instead
func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) {
return NewNucleiEngineCtx(context.Background(), options...)
}

View File

@ -37,7 +37,7 @@ import (
var sharedInit sync.Once = sync.Once{}
// applyRequiredDefaults to options
func (e *NucleiEngine) applyRequiredDefaults() {
func (e *NucleiEngine) applyRequiredDefaults(ctx context.Context) {
mockoutput := testutils.NewMockOutputWriter(e.opts.OmitTemplate)
mockoutput.WriteCallback = func(event *output.ResultEvent) {
if len(e.resultCallbacks) > 0 {
@ -81,7 +81,7 @@ func (e *NucleiEngine) applyRequiredDefaults() {
e.interactshOpts = interactsh.DefaultOptions(e.customWriter, e.rc, e.customProgress)
}
if e.rateLimiter == nil {
e.rateLimiter = ratelimit.New(context.Background(), 150, time.Second)
e.rateLimiter = ratelimit.New(ctx, 150, time.Second)
}
if e.opts.ExcludeTags == nil {
e.opts.ExcludeTags = []string{}
@ -94,7 +94,7 @@ func (e *NucleiEngine) applyRequiredDefaults() {
}
// init
func (e *NucleiEngine) init() error {
func (e *NucleiEngine) init(ctx context.Context) error {
if e.opts.Verbose {
gologger.DefaultLogger.SetMaxLevel(levels.LevelVerbose)
} else if e.opts.Debug {
@ -121,7 +121,7 @@ func (e *NucleiEngine) init() error {
_ = protocolinit.Init(e.opts)
})
e.applyRequiredDefaults()
e.applyRequiredDefaults(ctx)
var err error
// setup progressbar
@ -204,9 +204,9 @@ func (e *NucleiEngine) init() error {
e.opts.RateLimitDuration = time.Second
}
if e.opts.RateLimit == 0 && e.opts.RateLimitDuration == 0 {
e.executerOpts.RateLimiter = ratelimit.NewUnlimited(context.Background())
e.executerOpts.RateLimiter = ratelimit.NewUnlimited(ctx)
} else {
e.executerOpts.RateLimiter = ratelimit.New(context.Background(), uint(e.opts.RateLimit), e.opts.RateLimitDuration)
e.executerOpts.RateLimiter = ratelimit.New(ctx, uint(e.opts.RateLimit), e.opts.RateLimitDuration)
}
}

View File

@ -1,6 +1,7 @@
package sdk_test
import (
"context"
"os"
"os/exec"
"testing"
@ -28,7 +29,8 @@ func TestSimpleNuclei(t *testing.T) {
time.Sleep(2 * time.Second)
goleak.VerifyNone(t, knownLeaks...)
}()
ne, err := nuclei.NewNucleiEngine(
ne, err := nuclei.NewNucleiEngineCtx(
context.TODO(),
nuclei.WithTemplateFilters(nuclei.TemplateFilters{ProtocolTypes: "dns"}), // filter dns templates
nuclei.EnableStatsWithOpts(nuclei.StatsOptions{JSON: true}),
)
@ -62,7 +64,8 @@ func TestSimpleNucleiRemote(t *testing.T) {
time.Sleep(2 * time.Second)
goleak.VerifyNone(t, knownLeaks...)
}()
ne, err := nuclei.NewNucleiEngine(
ne, err := nuclei.NewNucleiEngineCtx(
context.TODO(),
nuclei.WithTemplatesOrWorkflows(
nuclei.TemplateSources{
RemoteTemplates: []string{"https://cloud.projectdiscovery.io/public/nameserver-fingerprint.yaml"},
@ -100,7 +103,7 @@ func TestThreadSafeNuclei(t *testing.T) {
goleak.VerifyNone(t, knownLeaks...)
}()
// create nuclei engine with options
ne, err := nuclei.NewThreadSafeNucleiEngine()
ne, err := nuclei.NewThreadSafeNucleiEngineCtx(context.TODO())
require.Nil(t, err)
// scan 1 = run dns templates on scanme.sh