feat: added support for context cancellation to engine (#5096)

* feat: added support for context cancellation to engine

* misc

* feat: added contexts everywhere

* misc

* misc

* use granular http timeouts and increase http timeout to 30s using multiplier

* track response header timeout in mhe

* update responseHeaderTimeout to 5sec

* skip failing windows test

---------

Co-authored-by: Tarun Koyalwar <tarun@projectdiscovery.io>
This commit is contained in:
Ice3man 2024-04-25 15:37:56 +05:30 committed by GitHub
parent 3dfcec0a36
commit 0b82e8b7aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 279 additions and 113 deletions

View File

@ -128,7 +128,7 @@ func executeNucleiAsLibrary(templatePath, templateURL string) ([]string, error)
}
store.Load()
_ = engine.Execute(store.Templates(), provider.NewSimpleInputProviderWithUrls(templateURL))
_ = engine.Execute(context.Background(), store.Templates(), provider.NewSimpleInputProviderWithUrls(templateURL))
engine.WorkPool().Wait() // Wait for the scan to finish
return results, nil

View File

@ -1,6 +1,7 @@
package runner
import (
"context"
"fmt"
"github.com/projectdiscovery/nuclei/v3/pkg/authprovider/authx"
@ -71,7 +72,8 @@ func GetLazyAuthFetchCallback(opts *AuthLazyFetchOptions) authx.LazyFetchSecret
tmpl := tmpls[0]
// add args to tmpl here
vars := map[string]interface{}{}
ctx := scan.NewScanContext(contextargs.NewWithInput(d.Input))
mainCtx := context.Background()
ctx := scan.NewScanContext(mainCtx, contextargs.NewWithInput(mainCtx, d.Input))
for _, v := range d.Variables {
vars[v.Key] = v.Value
ctx.Input.Add(v.Key, v.Value)

View File

@ -669,7 +669,7 @@ func (r *Runner) executeTemplatesInput(store *loader.Store, engine *core.Engine)
if r.inputProvider == nil {
return nil, errors.New("no input provider found")
}
results := engine.ExecuteScanWithOpts(finalTemplates, r.inputProvider, r.options.DisableClustering)
results := engine.ExecuteScanWithOpts(context.Background(), finalTemplates, r.inputProvider, r.options.DisableClustering)
return results, nil
}

View File

@ -138,7 +138,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOpts(targets []string, opts ..
engine := core.New(tmpEngine.opts)
engine.SetExecuterOptions(unsafeOpts.executerOpts)
_ = engine.ExecuteScanWithOpts(store.Templates(), inputProvider, false)
_ = engine.ExecuteScanWithOpts(context.Background(), store.Templates(), inputProvider, false)
engine.WorkPool().Wait()
return nil

View File

@ -3,6 +3,7 @@ package nuclei
import (
"bufio"
"bytes"
"context"
"io"
"github.com/projectdiscovery/nuclei/v3/pkg/authprovider"
@ -210,7 +211,7 @@ func (e *NucleiEngine) ExecuteWithCallback(callback ...func(event *output.Result
}
e.resultCallbacks = append(e.resultCallbacks, filtered...)
_ = e.engine.ExecuteScanWithOpts(e.store.Templates(), e.inputProvider, false)
_ = e.engine.ExecuteScanWithOpts(context.Background(), e.store.Templates(), e.inputProvider, false)
defer e.engine.WorkPool().Wait()
return nil
}

View File

@ -1,6 +1,7 @@
package core
import (
"context"
"sync"
"sync/atomic"
@ -20,18 +21,18 @@ import (
//
// All the execution logic for the templates/workflows happens in this part
// of the engine.
func (e *Engine) Execute(templates []*templates.Template, target provider.InputProvider) *atomic.Bool {
return e.ExecuteScanWithOpts(templates, target, false)
func (e *Engine) Execute(ctx context.Context, templates []*templates.Template, target provider.InputProvider) *atomic.Bool {
return e.ExecuteScanWithOpts(ctx, templates, target, false)
}
// ExecuteWithResults a list of templates with results
func (e *Engine) ExecuteWithResults(templatesList []*templates.Template, target provider.InputProvider, callback func(*output.ResultEvent)) *atomic.Bool {
func (e *Engine) ExecuteWithResults(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider, callback func(*output.ResultEvent)) *atomic.Bool {
e.Callback = callback
return e.ExecuteScanWithOpts(templatesList, target, false)
return e.ExecuteScanWithOpts(ctx, templatesList, target, false)
}
// ExecuteScanWithOpts executes scan with given scanStrategy
func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target provider.InputProvider, noCluster bool) *atomic.Bool {
func (e *Engine) ExecuteScanWithOpts(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider, noCluster bool) *atomic.Bool {
results := &atomic.Bool{}
selfcontainedWg := &sync.WaitGroup{}
@ -83,14 +84,14 @@ func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target
}
// Execute All SelfContained in parallel
e.executeAllSelfContained(selfContained, results, selfcontainedWg)
e.executeAllSelfContained(ctx, selfContained, results, selfcontainedWg)
strategyResult := &atomic.Bool{}
switch e.options.ScanStrategy {
case scanstrategy.TemplateSpray.String():
strategyResult = e.executeTemplateSpray(filtered, target)
strategyResult = e.executeTemplateSpray(ctx, filtered, target)
case scanstrategy.HostSpray.String():
strategyResult = e.executeHostSpray(filtered, target)
strategyResult = e.executeHostSpray(ctx, filtered, target)
}
results.CompareAndSwap(false, strategyResult.Load())
@ -100,7 +101,7 @@ func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target
}
// executeTemplateSpray executes scan using template spray strategy where targets are iterated over each template
func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
func (e *Engine) executeTemplateSpray(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
results := &atomic.Bool{}
// wp is workpool that contains different waitgroups for
@ -108,6 +109,12 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe
wp := e.GetWorkPool()
for _, template := range templatesList {
select {
case <-ctx.Done():
return results
default:
}
// resize check point - nop if there are no changes
wp.RefreshWithConfig(e.GetWorkPoolConfig())
@ -125,7 +132,7 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe
// All other request types are executed here
// Note: executeTemplateWithTargets creates goroutines and blocks
// given template is executed on all targets
e.executeTemplateWithTargets(tpl, target, results)
e.executeTemplateWithTargets(ctx, tpl, target, results)
}(template)
}
wp.Wait()
@ -133,15 +140,21 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe
}
// executeHostSpray executes scan using host spray strategy where templates are iterated over each target
func (e *Engine) executeHostSpray(templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
func (e *Engine) executeHostSpray(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
results := &atomic.Bool{}
wp, _ := syncutil.New(syncutil.WithSize(e.options.BulkSize + e.options.HeadlessBulkSize))
target.Iterate(func(value *contextargs.MetaInput) bool {
select {
case <-ctx.Done():
return false
default:
}
wp.Add()
go func(targetval *contextargs.MetaInput) {
defer wp.Done()
e.executeTemplatesOnTarget(templatesList, targetval, results)
e.executeTemplatesOnTarget(ctx, templatesList, targetval, results)
}(value)
return true
})

View File

@ -1,6 +1,7 @@
package core
import (
"context"
"sync"
"sync/atomic"
@ -17,14 +18,14 @@ import (
// Executors are low level executors that deals with template execution on a target
// executeAllSelfContained executes all self contained templates that do not use `target`
func (e *Engine) executeAllSelfContained(alltemplates []*templates.Template, results *atomic.Bool, sg *sync.WaitGroup) {
func (e *Engine) executeAllSelfContained(ctx context.Context, alltemplates []*templates.Template, results *atomic.Bool, sg *sync.WaitGroup) {
for _, v := range alltemplates {
sg.Add(1)
go func(template *templates.Template) {
defer sg.Done()
var err error
var match bool
ctx := scan.NewScanContext(contextargs.New())
ctx := scan.NewScanContext(ctx, contextargs.New(ctx))
if e.Callback != nil {
if results, err := template.Executer.ExecuteWithResults(ctx); err != nil {
for _, result := range results {
@ -45,7 +46,7 @@ func (e *Engine) executeAllSelfContained(alltemplates []*templates.Template, res
}
// executeTemplateWithTarget executes a given template on x targets (with a internal targetpool(i.e concurrency))
func (e *Engine) executeTemplateWithTargets(template *templates.Template, target provider.InputProvider, results *atomic.Bool) {
func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templates.Template, target provider.InputProvider, results *atomic.Bool) {
// this is target pool i.e max target to execute
wg := e.workPool.InputPool(template.Type())
@ -77,6 +78,12 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target
}
target.Iterate(func(scannedValue *contextargs.MetaInput) bool {
select {
case <-ctx.Done():
return false // exit
default:
}
// Best effort to track the host progression
// skips indexes lower than the minimum in-flight at interruption time
var skip bool
@ -114,9 +121,9 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target
var match bool
var err error
ctxArgs := contextargs.New()
ctxArgs := contextargs.New(ctx)
ctxArgs.MetaInput = value
ctx := scan.NewScanContext(ctxArgs)
ctx := scan.NewScanContext(ctx, ctxArgs)
switch template.Type() {
case types.WorkflowProtocol:
match = e.executeWorkflow(ctx, template.CompiledWorkflow)
@ -149,7 +156,7 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target
}
// executeTemplatesOnTarget execute given templates on given single target
func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, target *contextargs.MetaInput, results *atomic.Bool) {
func (e *Engine) executeTemplatesOnTarget(ctx context.Context, alltemplates []*templates.Template, target *contextargs.MetaInput, results *atomic.Bool) {
// all templates are executed on single target
// wp is workpool that contains different waitgroups for
@ -158,6 +165,12 @@ func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, ta
wp := e.GetWorkPool()
for _, tpl := range alltemplates {
select {
case <-ctx.Done():
return
default:
}
// resize check point - nop if there are no changes
wp.RefreshWithConfig(e.GetWorkPoolConfig())
@ -173,9 +186,9 @@ func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, ta
var match bool
var err error
ctxArgs := contextargs.New()
ctxArgs := contextargs.New(ctx)
ctxArgs.MetaInput = value
ctx := scan.NewScanContext(ctxArgs)
ctx := scan.NewScanContext(ctx, ctxArgs)
switch template.Type() {
case types.WorkflowProtocol:
match = e.executeWorkflow(ctx, template.CompiledWorkflow)
@ -230,9 +243,11 @@ func (e *ChildExecuter) Execute(template *templates.Template, value *contextargs
go func(tpl *templates.Template) {
defer wg.Done()
ctxArgs := contextargs.New()
// TODO: Workflows are a no-op for now. We need to
// implement them in the future with context cancellation
ctxArgs := contextargs.New(context.Background())
ctxArgs.MetaInput = value
ctx := scan.NewScanContext(ctxArgs)
ctx := scan.NewScanContext(context.Background(), ctxArgs)
match, err := template.Executer.Execute(ctx)
if err != nil {
gologger.Warning().Msgf("[%s] Could not execute step: %s\n", e.e.executerOpts.Colorizer.BrightBlue(template.ID), err)

View File

@ -21,7 +21,7 @@ func (e *Engine) executeWorkflow(ctx *scan.ScanContext, w *workflows.Workflow) b
// at this point we should be at the start root execution of a workflow tree, hence we create global shared instances
workflowCookieJar, _ := cookiejar.New(nil)
ctxArgs := contextargs.New()
ctxArgs := contextargs.New(ctx.Context())
ctxArgs.MetaInput = ctx.Input.MetaInput
ctxArgs.CookieJar = workflowCookieJar
@ -139,7 +139,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan
defer swg.Done()
// create a new context with the same input but with unset callbacks
subCtx := scan.NewScanContext(ctx.Input)
subCtx := scan.NewScanContext(ctx.Context(), ctx.Input)
if err := e.runWorkflowStep(subtemplate, subCtx, results, swg, w); err != nil {
gologger.Warning().Msgf(workflowStepExecutionError, subtemplate.Template, err)
}
@ -165,7 +165,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan
go func(template *workflows.WorkflowTemplate) {
// create a new context with the same input but with unset callbacks
subCtx := scan.NewScanContext(ctx.Input)
subCtx := scan.NewScanContext(ctx.Context(), ctx.Input)
if err := e.runWorkflowStep(template, subCtx, results, swg, w); err != nil {
gologger.Warning().Msgf(workflowStepExecutionError, template.Template, err)
}

View File

@ -1,6 +1,7 @@
package core
import (
"context"
"testing"
"github.com/projectdiscovery/nuclei/v3/pkg/model/types/stringslice"
@ -25,8 +26,8 @@ func TestWorkflowsSimple(t *testing.T) {
}}
engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value")
}
@ -49,8 +50,8 @@ func TestWorkflowsSimpleMultiple(t *testing.T) {
}}
engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value")
@ -77,8 +78,8 @@ func TestWorkflowsSubtemplates(t *testing.T) {
}}
engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value")
@ -103,8 +104,8 @@ func TestWorkflowsSubtemplatesNoMatch(t *testing.T) {
}}
engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.False(t, matched, "could not get correct match value")
@ -134,8 +135,8 @@ func TestWorkflowsSubtemplatesWithMatcher(t *testing.T) {
}}
engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.True(t, matched, "could not get correct match value")
@ -165,8 +166,8 @@ func TestWorkflowsSubtemplatesWithMatcherNoMatch(t *testing.T) {
}}
engine := &Engine{}
input := contextargs.NewWithInput("https://test.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "https://test.com")
ctx := scan.NewScanContext(context.Background(), input)
matched := engine.executeWorkflow(ctx, workflow)
require.False(t, matched, "could not get correct match value")

View File

@ -3,6 +3,7 @@ package list
import (
"net"
"os"
"runtime"
"strconv"
"strings"
"testing"
@ -77,6 +78,9 @@ func (m *mockDnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
func Test_scanallips_normalizeStoreInputValue(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Skipping test see: https://github.com/projectdiscovery/nuclei/issues/5097")
}
srv := &dns.Server{Addr: ":" + strconv.Itoa(61234), Net: "udp"}
srv.Handler = &mockDnsHandler{}

View File

@ -37,6 +37,8 @@ type ExecuteOptions struct {
// Source is original source of the script
Source *string
Context context.Context
// Manually exported objects
exports map[string]interface{}
}
@ -77,13 +79,13 @@ func (c *Compiler) Execute(code string, args *ExecuteArgs) (ExecuteResult, error
if err != nil {
return nil, err
}
return c.ExecuteWithOptions(p, args, &ExecuteOptions{})
return c.ExecuteWithOptions(p, args, &ExecuteOptions{Context: context.Background()})
}
// ExecuteWithOptions executes a script with the provided options.
func (c *Compiler) ExecuteWithOptions(program *goja.Program, args *ExecuteArgs, opts *ExecuteOptions) (ExecuteResult, error) {
if opts == nil {
opts = &ExecuteOptions{}
opts = &ExecuteOptions{Context: context.Background()}
}
if args == nil {
args = NewExecuteArgs()
@ -105,7 +107,7 @@ func (c *Compiler) ExecuteWithOptions(program *goja.Program, args *ExecuteArgs,
}
// execute with context and timeout
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(opts.Timeout)*time.Second)
ctx, cancel := context.WithTimeout(opts.Context, time.Duration(opts.Timeout)*time.Second)
defer cancel()
// execute the script
results, err := contextutil.ExecFuncWithTwoReturns(ctx, func() (val goja.Value, err error) {

View File

@ -199,6 +199,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
Source: &request.PreCondition,
Callback: registerPreConditionFunctions,
Cleanup: cleanUpPreConditionFunctions,
Context: input.Context(),
})
if err != nil {
return errorutil.NewWithTag(request.TemplateID, "could not execute pre-condition: %s", err)

View File

@ -3,6 +3,7 @@
package code
import (
"context"
"testing"
"github.com/stretchr/testify/require"
@ -31,7 +32,7 @@ func TestCodeProtocol(t *testing.T) {
require.Nil(t, err, "could not compile code request")
var gotEvent output.InternalEvent
ctxArgs := contextargs.NewWithInput("")
ctxArgs := contextargs.NewWithInput(context.Background(), "")
err = request.ExecuteWithResults(ctxArgs, nil, nil, func(event *output.InternalWrappedEvent) {
gotEvent = event.InternalEvent
})

View File

@ -1,6 +1,7 @@
package automaticscan
import (
"context"
"io"
"net/http"
"os"
@ -189,7 +190,7 @@ func (s *Service) executeAutomaticScanOnTarget(input *contextargs.MetaInput) {
execOptions.Progress = &testutils.MockProgressClient{} // stats are not supported yet due to centralized logic and cannot be reinitialized
eng.SetExecuterOptions(execOptions)
tmp := eng.ExecuteScanWithOpts(finalTemplates, provider.NewSimpleInputProviderWithUrls(input.Input), true)
tmp := eng.ExecuteScanWithOpts(context.Background(), finalTemplates, provider.NewSimpleInputProviderWithUrls(input.Input), true)
s.hasResults.Store(tmp.Load())
}
@ -244,7 +245,9 @@ func (s *Service) getTagsUsingWappalyzer(input *contextargs.MetaInput) []string
// getTagsUsingDetectionTemplates returns tags using detection templates
func (s *Service) getTagsUsingDetectionTemplates(input *contextargs.MetaInput) ([]string, int) {
ctxArgs := contextargs.NewWithInput(input.Input)
ctx := context.Background()
ctxArgs := contextargs.NewWithInput(ctx, input.Input)
// execute tech detection templates on target
tags := map[string]struct{}{}
@ -256,7 +259,7 @@ func (s *Service) getTagsUsingDetectionTemplates(input *contextargs.MetaInput) (
sg.Add()
go func(template *templates.Template) {
defer sg.Done()
ctx := scan.NewScanContext(ctxArgs)
ctx := scan.NewScanContext(ctx, ctxArgs)
ctx.OnResult = func(event *output.InternalWrappedEvent) {
if event == nil {
return

View File

@ -1,6 +1,7 @@
package contextargs
import (
"context"
"net/http/cookiejar"
"strings"
"sync/atomic"
@ -19,6 +20,8 @@ var (
// Context implements a shared context struct to share information across multiple templates within a workflow
type Context struct {
ctx context.Context
// Meta is the target for the executor
MetaInput *MetaInput
@ -30,17 +33,18 @@ type Context struct {
}
// Create a new contextargs instance
func New() *Context {
return NewWithInput("")
func New(ctx context.Context) *Context {
return NewWithInput(ctx, "")
}
// Create a new contextargs instance with input string
func NewWithInput(input string) *Context {
func NewWithInput(ctx context.Context, input string) *Context {
jar, err := cookiejar.New(nil)
if err != nil {
gologger.Error().Msgf("contextargs: could not create cookie jar: %s\n", err)
}
return &Context{
ctx: ctx,
MetaInput: &MetaInput{Input: input},
CookieJar: jar,
args: &mapsutil.SyncLockMap[string, interface{}]{
@ -50,6 +54,11 @@ func NewWithInput(input string) *Context {
}
}
// Context returns the context of the current contextargs
func (ctx *Context) Context() context.Context {
return ctx.ctx
}
// Set the specific key-value pair
func (ctx *Context) Set(key string, value interface{}) {
_ = ctx.args.Set(key, value)
@ -158,6 +167,7 @@ func (ctx *Context) HasArgs() bool {
func (ctx *Context) Clone() *Context {
newCtx := &Context{
ctx: ctx.ctx,
MetaInput: ctx.MetaInput.Clone(),
args: ctx.args.Clone(),
CookieJar: ctx.CookieJar,

View File

@ -124,7 +124,7 @@ func (c *Cache) MarkFailed(value string, err error) {
_ = c.failedTargets.Set(finalValue, existingCacheItemValue)
}
var reCheckError = regexp.MustCompile(`(no address found for host|Client\.Timeout exceeded while awaiting headers|could not resolve host|connection refused|connection reset by peer|i/o timeout|could not connect to any address found for host)`)
var reCheckError = regexp.MustCompile(`(no address found for host|Client\.Timeout exceeded while awaiting headers|could not resolve host|connection refused|connection reset by peer|i/o timeout|could not connect to any address found for host|timeout awaiting response headers)`)
// checkError checks if an error represents a type that should be
// added to the host skipping table.

View File

@ -80,6 +80,12 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata,
break
}
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
// resize check point - nop if there are no changes
if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency {
swg.Resize(request.options.Options.PayloadConcurrency)

View File

@ -1,6 +1,7 @@
package dns
import (
"context"
"testing"
"github.com/stretchr/testify/require"
@ -54,7 +55,7 @@ func TestDNSExecuteWithResults(t *testing.T) {
t.Run("domain-valid", func(t *testing.T) {
metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput("example.com")
ctxArgs := contextargs.NewWithInput(context.Background(), "example.com")
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
finalEvent = event
})
@ -70,7 +71,7 @@ func TestDNSExecuteWithResults(t *testing.T) {
t.Run("url-to-domain", func(t *testing.T) {
metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)
err := request.ExecuteWithResults(contextargs.NewWithInput("https://example.com"), metadata, previous, func(event *output.InternalWrappedEvent) {
err := request.ExecuteWithResults(contextargs.NewWithInput(context.Background(), "https://example.com"), metadata, previous, func(event *output.InternalWrappedEvent) {
finalEvent = event
})
require.Nil(t, err, "could not execute dns request")

View File

@ -1,6 +1,7 @@
package file
import (
"context"
"os"
"path/filepath"
"testing"
@ -67,7 +68,7 @@ func TestFileExecuteWithResults(t *testing.T) {
t.Run("valid", func(t *testing.T) {
metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(tempDir)
ctxArgs := contextargs.NewWithInput(context.Background(), tempDir)
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
finalEvent = event
})

View File

@ -1,6 +1,7 @@
package engine
import (
"context"
"fmt"
"io"
"math/rand"
@ -595,7 +596,7 @@ func testHeadless(t *testing.T, actions []*Action, timeout time.Duration, handle
ts := httptest.NewServer(http.HandlerFunc(handler))
defer ts.Close()
input := contextargs.NewWithInput(ts.URL)
input := contextargs.NewWithInput(context.Background(), ts.URL)
input.CookieJar, err = cookiejar.New(nil)
require.Nil(t, err)
@ -674,7 +675,7 @@ func TestBlockedHeadlessURLS(t *testing.T) {
{ActionType: ActionTypeHolder{ActionType: ActionWaitLoad}},
}
data, page, err := instance.Run(contextargs.NewWithInput(ts.URL), actions, nil, &Options{Timeout: 20 * time.Second, Options: opts}) // allow file access in test
data, page, err := instance.Run(contextargs.NewWithInput(context.Background(), ts.URL), actions, nil, &Options{Timeout: 20 * time.Second, Options: opts}) // allow file access in test
require.Error(t, err, "expected error for url %s got %v", testcase, data)
require.True(t, stringsutil.ContainsAny(err.Error(), "net::ERR_ACCESS_DENIED", "failed to parse url", "Cannot navigate to invalid URL", "net::ERR_ABORTED", "net::ERR_INVALID_URL"), "found different error %v for testcases %v", err, testcase)
require.Len(t, data, 0, "expected no data for url %s got %v", testcase, data)

View File

@ -44,7 +44,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata,
if err != nil {
return err
}
input = contextargs.NewWithInput(url)
input = contextargs.NewWithInput(input.Context(), url)
}
if request.options.Browser.UserAgent() == "" {

View File

@ -40,7 +40,7 @@ func TestMakeRequestFromModal(t *testing.T) {
generator := request.newGenerator(false)
inputData, payloads, _ := generator.nextValue()
req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{})
req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request")
if req.request.URL == nil {
t.Fatalf("url is nil in generator make")
@ -70,13 +70,13 @@ func TestMakeRequestFromModalTrimSuffixSlash(t *testing.T) {
generator := request.newGenerator(false)
inputData, payloads, _ := generator.nextValue()
req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com/test.php"), inputData, payloads, map[string]interface{}{})
req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com/test.php"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request")
require.Equal(t, "https://example.com/test.php?query=example", req.request.URL.String(), "could not get correct request path")
generator = request.newGenerator(false)
inputData, payloads, _ = generator.nextValue()
req, err = generator.Make(context.Background(), contextargs.NewWithInput("https://example.com/test/"), inputData, payloads, map[string]interface{}{})
req, err = generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com/test/"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request")
require.Equal(t, "https://example.com/test/?query=example", req.request.URL.String(), "could not get correct request path")
}
@ -110,13 +110,13 @@ Accept-Encoding: gzip`},
generator := request.newGenerator(false)
inputData, payloads, _ := generator.nextValue()
req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{})
req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request")
authorization := req.request.Header.Get("Authorization")
require.Equal(t, "Basic admin:admin", authorization, "could not get correct authorization headers from raw")
inputData, payloads, _ = generator.nextValue()
req, err = generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{})
req, err = generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request")
authorization = req.request.Header.Get("Authorization")
require.Equal(t, "Basic admin:guest", authorization, "could not get correct authorization headers from raw")
@ -151,13 +151,13 @@ Accept-Encoding: gzip`},
generator := request.newGenerator(false)
inputData, payloads, _ := generator.nextValue()
req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{})
req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request")
authorization := req.request.Header.Get("Authorization")
require.Equal(t, "Basic YWRtaW46YWRtaW4=", authorization, "could not get correct authorization headers from raw")
inputData, payloads, _ = generator.nextValue()
req, err = generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{})
req, err = generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request")
authorization = req.request.Header.Get("Authorization")
require.Equal(t, "Basic YWRtaW46Z3Vlc3Q=", authorization, "could not get correct authorization headers from raw")
@ -195,7 +195,7 @@ func TestMakeRequestFromModelUniqueInteractsh(t *testing.T) {
require.Nil(t, err, "could not create interactsh client")
inputData, payloads, _ := generator.nextValue()
got, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{})
got, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
require.Nil(t, err, "could not make http request")
// check if all the interactsh markers are replaced with unique urls

View File

@ -35,8 +35,18 @@ var (
forceMaxRedirects int
normalClient *retryablehttp.Client
clientPool *mapsutil.SyncLockMap[string, *retryablehttp.Client]
// ResponseHeaderTimeout is the timeout for response headers
// to be read from the server (this prevents infinite hang started by server if any)
ResponseHeaderTimeout = time.Duration(5) * time.Second
// HttpTimeoutMultiplier is the multiplier for the http timeout
HttpTimeoutMultiplier = 3
)
// GetHttpTimeout returns the http timeout for the client
func GetHttpTimeout(opts *types.Options) time.Duration {
return time.Duration(opts.Timeout*HttpTimeoutMultiplier) * time.Second
}
// Init initializes the clientpool implementation
func Init(options *types.Options) error {
// Don't create clients if already created in the past.
@ -139,7 +149,7 @@ func GetRawHTTP(options *types.Options) *rawhttp.Client {
} else if Dialer != nil {
rawHttpOptions.FastDialer = Dialer
}
rawHttpOptions.Timeout = time.Duration(options.Timeout) * time.Second
rawHttpOptions.Timeout = GetHttpTimeout(options)
rawHttpClient = rawhttp.NewClient(rawHttpOptions)
}
return rawHttpClient
@ -242,6 +252,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl
MaxConnsPerHost: maxConnsPerHost,
TLSClientConfig: tlsConfig,
DisableKeepAlives: disableKeepAlives,
ResponseHeaderTimeout: ResponseHeaderTimeout,
}
if types.ProxyURL != "" {
@ -288,7 +299,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl
CheckRedirect: makeCheckRedirectFunc(redirectFlow, maxRedirects),
}
if !configuration.NoTimeout {
httpclient.Timeout = time.Duration(options.Timeout) * time.Second
httpclient.Timeout = GetHttpTimeout(options)
}
client := retryablehttp.NewWithHTTPClient(httpclient, retryableHttpOptions)
if jar != nil {

View File

@ -235,6 +235,12 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
break
}
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
// resize check point - nop if there are no changes
if shouldFollowGlobal && spmHandler.Size() != request.options.Options.PayloadConcurrency {
spmHandler.Resize(request.options.Options.PayloadConcurrency)
@ -356,6 +362,13 @@ func (request *Request) executeTurboHTTP(input *contextargs.Context, dynamicValu
if !ok {
break
}
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
if spmHandler.FoundFirstMatch() || request.isUnresponsiveHost(input) || spmHandler.Cancelled() {
// skip if first match is found
break
@ -433,8 +446,9 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
request.options.RateLimitTake()
ctx := request.newContext(input)
ctxWithTimeout, cancel := context.WithTimeout(ctx, time.Duration(request.options.Options.Timeout)*time.Second)
ctxWithTimeout, cancel := context.WithTimeout(ctx, httpclientpool.GetHttpTimeout(request.options.Options))
defer cancel()
generatedHttpRequest, err := generator.Make(ctxWithTimeout, input, data, payloads, dynamicValue)
if err != nil {
if err == types.ErrNoMoreRequests {
@ -512,6 +526,13 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
if !ok {
break
}
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
var gotErr error
var skip bool
if len(gotDynamicValues) > 0 {
@ -1041,9 +1062,9 @@ func (request *Request) pruneSignatureInternalValues(maps ...map[string]interfac
func (request *Request) newContext(input *contextargs.Context) context.Context {
if input.MetaInput.CustomIP != "" {
return context.WithValue(context.Background(), fastdialer.IP, input.MetaInput.CustomIP)
return context.WithValue(input.Context(), fastdialer.IP, input.MetaInput.CustomIP)
}
return context.Background()
return input.Context()
}
// markUnresponsiveHost checks if the error is a unreponsive host error and marks it

View File

@ -9,6 +9,7 @@ import (
"time"
"github.com/projectdiscovery/fastdialer/fastdialer"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/http/httpclientpool"
"github.com/projectdiscovery/retryablehttp-go"
iputil "github.com/projectdiscovery/utils/ip"
stringsutil "github.com/projectdiscovery/utils/strings"
@ -124,7 +125,7 @@ func (r *Request) parseAnnotations(rawRequest string, request *retryablehttp.Req
}
} else {
//nolint:govet // cancelled automatically by withTimeout
ctx, overrides.cancelFunc = context.WithTimeout(context.Background(), time.Duration(r.options.Options.Timeout)*time.Second)
ctx, overrides.cancelFunc = context.WithTimeout(context.Background(), httpclientpool.GetHttpTimeout(r.options.Options))
request = request.Clone(ctx)
}
}

View File

@ -110,9 +110,21 @@ func (request *Request) executeFuzzingRule(input *contextargs.Context, previous
func (request *Request) executeAllFuzzingRules(input *contextargs.Context, values map[string]interface{}, baseRequest *retryablehttp.Request, callback protocols.OutputEventCallback) error {
applicable := false
for _, rule := range request.Fuzzing {
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
err := rule.Execute(&fuzz.ExecuteRuleInput{
Input: input,
Callback: func(gr fuzz.GeneratedRequest) bool {
select {
case <-input.Context().Done():
return false
default:
}
// TODO: replace this after scanContext Refactor
return request.executeGeneratedFuzzingRequest(gr, input, callback)
},

View File

@ -1,6 +1,7 @@
package http
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
@ -83,7 +84,7 @@ Disallow: /c`))
t.Run("test", func(t *testing.T) {
metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(ts.URL)
ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
if event.OperatorsResult != nil && event.OperatorsResult.Matched {
matchCount++
@ -159,7 +160,7 @@ func TestDisableTE(t *testing.T) {
t.Run("test", func(t *testing.T) {
metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(ts.URL)
ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
if event.OperatorsResult != nil && event.OperatorsResult.Matched {
matchCount++
@ -172,7 +173,7 @@ func TestDisableTE(t *testing.T) {
t.Run("test2", func(t *testing.T) {
metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(ts.URL)
ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
err := request2.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
if event.OperatorsResult != nil && event.OperatorsResult.Matched {
matchCount++
@ -242,7 +243,7 @@ func TestReqURLPattern(t *testing.T) {
t.Run("test", func(t *testing.T) {
metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(ts.URL)
ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
if event.OperatorsResult != nil && event.OperatorsResult.Matched {
matchCount++

View File

@ -154,6 +154,7 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error {
opts := &compiler.ExecuteOptions{
Timeout: request.Timeout,
Source: &request.Init,
Context: context.Background(),
}
// register 'export' function to export variables from init code
// these are saved in args and are available in pre-condition and request code
@ -343,7 +344,7 @@ func (request *Request) ExecuteWithResults(target *contextargs.Context, dynamicV
argsCopy.TemplateCtx = templateCtx.GetAll()
result, err := request.options.JsCompiler.ExecuteWithOptions(request.preConditionCompiled, argsCopy,
&compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.PreCondition})
&compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.PreCondition, Context: target.Context()})
if err != nil {
return errorutil.NewWithTag(request.TemplateID, "could not execute pre-condition: %s", err)
}
@ -373,6 +374,12 @@ func (request *Request) ExecuteWithResults(target *contextargs.Context, dynamicV
return nil
}
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
if err := request.executeRequestWithPayloads(hostPort, input, hostname, value, payloadValues, func(result *output.InternalWrappedEvent) {
if result.OperatorsResult != nil && result.OperatorsResult.Matched {
gotMatches = true
@ -419,6 +426,12 @@ func (request *Request) executeRequestParallel(ctxParent context.Context, hostPo
break
}
select {
case <-input.Context().Done():
return
default:
}
// resize check point - nop if there are no changes
if shouldFollowGlobal && sg.Size != request.options.Options.PayloadConcurrency {
sg.Resize(request.options.Options.PayloadConcurrency)
@ -486,7 +499,7 @@ func (request *Request) executeRequestWithPayloads(hostPort string, input *conte
}
results, err := request.options.JsCompiler.ExecuteWithOptions(request.scriptCompiled, argsCopy,
&compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.Code})
&compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.Code, Context: input.Context()})
if err != nil {
// shouldn't fail even if it returned error instead create a failure event
results = compiler.ExecuteResult{"success": false, "error": err.Error()}

View File

@ -139,6 +139,12 @@ func (request *Request) executeOnTarget(input *contextargs.Context, visited maps
variables = generators.MergeMaps(variablesMap, variables, request.options.Constants)
for _, kv := range request.addresses {
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
actualAddress := replacer.Replace(kv.address, variables)
if visited.Has(actualAddress) && !request.options.Options.DisableClustering {
@ -186,6 +192,12 @@ func (request *Request) executeAddress(variables map[string]interface{}, actualA
break
}
select {
case <-input.Context().Done():
return input.Context().Err()
default:
}
// resize check point - nop if there are no changes
if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency {
swg.Resize(request.options.Options.PayloadConcurrency)

View File

@ -1,6 +1,7 @@
package network
import (
"context"
"encoding/hex"
"fmt"
"net/http"
@ -65,7 +66,7 @@ func TestNetworkExecuteWithResults(t *testing.T) {
t.Run("domain-valid", func(t *testing.T) {
metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(parsed.Host)
ctxArgs := contextargs.NewWithInput(context.Background(), parsed.Host)
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
finalEvent = event
})
@ -81,7 +82,7 @@ func TestNetworkExecuteWithResults(t *testing.T) {
t.Run("invalid-port-override", func(t *testing.T) {
metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput("127.0.0.1:11211")
ctxArgs := contextargs.NewWithInput(context.Background(), "127.0.0.1:11211")
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
finalEvent = event
})
@ -95,7 +96,7 @@ func TestNetworkExecuteWithResults(t *testing.T) {
t.Run("hex-to-string", func(t *testing.T) {
metadata := make(output.InternalEvent)
previous := make(output.InternalEvent)
ctxArgs := contextargs.NewWithInput(parsed.Host)
ctxArgs := contextargs.NewWithInput(context.Background(), parsed.Host)
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
finalEvent = event
})

View File

@ -1,6 +1,7 @@
package protocols
import (
"context"
"encoding/base64"
"sync/atomic"
@ -173,7 +174,7 @@ func (e *ExecutorOptions) GetTemplateCtx(input *contextargs.MetaInput) *contexta
templateCtx, ok := e.templateCtxStore.Get(scanId)
if !ok {
// if template context does not exist create new and add it to store and return it
templateCtx = contextargs.New()
templateCtx = contextargs.New(context.Background())
templateCtx.MetaInput = input
_ = e.templateCtxStore.Set(scanId, templateCtx)
}

View File

@ -1,6 +1,7 @@
package ssl
import (
"context"
"testing"
"github.com/stretchr/testify/require"
@ -28,7 +29,7 @@ func TestSSLProtocol(t *testing.T) {
require.Nil(t, err, "could not compile ssl request")
var gotEvent output.InternalEvent
ctxArgs := contextargs.NewWithInput("scanme.sh:443")
ctxArgs := contextargs.NewWithInput(context.Background(), "scanme.sh:443")
err = request.ExecuteWithResults(ctxArgs, nil, nil, func(event *output.InternalWrappedEvent) {
gotEvent = event.InternalEvent
})

View File

@ -1,6 +1,7 @@
package utils
import (
"context"
"reflect"
"testing"
@ -49,7 +50,7 @@ func TestHTTPVariables(t *testing.T) {
require.Equal(t, values["Hostname"], "foobar.com", "incorrect hostname")
baseURL = "http://scanme.sh"
ctxArgs := contextargs.NewWithInput(baseURL)
ctxArgs := contextargs.NewWithInput(context.Background(), baseURL)
ctxArgs.MetaInput.CustomIP = "1.2.3.4"
values = GenerateVariablesWithContextArgs(ctxArgs, true)

View File

@ -10,7 +10,6 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"
)
type ScanContextOption func(*ScanContext)
func WithEvents() ScanContextOption {
@ -20,7 +19,8 @@ func WithEvents() ScanContextOption {
}
type ScanContext struct {
context.Context
ctx context.Context
// exported / configurable fields
Input *contextargs.Context
@ -43,8 +43,13 @@ type ScanContext struct {
}
// NewScanContext creates a new scan context using input
func NewScanContext(input *contextargs.Context) *ScanContext {
return &ScanContext{Input: input}
func NewScanContext(ctx context.Context, input *contextargs.Context) *ScanContext {
return &ScanContext{ctx: ctx, Input: input}
}
// Context returns the context of the scan
func (s *ScanContext) Context() context.Context {
return s.ctx
}
// GenerateResult returns final results slice from all events

View File

@ -303,7 +303,7 @@ func (e *ClusterExecuter) Execute(ctx *scan.ScanContext) (bool, error) {
// ExecuteWithResults executes the protocol requests and returns results instead of writing them.
func (e *ClusterExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) {
scanCtx := scan.NewScanContext(ctx.Input)
scanCtx := scan.NewScanContext(ctx.Context(), ctx.Input)
dynamicValues := make(map[string]interface{})
inputItem := ctx.Input.Clone()

View File

@ -174,6 +174,12 @@ func (f *FlowExecutor) Compile() error {
// ExecuteWithResults executes the flow and returns results
func (f *FlowExecutor) ExecuteWithResults(ctx *scan.ScanContext) error {
select {
case <-ctx.Context().Done():
return ctx.Context().Err()
default:
}
f.ctx.Input = ctx.Input
// -----Load all types of variables-----
// add all input args to template context

View File

@ -55,8 +55,8 @@ func TestFlowTemplateWithIndex(t *testing.T) {
err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template")
input := contextargs.NewWithInput("hackerone.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "hackerone.com")
ctx := scan.NewScanContext(context.Background(), input)
gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template")
require.True(t, gotresults)
@ -74,8 +74,8 @@ func TestFlowTemplateWithID(t *testing.T) {
err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template")
target := contextargs.NewWithInput("hackerone.com")
ctx := scan.NewScanContext(target)
target := contextargs.NewWithInput(context.Background(), "hackerone.com")
ctx := scan.NewScanContext(context.Background(), target)
gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template")
require.True(t, gotresults)
@ -96,8 +96,8 @@ func TestFlowWithProtoPrefix(t *testing.T) {
err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template")
input := contextargs.NewWithInput("hackerone.com")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "hackerone.com")
ctx := scan.NewScanContext(context.Background(), input)
gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template")
require.True(t, gotresults)
@ -116,8 +116,8 @@ func TestFlowWithConditionNegative(t *testing.T) {
err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template")
input := contextargs.NewWithInput("scanme.sh")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "scanme.sh")
ctx := scan.NewScanContext(context.Background(), input)
// expect no results and verify thant dns request is executed and http is not
gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template")
@ -137,8 +137,8 @@ func TestFlowWithConditionPositive(t *testing.T) {
err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template")
input := contextargs.NewWithInput("blog.projectdiscovery.io")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io")
ctx := scan.NewScanContext(context.Background(), input)
// positive match . expect results also verify that both dns() and http() were executed
gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template")
@ -158,8 +158,8 @@ func TestFlowWithNoMatchers(t *testing.T) {
err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template")
input := contextargs.NewWithInput("blog.projectdiscovery.io")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io")
ctx := scan.NewScanContext(context.Background(), input)
// positive match . expect results also verify that both dns() and http() were executed
gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template")
@ -174,8 +174,8 @@ func TestFlowWithNoMatchers(t *testing.T) {
err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template")
anotherInput := contextargs.NewWithInput("blog.projectdiscovery.io")
anotherCtx := scan.NewScanContext(anotherInput)
anotherInput := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io")
anotherCtx := scan.NewScanContext(context.Background(), anotherInput)
// positive match . expect results also verify that both dns() and http() were executed
gotresults, err = Template.Executer.Execute(anotherCtx)
require.Nil(t, err, "could not execute template")

View File

@ -45,6 +45,12 @@ func (g *Generic) ExecuteWithResults(ctx *scan.ScanContext) error {
previous := mapsutil.NewSyncLockMap[string, any]()
for _, req := range g.requests {
select {
case <-ctx.Context().Done():
return ctx.Context().Err()
default:
}
inputItem := ctx.Input.Clone()
if g.options.InputHelper != nil && ctx.Input.MetaInput.Input != "" {
if inputItem.MetaInput.Input = g.options.InputHelper.Transform(inputItem.MetaInput.Input, req.Type()); inputItem.MetaInput.Input == "" {

View File

@ -44,6 +44,12 @@ func (m *MultiProtocol) Compile() error {
// ExecuteWithResults executes the template and returns results
func (m *MultiProtocol) ExecuteWithResults(ctx *scan.ScanContext) error {
select {
case <-ctx.Context().Done():
return ctx.Context().Err()
default:
}
// put all readonly args into template context
m.options.GetTemplateCtx(ctx.Input.MetaInput).Merge(m.readOnlyArgs)
@ -96,6 +102,12 @@ func (m *MultiProtocol) ExecuteWithResults(ctx *scan.ScanContext) error {
// execute all protocols in the queue
for _, req := range m.requests {
select {
case <-ctx.Context().Done():
return ctx.Context().Err()
default:
}
values := m.options.GetTemplateCtx(ctx.Input.MetaInput).GetAll()
err := req.ExecuteWithResults(ctx.Input, output.InternalEvent(values), nil, multiProtoCallback)
// if error skip execution of next protocols

View File

@ -54,8 +54,8 @@ func TestMultiProtoWithDynamicExtractor(t *testing.T) {
err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template")
input := contextargs.NewWithInput("blog.projectdiscovery.io")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io")
ctx := scan.NewScanContext(context.Background(), input)
gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template")
require.True(t, gotresults)
@ -71,8 +71,8 @@ func TestMultiProtoWithProtoPrefix(t *testing.T) {
err = Template.Executer.Compile()
require.Nil(t, err, "could not compile template")
input := contextargs.NewWithInput("blog.projectdiscovery.io")
ctx := scan.NewScanContext(input)
input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io")
ctx := scan.NewScanContext(context.Background(), input)
gotresults, err := Template.Executer.Execute(ctx)
require.Nil(t, err, "could not execute template")
require.True(t, gotresults)