From 0b82e8b7aa84ec637f0e2cb33da0757622c956ef Mon Sep 17 00:00:00 2001 From: Ice3man Date: Thu, 25 Apr 2024 15:37:56 +0530 Subject: [PATCH] feat: added support for context cancellation to engine (#5096) * feat: added support for context cancellation to engine * misc * feat: added contexts everywhere * misc * misc * use granular http timeouts and increase http timeout to 30s using multiplier * track response header timeout in mhe * update responseHeaderTimeout to 5sec * skip failing windows test --------- Co-authored-by: Tarun Koyalwar --- cmd/integration-test/library.go | 2 +- internal/runner/lazy.go | 4 +- internal/runner/runner.go | 2 +- lib/multi.go | 2 +- lib/sdk.go | 3 +- pkg/core/execute_options.go | 37 +++++++++++++------ pkg/core/executors.go | 35 +++++++++++++----- pkg/core/workflow_execute.go | 6 +-- pkg/core/workflow_execute_test.go | 25 +++++++------ pkg/input/provider/list/hmap_test.go | 4 ++ pkg/js/compiler/compiler.go | 8 ++-- pkg/protocols/code/code.go | 1 + pkg/protocols/code/code_test.go | 3 +- .../common/automaticscan/automaticscan.go | 9 +++-- .../common/contextargs/contextargs.go | 16 ++++++-- .../common/hosterrorscache/hosterrorscache.go | 2 +- pkg/protocols/dns/request.go | 6 +++ pkg/protocols/dns/request_test.go | 5 ++- pkg/protocols/file/request_test.go | 3 +- .../headless/engine/page_actions_test.go | 5 ++- pkg/protocols/headless/request.go | 2 +- pkg/protocols/http/build_request_test.go | 16 ++++---- .../http/httpclientpool/clientpool.go | 25 +++++++++---- pkg/protocols/http/request.go | 27 ++++++++++++-- pkg/protocols/http/request_annotations.go | 3 +- pkg/protocols/http/request_fuzz.go | 12 ++++++ pkg/protocols/http/request_test.go | 9 +++-- pkg/protocols/javascript/js.go | 17 ++++++++- pkg/protocols/network/request.go | 12 ++++++ pkg/protocols/network/request_test.go | 7 ++-- pkg/protocols/protocols.go | 3 +- pkg/protocols/ssl/ssl_test.go | 3 +- pkg/protocols/utils/variables_test.go | 3 +- pkg/scan/scan_context.go | 13 +++++-- pkg/templates/cluster.go | 2 +- pkg/tmplexec/flow/flow_executor.go | 6 +++ pkg/tmplexec/flow/flow_executor_test.go | 28 +++++++------- pkg/tmplexec/generic/exec.go | 6 +++ pkg/tmplexec/multiproto/multi.go | 12 ++++++ pkg/tmplexec/multiproto/multi_test.go | 8 ++-- 40 files changed, 279 insertions(+), 113 deletions(-) diff --git a/cmd/integration-test/library.go b/cmd/integration-test/library.go index a714b29ab..1324688e0 100644 --- a/cmd/integration-test/library.go +++ b/cmd/integration-test/library.go @@ -128,7 +128,7 @@ func executeNucleiAsLibrary(templatePath, templateURL string) ([]string, error) } store.Load() - _ = engine.Execute(store.Templates(), provider.NewSimpleInputProviderWithUrls(templateURL)) + _ = engine.Execute(context.Background(), store.Templates(), provider.NewSimpleInputProviderWithUrls(templateURL)) engine.WorkPool().Wait() // Wait for the scan to finish return results, nil diff --git a/internal/runner/lazy.go b/internal/runner/lazy.go index b61dd5515..eb4137451 100644 --- a/internal/runner/lazy.go +++ b/internal/runner/lazy.go @@ -1,6 +1,7 @@ package runner import ( + "context" "fmt" "github.com/projectdiscovery/nuclei/v3/pkg/authprovider/authx" @@ -71,7 +72,8 @@ func GetLazyAuthFetchCallback(opts *AuthLazyFetchOptions) authx.LazyFetchSecret tmpl := tmpls[0] // add args to tmpl here vars := map[string]interface{}{} - ctx := scan.NewScanContext(contextargs.NewWithInput(d.Input)) + mainCtx := context.Background() + ctx := scan.NewScanContext(mainCtx, contextargs.NewWithInput(mainCtx, d.Input)) for _, v := range d.Variables { vars[v.Key] = v.Value ctx.Input.Add(v.Key, v.Value) diff --git a/internal/runner/runner.go b/internal/runner/runner.go index ab2fe5de9..43f16ff10 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -669,7 +669,7 @@ func (r *Runner) executeTemplatesInput(store *loader.Store, engine *core.Engine) if r.inputProvider == nil { return nil, errors.New("no input provider found") } - results := engine.ExecuteScanWithOpts(finalTemplates, r.inputProvider, r.options.DisableClustering) + results := engine.ExecuteScanWithOpts(context.Background(), finalTemplates, r.inputProvider, r.options.DisableClustering) return results, nil } diff --git a/lib/multi.go b/lib/multi.go index 6fa791a2f..bdcee7967 100644 --- a/lib/multi.go +++ b/lib/multi.go @@ -138,7 +138,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOpts(targets []string, opts .. engine := core.New(tmpEngine.opts) engine.SetExecuterOptions(unsafeOpts.executerOpts) - _ = engine.ExecuteScanWithOpts(store.Templates(), inputProvider, false) + _ = engine.ExecuteScanWithOpts(context.Background(), store.Templates(), inputProvider, false) engine.WorkPool().Wait() return nil diff --git a/lib/sdk.go b/lib/sdk.go index c80abb061..f6b5ea7c4 100644 --- a/lib/sdk.go +++ b/lib/sdk.go @@ -3,6 +3,7 @@ package nuclei import ( "bufio" "bytes" + "context" "io" "github.com/projectdiscovery/nuclei/v3/pkg/authprovider" @@ -210,7 +211,7 @@ func (e *NucleiEngine) ExecuteWithCallback(callback ...func(event *output.Result } e.resultCallbacks = append(e.resultCallbacks, filtered...) - _ = e.engine.ExecuteScanWithOpts(e.store.Templates(), e.inputProvider, false) + _ = e.engine.ExecuteScanWithOpts(context.Background(), e.store.Templates(), e.inputProvider, false) defer e.engine.WorkPool().Wait() return nil } diff --git a/pkg/core/execute_options.go b/pkg/core/execute_options.go index 93f197fc2..4d27b5f66 100644 --- a/pkg/core/execute_options.go +++ b/pkg/core/execute_options.go @@ -1,6 +1,7 @@ package core import ( + "context" "sync" "sync/atomic" @@ -20,18 +21,18 @@ import ( // // All the execution logic for the templates/workflows happens in this part // of the engine. -func (e *Engine) Execute(templates []*templates.Template, target provider.InputProvider) *atomic.Bool { - return e.ExecuteScanWithOpts(templates, target, false) +func (e *Engine) Execute(ctx context.Context, templates []*templates.Template, target provider.InputProvider) *atomic.Bool { + return e.ExecuteScanWithOpts(ctx, templates, target, false) } // ExecuteWithResults a list of templates with results -func (e *Engine) ExecuteWithResults(templatesList []*templates.Template, target provider.InputProvider, callback func(*output.ResultEvent)) *atomic.Bool { +func (e *Engine) ExecuteWithResults(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider, callback func(*output.ResultEvent)) *atomic.Bool { e.Callback = callback - return e.ExecuteScanWithOpts(templatesList, target, false) + return e.ExecuteScanWithOpts(ctx, templatesList, target, false) } // ExecuteScanWithOpts executes scan with given scanStrategy -func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target provider.InputProvider, noCluster bool) *atomic.Bool { +func (e *Engine) ExecuteScanWithOpts(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider, noCluster bool) *atomic.Bool { results := &atomic.Bool{} selfcontainedWg := &sync.WaitGroup{} @@ -83,14 +84,14 @@ func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target } // Execute All SelfContained in parallel - e.executeAllSelfContained(selfContained, results, selfcontainedWg) + e.executeAllSelfContained(ctx, selfContained, results, selfcontainedWg) strategyResult := &atomic.Bool{} switch e.options.ScanStrategy { case scanstrategy.TemplateSpray.String(): - strategyResult = e.executeTemplateSpray(filtered, target) + strategyResult = e.executeTemplateSpray(ctx, filtered, target) case scanstrategy.HostSpray.String(): - strategyResult = e.executeHostSpray(filtered, target) + strategyResult = e.executeHostSpray(ctx, filtered, target) } results.CompareAndSwap(false, strategyResult.Load()) @@ -100,7 +101,7 @@ func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target } // executeTemplateSpray executes scan using template spray strategy where targets are iterated over each template -func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool { +func (e *Engine) executeTemplateSpray(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool { results := &atomic.Bool{} // wp is workpool that contains different waitgroups for @@ -108,6 +109,12 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe wp := e.GetWorkPool() for _, template := range templatesList { + select { + case <-ctx.Done(): + return results + default: + } + // resize check point - nop if there are no changes wp.RefreshWithConfig(e.GetWorkPoolConfig()) @@ -125,7 +132,7 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe // All other request types are executed here // Note: executeTemplateWithTargets creates goroutines and blocks // given template is executed on all targets - e.executeTemplateWithTargets(tpl, target, results) + e.executeTemplateWithTargets(ctx, tpl, target, results) }(template) } wp.Wait() @@ -133,15 +140,21 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe } // executeHostSpray executes scan using host spray strategy where templates are iterated over each target -func (e *Engine) executeHostSpray(templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool { +func (e *Engine) executeHostSpray(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool { results := &atomic.Bool{} wp, _ := syncutil.New(syncutil.WithSize(e.options.BulkSize + e.options.HeadlessBulkSize)) target.Iterate(func(value *contextargs.MetaInput) bool { + select { + case <-ctx.Done(): + return false + default: + } + wp.Add() go func(targetval *contextargs.MetaInput) { defer wp.Done() - e.executeTemplatesOnTarget(templatesList, targetval, results) + e.executeTemplatesOnTarget(ctx, templatesList, targetval, results) }(value) return true }) diff --git a/pkg/core/executors.go b/pkg/core/executors.go index 14fb75c63..89c85b2ad 100644 --- a/pkg/core/executors.go +++ b/pkg/core/executors.go @@ -1,6 +1,7 @@ package core import ( + "context" "sync" "sync/atomic" @@ -17,14 +18,14 @@ import ( // Executors are low level executors that deals with template execution on a target // executeAllSelfContained executes all self contained templates that do not use `target` -func (e *Engine) executeAllSelfContained(alltemplates []*templates.Template, results *atomic.Bool, sg *sync.WaitGroup) { +func (e *Engine) executeAllSelfContained(ctx context.Context, alltemplates []*templates.Template, results *atomic.Bool, sg *sync.WaitGroup) { for _, v := range alltemplates { sg.Add(1) go func(template *templates.Template) { defer sg.Done() var err error var match bool - ctx := scan.NewScanContext(contextargs.New()) + ctx := scan.NewScanContext(ctx, contextargs.New(ctx)) if e.Callback != nil { if results, err := template.Executer.ExecuteWithResults(ctx); err != nil { for _, result := range results { @@ -45,7 +46,7 @@ func (e *Engine) executeAllSelfContained(alltemplates []*templates.Template, res } // executeTemplateWithTarget executes a given template on x targets (with a internal targetpool(i.e concurrency)) -func (e *Engine) executeTemplateWithTargets(template *templates.Template, target provider.InputProvider, results *atomic.Bool) { +func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templates.Template, target provider.InputProvider, results *atomic.Bool) { // this is target pool i.e max target to execute wg := e.workPool.InputPool(template.Type()) @@ -77,6 +78,12 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target } target.Iterate(func(scannedValue *contextargs.MetaInput) bool { + select { + case <-ctx.Done(): + return false // exit + default: + } + // Best effort to track the host progression // skips indexes lower than the minimum in-flight at interruption time var skip bool @@ -114,9 +121,9 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target var match bool var err error - ctxArgs := contextargs.New() + ctxArgs := contextargs.New(ctx) ctxArgs.MetaInput = value - ctx := scan.NewScanContext(ctxArgs) + ctx := scan.NewScanContext(ctx, ctxArgs) switch template.Type() { case types.WorkflowProtocol: match = e.executeWorkflow(ctx, template.CompiledWorkflow) @@ -149,7 +156,7 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target } // executeTemplatesOnTarget execute given templates on given single target -func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, target *contextargs.MetaInput, results *atomic.Bool) { +func (e *Engine) executeTemplatesOnTarget(ctx context.Context, alltemplates []*templates.Template, target *contextargs.MetaInput, results *atomic.Bool) { // all templates are executed on single target // wp is workpool that contains different waitgroups for @@ -158,6 +165,12 @@ func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, ta wp := e.GetWorkPool() for _, tpl := range alltemplates { + select { + case <-ctx.Done(): + return + default: + } + // resize check point - nop if there are no changes wp.RefreshWithConfig(e.GetWorkPoolConfig()) @@ -173,9 +186,9 @@ func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, ta var match bool var err error - ctxArgs := contextargs.New() + ctxArgs := contextargs.New(ctx) ctxArgs.MetaInput = value - ctx := scan.NewScanContext(ctxArgs) + ctx := scan.NewScanContext(ctx, ctxArgs) switch template.Type() { case types.WorkflowProtocol: match = e.executeWorkflow(ctx, template.CompiledWorkflow) @@ -230,9 +243,11 @@ func (e *ChildExecuter) Execute(template *templates.Template, value *contextargs go func(tpl *templates.Template) { defer wg.Done() - ctxArgs := contextargs.New() + // TODO: Workflows are a no-op for now. We need to + // implement them in the future with context cancellation + ctxArgs := contextargs.New(context.Background()) ctxArgs.MetaInput = value - ctx := scan.NewScanContext(ctxArgs) + ctx := scan.NewScanContext(context.Background(), ctxArgs) match, err := template.Executer.Execute(ctx) if err != nil { gologger.Warning().Msgf("[%s] Could not execute step: %s\n", e.e.executerOpts.Colorizer.BrightBlue(template.ID), err) diff --git a/pkg/core/workflow_execute.go b/pkg/core/workflow_execute.go index c17a1af8d..d0d3ede00 100644 --- a/pkg/core/workflow_execute.go +++ b/pkg/core/workflow_execute.go @@ -21,7 +21,7 @@ func (e *Engine) executeWorkflow(ctx *scan.ScanContext, w *workflows.Workflow) b // at this point we should be at the start root execution of a workflow tree, hence we create global shared instances workflowCookieJar, _ := cookiejar.New(nil) - ctxArgs := contextargs.New() + ctxArgs := contextargs.New(ctx.Context()) ctxArgs.MetaInput = ctx.Input.MetaInput ctxArgs.CookieJar = workflowCookieJar @@ -139,7 +139,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan defer swg.Done() // create a new context with the same input but with unset callbacks - subCtx := scan.NewScanContext(ctx.Input) + subCtx := scan.NewScanContext(ctx.Context(), ctx.Input) if err := e.runWorkflowStep(subtemplate, subCtx, results, swg, w); err != nil { gologger.Warning().Msgf(workflowStepExecutionError, subtemplate.Template, err) } @@ -165,7 +165,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan go func(template *workflows.WorkflowTemplate) { // create a new context with the same input but with unset callbacks - subCtx := scan.NewScanContext(ctx.Input) + subCtx := scan.NewScanContext(ctx.Context(), ctx.Input) if err := e.runWorkflowStep(template, subCtx, results, swg, w); err != nil { gologger.Warning().Msgf(workflowStepExecutionError, template.Template, err) } diff --git a/pkg/core/workflow_execute_test.go b/pkg/core/workflow_execute_test.go index f3d6a7f23..0c478a5e1 100644 --- a/pkg/core/workflow_execute_test.go +++ b/pkg/core/workflow_execute_test.go @@ -1,6 +1,7 @@ package core import ( + "context" "testing" "github.com/projectdiscovery/nuclei/v3/pkg/model/types/stringslice" @@ -25,8 +26,8 @@ func TestWorkflowsSimple(t *testing.T) { }} engine := &Engine{} - input := contextargs.NewWithInput("https://test.com") - ctx := scan.NewScanContext(input) + input := contextargs.NewWithInput(context.Background(), "https://test.com") + ctx := scan.NewScanContext(context.Background(), input) matched := engine.executeWorkflow(ctx, workflow) require.True(t, matched, "could not get correct match value") } @@ -49,8 +50,8 @@ func TestWorkflowsSimpleMultiple(t *testing.T) { }} engine := &Engine{} - input := contextargs.NewWithInput("https://test.com") - ctx := scan.NewScanContext(input) + input := contextargs.NewWithInput(context.Background(), "https://test.com") + ctx := scan.NewScanContext(context.Background(), input) matched := engine.executeWorkflow(ctx, workflow) require.True(t, matched, "could not get correct match value") @@ -77,8 +78,8 @@ func TestWorkflowsSubtemplates(t *testing.T) { }} engine := &Engine{} - input := contextargs.NewWithInput("https://test.com") - ctx := scan.NewScanContext(input) + input := contextargs.NewWithInput(context.Background(), "https://test.com") + ctx := scan.NewScanContext(context.Background(), input) matched := engine.executeWorkflow(ctx, workflow) require.True(t, matched, "could not get correct match value") @@ -103,8 +104,8 @@ func TestWorkflowsSubtemplatesNoMatch(t *testing.T) { }} engine := &Engine{} - input := contextargs.NewWithInput("https://test.com") - ctx := scan.NewScanContext(input) + input := contextargs.NewWithInput(context.Background(), "https://test.com") + ctx := scan.NewScanContext(context.Background(), input) matched := engine.executeWorkflow(ctx, workflow) require.False(t, matched, "could not get correct match value") @@ -134,8 +135,8 @@ func TestWorkflowsSubtemplatesWithMatcher(t *testing.T) { }} engine := &Engine{} - input := contextargs.NewWithInput("https://test.com") - ctx := scan.NewScanContext(input) + input := contextargs.NewWithInput(context.Background(), "https://test.com") + ctx := scan.NewScanContext(context.Background(), input) matched := engine.executeWorkflow(ctx, workflow) require.True(t, matched, "could not get correct match value") @@ -165,8 +166,8 @@ func TestWorkflowsSubtemplatesWithMatcherNoMatch(t *testing.T) { }} engine := &Engine{} - input := contextargs.NewWithInput("https://test.com") - ctx := scan.NewScanContext(input) + input := contextargs.NewWithInput(context.Background(), "https://test.com") + ctx := scan.NewScanContext(context.Background(), input) matched := engine.executeWorkflow(ctx, workflow) require.False(t, matched, "could not get correct match value") diff --git a/pkg/input/provider/list/hmap_test.go b/pkg/input/provider/list/hmap_test.go index 95fc57f2d..b08782e12 100644 --- a/pkg/input/provider/list/hmap_test.go +++ b/pkg/input/provider/list/hmap_test.go @@ -3,6 +3,7 @@ package list import ( "net" "os" + "runtime" "strconv" "strings" "testing" @@ -77,6 +78,9 @@ func (m *mockDnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } func Test_scanallips_normalizeStoreInputValue(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping test see: https://github.com/projectdiscovery/nuclei/issues/5097") + } srv := &dns.Server{Addr: ":" + strconv.Itoa(61234), Net: "udp"} srv.Handler = &mockDnsHandler{} diff --git a/pkg/js/compiler/compiler.go b/pkg/js/compiler/compiler.go index 265bbaf04..05eacda6c 100644 --- a/pkg/js/compiler/compiler.go +++ b/pkg/js/compiler/compiler.go @@ -37,6 +37,8 @@ type ExecuteOptions struct { // Source is original source of the script Source *string + Context context.Context + // Manually exported objects exports map[string]interface{} } @@ -77,13 +79,13 @@ func (c *Compiler) Execute(code string, args *ExecuteArgs) (ExecuteResult, error if err != nil { return nil, err } - return c.ExecuteWithOptions(p, args, &ExecuteOptions{}) + return c.ExecuteWithOptions(p, args, &ExecuteOptions{Context: context.Background()}) } // ExecuteWithOptions executes a script with the provided options. func (c *Compiler) ExecuteWithOptions(program *goja.Program, args *ExecuteArgs, opts *ExecuteOptions) (ExecuteResult, error) { if opts == nil { - opts = &ExecuteOptions{} + opts = &ExecuteOptions{Context: context.Background()} } if args == nil { args = NewExecuteArgs() @@ -105,7 +107,7 @@ func (c *Compiler) ExecuteWithOptions(program *goja.Program, args *ExecuteArgs, } // execute with context and timeout - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(opts.Timeout)*time.Second) + ctx, cancel := context.WithTimeout(opts.Context, time.Duration(opts.Timeout)*time.Second) defer cancel() // execute the script results, err := contextutil.ExecFuncWithTwoReturns(ctx, func() (val goja.Value, err error) { diff --git a/pkg/protocols/code/code.go b/pkg/protocols/code/code.go index 1b59de3ef..70dda7254 100644 --- a/pkg/protocols/code/code.go +++ b/pkg/protocols/code/code.go @@ -199,6 +199,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa Source: &request.PreCondition, Callback: registerPreConditionFunctions, Cleanup: cleanUpPreConditionFunctions, + Context: input.Context(), }) if err != nil { return errorutil.NewWithTag(request.TemplateID, "could not execute pre-condition: %s", err) diff --git a/pkg/protocols/code/code_test.go b/pkg/protocols/code/code_test.go index 1ace1388f..320f7c548 100644 --- a/pkg/protocols/code/code_test.go +++ b/pkg/protocols/code/code_test.go @@ -3,6 +3,7 @@ package code import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -31,7 +32,7 @@ func TestCodeProtocol(t *testing.T) { require.Nil(t, err, "could not compile code request") var gotEvent output.InternalEvent - ctxArgs := contextargs.NewWithInput("") + ctxArgs := contextargs.NewWithInput(context.Background(), "") err = request.ExecuteWithResults(ctxArgs, nil, nil, func(event *output.InternalWrappedEvent) { gotEvent = event.InternalEvent }) diff --git a/pkg/protocols/common/automaticscan/automaticscan.go b/pkg/protocols/common/automaticscan/automaticscan.go index 7119377b0..4567f498d 100644 --- a/pkg/protocols/common/automaticscan/automaticscan.go +++ b/pkg/protocols/common/automaticscan/automaticscan.go @@ -1,6 +1,7 @@ package automaticscan import ( + "context" "io" "net/http" "os" @@ -189,7 +190,7 @@ func (s *Service) executeAutomaticScanOnTarget(input *contextargs.MetaInput) { execOptions.Progress = &testutils.MockProgressClient{} // stats are not supported yet due to centralized logic and cannot be reinitialized eng.SetExecuterOptions(execOptions) - tmp := eng.ExecuteScanWithOpts(finalTemplates, provider.NewSimpleInputProviderWithUrls(input.Input), true) + tmp := eng.ExecuteScanWithOpts(context.Background(), finalTemplates, provider.NewSimpleInputProviderWithUrls(input.Input), true) s.hasResults.Store(tmp.Load()) } @@ -244,7 +245,9 @@ func (s *Service) getTagsUsingWappalyzer(input *contextargs.MetaInput) []string // getTagsUsingDetectionTemplates returns tags using detection templates func (s *Service) getTagsUsingDetectionTemplates(input *contextargs.MetaInput) ([]string, int) { - ctxArgs := contextargs.NewWithInput(input.Input) + ctx := context.Background() + + ctxArgs := contextargs.NewWithInput(ctx, input.Input) // execute tech detection templates on target tags := map[string]struct{}{} @@ -256,7 +259,7 @@ func (s *Service) getTagsUsingDetectionTemplates(input *contextargs.MetaInput) ( sg.Add() go func(template *templates.Template) { defer sg.Done() - ctx := scan.NewScanContext(ctxArgs) + ctx := scan.NewScanContext(ctx, ctxArgs) ctx.OnResult = func(event *output.InternalWrappedEvent) { if event == nil { return diff --git a/pkg/protocols/common/contextargs/contextargs.go b/pkg/protocols/common/contextargs/contextargs.go index 4ebaa1561..8e8f0361c 100644 --- a/pkg/protocols/common/contextargs/contextargs.go +++ b/pkg/protocols/common/contextargs/contextargs.go @@ -1,6 +1,7 @@ package contextargs import ( + "context" "net/http/cookiejar" "strings" "sync/atomic" @@ -19,6 +20,8 @@ var ( // Context implements a shared context struct to share information across multiple templates within a workflow type Context struct { + ctx context.Context + // Meta is the target for the executor MetaInput *MetaInput @@ -30,17 +33,18 @@ type Context struct { } // Create a new contextargs instance -func New() *Context { - return NewWithInput("") +func New(ctx context.Context) *Context { + return NewWithInput(ctx, "") } // Create a new contextargs instance with input string -func NewWithInput(input string) *Context { +func NewWithInput(ctx context.Context, input string) *Context { jar, err := cookiejar.New(nil) if err != nil { gologger.Error().Msgf("contextargs: could not create cookie jar: %s\n", err) } return &Context{ + ctx: ctx, MetaInput: &MetaInput{Input: input}, CookieJar: jar, args: &mapsutil.SyncLockMap[string, interface{}]{ @@ -50,6 +54,11 @@ func NewWithInput(input string) *Context { } } +// Context returns the context of the current contextargs +func (ctx *Context) Context() context.Context { + return ctx.ctx +} + // Set the specific key-value pair func (ctx *Context) Set(key string, value interface{}) { _ = ctx.args.Set(key, value) @@ -158,6 +167,7 @@ func (ctx *Context) HasArgs() bool { func (ctx *Context) Clone() *Context { newCtx := &Context{ + ctx: ctx.ctx, MetaInput: ctx.MetaInput.Clone(), args: ctx.args.Clone(), CookieJar: ctx.CookieJar, diff --git a/pkg/protocols/common/hosterrorscache/hosterrorscache.go b/pkg/protocols/common/hosterrorscache/hosterrorscache.go index 3ada7718a..3abd1ec55 100644 --- a/pkg/protocols/common/hosterrorscache/hosterrorscache.go +++ b/pkg/protocols/common/hosterrorscache/hosterrorscache.go @@ -124,7 +124,7 @@ func (c *Cache) MarkFailed(value string, err error) { _ = c.failedTargets.Set(finalValue, existingCacheItemValue) } -var reCheckError = regexp.MustCompile(`(no address found for host|Client\.Timeout exceeded while awaiting headers|could not resolve host|connection refused|connection reset by peer|i/o timeout|could not connect to any address found for host)`) +var reCheckError = regexp.MustCompile(`(no address found for host|Client\.Timeout exceeded while awaiting headers|could not resolve host|connection refused|connection reset by peer|i/o timeout|could not connect to any address found for host|timeout awaiting response headers)`) // checkError checks if an error represents a type that should be // added to the host skipping table. diff --git a/pkg/protocols/dns/request.go b/pkg/protocols/dns/request.go index d4e70e13e..608e4730a 100644 --- a/pkg/protocols/dns/request.go +++ b/pkg/protocols/dns/request.go @@ -80,6 +80,12 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata, break } + select { + case <-input.Context().Done(): + return input.Context().Err() + default: + } + // resize check point - nop if there are no changes if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency { swg.Resize(request.options.Options.PayloadConcurrency) diff --git a/pkg/protocols/dns/request_test.go b/pkg/protocols/dns/request_test.go index f32a8c622..81f0b98ab 100644 --- a/pkg/protocols/dns/request_test.go +++ b/pkg/protocols/dns/request_test.go @@ -1,6 +1,7 @@ package dns import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -54,7 +55,7 @@ func TestDNSExecuteWithResults(t *testing.T) { t.Run("domain-valid", func(t *testing.T) { metadata := make(output.InternalEvent) previous := make(output.InternalEvent) - ctxArgs := contextargs.NewWithInput("example.com") + ctxArgs := contextargs.NewWithInput(context.Background(), "example.com") err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { finalEvent = event }) @@ -70,7 +71,7 @@ func TestDNSExecuteWithResults(t *testing.T) { t.Run("url-to-domain", func(t *testing.T) { metadata := make(output.InternalEvent) previous := make(output.InternalEvent) - err := request.ExecuteWithResults(contextargs.NewWithInput("https://example.com"), metadata, previous, func(event *output.InternalWrappedEvent) { + err := request.ExecuteWithResults(contextargs.NewWithInput(context.Background(), "https://example.com"), metadata, previous, func(event *output.InternalWrappedEvent) { finalEvent = event }) require.Nil(t, err, "could not execute dns request") diff --git a/pkg/protocols/file/request_test.go b/pkg/protocols/file/request_test.go index ff41e3e8b..7c69cf5bc 100644 --- a/pkg/protocols/file/request_test.go +++ b/pkg/protocols/file/request_test.go @@ -1,6 +1,7 @@ package file import ( + "context" "os" "path/filepath" "testing" @@ -67,7 +68,7 @@ func TestFileExecuteWithResults(t *testing.T) { t.Run("valid", func(t *testing.T) { metadata := make(output.InternalEvent) previous := make(output.InternalEvent) - ctxArgs := contextargs.NewWithInput(tempDir) + ctxArgs := contextargs.NewWithInput(context.Background(), tempDir) err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { finalEvent = event }) diff --git a/pkg/protocols/headless/engine/page_actions_test.go b/pkg/protocols/headless/engine/page_actions_test.go index a4b69eeff..e6699d640 100644 --- a/pkg/protocols/headless/engine/page_actions_test.go +++ b/pkg/protocols/headless/engine/page_actions_test.go @@ -1,6 +1,7 @@ package engine import ( + "context" "fmt" "io" "math/rand" @@ -595,7 +596,7 @@ func testHeadless(t *testing.T, actions []*Action, timeout time.Duration, handle ts := httptest.NewServer(http.HandlerFunc(handler)) defer ts.Close() - input := contextargs.NewWithInput(ts.URL) + input := contextargs.NewWithInput(context.Background(), ts.URL) input.CookieJar, err = cookiejar.New(nil) require.Nil(t, err) @@ -674,7 +675,7 @@ func TestBlockedHeadlessURLS(t *testing.T) { {ActionType: ActionTypeHolder{ActionType: ActionWaitLoad}}, } - data, page, err := instance.Run(contextargs.NewWithInput(ts.URL), actions, nil, &Options{Timeout: 20 * time.Second, Options: opts}) // allow file access in test + data, page, err := instance.Run(contextargs.NewWithInput(context.Background(), ts.URL), actions, nil, &Options{Timeout: 20 * time.Second, Options: opts}) // allow file access in test require.Error(t, err, "expected error for url %s got %v", testcase, data) require.True(t, stringsutil.ContainsAny(err.Error(), "net::ERR_ACCESS_DENIED", "failed to parse url", "Cannot navigate to invalid URL", "net::ERR_ABORTED", "net::ERR_INVALID_URL"), "found different error %v for testcases %v", err, testcase) require.Len(t, data, 0, "expected no data for url %s got %v", testcase, data) diff --git a/pkg/protocols/headless/request.go b/pkg/protocols/headless/request.go index f45afdd35..05bf5d354 100644 --- a/pkg/protocols/headless/request.go +++ b/pkg/protocols/headless/request.go @@ -44,7 +44,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata, if err != nil { return err } - input = contextargs.NewWithInput(url) + input = contextargs.NewWithInput(input.Context(), url) } if request.options.Browser.UserAgent() == "" { diff --git a/pkg/protocols/http/build_request_test.go b/pkg/protocols/http/build_request_test.go index bc87549f0..4405bd10b 100644 --- a/pkg/protocols/http/build_request_test.go +++ b/pkg/protocols/http/build_request_test.go @@ -40,7 +40,7 @@ func TestMakeRequestFromModal(t *testing.T) { generator := request.newGenerator(false) inputData, payloads, _ := generator.nextValue() - req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{}) + req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{}) require.Nil(t, err, "could not make http request") if req.request.URL == nil { t.Fatalf("url is nil in generator make") @@ -70,13 +70,13 @@ func TestMakeRequestFromModalTrimSuffixSlash(t *testing.T) { generator := request.newGenerator(false) inputData, payloads, _ := generator.nextValue() - req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com/test.php"), inputData, payloads, map[string]interface{}{}) + req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com/test.php"), inputData, payloads, map[string]interface{}{}) require.Nil(t, err, "could not make http request") require.Equal(t, "https://example.com/test.php?query=example", req.request.URL.String(), "could not get correct request path") generator = request.newGenerator(false) inputData, payloads, _ = generator.nextValue() - req, err = generator.Make(context.Background(), contextargs.NewWithInput("https://example.com/test/"), inputData, payloads, map[string]interface{}{}) + req, err = generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com/test/"), inputData, payloads, map[string]interface{}{}) require.Nil(t, err, "could not make http request") require.Equal(t, "https://example.com/test/?query=example", req.request.URL.String(), "could not get correct request path") } @@ -110,13 +110,13 @@ Accept-Encoding: gzip`}, generator := request.newGenerator(false) inputData, payloads, _ := generator.nextValue() - req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{}) + req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{}) require.Nil(t, err, "could not make http request") authorization := req.request.Header.Get("Authorization") require.Equal(t, "Basic admin:admin", authorization, "could not get correct authorization headers from raw") inputData, payloads, _ = generator.nextValue() - req, err = generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{}) + req, err = generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{}) require.Nil(t, err, "could not make http request") authorization = req.request.Header.Get("Authorization") require.Equal(t, "Basic admin:guest", authorization, "could not get correct authorization headers from raw") @@ -151,13 +151,13 @@ Accept-Encoding: gzip`}, generator := request.newGenerator(false) inputData, payloads, _ := generator.nextValue() - req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{}) + req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{}) require.Nil(t, err, "could not make http request") authorization := req.request.Header.Get("Authorization") require.Equal(t, "Basic YWRtaW46YWRtaW4=", authorization, "could not get correct authorization headers from raw") inputData, payloads, _ = generator.nextValue() - req, err = generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{}) + req, err = generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{}) require.Nil(t, err, "could not make http request") authorization = req.request.Header.Get("Authorization") require.Equal(t, "Basic YWRtaW46Z3Vlc3Q=", authorization, "could not get correct authorization headers from raw") @@ -195,7 +195,7 @@ func TestMakeRequestFromModelUniqueInteractsh(t *testing.T) { require.Nil(t, err, "could not create interactsh client") inputData, payloads, _ := generator.nextValue() - got, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{}) + got, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{}) require.Nil(t, err, "could not make http request") // check if all the interactsh markers are replaced with unique urls diff --git a/pkg/protocols/http/httpclientpool/clientpool.go b/pkg/protocols/http/httpclientpool/clientpool.go index 1560e787a..3e2baf55a 100644 --- a/pkg/protocols/http/httpclientpool/clientpool.go +++ b/pkg/protocols/http/httpclientpool/clientpool.go @@ -35,8 +35,18 @@ var ( forceMaxRedirects int normalClient *retryablehttp.Client clientPool *mapsutil.SyncLockMap[string, *retryablehttp.Client] + // ResponseHeaderTimeout is the timeout for response headers + // to be read from the server (this prevents infinite hang started by server if any) + ResponseHeaderTimeout = time.Duration(5) * time.Second + // HttpTimeoutMultiplier is the multiplier for the http timeout + HttpTimeoutMultiplier = 3 ) +// GetHttpTimeout returns the http timeout for the client +func GetHttpTimeout(opts *types.Options) time.Duration { + return time.Duration(opts.Timeout*HttpTimeoutMultiplier) * time.Second +} + // Init initializes the clientpool implementation func Init(options *types.Options) error { // Don't create clients if already created in the past. @@ -139,7 +149,7 @@ func GetRawHTTP(options *types.Options) *rawhttp.Client { } else if Dialer != nil { rawHttpOptions.FastDialer = Dialer } - rawHttpOptions.Timeout = time.Duration(options.Timeout) * time.Second + rawHttpOptions.Timeout = GetHttpTimeout(options) rawHttpClient = rawhttp.NewClient(rawHttpOptions) } return rawHttpClient @@ -237,11 +247,12 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl } return Dialer.DialTLS(ctx, network, addr) }, - MaxIdleConns: maxIdleConns, - MaxIdleConnsPerHost: maxIdleConnsPerHost, - MaxConnsPerHost: maxConnsPerHost, - TLSClientConfig: tlsConfig, - DisableKeepAlives: disableKeepAlives, + MaxIdleConns: maxIdleConns, + MaxIdleConnsPerHost: maxIdleConnsPerHost, + MaxConnsPerHost: maxConnsPerHost, + TLSClientConfig: tlsConfig, + DisableKeepAlives: disableKeepAlives, + ResponseHeaderTimeout: ResponseHeaderTimeout, } if types.ProxyURL != "" { @@ -288,7 +299,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl CheckRedirect: makeCheckRedirectFunc(redirectFlow, maxRedirects), } if !configuration.NoTimeout { - httpclient.Timeout = time.Duration(options.Timeout) * time.Second + httpclient.Timeout = GetHttpTimeout(options) } client := retryablehttp.NewWithHTTPClient(httpclient, retryableHttpOptions) if jar != nil { diff --git a/pkg/protocols/http/request.go b/pkg/protocols/http/request.go index d966a4bb7..545247ef5 100644 --- a/pkg/protocols/http/request.go +++ b/pkg/protocols/http/request.go @@ -235,6 +235,12 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV break } + select { + case <-input.Context().Done(): + return input.Context().Err() + default: + } + // resize check point - nop if there are no changes if shouldFollowGlobal && spmHandler.Size() != request.options.Options.PayloadConcurrency { spmHandler.Resize(request.options.Options.PayloadConcurrency) @@ -356,6 +362,13 @@ func (request *Request) executeTurboHTTP(input *contextargs.Context, dynamicValu if !ok { break } + + select { + case <-input.Context().Done(): + return input.Context().Err() + default: + } + if spmHandler.FoundFirstMatch() || request.isUnresponsiveHost(input) || spmHandler.Cancelled() { // skip if first match is found break @@ -433,8 +446,9 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa request.options.RateLimitTake() ctx := request.newContext(input) - ctxWithTimeout, cancel := context.WithTimeout(ctx, time.Duration(request.options.Options.Timeout)*time.Second) + ctxWithTimeout, cancel := context.WithTimeout(ctx, httpclientpool.GetHttpTimeout(request.options.Options)) defer cancel() + generatedHttpRequest, err := generator.Make(ctxWithTimeout, input, data, payloads, dynamicValue) if err != nil { if err == types.ErrNoMoreRequests { @@ -512,6 +526,13 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa if !ok { break } + + select { + case <-input.Context().Done(): + return input.Context().Err() + default: + } + var gotErr error var skip bool if len(gotDynamicValues) > 0 { @@ -1041,9 +1062,9 @@ func (request *Request) pruneSignatureInternalValues(maps ...map[string]interfac func (request *Request) newContext(input *contextargs.Context) context.Context { if input.MetaInput.CustomIP != "" { - return context.WithValue(context.Background(), fastdialer.IP, input.MetaInput.CustomIP) + return context.WithValue(input.Context(), fastdialer.IP, input.MetaInput.CustomIP) } - return context.Background() + return input.Context() } // markUnresponsiveHost checks if the error is a unreponsive host error and marks it diff --git a/pkg/protocols/http/request_annotations.go b/pkg/protocols/http/request_annotations.go index 6325e4e73..67bf7c167 100644 --- a/pkg/protocols/http/request_annotations.go +++ b/pkg/protocols/http/request_annotations.go @@ -9,6 +9,7 @@ import ( "time" "github.com/projectdiscovery/fastdialer/fastdialer" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/http/httpclientpool" "github.com/projectdiscovery/retryablehttp-go" iputil "github.com/projectdiscovery/utils/ip" stringsutil "github.com/projectdiscovery/utils/strings" @@ -124,7 +125,7 @@ func (r *Request) parseAnnotations(rawRequest string, request *retryablehttp.Req } } else { //nolint:govet // cancelled automatically by withTimeout - ctx, overrides.cancelFunc = context.WithTimeout(context.Background(), time.Duration(r.options.Options.Timeout)*time.Second) + ctx, overrides.cancelFunc = context.WithTimeout(context.Background(), httpclientpool.GetHttpTimeout(r.options.Options)) request = request.Clone(ctx) } } diff --git a/pkg/protocols/http/request_fuzz.go b/pkg/protocols/http/request_fuzz.go index aed17f54e..fc2a4e757 100644 --- a/pkg/protocols/http/request_fuzz.go +++ b/pkg/protocols/http/request_fuzz.go @@ -110,9 +110,21 @@ func (request *Request) executeFuzzingRule(input *contextargs.Context, previous func (request *Request) executeAllFuzzingRules(input *contextargs.Context, values map[string]interface{}, baseRequest *retryablehttp.Request, callback protocols.OutputEventCallback) error { applicable := false for _, rule := range request.Fuzzing { + select { + case <-input.Context().Done(): + return input.Context().Err() + default: + } + err := rule.Execute(&fuzz.ExecuteRuleInput{ Input: input, Callback: func(gr fuzz.GeneratedRequest) bool { + select { + case <-input.Context().Done(): + return false + default: + } + // TODO: replace this after scanContext Refactor return request.executeGeneratedFuzzingRequest(gr, input, callback) }, diff --git a/pkg/protocols/http/request_test.go b/pkg/protocols/http/request_test.go index cd8e860b0..c0bd2bb34 100644 --- a/pkg/protocols/http/request_test.go +++ b/pkg/protocols/http/request_test.go @@ -1,6 +1,7 @@ package http import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -83,7 +84,7 @@ Disallow: /c`)) t.Run("test", func(t *testing.T) { metadata := make(output.InternalEvent) previous := make(output.InternalEvent) - ctxArgs := contextargs.NewWithInput(ts.URL) + ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL) err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { if event.OperatorsResult != nil && event.OperatorsResult.Matched { matchCount++ @@ -159,7 +160,7 @@ func TestDisableTE(t *testing.T) { t.Run("test", func(t *testing.T) { metadata := make(output.InternalEvent) previous := make(output.InternalEvent) - ctxArgs := contextargs.NewWithInput(ts.URL) + ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL) err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { if event.OperatorsResult != nil && event.OperatorsResult.Matched { matchCount++ @@ -172,7 +173,7 @@ func TestDisableTE(t *testing.T) { t.Run("test2", func(t *testing.T) { metadata := make(output.InternalEvent) previous := make(output.InternalEvent) - ctxArgs := contextargs.NewWithInput(ts.URL) + ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL) err := request2.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { if event.OperatorsResult != nil && event.OperatorsResult.Matched { matchCount++ @@ -242,7 +243,7 @@ func TestReqURLPattern(t *testing.T) { t.Run("test", func(t *testing.T) { metadata := make(output.InternalEvent) previous := make(output.InternalEvent) - ctxArgs := contextargs.NewWithInput(ts.URL) + ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL) err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { if event.OperatorsResult != nil && event.OperatorsResult.Matched { matchCount++ diff --git a/pkg/protocols/javascript/js.go b/pkg/protocols/javascript/js.go index 54bb37883..61257a968 100644 --- a/pkg/protocols/javascript/js.go +++ b/pkg/protocols/javascript/js.go @@ -154,6 +154,7 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error { opts := &compiler.ExecuteOptions{ Timeout: request.Timeout, Source: &request.Init, + Context: context.Background(), } // register 'export' function to export variables from init code // these are saved in args and are available in pre-condition and request code @@ -343,7 +344,7 @@ func (request *Request) ExecuteWithResults(target *contextargs.Context, dynamicV argsCopy.TemplateCtx = templateCtx.GetAll() result, err := request.options.JsCompiler.ExecuteWithOptions(request.preConditionCompiled, argsCopy, - &compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.PreCondition}) + &compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.PreCondition, Context: target.Context()}) if err != nil { return errorutil.NewWithTag(request.TemplateID, "could not execute pre-condition: %s", err) } @@ -373,6 +374,12 @@ func (request *Request) ExecuteWithResults(target *contextargs.Context, dynamicV return nil } + select { + case <-input.Context().Done(): + return input.Context().Err() + default: + } + if err := request.executeRequestWithPayloads(hostPort, input, hostname, value, payloadValues, func(result *output.InternalWrappedEvent) { if result.OperatorsResult != nil && result.OperatorsResult.Matched { gotMatches = true @@ -419,6 +426,12 @@ func (request *Request) executeRequestParallel(ctxParent context.Context, hostPo break } + select { + case <-input.Context().Done(): + return + default: + } + // resize check point - nop if there are no changes if shouldFollowGlobal && sg.Size != request.options.Options.PayloadConcurrency { sg.Resize(request.options.Options.PayloadConcurrency) @@ -486,7 +499,7 @@ func (request *Request) executeRequestWithPayloads(hostPort string, input *conte } results, err := request.options.JsCompiler.ExecuteWithOptions(request.scriptCompiled, argsCopy, - &compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.Code}) + &compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.Code, Context: input.Context()}) if err != nil { // shouldn't fail even if it returned error instead create a failure event results = compiler.ExecuteResult{"success": false, "error": err.Error()} diff --git a/pkg/protocols/network/request.go b/pkg/protocols/network/request.go index 1f51ebe11..23eed8cb0 100644 --- a/pkg/protocols/network/request.go +++ b/pkg/protocols/network/request.go @@ -139,6 +139,12 @@ func (request *Request) executeOnTarget(input *contextargs.Context, visited maps variables = generators.MergeMaps(variablesMap, variables, request.options.Constants) for _, kv := range request.addresses { + select { + case <-input.Context().Done(): + return input.Context().Err() + default: + } + actualAddress := replacer.Replace(kv.address, variables) if visited.Has(actualAddress) && !request.options.Options.DisableClustering { @@ -186,6 +192,12 @@ func (request *Request) executeAddress(variables map[string]interface{}, actualA break } + select { + case <-input.Context().Done(): + return input.Context().Err() + default: + } + // resize check point - nop if there are no changes if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency { swg.Resize(request.options.Options.PayloadConcurrency) diff --git a/pkg/protocols/network/request_test.go b/pkg/protocols/network/request_test.go index bf0cd531d..1945888e9 100644 --- a/pkg/protocols/network/request_test.go +++ b/pkg/protocols/network/request_test.go @@ -1,6 +1,7 @@ package network import ( + "context" "encoding/hex" "fmt" "net/http" @@ -65,7 +66,7 @@ func TestNetworkExecuteWithResults(t *testing.T) { t.Run("domain-valid", func(t *testing.T) { metadata := make(output.InternalEvent) previous := make(output.InternalEvent) - ctxArgs := contextargs.NewWithInput(parsed.Host) + ctxArgs := contextargs.NewWithInput(context.Background(), parsed.Host) err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { finalEvent = event }) @@ -81,7 +82,7 @@ func TestNetworkExecuteWithResults(t *testing.T) { t.Run("invalid-port-override", func(t *testing.T) { metadata := make(output.InternalEvent) previous := make(output.InternalEvent) - ctxArgs := contextargs.NewWithInput("127.0.0.1:11211") + ctxArgs := contextargs.NewWithInput(context.Background(), "127.0.0.1:11211") err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { finalEvent = event }) @@ -95,7 +96,7 @@ func TestNetworkExecuteWithResults(t *testing.T) { t.Run("hex-to-string", func(t *testing.T) { metadata := make(output.InternalEvent) previous := make(output.InternalEvent) - ctxArgs := contextargs.NewWithInput(parsed.Host) + ctxArgs := contextargs.NewWithInput(context.Background(), parsed.Host) err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { finalEvent = event }) diff --git a/pkg/protocols/protocols.go b/pkg/protocols/protocols.go index 6328bc36a..6b4904c8d 100644 --- a/pkg/protocols/protocols.go +++ b/pkg/protocols/protocols.go @@ -1,6 +1,7 @@ package protocols import ( + "context" "encoding/base64" "sync/atomic" @@ -173,7 +174,7 @@ func (e *ExecutorOptions) GetTemplateCtx(input *contextargs.MetaInput) *contexta templateCtx, ok := e.templateCtxStore.Get(scanId) if !ok { // if template context does not exist create new and add it to store and return it - templateCtx = contextargs.New() + templateCtx = contextargs.New(context.Background()) templateCtx.MetaInput = input _ = e.templateCtxStore.Set(scanId, templateCtx) } diff --git a/pkg/protocols/ssl/ssl_test.go b/pkg/protocols/ssl/ssl_test.go index 009cf98d3..59c9f85f3 100644 --- a/pkg/protocols/ssl/ssl_test.go +++ b/pkg/protocols/ssl/ssl_test.go @@ -1,6 +1,7 @@ package ssl import ( + "context" "testing" "github.com/stretchr/testify/require" @@ -28,7 +29,7 @@ func TestSSLProtocol(t *testing.T) { require.Nil(t, err, "could not compile ssl request") var gotEvent output.InternalEvent - ctxArgs := contextargs.NewWithInput("scanme.sh:443") + ctxArgs := contextargs.NewWithInput(context.Background(), "scanme.sh:443") err = request.ExecuteWithResults(ctxArgs, nil, nil, func(event *output.InternalWrappedEvent) { gotEvent = event.InternalEvent }) diff --git a/pkg/protocols/utils/variables_test.go b/pkg/protocols/utils/variables_test.go index c24305017..b83529499 100644 --- a/pkg/protocols/utils/variables_test.go +++ b/pkg/protocols/utils/variables_test.go @@ -1,6 +1,7 @@ package utils import ( + "context" "reflect" "testing" @@ -49,7 +50,7 @@ func TestHTTPVariables(t *testing.T) { require.Equal(t, values["Hostname"], "foobar.com", "incorrect hostname") baseURL = "http://scanme.sh" - ctxArgs := contextargs.NewWithInput(baseURL) + ctxArgs := contextargs.NewWithInput(context.Background(), baseURL) ctxArgs.MetaInput.CustomIP = "1.2.3.4" values = GenerateVariablesWithContextArgs(ctxArgs, true) diff --git a/pkg/scan/scan_context.go b/pkg/scan/scan_context.go index 885134901..69fbfde4c 100644 --- a/pkg/scan/scan_context.go +++ b/pkg/scan/scan_context.go @@ -10,7 +10,6 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs" ) - type ScanContextOption func(*ScanContext) func WithEvents() ScanContextOption { @@ -20,7 +19,8 @@ func WithEvents() ScanContextOption { } type ScanContext struct { - context.Context + ctx context.Context + // exported / configurable fields Input *contextargs.Context @@ -43,8 +43,13 @@ type ScanContext struct { } // NewScanContext creates a new scan context using input -func NewScanContext(input *contextargs.Context) *ScanContext { - return &ScanContext{Input: input} +func NewScanContext(ctx context.Context, input *contextargs.Context) *ScanContext { + return &ScanContext{ctx: ctx, Input: input} +} + +// Context returns the context of the scan +func (s *ScanContext) Context() context.Context { + return s.ctx } // GenerateResult returns final results slice from all events diff --git a/pkg/templates/cluster.go b/pkg/templates/cluster.go index b1e0b56e2..97e7b67bd 100644 --- a/pkg/templates/cluster.go +++ b/pkg/templates/cluster.go @@ -303,7 +303,7 @@ func (e *ClusterExecuter) Execute(ctx *scan.ScanContext) (bool, error) { // ExecuteWithResults executes the protocol requests and returns results instead of writing them. func (e *ClusterExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) { - scanCtx := scan.NewScanContext(ctx.Input) + scanCtx := scan.NewScanContext(ctx.Context(), ctx.Input) dynamicValues := make(map[string]interface{}) inputItem := ctx.Input.Clone() diff --git a/pkg/tmplexec/flow/flow_executor.go b/pkg/tmplexec/flow/flow_executor.go index 31aa9f0f4..c3f0014ca 100644 --- a/pkg/tmplexec/flow/flow_executor.go +++ b/pkg/tmplexec/flow/flow_executor.go @@ -174,6 +174,12 @@ func (f *FlowExecutor) Compile() error { // ExecuteWithResults executes the flow and returns results func (f *FlowExecutor) ExecuteWithResults(ctx *scan.ScanContext) error { + select { + case <-ctx.Context().Done(): + return ctx.Context().Err() + default: + } + f.ctx.Input = ctx.Input // -----Load all types of variables----- // add all input args to template context diff --git a/pkg/tmplexec/flow/flow_executor_test.go b/pkg/tmplexec/flow/flow_executor_test.go index bd3a526d2..b47b38a2a 100644 --- a/pkg/tmplexec/flow/flow_executor_test.go +++ b/pkg/tmplexec/flow/flow_executor_test.go @@ -55,8 +55,8 @@ func TestFlowTemplateWithIndex(t *testing.T) { err = Template.Executer.Compile() require.Nil(t, err, "could not compile template") - input := contextargs.NewWithInput("hackerone.com") - ctx := scan.NewScanContext(input) + input := contextargs.NewWithInput(context.Background(), "hackerone.com") + ctx := scan.NewScanContext(context.Background(), input) gotresults, err := Template.Executer.Execute(ctx) require.Nil(t, err, "could not execute template") require.True(t, gotresults) @@ -74,8 +74,8 @@ func TestFlowTemplateWithID(t *testing.T) { err = Template.Executer.Compile() require.Nil(t, err, "could not compile template") - target := contextargs.NewWithInput("hackerone.com") - ctx := scan.NewScanContext(target) + target := contextargs.NewWithInput(context.Background(), "hackerone.com") + ctx := scan.NewScanContext(context.Background(), target) gotresults, err := Template.Executer.Execute(ctx) require.Nil(t, err, "could not execute template") require.True(t, gotresults) @@ -96,8 +96,8 @@ func TestFlowWithProtoPrefix(t *testing.T) { err = Template.Executer.Compile() require.Nil(t, err, "could not compile template") - input := contextargs.NewWithInput("hackerone.com") - ctx := scan.NewScanContext(input) + input := contextargs.NewWithInput(context.Background(), "hackerone.com") + ctx := scan.NewScanContext(context.Background(), input) gotresults, err := Template.Executer.Execute(ctx) require.Nil(t, err, "could not execute template") require.True(t, gotresults) @@ -116,8 +116,8 @@ func TestFlowWithConditionNegative(t *testing.T) { err = Template.Executer.Compile() require.Nil(t, err, "could not compile template") - input := contextargs.NewWithInput("scanme.sh") - ctx := scan.NewScanContext(input) + input := contextargs.NewWithInput(context.Background(), "scanme.sh") + ctx := scan.NewScanContext(context.Background(), input) // expect no results and verify thant dns request is executed and http is not gotresults, err := Template.Executer.Execute(ctx) require.Nil(t, err, "could not execute template") @@ -137,8 +137,8 @@ func TestFlowWithConditionPositive(t *testing.T) { err = Template.Executer.Compile() require.Nil(t, err, "could not compile template") - input := contextargs.NewWithInput("blog.projectdiscovery.io") - ctx := scan.NewScanContext(input) + input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io") + ctx := scan.NewScanContext(context.Background(), input) // positive match . expect results also verify that both dns() and http() were executed gotresults, err := Template.Executer.Execute(ctx) require.Nil(t, err, "could not execute template") @@ -158,8 +158,8 @@ func TestFlowWithNoMatchers(t *testing.T) { err = Template.Executer.Compile() require.Nil(t, err, "could not compile template") - input := contextargs.NewWithInput("blog.projectdiscovery.io") - ctx := scan.NewScanContext(input) + input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io") + ctx := scan.NewScanContext(context.Background(), input) // positive match . expect results also verify that both dns() and http() were executed gotresults, err := Template.Executer.Execute(ctx) require.Nil(t, err, "could not execute template") @@ -174,8 +174,8 @@ func TestFlowWithNoMatchers(t *testing.T) { err = Template.Executer.Compile() require.Nil(t, err, "could not compile template") - anotherInput := contextargs.NewWithInput("blog.projectdiscovery.io") - anotherCtx := scan.NewScanContext(anotherInput) + anotherInput := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io") + anotherCtx := scan.NewScanContext(context.Background(), anotherInput) // positive match . expect results also verify that both dns() and http() were executed gotresults, err = Template.Executer.Execute(anotherCtx) require.Nil(t, err, "could not execute template") diff --git a/pkg/tmplexec/generic/exec.go b/pkg/tmplexec/generic/exec.go index e47cfd730..29ace4477 100644 --- a/pkg/tmplexec/generic/exec.go +++ b/pkg/tmplexec/generic/exec.go @@ -45,6 +45,12 @@ func (g *Generic) ExecuteWithResults(ctx *scan.ScanContext) error { previous := mapsutil.NewSyncLockMap[string, any]() for _, req := range g.requests { + select { + case <-ctx.Context().Done(): + return ctx.Context().Err() + default: + } + inputItem := ctx.Input.Clone() if g.options.InputHelper != nil && ctx.Input.MetaInput.Input != "" { if inputItem.MetaInput.Input = g.options.InputHelper.Transform(inputItem.MetaInput.Input, req.Type()); inputItem.MetaInput.Input == "" { diff --git a/pkg/tmplexec/multiproto/multi.go b/pkg/tmplexec/multiproto/multi.go index 997e62243..58858f971 100644 --- a/pkg/tmplexec/multiproto/multi.go +++ b/pkg/tmplexec/multiproto/multi.go @@ -44,6 +44,12 @@ func (m *MultiProtocol) Compile() error { // ExecuteWithResults executes the template and returns results func (m *MultiProtocol) ExecuteWithResults(ctx *scan.ScanContext) error { + select { + case <-ctx.Context().Done(): + return ctx.Context().Err() + default: + } + // put all readonly args into template context m.options.GetTemplateCtx(ctx.Input.MetaInput).Merge(m.readOnlyArgs) @@ -96,6 +102,12 @@ func (m *MultiProtocol) ExecuteWithResults(ctx *scan.ScanContext) error { // execute all protocols in the queue for _, req := range m.requests { + select { + case <-ctx.Context().Done(): + return ctx.Context().Err() + default: + } + values := m.options.GetTemplateCtx(ctx.Input.MetaInput).GetAll() err := req.ExecuteWithResults(ctx.Input, output.InternalEvent(values), nil, multiProtoCallback) // if error skip execution of next protocols diff --git a/pkg/tmplexec/multiproto/multi_test.go b/pkg/tmplexec/multiproto/multi_test.go index b63268029..4f2aa25e4 100644 --- a/pkg/tmplexec/multiproto/multi_test.go +++ b/pkg/tmplexec/multiproto/multi_test.go @@ -54,8 +54,8 @@ func TestMultiProtoWithDynamicExtractor(t *testing.T) { err = Template.Executer.Compile() require.Nil(t, err, "could not compile template") - input := contextargs.NewWithInput("blog.projectdiscovery.io") - ctx := scan.NewScanContext(input) + input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io") + ctx := scan.NewScanContext(context.Background(), input) gotresults, err := Template.Executer.Execute(ctx) require.Nil(t, err, "could not execute template") require.True(t, gotresults) @@ -71,8 +71,8 @@ func TestMultiProtoWithProtoPrefix(t *testing.T) { err = Template.Executer.Compile() require.Nil(t, err, "could not compile template") - input := contextargs.NewWithInput("blog.projectdiscovery.io") - ctx := scan.NewScanContext(input) + input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io") + ctx := scan.NewScanContext(context.Background(), input) gotresults, err := Template.Executer.Execute(ctx) require.Nil(t, err, "could not execute template") require.True(t, gotresults)