feat: added support for context cancellation to engine (#5096)

* feat: added support for context cancellation to engine

* misc

* feat: added contexts everywhere

* misc

* misc

* use granular http timeouts and increase http timeout to 30s using multiplier

* track response header timeout in mhe

* update responseHeaderTimeout to 5sec

* skip failing windows test

---------

Co-authored-by: Tarun Koyalwar <tarun@projectdiscovery.io>
This commit is contained in:
Ice3man 2024-04-25 15:37:56 +05:30 committed by GitHub
parent 3dfcec0a36
commit 0b82e8b7aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 279 additions and 113 deletions

View File

@ -128,7 +128,7 @@ func executeNucleiAsLibrary(templatePath, templateURL string) ([]string, error)
} }
store.Load() store.Load()
_ = engine.Execute(store.Templates(), provider.NewSimpleInputProviderWithUrls(templateURL)) _ = engine.Execute(context.Background(), store.Templates(), provider.NewSimpleInputProviderWithUrls(templateURL))
engine.WorkPool().Wait() // Wait for the scan to finish engine.WorkPool().Wait() // Wait for the scan to finish
return results, nil return results, nil

View File

@ -1,6 +1,7 @@
package runner package runner
import ( import (
"context"
"fmt" "fmt"
"github.com/projectdiscovery/nuclei/v3/pkg/authprovider/authx" "github.com/projectdiscovery/nuclei/v3/pkg/authprovider/authx"
@ -71,7 +72,8 @@ func GetLazyAuthFetchCallback(opts *AuthLazyFetchOptions) authx.LazyFetchSecret
tmpl := tmpls[0] tmpl := tmpls[0]
// add args to tmpl here // add args to tmpl here
vars := map[string]interface{}{} vars := map[string]interface{}{}
ctx := scan.NewScanContext(contextargs.NewWithInput(d.Input)) mainCtx := context.Background()
ctx := scan.NewScanContext(mainCtx, contextargs.NewWithInput(mainCtx, d.Input))
for _, v := range d.Variables { for _, v := range d.Variables {
vars[v.Key] = v.Value vars[v.Key] = v.Value
ctx.Input.Add(v.Key, v.Value) ctx.Input.Add(v.Key, v.Value)

View File

@ -669,7 +669,7 @@ func (r *Runner) executeTemplatesInput(store *loader.Store, engine *core.Engine)
if r.inputProvider == nil { if r.inputProvider == nil {
return nil, errors.New("no input provider found") return nil, errors.New("no input provider found")
} }
results := engine.ExecuteScanWithOpts(finalTemplates, r.inputProvider, r.options.DisableClustering) results := engine.ExecuteScanWithOpts(context.Background(), finalTemplates, r.inputProvider, r.options.DisableClustering)
return results, nil return results, nil
} }

View File

@ -138,7 +138,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOpts(targets []string, opts ..
engine := core.New(tmpEngine.opts) engine := core.New(tmpEngine.opts)
engine.SetExecuterOptions(unsafeOpts.executerOpts) engine.SetExecuterOptions(unsafeOpts.executerOpts)
_ = engine.ExecuteScanWithOpts(store.Templates(), inputProvider, false) _ = engine.ExecuteScanWithOpts(context.Background(), store.Templates(), inputProvider, false)
engine.WorkPool().Wait() engine.WorkPool().Wait()
return nil return nil

View File

@ -3,6 +3,7 @@ package nuclei
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"io" "io"
"github.com/projectdiscovery/nuclei/v3/pkg/authprovider" "github.com/projectdiscovery/nuclei/v3/pkg/authprovider"
@ -210,7 +211,7 @@ func (e *NucleiEngine) ExecuteWithCallback(callback ...func(event *output.Result
} }
e.resultCallbacks = append(e.resultCallbacks, filtered...) e.resultCallbacks = append(e.resultCallbacks, filtered...)
_ = e.engine.ExecuteScanWithOpts(e.store.Templates(), e.inputProvider, false) _ = e.engine.ExecuteScanWithOpts(context.Background(), e.store.Templates(), e.inputProvider, false)
defer e.engine.WorkPool().Wait() defer e.engine.WorkPool().Wait()
return nil return nil
} }

View File

@ -1,6 +1,7 @@
package core package core
import ( import (
"context"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -20,18 +21,18 @@ import (
// //
// All the execution logic for the templates/workflows happens in this part // All the execution logic for the templates/workflows happens in this part
// of the engine. // of the engine.
func (e *Engine) Execute(templates []*templates.Template, target provider.InputProvider) *atomic.Bool { func (e *Engine) Execute(ctx context.Context, templates []*templates.Template, target provider.InputProvider) *atomic.Bool {
return e.ExecuteScanWithOpts(templates, target, false) return e.ExecuteScanWithOpts(ctx, templates, target, false)
} }
// ExecuteWithResults a list of templates with results // ExecuteWithResults a list of templates with results
func (e *Engine) ExecuteWithResults(templatesList []*templates.Template, target provider.InputProvider, callback func(*output.ResultEvent)) *atomic.Bool { func (e *Engine) ExecuteWithResults(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider, callback func(*output.ResultEvent)) *atomic.Bool {
e.Callback = callback e.Callback = callback
return e.ExecuteScanWithOpts(templatesList, target, false) return e.ExecuteScanWithOpts(ctx, templatesList, target, false)
} }
// ExecuteScanWithOpts executes scan with given scanStrategy // ExecuteScanWithOpts executes scan with given scanStrategy
func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target provider.InputProvider, noCluster bool) *atomic.Bool { func (e *Engine) ExecuteScanWithOpts(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider, noCluster bool) *atomic.Bool {
results := &atomic.Bool{} results := &atomic.Bool{}
selfcontainedWg := &sync.WaitGroup{} selfcontainedWg := &sync.WaitGroup{}
@ -83,14 +84,14 @@ func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target
} }
// Execute All SelfContained in parallel // Execute All SelfContained in parallel
e.executeAllSelfContained(selfContained, results, selfcontainedWg) e.executeAllSelfContained(ctx, selfContained, results, selfcontainedWg)
strategyResult := &atomic.Bool{} strategyResult := &atomic.Bool{}
switch e.options.ScanStrategy { switch e.options.ScanStrategy {
case scanstrategy.TemplateSpray.String(): case scanstrategy.TemplateSpray.String():
strategyResult = e.executeTemplateSpray(filtered, target) strategyResult = e.executeTemplateSpray(ctx, filtered, target)
case scanstrategy.HostSpray.String(): case scanstrategy.HostSpray.String():
strategyResult = e.executeHostSpray(filtered, target) strategyResult = e.executeHostSpray(ctx, filtered, target)
} }
results.CompareAndSwap(false, strategyResult.Load()) results.CompareAndSwap(false, strategyResult.Load())
@ -100,7 +101,7 @@ func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target
} }
// executeTemplateSpray executes scan using template spray strategy where targets are iterated over each template // executeTemplateSpray executes scan using template spray strategy where targets are iterated over each template
func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool { func (e *Engine) executeTemplateSpray(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
results := &atomic.Bool{} results := &atomic.Bool{}
// wp is workpool that contains different waitgroups for // wp is workpool that contains different waitgroups for
@ -108,6 +109,12 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe
wp := e.GetWorkPool() wp := e.GetWorkPool()
for _, template := range templatesList { for _, template := range templatesList {
select {
case <-ctx.Done():
return results
default:
}
// resize check point - nop if there are no changes // resize check point - nop if there are no changes
wp.RefreshWithConfig(e.GetWorkPoolConfig()) wp.RefreshWithConfig(e.GetWorkPoolConfig())
@ -125,7 +132,7 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe
// All other request types are executed here // All other request types are executed here
// Note: executeTemplateWithTargets creates goroutines and blocks // Note: executeTemplateWithTargets creates goroutines and blocks
// given template is executed on all targets // given template is executed on all targets
e.executeTemplateWithTargets(tpl, target, results) e.executeTemplateWithTargets(ctx, tpl, target, results)
}(template) }(template)
} }
wp.Wait() wp.Wait()
@ -133,15 +140,21 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe
} }
// executeHostSpray executes scan using host spray strategy where templates are iterated over each target // executeHostSpray executes scan using host spray strategy where templates are iterated over each target
func (e *Engine) executeHostSpray(templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool { func (e *Engine) executeHostSpray(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
results := &atomic.Bool{} results := &atomic.Bool{}
wp, _ := syncutil.New(syncutil.WithSize(e.options.BulkSize + e.options.HeadlessBulkSize)) wp, _ := syncutil.New(syncutil.WithSize(e.options.BulkSize + e.options.HeadlessBulkSize))
target.Iterate(func(value *contextargs.MetaInput) bool { target.Iterate(func(value *contextargs.MetaInput) bool {
select {
case <-ctx.Done():
return false
default:
}
wp.Add() wp.Add()
go func(targetval *contextargs.MetaInput) { go func(targetval *contextargs.MetaInput) {
defer wp.Done() defer wp.Done()
e.executeTemplatesOnTarget(templatesList, targetval, results) e.executeTemplatesOnTarget(ctx, templatesList, targetval, results)
}(value) }(value)
return true return true
}) })

View File

@ -1,6 +1,7 @@
package core package core
import ( import (
"context"
"sync" "sync"
"sync/atomic" "sync/atomic"
@ -17,14 +18,14 @@ import (
// Executors are low level executors that deals with template execution on a target // Executors are low level executors that deals with template execution on a target
// executeAllSelfContained executes all self contained templates that do not use `target` // executeAllSelfContained executes all self contained templates that do not use `target`
func (e *Engine) executeAllSelfContained(alltemplates []*templates.Template, results *atomic.Bool, sg *sync.WaitGroup) { func (e *Engine) executeAllSelfContained(ctx context.Context, alltemplates []*templates.Template, results *atomic.Bool, sg *sync.WaitGroup) {
for _, v := range alltemplates { for _, v := range alltemplates {
sg.Add(1) sg.Add(1)
go func(template *templates.Template) { go func(template *templates.Template) {
defer sg.Done() defer sg.Done()
var err error var err error
var match bool var match bool
ctx := scan.NewScanContext(contextargs.New()) ctx := scan.NewScanContext(ctx, contextargs.New(ctx))
if e.Callback != nil { if e.Callback != nil {
if results, err := template.Executer.ExecuteWithResults(ctx); err != nil { if results, err := template.Executer.ExecuteWithResults(ctx); err != nil {
for _, result := range results { for _, result := range results {
@ -45,7 +46,7 @@ func (e *Engine) executeAllSelfContained(alltemplates []*templates.Template, res
} }
// executeTemplateWithTarget executes a given template on x targets (with a internal targetpool(i.e concurrency)) // executeTemplateWithTarget executes a given template on x targets (with a internal targetpool(i.e concurrency))
func (e *Engine) executeTemplateWithTargets(template *templates.Template, target provider.InputProvider, results *atomic.Bool) { func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templates.Template, target provider.InputProvider, results *atomic.Bool) {
// this is target pool i.e max target to execute // this is target pool i.e max target to execute
wg := e.workPool.InputPool(template.Type()) wg := e.workPool.InputPool(template.Type())
@ -77,6 +78,12 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target
} }
target.Iterate(func(scannedValue *contextargs.MetaInput) bool { target.Iterate(func(scannedValue *contextargs.MetaInput) bool {
select {
case <-ctx.Done():
return false // exit
default:
}
// Best effort to track the host progression // Best effort to track the host progression
// skips indexes lower than the minimum in-flight at interruption time // skips indexes lower than the minimum in-flight at interruption time
var skip bool var skip bool
@ -114,9 +121,9 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target
var match bool var match bool
var err error var err error
ctxArgs := contextargs.New() ctxArgs := contextargs.New(ctx)
ctxArgs.MetaInput = value ctxArgs.MetaInput = value
ctx := scan.NewScanContext(ctxArgs) ctx := scan.NewScanContext(ctx, ctxArgs)
switch template.Type() { switch template.Type() {
case types.WorkflowProtocol: case types.WorkflowProtocol:
match = e.executeWorkflow(ctx, template.CompiledWorkflow) match = e.executeWorkflow(ctx, template.CompiledWorkflow)
@ -149,7 +156,7 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target
} }
// executeTemplatesOnTarget execute given templates on given single target // executeTemplatesOnTarget execute given templates on given single target
func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, target *contextargs.MetaInput, results *atomic.Bool) { func (e *Engine) executeTemplatesOnTarget(ctx context.Context, alltemplates []*templates.Template, target *contextargs.MetaInput, results *atomic.Bool) {
// all templates are executed on single target // all templates are executed on single target
// wp is workpool that contains different waitgroups for // wp is workpool that contains different waitgroups for
@ -158,6 +165,12 @@ func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, ta
wp := e.GetWorkPool() wp := e.GetWorkPool()
for _, tpl := range alltemplates { for _, tpl := range alltemplates {
select {
case <-ctx.Done():
return
default:
}
// resize check point - nop if there are no changes // resize check point - nop if there are no changes
wp.RefreshWithConfig(e.GetWorkPoolConfig()) wp.RefreshWithConfig(e.GetWorkPoolConfig())
@ -173,9 +186,9 @@ func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, ta
var match bool var match bool
var err error var err error
ctxArgs := contextargs.New() ctxArgs := contextargs.New(ctx)
ctxArgs.MetaInput = value ctxArgs.MetaInput = value
ctx := scan.NewScanContext(ctxArgs) ctx := scan.NewScanContext(ctx, ctxArgs)
switch template.Type() { switch template.Type() {
case types.WorkflowProtocol: case types.WorkflowProtocol:
match = e.executeWorkflow(ctx, template.CompiledWorkflow) match = e.executeWorkflow(ctx, template.CompiledWorkflow)
@ -230,9 +243,11 @@ func (e *ChildExecuter) Execute(template *templates.Template, value *contextargs
go func(tpl *templates.Template) { go func(tpl *templates.Template) {
defer wg.Done() defer wg.Done()
ctxArgs := contextargs.New() // TODO: Workflows are a no-op for now. We need to
// implement them in the future with context cancellation
ctxArgs := contextargs.New(context.Background())
ctxArgs.MetaInput = value ctxArgs.MetaInput = value
ctx := scan.NewScanContext(ctxArgs) ctx := scan.NewScanContext(context.Background(), ctxArgs)
match, err := template.Executer.Execute(ctx) match, err := template.Executer.Execute(ctx)
if err != nil { if err != nil {
gologger.Warning().Msgf("[%s] Could not execute step: %s\n", e.e.executerOpts.Colorizer.BrightBlue(template.ID), err) gologger.Warning().Msgf("[%s] Could not execute step: %s\n", e.e.executerOpts.Colorizer.BrightBlue(template.ID), err)

View File

@ -21,7 +21,7 @@ func (e *Engine) executeWorkflow(ctx *scan.ScanContext, w *workflows.Workflow) b
// at this point we should be at the start root execution of a workflow tree, hence we create global shared instances // at this point we should be at the start root execution of a workflow tree, hence we create global shared instances
workflowCookieJar, _ := cookiejar.New(nil) workflowCookieJar, _ := cookiejar.New(nil)
ctxArgs := contextargs.New() ctxArgs := contextargs.New(ctx.Context())
ctxArgs.MetaInput = ctx.Input.MetaInput ctxArgs.MetaInput = ctx.Input.MetaInput
ctxArgs.CookieJar = workflowCookieJar ctxArgs.CookieJar = workflowCookieJar
@ -139,7 +139,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan
defer swg.Done() defer swg.Done()
// create a new context with the same input but with unset callbacks // create a new context with the same input but with unset callbacks
subCtx := scan.NewScanContext(ctx.Input) subCtx := scan.NewScanContext(ctx.Context(), ctx.Input)
if err := e.runWorkflowStep(subtemplate, subCtx, results, swg, w); err != nil { if err := e.runWorkflowStep(subtemplate, subCtx, results, swg, w); err != nil {
gologger.Warning().Msgf(workflowStepExecutionError, subtemplate.Template, err) gologger.Warning().Msgf(workflowStepExecutionError, subtemplate.Template, err)
} }
@ -165,7 +165,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan
go func(template *workflows.WorkflowTemplate) { go func(template *workflows.WorkflowTemplate) {
// create a new context with the same input but with unset callbacks // create a new context with the same input but with unset callbacks
subCtx := scan.NewScanContext(ctx.Input) subCtx := scan.NewScanContext(ctx.Context(), ctx.Input)
if err := e.runWorkflowStep(template, subCtx, results, swg, w); err != nil { if err := e.runWorkflowStep(template, subCtx, results, swg, w); err != nil {
gologger.Warning().Msgf(workflowStepExecutionError, template.Template, err) gologger.Warning().Msgf(workflowStepExecutionError, template.Template, err)
} }

View File

@ -1,6 +1,7 @@
package core package core
import ( import (
"context"
"testing" "testing"
"github.com/projectdiscovery/nuclei/v3/pkg/model/types/stringslice" "github.com/projectdiscovery/nuclei/v3/pkg/model/types/stringslice"
@ -25,8 +26,8 @@ func TestWorkflowsSimple(t *testing.T) {
}} }}
engine := &Engine{} engine := &Engine{}
input := contextargs.NewWithInput("https://test.com") input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(input) ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow) matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value") require.True(t, matched, "could not get correct match value")
} }
@ -49,8 +50,8 @@ func TestWorkflowsSimpleMultiple(t *testing.T) {
}} }}
engine := &Engine{} engine := &Engine{}
input := contextargs.NewWithInput("https://test.com") input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(input) ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow) matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value") require.True(t, matched, "could not get correct match value")
@ -77,8 +78,8 @@ func TestWorkflowsSubtemplates(t *testing.T) {
}} }}
engine := &Engine{} engine := &Engine{}
input := contextargs.NewWithInput("https://test.com") input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(input) ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow) matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value") require.True(t, matched, "could not get correct match value")
@ -103,8 +104,8 @@ func TestWorkflowsSubtemplatesNoMatch(t *testing.T) {
}} }}
engine := &Engine{} engine := &Engine{}
input := contextargs.NewWithInput("https://test.com") input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(input) ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow) matched := engine.executeWorkflow(ctx, workflow)
require.False(t, matched, "could not get correct match value") require.False(t, matched, "could not get correct match value")
@ -134,8 +135,8 @@ func TestWorkflowsSubtemplatesWithMatcher(t *testing.T) {
}} }}
engine := &Engine{} engine := &Engine{}
input := contextargs.NewWithInput("https://test.com") input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(input) ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow) matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value") require.True(t, matched, "could not get correct match value")
@ -165,8 +166,8 @@ func TestWorkflowsSubtemplatesWithMatcherNoMatch(t *testing.T) {
}} }}
engine := &Engine{} engine := &Engine{}
input := contextargs.NewWithInput("https://test.com") input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(input) ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow) matched := engine.executeWorkflow(ctx, workflow)
require.False(t, matched, "could not get correct match value") require.False(t, matched, "could not get correct match value")

View File

@ -3,6 +3,7 @@ package list
import ( import (
"net" "net"
"os" "os"
"runtime"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
@ -77,6 +78,9 @@ func (m *mockDnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
func Test_scanallips_normalizeStoreInputValue(t *testing.T) { func Test_scanallips_normalizeStoreInputValue(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skipping test see: https://github.com/projectdiscovery/nuclei/issues/5097")
}
srv := &dns.Server{Addr: ":" + strconv.Itoa(61234), Net: "udp"} srv := &dns.Server{Addr: ":" + strconv.Itoa(61234), Net: "udp"}
srv.Handler = &mockDnsHandler{} srv.Handler = &mockDnsHandler{}

View File

@ -37,6 +37,8 @@ type ExecuteOptions struct {
// Source is original source of the script // Source is original source of the script
Source *string Source *string
Context context.Context
// Manually exported objects // Manually exported objects
exports map[string]interface{} exports map[string]interface{}
} }
@ -77,13 +79,13 @@ func (c *Compiler) Execute(code string, args *ExecuteArgs) (ExecuteResult, error
if err != nil { if err != nil {
return nil, err return nil, err
} }
return c.ExecuteWithOptions(p, args, &ExecuteOptions{}) return c.ExecuteWithOptions(p, args, &ExecuteOptions{Context: context.Background()})
} }
// ExecuteWithOptions executes a script with the provided options. // ExecuteWithOptions executes a script with the provided options.
func (c *Compiler) ExecuteWithOptions(program *goja.Program, args *ExecuteArgs, opts *ExecuteOptions) (ExecuteResult, error) { func (c *Compiler) ExecuteWithOptions(program *goja.Program, args *ExecuteArgs, opts *ExecuteOptions) (ExecuteResult, error) {
if opts == nil { if opts == nil {
opts = &ExecuteOptions{} opts = &ExecuteOptions{Context: context.Background()}
} }
if args == nil { if args == nil {
args = NewExecuteArgs() args = NewExecuteArgs()
@ -105,7 +107,7 @@ func (c *Compiler) ExecuteWithOptions(program *goja.Program, args *ExecuteArgs,
} }
// execute with context and timeout // execute with context and timeout
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(opts.Timeout)*time.Second) ctx, cancel := context.WithTimeout(opts.Context, time.Duration(opts.Timeout)*time.Second)
defer cancel() defer cancel()
// execute the script // execute the script
results, err := contextutil.ExecFuncWithTwoReturns(ctx, func() (val goja.Value, err error) { results, err := contextutil.ExecFuncWithTwoReturns(ctx, func() (val goja.Value, err error) {

View File

@ -199,6 +199,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
Source: &request.PreCondition, Source: &request.PreCondition,
Callback: registerPreConditionFunctions, Callback: registerPreConditionFunctions,
Cleanup: cleanUpPreConditionFunctions, Cleanup: cleanUpPreConditionFunctions,
Context: input.Context(),
}) })
if err != nil { if err != nil {
return errorutil.NewWithTag(request.TemplateID, "could not execute pre-condition: %s", err) return errorutil.NewWithTag(request.TemplateID, "could not execute pre-condition: %s", err)

View File

@ -3,6 +3,7 @@
package code package code
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -31,7 +32,7 @@ func TestCodeProtocol(t *testing.T) {
require.Nil(t, err, "could not compile code request") require.Nil(t, err, "could not compile code request")
var gotEvent output.InternalEvent var gotEvent output.InternalEvent
ctxArgs := contextargs.NewWithInput("") ctxArgs := contextargs.NewWithInput(context.Background(), "")
err = request.ExecuteWithResults(ctxArgs, nil, nil, func(event *output.InternalWrappedEvent) { err = request.ExecuteWithResults(ctxArgs, nil, nil, func(event *output.InternalWrappedEvent) {
gotEvent = event.InternalEvent gotEvent = event.InternalEvent
}) })

View File

@ -1,6 +1,7 @@
package automaticscan package automaticscan
import ( import (
"context"
"io" "io"
"net/http" "net/http"
"os" "os"
@ -189,7 +190,7 @@ func (s *Service) executeAutomaticScanOnTarget(input *contextargs.MetaInput) {
execOptions.Progress = &testutils.MockProgressClient{} // stats are not supported yet due to centralized logic and cannot be reinitialized execOptions.Progress = &testutils.MockProgressClient{} // stats are not supported yet due to centralized logic and cannot be reinitialized
eng.SetExecuterOptions(execOptions) eng.SetExecuterOptions(execOptions)
tmp := eng.ExecuteScanWithOpts(finalTemplates, provider.NewSimpleInputProviderWithUrls(input.Input), true) tmp := eng.ExecuteScanWithOpts(context.Background(), finalTemplates, provider.NewSimpleInputProviderWithUrls(input.Input), true)
s.hasResults.Store(tmp.Load()) s.hasResults.Store(tmp.Load())
} }
@ -244,7 +245,9 @@ func (s *Service) getTagsUsingWappalyzer(input *contextargs.MetaInput) []string
// getTagsUsingDetectionTemplates returns tags using detection templates // getTagsUsingDetectionTemplates returns tags using detection templates
func (s *Service) getTagsUsingDetectionTemplates(input *contextargs.MetaInput) ([]string, int) { func (s *Service) getTagsUsingDetectionTemplates(input *contextargs.MetaInput) ([]string, int) {
ctxArgs := contextargs.NewWithInput(input.Input) ctx := context.Background()
ctxArgs := contextargs.NewWithInput(ctx, input.Input)
// execute tech detection templates on target // execute tech detection templates on target
tags := map[string]struct{}{} tags := map[string]struct{}{}
@ -256,7 +259,7 @@ func (s *Service) getTagsUsingDetectionTemplates(input *contextargs.MetaInput) (
sg.Add() sg.Add()
go func(template *templates.Template) { go func(template *templates.Template) {
defer sg.Done() defer sg.Done()
ctx := scan.NewScanContext(ctxArgs) ctx := scan.NewScanContext(ctx, ctxArgs)
ctx.OnResult = func(event *output.InternalWrappedEvent) { ctx.OnResult = func(event *output.InternalWrappedEvent) {
if event == nil { if event == nil {
return return

View File

@ -1,6 +1,7 @@
package contextargs package contextargs
import ( import (
"context"
"net/http/cookiejar" "net/http/cookiejar"
"strings" "strings"
"sync/atomic" "sync/atomic"
@ -19,6 +20,8 @@ var (
// Context implements a shared context struct to share information across multiple templates within a workflow // Context implements a shared context struct to share information across multiple templates within a workflow
type Context struct { type Context struct {
ctx context.Context
// Meta is the target for the executor // Meta is the target for the executor
MetaInput *MetaInput MetaInput *MetaInput
@ -30,17 +33,18 @@ type Context struct {
} }
// Create a new contextargs instance // Create a new contextargs instance
func New() *Context { func New(ctx context.Context) *Context {
return NewWithInput("") return NewWithInput(ctx, "")
} }
// Create a new contextargs instance with input string // Create a new contextargs instance with input string
func NewWithInput(input string) *Context { func NewWithInput(ctx context.Context, input string) *Context {
jar, err := cookiejar.New(nil) jar, err := cookiejar.New(nil)
if err != nil { if err != nil {
gologger.Error().Msgf("contextargs: could not create cookie jar: %s\n", err) gologger.Error().Msgf("contextargs: could not create cookie jar: %s\n", err)
} }
return &Context{ return &Context{
ctx: ctx,
MetaInput: &MetaInput{Input: input}, MetaInput: &MetaInput{Input: input},
CookieJar: jar, CookieJar: jar,
args: &mapsutil.SyncLockMap[string, interface{}]{ args: &mapsutil.SyncLockMap[string, interface{}]{
@ -50,6 +54,11 @@ func NewWithInput(input string) *Context {
} }
} }
// Context returns the context of the current contextargs
func (ctx *Context) Context() context.Context {
return ctx.ctx
}
// Set the specific key-value pair // Set the specific key-value pair
func (ctx *Context) Set(key string, value interface{}) { func (ctx *Context) Set(key string, value interface{}) {
_ = ctx.args.Set(key, value) _ = ctx.args.Set(key, value)
@ -158,6 +167,7 @@ func (ctx *Context) HasArgs() bool {
func (ctx *Context) Clone() *Context { func (ctx *Context) Clone() *Context {
newCtx := &Context{ newCtx := &Context{
ctx: ctx.ctx,
MetaInput: ctx.MetaInput.Clone(), MetaInput: ctx.MetaInput.Clone(),
args: ctx.args.Clone(), args: ctx.args.Clone(),
CookieJar: ctx.CookieJar, CookieJar: ctx.CookieJar,

View File

@ -124,7 +124,7 @@ func (c *Cache) MarkFailed(value string, err error) {
_ = c.failedTargets.Set(finalValue, existingCacheItemValue) _ = c.failedTargets.Set(finalValue, existingCacheItemValue)
} }
var reCheckError = regexp.MustCompile(`(no address found for host|Client\.Timeout exceeded while awaiting headers|could not resolve host|connection refused|connection reset by peer|i/o timeout|could not connect to any address found for host)`) var reCheckError = regexp.MustCompile(`(no address found for host|Client\.Timeout exceeded while awaiting headers|could not resolve host|connection refused|connection reset by peer|i/o timeout|could not connect to any address found for host|timeout awaiting response headers)`)
// checkError checks if an error represents a type that should be // checkError checks if an error represents a type that should be
// added to the host skipping table. // added to the host skipping table.

View File

@ -80,6 +80,12 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata,
break break
} }
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
// resize check point - nop if there are no changes // resize check point - nop if there are no changes
if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency { if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency {
swg.Resize(request.options.Options.PayloadConcurrency) swg.Resize(request.options.Options.PayloadConcurrency)

View File

@ -1,6 +1,7 @@
package dns package dns
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -54,7 +55,7 @@ func TestDNSExecuteWithResults(t *testing.T) {
t.Run("domain-valid", func(t *testing.T) { t.Run("domain-valid", func(t *testing.T) {
metadata := make(output.InternalEvent) metadata := make(output.InternalEvent)
previous := make(output.InternalEvent) previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput("example.com") ctxArgs := contextargs.NewWithInput(context.Background(), "example.com")
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
finalEvent = event finalEvent = event
}) })
@ -70,7 +71,7 @@ func TestDNSExecuteWithResults(t *testing.T) {
t.Run("url-to-domain", func(t *testing.T) { t.Run("url-to-domain", func(t *testing.T) {
metadata := make(output.InternalEvent) metadata := make(output.InternalEvent)
previous := make(output.InternalEvent) previous := make(output.InternalEvent)
err := request.ExecuteWithResults(contextargs.NewWithInput("https://example.com"), metadata, previous, func(event *output.InternalWrappedEvent) { err := request.ExecuteWithResults(contextargs.NewWithInput(context.Background(), "https://example.com"), metadata, previous, func(event *output.InternalWrappedEvent) {
finalEvent = event finalEvent = event
}) })
require.Nil(t, err, "could not execute dns request") require.Nil(t, err, "could not execute dns request")

View File

@ -1,6 +1,7 @@
package file package file
import ( import (
"context"
"os" "os"
"path/filepath" "path/filepath"
"testing" "testing"
@ -67,7 +68,7 @@ func TestFileExecuteWithResults(t *testing.T) {
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
metadata := make(output.InternalEvent) metadata := make(output.InternalEvent)
previous := make(output.InternalEvent) previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(tempDir) ctxArgs := contextargs.NewWithInput(context.Background(), tempDir)
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
finalEvent = event finalEvent = event
}) })

View File

@ -1,6 +1,7 @@
package engine package engine
import ( import (
"context"
"fmt" "fmt"
"io" "io"
"math/rand" "math/rand"
@ -595,7 +596,7 @@ func testHeadless(t *testing.T, actions []*Action, timeout time.Duration, handle
ts := httptest.NewServer(http.HandlerFunc(handler)) ts := httptest.NewServer(http.HandlerFunc(handler))
defer ts.Close() defer ts.Close()
input := contextargs.NewWithInput(ts.URL) input := contextargs.NewWithInput(context.Background(), ts.URL)
input.CookieJar, err = cookiejar.New(nil) input.CookieJar, err = cookiejar.New(nil)
require.Nil(t, err) require.Nil(t, err)
@ -674,7 +675,7 @@ func TestBlockedHeadlessURLS(t *testing.T) {
{ActionType: ActionTypeHolder{ActionType: ActionWaitLoad}}, {ActionType: ActionTypeHolder{ActionType: ActionWaitLoad}},
} }
data, page, err := instance.Run(contextargs.NewWithInput(ts.URL), actions, nil, &Options{Timeout: 20 * time.Second, Options: opts}) // allow file access in test data, page, err := instance.Run(contextargs.NewWithInput(context.Background(), ts.URL), actions, nil, &Options{Timeout: 20 * time.Second, Options: opts}) // allow file access in test
require.Error(t, err, "expected error for url %s got %v", testcase, data) require.Error(t, err, "expected error for url %s got %v", testcase, data)
require.True(t, stringsutil.ContainsAny(err.Error(), "net::ERR_ACCESS_DENIED", "failed to parse url", "Cannot navigate to invalid URL", "net::ERR_ABORTED", "net::ERR_INVALID_URL"), "found different error %v for testcases %v", err, testcase) require.True(t, stringsutil.ContainsAny(err.Error(), "net::ERR_ACCESS_DENIED", "failed to parse url", "Cannot navigate to invalid URL", "net::ERR_ABORTED", "net::ERR_INVALID_URL"), "found different error %v for testcases %v", err, testcase)
require.Len(t, data, 0, "expected no data for url %s got %v", testcase, data) require.Len(t, data, 0, "expected no data for url %s got %v", testcase, data)

View File

@ -44,7 +44,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata,
if err != nil { if err != nil {
return err return err
} }
input = contextargs.NewWithInput(url) input = contextargs.NewWithInput(input.Context(), url)
} }
if request.options.Browser.UserAgent() == "" { if request.options.Browser.UserAgent() == "" {

View File

@ -40,7 +40,7 @@ func TestMakeRequestFromModal(t *testing.T) {
generator := request.newGenerator(false) generator := request.newGenerator(false)
inputData, payloads, _ := generator.nextValue() inputData, payloads, _ := generator.nextValue()
req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{}) req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request") require.Nil(t, err, "could not make http request")
if req.request.URL == nil { if req.request.URL == nil {
t.Fatalf("url is nil in generator make") t.Fatalf("url is nil in generator make")
@ -70,13 +70,13 @@ func TestMakeRequestFromModalTrimSuffixSlash(t *testing.T) {
generator := request.newGenerator(false) generator := request.newGenerator(false)
inputData, payloads, _ := generator.nextValue() inputData, payloads, _ := generator.nextValue()
req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com/test.php"), inputData, payloads, map[string]interface{}{}) req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com/test.php"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request") require.Nil(t, err, "could not make http request")
require.Equal(t, "https://example.com/test.php?query=example", req.request.URL.String(), "could not get correct request path") require.Equal(t, "https://example.com/test.php?query=example", req.request.URL.String(), "could not get correct request path")
generator = request.newGenerator(false) generator = request.newGenerator(false)
inputData, payloads, _ = generator.nextValue() inputData, payloads, _ = generator.nextValue()
req, err = generator.Make(context.Background(), contextargs.NewWithInput("https://example.com/test/"), inputData, payloads, map[string]interface{}{}) req, err = generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com/test/"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request") require.Nil(t, err, "could not make http request")
require.Equal(t, "https://example.com/test/?query=example", req.request.URL.String(), "could not get correct request path") require.Equal(t, "https://example.com/test/?query=example", req.request.URL.String(), "could not get correct request path")
} }
@ -110,13 +110,13 @@ Accept-Encoding: gzip`},
generator := request.newGenerator(false) generator := request.newGenerator(false)
inputData, payloads, _ := generator.nextValue() inputData, payloads, _ := generator.nextValue()
req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{}) req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request") require.Nil(t, err, "could not make http request")
authorization := req.request.Header.Get("Authorization") authorization := req.request.Header.Get("Authorization")
require.Equal(t, "Basic admin:admin", authorization, "could not get correct authorization headers from raw") require.Equal(t, "Basic admin:admin", authorization, "could not get correct authorization headers from raw")
inputData, payloads, _ = generator.nextValue() inputData, payloads, _ = generator.nextValue()
req, err = generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{}) req, err = generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request") require.Nil(t, err, "could not make http request")
authorization = req.request.Header.Get("Authorization") authorization = req.request.Header.Get("Authorization")
require.Equal(t, "Basic admin:guest", authorization, "could not get correct authorization headers from raw") require.Equal(t, "Basic admin:guest", authorization, "could not get correct authorization headers from raw")
@ -151,13 +151,13 @@ Accept-Encoding: gzip`},
generator := request.newGenerator(false) generator := request.newGenerator(false)
inputData, payloads, _ := generator.nextValue() inputData, payloads, _ := generator.nextValue()
req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{}) req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request") require.Nil(t, err, "could not make http request")
authorization := req.request.Header.Get("Authorization") authorization := req.request.Header.Get("Authorization")
require.Equal(t, "Basic YWRtaW46YWRtaW4=", authorization, "could not get correct authorization headers from raw") require.Equal(t, "Basic YWRtaW46YWRtaW4=", authorization, "could not get correct authorization headers from raw")
inputData, payloads, _ = generator.nextValue() inputData, payloads, _ = generator.nextValue()
req, err = generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{}) req, err = generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request") require.Nil(t, err, "could not make http request")
authorization = req.request.Header.Get("Authorization") authorization = req.request.Header.Get("Authorization")
require.Equal(t, "Basic YWRtaW46Z3Vlc3Q=", authorization, "could not get correct authorization headers from raw") require.Equal(t, "Basic YWRtaW46Z3Vlc3Q=", authorization, "could not get correct authorization headers from raw")
@ -195,7 +195,7 @@ func TestMakeRequestFromModelUniqueInteractsh(t *testing.T) {
require.Nil(t, err, "could not create interactsh client") require.Nil(t, err, "could not create interactsh client")
inputData, payloads, _ := generator.nextValue() inputData, payloads, _ := generator.nextValue()
got, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{}) got, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request") require.Nil(t, err, "could not make http request")
// check if all the interactsh markers are replaced with unique urls // check if all the interactsh markers are replaced with unique urls

View File

@ -35,8 +35,18 @@ var (
forceMaxRedirects int forceMaxRedirects int
normalClient *retryablehttp.Client normalClient *retryablehttp.Client
clientPool *mapsutil.SyncLockMap[string, *retryablehttp.Client] clientPool *mapsutil.SyncLockMap[string, *retryablehttp.Client]
// ResponseHeaderTimeout is the timeout for response headers
// to be read from the server (this prevents infinite hang started by server if any)
ResponseHeaderTimeout = time.Duration(5) * time.Second
// HttpTimeoutMultiplier is the multiplier for the http timeout
HttpTimeoutMultiplier = 3
) )
// GetHttpTimeout returns the http timeout for the client
func GetHttpTimeout(opts *types.Options) time.Duration {
return time.Duration(opts.Timeout*HttpTimeoutMultiplier) * time.Second
}
// Init initializes the clientpool implementation // Init initializes the clientpool implementation
func Init(options *types.Options) error { func Init(options *types.Options) error {
// Don't create clients if already created in the past. // Don't create clients if already created in the past.
@ -139,7 +149,7 @@ func GetRawHTTP(options *types.Options) *rawhttp.Client {
} else if Dialer != nil { } else if Dialer != nil {
rawHttpOptions.FastDialer = Dialer rawHttpOptions.FastDialer = Dialer
} }
rawHttpOptions.Timeout = time.Duration(options.Timeout) * time.Second rawHttpOptions.Timeout = GetHttpTimeout(options)
rawHttpClient = rawhttp.NewClient(rawHttpOptions) rawHttpClient = rawhttp.NewClient(rawHttpOptions)
} }
return rawHttpClient return rawHttpClient
@ -242,6 +252,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl
MaxConnsPerHost: maxConnsPerHost, MaxConnsPerHost: maxConnsPerHost,
TLSClientConfig: tlsConfig, TLSClientConfig: tlsConfig,
DisableKeepAlives: disableKeepAlives, DisableKeepAlives: disableKeepAlives,
ResponseHeaderTimeout: ResponseHeaderTimeout,
} }
if types.ProxyURL != "" { if types.ProxyURL != "" {
@ -288,7 +299,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl
CheckRedirect: makeCheckRedirectFunc(redirectFlow, maxRedirects), CheckRedirect: makeCheckRedirectFunc(redirectFlow, maxRedirects),
} }
if !configuration.NoTimeout { if !configuration.NoTimeout {
httpclient.Timeout = time.Duration(options.Timeout) * time.Second httpclient.Timeout = GetHttpTimeout(options)
} }
client := retryablehttp.NewWithHTTPClient(httpclient, retryableHttpOptions) client := retryablehttp.NewWithHTTPClient(httpclient, retryableHttpOptions)
if jar != nil { if jar != nil {

View File

@ -235,6 +235,12 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
break break
} }
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
// resize check point - nop if there are no changes // resize check point - nop if there are no changes
if shouldFollowGlobal && spmHandler.Size() != request.options.Options.PayloadConcurrency { if shouldFollowGlobal && spmHandler.Size() != request.options.Options.PayloadConcurrency {
spmHandler.Resize(request.options.Options.PayloadConcurrency) spmHandler.Resize(request.options.Options.PayloadConcurrency)
@ -356,6 +362,13 @@ func (request *Request) executeTurboHTTP(input *contextargs.Context, dynamicValu
if !ok { if !ok {
break break
} }
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
if spmHandler.FoundFirstMatch() || request.isUnresponsiveHost(input) || spmHandler.Cancelled() { if spmHandler.FoundFirstMatch() || request.isUnresponsiveHost(input) || spmHandler.Cancelled() {
// skip if first match is found // skip if first match is found
break break
@ -433,8 +446,9 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
request.options.RateLimitTake() request.options.RateLimitTake()
ctx := request.newContext(input) ctx := request.newContext(input)
ctxWithTimeout, cancel := context.WithTimeout(ctx, time.Duration(request.options.Options.Timeout)*time.Second) ctxWithTimeout, cancel := context.WithTimeout(ctx, httpclientpool.GetHttpTimeout(request.options.Options))
defer cancel() defer cancel()
generatedHttpRequest, err := generator.Make(ctxWithTimeout, input, data, payloads, dynamicValue) generatedHttpRequest, err := generator.Make(ctxWithTimeout, input, data, payloads, dynamicValue)
if err != nil { if err != nil {
if err == types.ErrNoMoreRequests { if err == types.ErrNoMoreRequests {
@ -512,6 +526,13 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
if !ok { if !ok {
break break
} }
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
var gotErr error var gotErr error
var skip bool var skip bool
if len(gotDynamicValues) > 0 { if len(gotDynamicValues) > 0 {
@ -1041,9 +1062,9 @@ func (request *Request) pruneSignatureInternalValues(maps ...map[string]interfac
func (request *Request) newContext(input *contextargs.Context) context.Context { func (request *Request) newContext(input *contextargs.Context) context.Context {
if input.MetaInput.CustomIP != "" { if input.MetaInput.CustomIP != "" {
return context.WithValue(context.Background(), fastdialer.IP, input.MetaInput.CustomIP) return context.WithValue(input.Context(), fastdialer.IP, input.MetaInput.CustomIP)
} }
return context.Background() return input.Context()
} }
// markUnresponsiveHost checks if the error is a unreponsive host error and marks it // markUnresponsiveHost checks if the error is a unreponsive host error and marks it

View File

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/projectdiscovery/fastdialer/fastdialer" "github.com/projectdiscovery/fastdialer/fastdialer"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/http/httpclientpool"
"github.com/projectdiscovery/retryablehttp-go" "github.com/projectdiscovery/retryablehttp-go"
iputil "github.com/projectdiscovery/utils/ip" iputil "github.com/projectdiscovery/utils/ip"
stringsutil "github.com/projectdiscovery/utils/strings" stringsutil "github.com/projectdiscovery/utils/strings"
@ -124,7 +125,7 @@ func (r *Request) parseAnnotations(rawRequest string, request *retryablehttp.Req
} }
} else { } else {
//nolint:govet // cancelled automatically by withTimeout //nolint:govet // cancelled automatically by withTimeout
ctx, overrides.cancelFunc = context.WithTimeout(context.Background(), time.Duration(r.options.Options.Timeout)*time.Second) ctx, overrides.cancelFunc = context.WithTimeout(context.Background(), httpclientpool.GetHttpTimeout(r.options.Options))
request = request.Clone(ctx) request = request.Clone(ctx)
} }
} }

View File

@ -110,9 +110,21 @@ func (request *Request) executeFuzzingRule(input *contextargs.Context, previous
func (request *Request) executeAllFuzzingRules(input *contextargs.Context, values map[string]interface{}, baseRequest *retryablehttp.Request, callback protocols.OutputEventCallback) error { func (request *Request) executeAllFuzzingRules(input *contextargs.Context, values map[string]interface{}, baseRequest *retryablehttp.Request, callback protocols.OutputEventCallback) error {
applicable := false applicable := false
for _, rule := range request.Fuzzing { for _, rule := range request.Fuzzing {
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
err := rule.Execute(&fuzz.ExecuteRuleInput{ err := rule.Execute(&fuzz.ExecuteRuleInput{
Input: input, Input: input,
Callback: func(gr fuzz.GeneratedRequest) bool { Callback: func(gr fuzz.GeneratedRequest) bool {
select {
case <-input.Context().Done():
return false
default:
}
// TODO: replace this after scanContext Refactor // TODO: replace this after scanContext Refactor
return request.executeGeneratedFuzzingRequest(gr, input, callback) return request.executeGeneratedFuzzingRequest(gr, input, callback)
}, },

View File

@ -1,6 +1,7 @@
package http package http
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -83,7 +84,7 @@ Disallow: /c`))
t.Run("test", func(t *testing.T) { t.Run("test", func(t *testing.T) {
metadata := make(output.InternalEvent) metadata := make(output.InternalEvent)
previous := make(output.InternalEvent) previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(ts.URL) ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
if event.OperatorsResult != nil && event.OperatorsResult.Matched { if event.OperatorsResult != nil && event.OperatorsResult.Matched {
matchCount++ matchCount++
@ -159,7 +160,7 @@ func TestDisableTE(t *testing.T) {
t.Run("test", func(t *testing.T) { t.Run("test", func(t *testing.T) {
metadata := make(output.InternalEvent) metadata := make(output.InternalEvent)
previous := make(output.InternalEvent) previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(ts.URL) ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
if event.OperatorsResult != nil && event.OperatorsResult.Matched { if event.OperatorsResult != nil && event.OperatorsResult.Matched {
matchCount++ matchCount++
@ -172,7 +173,7 @@ func TestDisableTE(t *testing.T) {
t.Run("test2", func(t *testing.T) { t.Run("test2", func(t *testing.T) {
metadata := make(output.InternalEvent) metadata := make(output.InternalEvent)
previous := make(output.InternalEvent) previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(ts.URL) ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
err := request2.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { err := request2.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
if event.OperatorsResult != nil && event.OperatorsResult.Matched { if event.OperatorsResult != nil && event.OperatorsResult.Matched {
matchCount++ matchCount++
@ -242,7 +243,7 @@ func TestReqURLPattern(t *testing.T) {
t.Run("test", func(t *testing.T) { t.Run("test", func(t *testing.T) {
metadata := make(output.InternalEvent) metadata := make(output.InternalEvent)
previous := make(output.InternalEvent) previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(ts.URL) ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
if event.OperatorsResult != nil && event.OperatorsResult.Matched { if event.OperatorsResult != nil && event.OperatorsResult.Matched {
matchCount++ matchCount++

View File

@ -154,6 +154,7 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error {
opts := &compiler.ExecuteOptions{ opts := &compiler.ExecuteOptions{
Timeout: request.Timeout, Timeout: request.Timeout,
Source: &request.Init, Source: &request.Init,
Context: context.Background(),
} }
// register 'export' function to export variables from init code // register 'export' function to export variables from init code
// these are saved in args and are available in pre-condition and request code // these are saved in args and are available in pre-condition and request code
@ -343,7 +344,7 @@ func (request *Request) ExecuteWithResults(target *contextargs.Context, dynamicV
argsCopy.TemplateCtx = templateCtx.GetAll() argsCopy.TemplateCtx = templateCtx.GetAll()
result, err := request.options.JsCompiler.ExecuteWithOptions(request.preConditionCompiled, argsCopy, result, err := request.options.JsCompiler.ExecuteWithOptions(request.preConditionCompiled, argsCopy,
&compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.PreCondition}) &compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.PreCondition, Context: target.Context()})
if err != nil { if err != nil {
return errorutil.NewWithTag(request.TemplateID, "could not execute pre-condition: %s", err) return errorutil.NewWithTag(request.TemplateID, "could not execute pre-condition: %s", err)
} }
@ -373,6 +374,12 @@ func (request *Request) ExecuteWithResults(target *contextargs.Context, dynamicV
return nil return nil
} }
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
if err := request.executeRequestWithPayloads(hostPort, input, hostname, value, payloadValues, func(result *output.InternalWrappedEvent) { if err := request.executeRequestWithPayloads(hostPort, input, hostname, value, payloadValues, func(result *output.InternalWrappedEvent) {
if result.OperatorsResult != nil && result.OperatorsResult.Matched { if result.OperatorsResult != nil && result.OperatorsResult.Matched {
gotMatches = true gotMatches = true
@ -419,6 +426,12 @@ func (request *Request) executeRequestParallel(ctxParent context.Context, hostPo
break break
} }
select {
case <-input.Context().Done():
return
default:
}
// resize check point - nop if there are no changes // resize check point - nop if there are no changes
if shouldFollowGlobal && sg.Size != request.options.Options.PayloadConcurrency { if shouldFollowGlobal && sg.Size != request.options.Options.PayloadConcurrency {
sg.Resize(request.options.Options.PayloadConcurrency) sg.Resize(request.options.Options.PayloadConcurrency)
@ -486,7 +499,7 @@ func (request *Request) executeRequestWithPayloads(hostPort string, input *conte
} }
results, err := request.options.JsCompiler.ExecuteWithOptions(request.scriptCompiled, argsCopy, results, err := request.options.JsCompiler.ExecuteWithOptions(request.scriptCompiled, argsCopy,
&compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.Code}) &compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.Code, Context: input.Context()})
if err != nil { if err != nil {
// shouldn't fail even if it returned error instead create a failure event // shouldn't fail even if it returned error instead create a failure event
results = compiler.ExecuteResult{"success": false, "error": err.Error()} results = compiler.ExecuteResult{"success": false, "error": err.Error()}

View File

@ -139,6 +139,12 @@ func (request *Request) executeOnTarget(input *contextargs.Context, visited maps
variables = generators.MergeMaps(variablesMap, variables, request.options.Constants) variables = generators.MergeMaps(variablesMap, variables, request.options.Constants)
for _, kv := range request.addresses { for _, kv := range request.addresses {
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
actualAddress := replacer.Replace(kv.address, variables) actualAddress := replacer.Replace(kv.address, variables)
if visited.Has(actualAddress) && !request.options.Options.DisableClustering { if visited.Has(actualAddress) && !request.options.Options.DisableClustering {
@ -186,6 +192,12 @@ func (request *Request) executeAddress(variables map[string]interface{}, actualA
break break
} }
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
// resize check point - nop if there are no changes // resize check point - nop if there are no changes
if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency { if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency {
swg.Resize(request.options.Options.PayloadConcurrency) swg.Resize(request.options.Options.PayloadConcurrency)

View File

@ -1,6 +1,7 @@
package network package network
import ( import (
"context"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net/http" "net/http"
@ -65,7 +66,7 @@ func TestNetworkExecuteWithResults(t *testing.T) {
t.Run("domain-valid", func(t *testing.T) { t.Run("domain-valid", func(t *testing.T) {
metadata := make(output.InternalEvent) metadata := make(output.InternalEvent)
previous := make(output.InternalEvent) previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(parsed.Host) ctxArgs := contextargs.NewWithInput(context.Background(), parsed.Host)
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
finalEvent = event finalEvent = event
}) })
@ -81,7 +82,7 @@ func TestNetworkExecuteWithResults(t *testing.T) {
t.Run("invalid-port-override", func(t *testing.T) { t.Run("invalid-port-override", func(t *testing.T) {
metadata := make(output.InternalEvent) metadata := make(output.InternalEvent)
previous := make(output.InternalEvent) previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput("127.0.0.1:11211") ctxArgs := contextargs.NewWithInput(context.Background(), "127.0.0.1:11211")
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
finalEvent = event finalEvent = event
}) })
@ -95,7 +96,7 @@ func TestNetworkExecuteWithResults(t *testing.T) {
t.Run("hex-to-string", func(t *testing.T) { t.Run("hex-to-string", func(t *testing.T) {
metadata := make(output.InternalEvent) metadata := make(output.InternalEvent)
previous := make(output.InternalEvent) previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(parsed.Host) ctxArgs := contextargs.NewWithInput(context.Background(), parsed.Host)
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
finalEvent = event finalEvent = event
}) })

View File

@ -1,6 +1,7 @@
package protocols package protocols
import ( import (
"context"
"encoding/base64" "encoding/base64"
"sync/atomic" "sync/atomic"
@ -173,7 +174,7 @@ func (e *ExecutorOptions) GetTemplateCtx(input *contextargs.MetaInput) *contexta
templateCtx, ok := e.templateCtxStore.Get(scanId) templateCtx, ok := e.templateCtxStore.Get(scanId)
if !ok { if !ok {
// if template context does not exist create new and add it to store and return it // if template context does not exist create new and add it to store and return it
templateCtx = contextargs.New() templateCtx = contextargs.New(context.Background())
templateCtx.MetaInput = input templateCtx.MetaInput = input
_ = e.templateCtxStore.Set(scanId, templateCtx) _ = e.templateCtxStore.Set(scanId, templateCtx)
} }

View File

@ -1,6 +1,7 @@
package ssl package ssl
import ( import (
"context"
"testing" "testing"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
@ -28,7 +29,7 @@ func TestSSLProtocol(t *testing.T) {
require.Nil(t, err, "could not compile ssl request") require.Nil(t, err, "could not compile ssl request")
var gotEvent output.InternalEvent var gotEvent output.InternalEvent
ctxArgs := contextargs.NewWithInput("scanme.sh:443") ctxArgs := contextargs.NewWithInput(context.Background(), "scanme.sh:443")
err = request.ExecuteWithResults(ctxArgs, nil, nil, func(event *output.InternalWrappedEvent) { err = request.ExecuteWithResults(ctxArgs, nil, nil, func(event *output.InternalWrappedEvent) {
gotEvent = event.InternalEvent gotEvent = event.InternalEvent
}) })

View File

@ -1,6 +1,7 @@
package utils package utils
import ( import (
"context"
"reflect" "reflect"
"testing" "testing"
@ -49,7 +50,7 @@ func TestHTTPVariables(t *testing.T) {
require.Equal(t, values["Hostname"], "foobar.com", "incorrect hostname") require.Equal(t, values["Hostname"], "foobar.com", "incorrect hostname")
baseURL = "http://scanme.sh" baseURL = "http://scanme.sh"
ctxArgs := contextargs.NewWithInput(baseURL) ctxArgs := contextargs.NewWithInput(context.Background(), baseURL)
ctxArgs.MetaInput.CustomIP = "1.2.3.4" ctxArgs.MetaInput.CustomIP = "1.2.3.4"
values = GenerateVariablesWithContextArgs(ctxArgs, true) values = GenerateVariablesWithContextArgs(ctxArgs, true)

View File

@ -10,7 +10,6 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"
) )
type ScanContextOption func(*ScanContext) type ScanContextOption func(*ScanContext)
func WithEvents() ScanContextOption { func WithEvents() ScanContextOption {
@ -20,7 +19,8 @@ func WithEvents() ScanContextOption {
} }
type ScanContext struct { type ScanContext struct {
context.Context ctx context.Context
// exported / configurable fields // exported / configurable fields
Input *contextargs.Context Input *contextargs.Context
@ -43,8 +43,13 @@ type ScanContext struct {
} }
// NewScanContext creates a new scan context using input // NewScanContext creates a new scan context using input
func NewScanContext(input *contextargs.Context) *ScanContext { func NewScanContext(ctx context.Context, input *contextargs.Context) *ScanContext {
return &ScanContext{Input: input} return &ScanContext{ctx: ctx, Input: input}
}
// Context returns the context of the scan
func (s *ScanContext) Context() context.Context {
return s.ctx
} }
// GenerateResult returns final results slice from all events // GenerateResult returns final results slice from all events

View File

@ -303,7 +303,7 @@ func (e *ClusterExecuter) Execute(ctx *scan.ScanContext) (bool, error) {
// ExecuteWithResults executes the protocol requests and returns results instead of writing them. // ExecuteWithResults executes the protocol requests and returns results instead of writing them.
func (e *ClusterExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) { func (e *ClusterExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) {
scanCtx := scan.NewScanContext(ctx.Input) scanCtx := scan.NewScanContext(ctx.Context(), ctx.Input)
dynamicValues := make(map[string]interface{}) dynamicValues := make(map[string]interface{})
inputItem := ctx.Input.Clone() inputItem := ctx.Input.Clone()

View File

@ -174,6 +174,12 @@ func (f *FlowExecutor) Compile() error {
// ExecuteWithResults executes the flow and returns results // ExecuteWithResults executes the flow and returns results
func (f *FlowExecutor) ExecuteWithResults(ctx *scan.ScanContext) error { func (f *FlowExecutor) ExecuteWithResults(ctx *scan.ScanContext) error {
select {
case <-ctx.Context().Done():
return ctx.Context().Err()
default:
}
f.ctx.Input = ctx.Input f.ctx.Input = ctx.Input
// -----Load all types of variables----- // -----Load all types of variables-----
// add all input args to template context // add all input args to template context

View File

@ -55,8 +55,8 @@ func TestFlowTemplateWithIndex(t *testing.T) {
err = Template.Executer.Compile() err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template") require.Nil(t, err, "could not compile template")
input := contextargs.NewWithInput("hackerone.com") input := contextargs.NewWithInput(context.Background(), "hackerone.com")
ctx := scan.NewScanContext(input) ctx := scan.NewScanContext(context.Background(), input)
gotresults, err := Template.Executer.Execute(ctx) gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template") require.Nil(t, err, "could not execute template")
require.True(t, gotresults) require.True(t, gotresults)
@ -74,8 +74,8 @@ func TestFlowTemplateWithID(t *testing.T) {
err = Template.Executer.Compile() err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template") require.Nil(t, err, "could not compile template")
target := contextargs.NewWithInput("hackerone.com") target := contextargs.NewWithInput(context.Background(), "hackerone.com")
ctx := scan.NewScanContext(target) ctx := scan.NewScanContext(context.Background(), target)
gotresults, err := Template.Executer.Execute(ctx) gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template") require.Nil(t, err, "could not execute template")
require.True(t, gotresults) require.True(t, gotresults)
@ -96,8 +96,8 @@ func TestFlowWithProtoPrefix(t *testing.T) {
err = Template.Executer.Compile() err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template") require.Nil(t, err, "could not compile template")
input := contextargs.NewWithInput("hackerone.com") input := contextargs.NewWithInput(context.Background(), "hackerone.com")
ctx := scan.NewScanContext(input) ctx := scan.NewScanContext(context.Background(), input)
gotresults, err := Template.Executer.Execute(ctx) gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template") require.Nil(t, err, "could not execute template")
require.True(t, gotresults) require.True(t, gotresults)
@ -116,8 +116,8 @@ func TestFlowWithConditionNegative(t *testing.T) {
err = Template.Executer.Compile() err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template") require.Nil(t, err, "could not compile template")
input := contextargs.NewWithInput("scanme.sh") input := contextargs.NewWithInput(context.Background(), "scanme.sh")
ctx := scan.NewScanContext(input) ctx := scan.NewScanContext(context.Background(), input)
// expect no results and verify thant dns request is executed and http is not // expect no results and verify thant dns request is executed and http is not
gotresults, err := Template.Executer.Execute(ctx) gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template") require.Nil(t, err, "could not execute template")
@ -137,8 +137,8 @@ func TestFlowWithConditionPositive(t *testing.T) {
err = Template.Executer.Compile() err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template") require.Nil(t, err, "could not compile template")
input := contextargs.NewWithInput("blog.projectdiscovery.io") input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io")
ctx := scan.NewScanContext(input) ctx := scan.NewScanContext(context.Background(), input)
// positive match . expect results also verify that both dns() and http() were executed // positive match . expect results also verify that both dns() and http() were executed
gotresults, err := Template.Executer.Execute(ctx) gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template") require.Nil(t, err, "could not execute template")
@ -158,8 +158,8 @@ func TestFlowWithNoMatchers(t *testing.T) {
err = Template.Executer.Compile() err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template") require.Nil(t, err, "could not compile template")
input := contextargs.NewWithInput("blog.projectdiscovery.io") input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io")
ctx := scan.NewScanContext(input) ctx := scan.NewScanContext(context.Background(), input)
// positive match . expect results also verify that both dns() and http() were executed // positive match . expect results also verify that both dns() and http() were executed
gotresults, err := Template.Executer.Execute(ctx) gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template") require.Nil(t, err, "could not execute template")
@ -174,8 +174,8 @@ func TestFlowWithNoMatchers(t *testing.T) {
err = Template.Executer.Compile() err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template") require.Nil(t, err, "could not compile template")
anotherInput := contextargs.NewWithInput("blog.projectdiscovery.io") anotherInput := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io")
anotherCtx := scan.NewScanContext(anotherInput) anotherCtx := scan.NewScanContext(context.Background(), anotherInput)
// positive match . expect results also verify that both dns() and http() were executed // positive match . expect results also verify that both dns() and http() were executed
gotresults, err = Template.Executer.Execute(anotherCtx) gotresults, err = Template.Executer.Execute(anotherCtx)
require.Nil(t, err, "could not execute template") require.Nil(t, err, "could not execute template")

View File

@ -45,6 +45,12 @@ func (g *Generic) ExecuteWithResults(ctx *scan.ScanContext) error {
previous := mapsutil.NewSyncLockMap[string, any]() previous := mapsutil.NewSyncLockMap[string, any]()
for _, req := range g.requests { for _, req := range g.requests {
select {
case <-ctx.Context().Done():
return ctx.Context().Err()
default:
}
inputItem := ctx.Input.Clone() inputItem := ctx.Input.Clone()
if g.options.InputHelper != nil && ctx.Input.MetaInput.Input != "" { if g.options.InputHelper != nil && ctx.Input.MetaInput.Input != "" {
if inputItem.MetaInput.Input = g.options.InputHelper.Transform(inputItem.MetaInput.Input, req.Type()); inputItem.MetaInput.Input == "" { if inputItem.MetaInput.Input = g.options.InputHelper.Transform(inputItem.MetaInput.Input, req.Type()); inputItem.MetaInput.Input == "" {

View File

@ -44,6 +44,12 @@ func (m *MultiProtocol) Compile() error {
// ExecuteWithResults executes the template and returns results // ExecuteWithResults executes the template and returns results
func (m *MultiProtocol) ExecuteWithResults(ctx *scan.ScanContext) error { func (m *MultiProtocol) ExecuteWithResults(ctx *scan.ScanContext) error {
select {
case <-ctx.Context().Done():
return ctx.Context().Err()
default:
}
// put all readonly args into template context // put all readonly args into template context
m.options.GetTemplateCtx(ctx.Input.MetaInput).Merge(m.readOnlyArgs) m.options.GetTemplateCtx(ctx.Input.MetaInput).Merge(m.readOnlyArgs)
@ -96,6 +102,12 @@ func (m *MultiProtocol) ExecuteWithResults(ctx *scan.ScanContext) error {
// execute all protocols in the queue // execute all protocols in the queue
for _, req := range m.requests { for _, req := range m.requests {
select {
case <-ctx.Context().Done():
return ctx.Context().Err()
default:
}
values := m.options.GetTemplateCtx(ctx.Input.MetaInput).GetAll() values := m.options.GetTemplateCtx(ctx.Input.MetaInput).GetAll()
err := req.ExecuteWithResults(ctx.Input, output.InternalEvent(values), nil, multiProtoCallback) err := req.ExecuteWithResults(ctx.Input, output.InternalEvent(values), nil, multiProtoCallback)
// if error skip execution of next protocols // if error skip execution of next protocols

View File

@ -54,8 +54,8 @@ func TestMultiProtoWithDynamicExtractor(t *testing.T) {
err = Template.Executer.Compile() err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template") require.Nil(t, err, "could not compile template")
input := contextargs.NewWithInput("blog.projectdiscovery.io") input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io")
ctx := scan.NewScanContext(input) ctx := scan.NewScanContext(context.Background(), input)
gotresults, err := Template.Executer.Execute(ctx) gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template") require.Nil(t, err, "could not execute template")
require.True(t, gotresults) require.True(t, gotresults)
@ -71,8 +71,8 @@ func TestMultiProtoWithProtoPrefix(t *testing.T) {
err = Template.Executer.Compile() err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template") require.Nil(t, err, "could not compile template")
input := contextargs.NewWithInput("blog.projectdiscovery.io") input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io")
ctx := scan.NewScanContext(input) ctx := scan.NewScanContext(context.Background(), input)
gotresults, err := Template.Executer.Execute(ctx) gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template") require.Nil(t, err, "could not execute template")
require.True(t, gotresults) require.True(t, gotresults)