diff --git a/lib/sdk.go b/lib/sdk.go index 7f2ec5bcc..a8639b1d6 100644 --- a/lib/sdk.go +++ b/lib/sdk.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "io" + "sync" "github.com/projectdiscovery/nuclei/v3/pkg/authprovider" "github.com/projectdiscovery/nuclei/v3/pkg/catalog" @@ -64,6 +65,7 @@ type NucleiEngine struct { templatesLoaded bool // unexported core fields + ctx context.Context interactshClient *interactsh.Client catalog catalog.Catalog rateLimiter *ratelimit.Limiter @@ -246,9 +248,9 @@ func (e *NucleiEngine) ExecuteCallbackWithCtx(ctx context.Context, callback ...f } filtered := []func(event *output.ResultEvent){} - for _, callback := range callback { - if callback != nil { - filtered = append(filtered, callback) + for _, cb := range callback { + if cb != nil { + filtered = append(filtered, cb) } } e.resultCallbacks = append(e.resultCallbacks, filtered...) @@ -258,15 +260,32 @@ func (e *NucleiEngine) ExecuteCallbackWithCtx(ctx context.Context, callback ...f return ErrNoTemplatesAvailable } - _ = e.engine.ExecuteScanWithOpts(ctx, templatesAndWorkflows, e.inputProvider, false) - defer e.engine.WorkPool().Wait() + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + _ = e.engine.ExecuteScanWithOpts(ctx, templatesAndWorkflows, e.inputProvider, false) + }() + + // wait for context to be cancelled + select { + case <-ctx.Done(): + <-wait(&wg) // wait for scan to finish + return ctx.Err() + case <-wait(&wg): + // scan finished + } return nil } // ExecuteWithCallback is same as ExecuteCallbackWithCtx but with default context // Note this is deprecated and will be removed in future major release func (e *NucleiEngine) ExecuteWithCallback(callback ...func(event *output.ResultEvent)) error { - return e.ExecuteCallbackWithCtx(context.Background(), callback...) + ctx := context.Background() + if e.ctx != nil { + ctx = e.ctx + } + return e.ExecuteCallbackWithCtx(ctx, callback...) } // Options return nuclei Type Options @@ -290,6 +309,7 @@ func NewNucleiEngineCtx(ctx context.Context, options ...NucleiSDKOptions) (*Nucl e := &NucleiEngine{ opts: types.DefaultOptions(), mode: singleInstance, + ctx: ctx, } for _, option := range options { if err := option(e); err != nil { @@ -306,3 +326,13 @@ func NewNucleiEngineCtx(ctx context.Context, options ...NucleiSDKOptions) (*Nucl func NewNucleiEngine(options ...NucleiSDKOptions) (*NucleiEngine, error) { return NewNucleiEngineCtx(context.Background(), options...) } + +// wait for a waitgroup to finish +func wait(wg *sync.WaitGroup) <-chan struct{} { + ch := make(chan struct{}) + go func() { + defer close(ch) + wg.Wait() + }() + return ch +} diff --git a/lib/sdk_test.go b/lib/sdk_test.go new file mode 100644 index 000000000..c86f8ebbf --- /dev/null +++ b/lib/sdk_test.go @@ -0,0 +1,37 @@ +package nuclei_test + +import ( + "context" + "log" + "testing" + "time" + + nuclei "github.com/projectdiscovery/nuclei/v3/lib" + "github.com/stretchr/testify/require" +) + +func TestContextCancelNucleiEngine(t *testing.T) { + // create nuclei engine with options + ctx, cancel := context.WithCancel(context.Background()) + ne, err := nuclei.NewNucleiEngineCtx(ctx, + nuclei.WithTemplateFilters(nuclei.TemplateFilters{Tags: []string{"oast"}}), + nuclei.EnableStatsWithOpts(nuclei.StatsOptions{MetricServerPort: 0}), + ) + require.NoError(t, err, "could not create nuclei engine") + + go func() { + time.Sleep(time.Second * 2) + cancel() + log.Println("Test: context cancelled") + }() + + // load targets and optionally probe non http/https targets + ne.LoadTargets([]string{"http://honey.scanme.sh"}, false) + // when callback is nil it nuclei will print JSON output to stdout + err = ne.ExecuteWithCallback(nil) + if err != nil { + // we expect a context cancellation error + require.ErrorIs(t, err, context.Canceled, "was expecting context cancellation error") + } + defer ne.Close() +}