mirror of
https://github.com/projectdiscovery/nuclei.git
synced 2025-12-17 15:55:26 +00:00
feat: added support for context cancellation to engine (#5096)
* feat: added support for context cancellation to engine * misc * feat: added contexts everywhere * misc * misc * use granular http timeouts and increase http timeout to 30s using multiplier * track response header timeout in mhe * update responseHeaderTimeout to 5sec * skip failing windows test --------- Co-authored-by: Tarun Koyalwar <tarun@projectdiscovery.io>
This commit is contained in:
parent
3dfcec0a36
commit
0b82e8b7aa
@ -128,7 +128,7 @@ func executeNucleiAsLibrary(templatePath, templateURL string) ([]string, error)
|
||||
}
|
||||
store.Load()
|
||||
|
||||
_ = engine.Execute(store.Templates(), provider.NewSimpleInputProviderWithUrls(templateURL))
|
||||
_ = engine.Execute(context.Background(), store.Templates(), provider.NewSimpleInputProviderWithUrls(templateURL))
|
||||
engine.WorkPool().Wait() // Wait for the scan to finish
|
||||
|
||||
return results, nil
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package runner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/authprovider/authx"
|
||||
@ -71,7 +72,8 @@ func GetLazyAuthFetchCallback(opts *AuthLazyFetchOptions) authx.LazyFetchSecret
|
||||
tmpl := tmpls[0]
|
||||
// add args to tmpl here
|
||||
vars := map[string]interface{}{}
|
||||
ctx := scan.NewScanContext(contextargs.NewWithInput(d.Input))
|
||||
mainCtx := context.Background()
|
||||
ctx := scan.NewScanContext(mainCtx, contextargs.NewWithInput(mainCtx, d.Input))
|
||||
for _, v := range d.Variables {
|
||||
vars[v.Key] = v.Value
|
||||
ctx.Input.Add(v.Key, v.Value)
|
||||
|
||||
@ -669,7 +669,7 @@ func (r *Runner) executeTemplatesInput(store *loader.Store, engine *core.Engine)
|
||||
if r.inputProvider == nil {
|
||||
return nil, errors.New("no input provider found")
|
||||
}
|
||||
results := engine.ExecuteScanWithOpts(finalTemplates, r.inputProvider, r.options.DisableClustering)
|
||||
results := engine.ExecuteScanWithOpts(context.Background(), finalTemplates, r.inputProvider, r.options.DisableClustering)
|
||||
return results, nil
|
||||
}
|
||||
|
||||
|
||||
@ -138,7 +138,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOpts(targets []string, opts ..
|
||||
engine := core.New(tmpEngine.opts)
|
||||
engine.SetExecuterOptions(unsafeOpts.executerOpts)
|
||||
|
||||
_ = engine.ExecuteScanWithOpts(store.Templates(), inputProvider, false)
|
||||
_ = engine.ExecuteScanWithOpts(context.Background(), store.Templates(), inputProvider, false)
|
||||
|
||||
engine.WorkPool().Wait()
|
||||
return nil
|
||||
|
||||
@ -3,6 +3,7 @@ package nuclei
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/authprovider"
|
||||
@ -210,7 +211,7 @@ func (e *NucleiEngine) ExecuteWithCallback(callback ...func(event *output.Result
|
||||
}
|
||||
e.resultCallbacks = append(e.resultCallbacks, filtered...)
|
||||
|
||||
_ = e.engine.ExecuteScanWithOpts(e.store.Templates(), e.inputProvider, false)
|
||||
_ = e.engine.ExecuteScanWithOpts(context.Background(), e.store.Templates(), e.inputProvider, false)
|
||||
defer e.engine.WorkPool().Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
@ -20,18 +21,18 @@ import (
|
||||
//
|
||||
// All the execution logic for the templates/workflows happens in this part
|
||||
// of the engine.
|
||||
func (e *Engine) Execute(templates []*templates.Template, target provider.InputProvider) *atomic.Bool {
|
||||
return e.ExecuteScanWithOpts(templates, target, false)
|
||||
func (e *Engine) Execute(ctx context.Context, templates []*templates.Template, target provider.InputProvider) *atomic.Bool {
|
||||
return e.ExecuteScanWithOpts(ctx, templates, target, false)
|
||||
}
|
||||
|
||||
// ExecuteWithResults a list of templates with results
|
||||
func (e *Engine) ExecuteWithResults(templatesList []*templates.Template, target provider.InputProvider, callback func(*output.ResultEvent)) *atomic.Bool {
|
||||
func (e *Engine) ExecuteWithResults(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider, callback func(*output.ResultEvent)) *atomic.Bool {
|
||||
e.Callback = callback
|
||||
return e.ExecuteScanWithOpts(templatesList, target, false)
|
||||
return e.ExecuteScanWithOpts(ctx, templatesList, target, false)
|
||||
}
|
||||
|
||||
// ExecuteScanWithOpts executes scan with given scanStrategy
|
||||
func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target provider.InputProvider, noCluster bool) *atomic.Bool {
|
||||
func (e *Engine) ExecuteScanWithOpts(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider, noCluster bool) *atomic.Bool {
|
||||
results := &atomic.Bool{}
|
||||
selfcontainedWg := &sync.WaitGroup{}
|
||||
|
||||
@ -83,14 +84,14 @@ func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target
|
||||
}
|
||||
|
||||
// Execute All SelfContained in parallel
|
||||
e.executeAllSelfContained(selfContained, results, selfcontainedWg)
|
||||
e.executeAllSelfContained(ctx, selfContained, results, selfcontainedWg)
|
||||
|
||||
strategyResult := &atomic.Bool{}
|
||||
switch e.options.ScanStrategy {
|
||||
case scanstrategy.TemplateSpray.String():
|
||||
strategyResult = e.executeTemplateSpray(filtered, target)
|
||||
strategyResult = e.executeTemplateSpray(ctx, filtered, target)
|
||||
case scanstrategy.HostSpray.String():
|
||||
strategyResult = e.executeHostSpray(filtered, target)
|
||||
strategyResult = e.executeHostSpray(ctx, filtered, target)
|
||||
}
|
||||
|
||||
results.CompareAndSwap(false, strategyResult.Load())
|
||||
@ -100,7 +101,7 @@ func (e *Engine) ExecuteScanWithOpts(templatesList []*templates.Template, target
|
||||
}
|
||||
|
||||
// executeTemplateSpray executes scan using template spray strategy where targets are iterated over each template
|
||||
func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
|
||||
func (e *Engine) executeTemplateSpray(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
|
||||
results := &atomic.Bool{}
|
||||
|
||||
// wp is workpool that contains different waitgroups for
|
||||
@ -108,6 +109,12 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe
|
||||
wp := e.GetWorkPool()
|
||||
|
||||
for _, template := range templatesList {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return results
|
||||
default:
|
||||
}
|
||||
|
||||
// resize check point - nop if there are no changes
|
||||
wp.RefreshWithConfig(e.GetWorkPoolConfig())
|
||||
|
||||
@ -125,7 +132,7 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe
|
||||
// All other request types are executed here
|
||||
// Note: executeTemplateWithTargets creates goroutines and blocks
|
||||
// given template is executed on all targets
|
||||
e.executeTemplateWithTargets(tpl, target, results)
|
||||
e.executeTemplateWithTargets(ctx, tpl, target, results)
|
||||
}(template)
|
||||
}
|
||||
wp.Wait()
|
||||
@ -133,15 +140,21 @@ func (e *Engine) executeTemplateSpray(templatesList []*templates.Template, targe
|
||||
}
|
||||
|
||||
// executeHostSpray executes scan using host spray strategy where templates are iterated over each target
|
||||
func (e *Engine) executeHostSpray(templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
|
||||
func (e *Engine) executeHostSpray(ctx context.Context, templatesList []*templates.Template, target provider.InputProvider) *atomic.Bool {
|
||||
results := &atomic.Bool{}
|
||||
wp, _ := syncutil.New(syncutil.WithSize(e.options.BulkSize + e.options.HeadlessBulkSize))
|
||||
|
||||
target.Iterate(func(value *contextargs.MetaInput) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
default:
|
||||
}
|
||||
|
||||
wp.Add()
|
||||
go func(targetval *contextargs.MetaInput) {
|
||||
defer wp.Done()
|
||||
e.executeTemplatesOnTarget(templatesList, targetval, results)
|
||||
e.executeTemplatesOnTarget(ctx, templatesList, targetval, results)
|
||||
}(value)
|
||||
return true
|
||||
})
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
@ -17,14 +18,14 @@ import (
|
||||
// Executors are low level executors that deals with template execution on a target
|
||||
|
||||
// executeAllSelfContained executes all self contained templates that do not use `target`
|
||||
func (e *Engine) executeAllSelfContained(alltemplates []*templates.Template, results *atomic.Bool, sg *sync.WaitGroup) {
|
||||
func (e *Engine) executeAllSelfContained(ctx context.Context, alltemplates []*templates.Template, results *atomic.Bool, sg *sync.WaitGroup) {
|
||||
for _, v := range alltemplates {
|
||||
sg.Add(1)
|
||||
go func(template *templates.Template) {
|
||||
defer sg.Done()
|
||||
var err error
|
||||
var match bool
|
||||
ctx := scan.NewScanContext(contextargs.New())
|
||||
ctx := scan.NewScanContext(ctx, contextargs.New(ctx))
|
||||
if e.Callback != nil {
|
||||
if results, err := template.Executer.ExecuteWithResults(ctx); err != nil {
|
||||
for _, result := range results {
|
||||
@ -45,7 +46,7 @@ func (e *Engine) executeAllSelfContained(alltemplates []*templates.Template, res
|
||||
}
|
||||
|
||||
// executeTemplateWithTarget executes a given template on x targets (with a internal targetpool(i.e concurrency))
|
||||
func (e *Engine) executeTemplateWithTargets(template *templates.Template, target provider.InputProvider, results *atomic.Bool) {
|
||||
func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templates.Template, target provider.InputProvider, results *atomic.Bool) {
|
||||
// this is target pool i.e max target to execute
|
||||
wg := e.workPool.InputPool(template.Type())
|
||||
|
||||
@ -77,6 +78,12 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target
|
||||
}
|
||||
|
||||
target.Iterate(func(scannedValue *contextargs.MetaInput) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false // exit
|
||||
default:
|
||||
}
|
||||
|
||||
// Best effort to track the host progression
|
||||
// skips indexes lower than the minimum in-flight at interruption time
|
||||
var skip bool
|
||||
@ -114,9 +121,9 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target
|
||||
|
||||
var match bool
|
||||
var err error
|
||||
ctxArgs := contextargs.New()
|
||||
ctxArgs := contextargs.New(ctx)
|
||||
ctxArgs.MetaInput = value
|
||||
ctx := scan.NewScanContext(ctxArgs)
|
||||
ctx := scan.NewScanContext(ctx, ctxArgs)
|
||||
switch template.Type() {
|
||||
case types.WorkflowProtocol:
|
||||
match = e.executeWorkflow(ctx, template.CompiledWorkflow)
|
||||
@ -149,7 +156,7 @@ func (e *Engine) executeTemplateWithTargets(template *templates.Template, target
|
||||
}
|
||||
|
||||
// executeTemplatesOnTarget execute given templates on given single target
|
||||
func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, target *contextargs.MetaInput, results *atomic.Bool) {
|
||||
func (e *Engine) executeTemplatesOnTarget(ctx context.Context, alltemplates []*templates.Template, target *contextargs.MetaInput, results *atomic.Bool) {
|
||||
// all templates are executed on single target
|
||||
|
||||
// wp is workpool that contains different waitgroups for
|
||||
@ -158,6 +165,12 @@ func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, ta
|
||||
wp := e.GetWorkPool()
|
||||
|
||||
for _, tpl := range alltemplates {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// resize check point - nop if there are no changes
|
||||
wp.RefreshWithConfig(e.GetWorkPoolConfig())
|
||||
|
||||
@ -173,9 +186,9 @@ func (e *Engine) executeTemplatesOnTarget(alltemplates []*templates.Template, ta
|
||||
|
||||
var match bool
|
||||
var err error
|
||||
ctxArgs := contextargs.New()
|
||||
ctxArgs := contextargs.New(ctx)
|
||||
ctxArgs.MetaInput = value
|
||||
ctx := scan.NewScanContext(ctxArgs)
|
||||
ctx := scan.NewScanContext(ctx, ctxArgs)
|
||||
switch template.Type() {
|
||||
case types.WorkflowProtocol:
|
||||
match = e.executeWorkflow(ctx, template.CompiledWorkflow)
|
||||
@ -230,9 +243,11 @@ func (e *ChildExecuter) Execute(template *templates.Template, value *contextargs
|
||||
go func(tpl *templates.Template) {
|
||||
defer wg.Done()
|
||||
|
||||
ctxArgs := contextargs.New()
|
||||
// TODO: Workflows are a no-op for now. We need to
|
||||
// implement them in the future with context cancellation
|
||||
ctxArgs := contextargs.New(context.Background())
|
||||
ctxArgs.MetaInput = value
|
||||
ctx := scan.NewScanContext(ctxArgs)
|
||||
ctx := scan.NewScanContext(context.Background(), ctxArgs)
|
||||
match, err := template.Executer.Execute(ctx)
|
||||
if err != nil {
|
||||
gologger.Warning().Msgf("[%s] Could not execute step: %s\n", e.e.executerOpts.Colorizer.BrightBlue(template.ID), err)
|
||||
|
||||
@ -21,7 +21,7 @@ func (e *Engine) executeWorkflow(ctx *scan.ScanContext, w *workflows.Workflow) b
|
||||
|
||||
// at this point we should be at the start root execution of a workflow tree, hence we create global shared instances
|
||||
workflowCookieJar, _ := cookiejar.New(nil)
|
||||
ctxArgs := contextargs.New()
|
||||
ctxArgs := contextargs.New(ctx.Context())
|
||||
ctxArgs.MetaInput = ctx.Input.MetaInput
|
||||
ctxArgs.CookieJar = workflowCookieJar
|
||||
|
||||
@ -139,7 +139,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan
|
||||
defer swg.Done()
|
||||
|
||||
// create a new context with the same input but with unset callbacks
|
||||
subCtx := scan.NewScanContext(ctx.Input)
|
||||
subCtx := scan.NewScanContext(ctx.Context(), ctx.Input)
|
||||
if err := e.runWorkflowStep(subtemplate, subCtx, results, swg, w); err != nil {
|
||||
gologger.Warning().Msgf(workflowStepExecutionError, subtemplate.Template, err)
|
||||
}
|
||||
@ -165,7 +165,7 @@ func (e *Engine) runWorkflowStep(template *workflows.WorkflowTemplate, ctx *scan
|
||||
|
||||
go func(template *workflows.WorkflowTemplate) {
|
||||
// create a new context with the same input but with unset callbacks
|
||||
subCtx := scan.NewScanContext(ctx.Input)
|
||||
subCtx := scan.NewScanContext(ctx.Context(), ctx.Input)
|
||||
if err := e.runWorkflowStep(template, subCtx, results, swg, w); err != nil {
|
||||
gologger.Warning().Msgf(workflowStepExecutionError, template.Template, err)
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/model/types/stringslice"
|
||||
@ -25,8 +26,8 @@ func TestWorkflowsSimple(t *testing.T) {
|
||||
}}
|
||||
|
||||
engine := &Engine{}
|
||||
input := contextargs.NewWithInput("https://test.com")
|
||||
ctx := scan.NewScanContext(input)
|
||||
input := contextargs.NewWithInput(context.Background(), "https://test.com")
|
||||
ctx := scan.NewScanContext(context.Background(), input)
|
||||
matched := engine.executeWorkflow(ctx, workflow)
|
||||
require.True(t, matched, "could not get correct match value")
|
||||
}
|
||||
@ -49,8 +50,8 @@ func TestWorkflowsSimpleMultiple(t *testing.T) {
|
||||
}}
|
||||
|
||||
engine := &Engine{}
|
||||
input := contextargs.NewWithInput("https://test.com")
|
||||
ctx := scan.NewScanContext(input)
|
||||
input := contextargs.NewWithInput(context.Background(), "https://test.com")
|
||||
ctx := scan.NewScanContext(context.Background(), input)
|
||||
matched := engine.executeWorkflow(ctx, workflow)
|
||||
require.True(t, matched, "could not get correct match value")
|
||||
|
||||
@ -77,8 +78,8 @@ func TestWorkflowsSubtemplates(t *testing.T) {
|
||||
}}
|
||||
|
||||
engine := &Engine{}
|
||||
input := contextargs.NewWithInput("https://test.com")
|
||||
ctx := scan.NewScanContext(input)
|
||||
input := contextargs.NewWithInput(context.Background(), "https://test.com")
|
||||
ctx := scan.NewScanContext(context.Background(), input)
|
||||
matched := engine.executeWorkflow(ctx, workflow)
|
||||
require.True(t, matched, "could not get correct match value")
|
||||
|
||||
@ -103,8 +104,8 @@ func TestWorkflowsSubtemplatesNoMatch(t *testing.T) {
|
||||
}}
|
||||
|
||||
engine := &Engine{}
|
||||
input := contextargs.NewWithInput("https://test.com")
|
||||
ctx := scan.NewScanContext(input)
|
||||
input := contextargs.NewWithInput(context.Background(), "https://test.com")
|
||||
ctx := scan.NewScanContext(context.Background(), input)
|
||||
matched := engine.executeWorkflow(ctx, workflow)
|
||||
require.False(t, matched, "could not get correct match value")
|
||||
|
||||
@ -134,8 +135,8 @@ func TestWorkflowsSubtemplatesWithMatcher(t *testing.T) {
|
||||
}}
|
||||
|
||||
engine := &Engine{}
|
||||
input := contextargs.NewWithInput("https://test.com")
|
||||
ctx := scan.NewScanContext(input)
|
||||
input := contextargs.NewWithInput(context.Background(), "https://test.com")
|
||||
ctx := scan.NewScanContext(context.Background(), input)
|
||||
matched := engine.executeWorkflow(ctx, workflow)
|
||||
require.True(t, matched, "could not get correct match value")
|
||||
|
||||
@ -165,8 +166,8 @@ func TestWorkflowsSubtemplatesWithMatcherNoMatch(t *testing.T) {
|
||||
}}
|
||||
|
||||
engine := &Engine{}
|
||||
input := contextargs.NewWithInput("https://test.com")
|
||||
ctx := scan.NewScanContext(input)
|
||||
input := contextargs.NewWithInput(context.Background(), "https://test.com")
|
||||
ctx := scan.NewScanContext(context.Background(), input)
|
||||
matched := engine.executeWorkflow(ctx, workflow)
|
||||
require.False(t, matched, "could not get correct match value")
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ package list
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
@ -77,6 +78,9 @@ func (m *mockDnsHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
}
|
||||
|
||||
func Test_scanallips_normalizeStoreInputValue(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Skipping test see: https://github.com/projectdiscovery/nuclei/issues/5097")
|
||||
}
|
||||
srv := &dns.Server{Addr: ":" + strconv.Itoa(61234), Net: "udp"}
|
||||
srv.Handler = &mockDnsHandler{}
|
||||
|
||||
|
||||
@ -37,6 +37,8 @@ type ExecuteOptions struct {
|
||||
// Source is original source of the script
|
||||
Source *string
|
||||
|
||||
Context context.Context
|
||||
|
||||
// Manually exported objects
|
||||
exports map[string]interface{}
|
||||
}
|
||||
@ -77,13 +79,13 @@ func (c *Compiler) Execute(code string, args *ExecuteArgs) (ExecuteResult, error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.ExecuteWithOptions(p, args, &ExecuteOptions{})
|
||||
return c.ExecuteWithOptions(p, args, &ExecuteOptions{Context: context.Background()})
|
||||
}
|
||||
|
||||
// ExecuteWithOptions executes a script with the provided options.
|
||||
func (c *Compiler) ExecuteWithOptions(program *goja.Program, args *ExecuteArgs, opts *ExecuteOptions) (ExecuteResult, error) {
|
||||
if opts == nil {
|
||||
opts = &ExecuteOptions{}
|
||||
opts = &ExecuteOptions{Context: context.Background()}
|
||||
}
|
||||
if args == nil {
|
||||
args = NewExecuteArgs()
|
||||
@ -105,7 +107,7 @@ func (c *Compiler) ExecuteWithOptions(program *goja.Program, args *ExecuteArgs,
|
||||
}
|
||||
|
||||
// execute with context and timeout
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(opts.Timeout)*time.Second)
|
||||
ctx, cancel := context.WithTimeout(opts.Context, time.Duration(opts.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
// execute the script
|
||||
results, err := contextutil.ExecFuncWithTwoReturns(ctx, func() (val goja.Value, err error) {
|
||||
|
||||
@ -199,6 +199,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
|
||||
Source: &request.PreCondition,
|
||||
Callback: registerPreConditionFunctions,
|
||||
Cleanup: cleanUpPreConditionFunctions,
|
||||
Context: input.Context(),
|
||||
})
|
||||
if err != nil {
|
||||
return errorutil.NewWithTag(request.TemplateID, "could not execute pre-condition: %s", err)
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
package code
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -31,7 +32,7 @@ func TestCodeProtocol(t *testing.T) {
|
||||
require.Nil(t, err, "could not compile code request")
|
||||
|
||||
var gotEvent output.InternalEvent
|
||||
ctxArgs := contextargs.NewWithInput("")
|
||||
ctxArgs := contextargs.NewWithInput(context.Background(), "")
|
||||
err = request.ExecuteWithResults(ctxArgs, nil, nil, func(event *output.InternalWrappedEvent) {
|
||||
gotEvent = event.InternalEvent
|
||||
})
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package automaticscan
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
@ -189,7 +190,7 @@ func (s *Service) executeAutomaticScanOnTarget(input *contextargs.MetaInput) {
|
||||
execOptions.Progress = &testutils.MockProgressClient{} // stats are not supported yet due to centralized logic and cannot be reinitialized
|
||||
eng.SetExecuterOptions(execOptions)
|
||||
|
||||
tmp := eng.ExecuteScanWithOpts(finalTemplates, provider.NewSimpleInputProviderWithUrls(input.Input), true)
|
||||
tmp := eng.ExecuteScanWithOpts(context.Background(), finalTemplates, provider.NewSimpleInputProviderWithUrls(input.Input), true)
|
||||
s.hasResults.Store(tmp.Load())
|
||||
}
|
||||
|
||||
@ -244,7 +245,9 @@ func (s *Service) getTagsUsingWappalyzer(input *contextargs.MetaInput) []string
|
||||
|
||||
// getTagsUsingDetectionTemplates returns tags using detection templates
|
||||
func (s *Service) getTagsUsingDetectionTemplates(input *contextargs.MetaInput) ([]string, int) {
|
||||
ctxArgs := contextargs.NewWithInput(input.Input)
|
||||
ctx := context.Background()
|
||||
|
||||
ctxArgs := contextargs.NewWithInput(ctx, input.Input)
|
||||
|
||||
// execute tech detection templates on target
|
||||
tags := map[string]struct{}{}
|
||||
@ -256,7 +259,7 @@ func (s *Service) getTagsUsingDetectionTemplates(input *contextargs.MetaInput) (
|
||||
sg.Add()
|
||||
go func(template *templates.Template) {
|
||||
defer sg.Done()
|
||||
ctx := scan.NewScanContext(ctxArgs)
|
||||
ctx := scan.NewScanContext(ctx, ctxArgs)
|
||||
ctx.OnResult = func(event *output.InternalWrappedEvent) {
|
||||
if event == nil {
|
||||
return
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package contextargs
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/cookiejar"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@ -19,6 +20,8 @@ var (
|
||||
|
||||
// Context implements a shared context struct to share information across multiple templates within a workflow
|
||||
type Context struct {
|
||||
ctx context.Context
|
||||
|
||||
// Meta is the target for the executor
|
||||
MetaInput *MetaInput
|
||||
|
||||
@ -30,17 +33,18 @@ type Context struct {
|
||||
}
|
||||
|
||||
// Create a new contextargs instance
|
||||
func New() *Context {
|
||||
return NewWithInput("")
|
||||
func New(ctx context.Context) *Context {
|
||||
return NewWithInput(ctx, "")
|
||||
}
|
||||
|
||||
// Create a new contextargs instance with input string
|
||||
func NewWithInput(input string) *Context {
|
||||
func NewWithInput(ctx context.Context, input string) *Context {
|
||||
jar, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
gologger.Error().Msgf("contextargs: could not create cookie jar: %s\n", err)
|
||||
}
|
||||
return &Context{
|
||||
ctx: ctx,
|
||||
MetaInput: &MetaInput{Input: input},
|
||||
CookieJar: jar,
|
||||
args: &mapsutil.SyncLockMap[string, interface{}]{
|
||||
@ -50,6 +54,11 @@ func NewWithInput(input string) *Context {
|
||||
}
|
||||
}
|
||||
|
||||
// Context returns the context of the current contextargs
|
||||
func (ctx *Context) Context() context.Context {
|
||||
return ctx.ctx
|
||||
}
|
||||
|
||||
// Set the specific key-value pair
|
||||
func (ctx *Context) Set(key string, value interface{}) {
|
||||
_ = ctx.args.Set(key, value)
|
||||
@ -158,6 +167,7 @@ func (ctx *Context) HasArgs() bool {
|
||||
|
||||
func (ctx *Context) Clone() *Context {
|
||||
newCtx := &Context{
|
||||
ctx: ctx.ctx,
|
||||
MetaInput: ctx.MetaInput.Clone(),
|
||||
args: ctx.args.Clone(),
|
||||
CookieJar: ctx.CookieJar,
|
||||
|
||||
@ -124,7 +124,7 @@ func (c *Cache) MarkFailed(value string, err error) {
|
||||
_ = c.failedTargets.Set(finalValue, existingCacheItemValue)
|
||||
}
|
||||
|
||||
var reCheckError = regexp.MustCompile(`(no address found for host|Client\.Timeout exceeded while awaiting headers|could not resolve host|connection refused|connection reset by peer|i/o timeout|could not connect to any address found for host)`)
|
||||
var reCheckError = regexp.MustCompile(`(no address found for host|Client\.Timeout exceeded while awaiting headers|could not resolve host|connection refused|connection reset by peer|i/o timeout|could not connect to any address found for host|timeout awaiting response headers)`)
|
||||
|
||||
// checkError checks if an error represents a type that should be
|
||||
// added to the host skipping table.
|
||||
|
||||
@ -80,6 +80,12 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata,
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-input.Context().Done():
|
||||
return input.Context().Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// resize check point - nop if there are no changes
|
||||
if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency {
|
||||
swg.Resize(request.options.Options.PayloadConcurrency)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -54,7 +55,7 @@ func TestDNSExecuteWithResults(t *testing.T) {
|
||||
t.Run("domain-valid", func(t *testing.T) {
|
||||
metadata := make(output.InternalEvent)
|
||||
previous := make(output.InternalEvent)
|
||||
ctxArgs := contextargs.NewWithInput("example.com")
|
||||
ctxArgs := contextargs.NewWithInput(context.Background(), "example.com")
|
||||
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
|
||||
finalEvent = event
|
||||
})
|
||||
@ -70,7 +71,7 @@ func TestDNSExecuteWithResults(t *testing.T) {
|
||||
t.Run("url-to-domain", func(t *testing.T) {
|
||||
metadata := make(output.InternalEvent)
|
||||
previous := make(output.InternalEvent)
|
||||
err := request.ExecuteWithResults(contextargs.NewWithInput("https://example.com"), metadata, previous, func(event *output.InternalWrappedEvent) {
|
||||
err := request.ExecuteWithResults(contextargs.NewWithInput(context.Background(), "https://example.com"), metadata, previous, func(event *output.InternalWrappedEvent) {
|
||||
finalEvent = event
|
||||
})
|
||||
require.Nil(t, err, "could not execute dns request")
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package file
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
@ -67,7 +68,7 @@ func TestFileExecuteWithResults(t *testing.T) {
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
metadata := make(output.InternalEvent)
|
||||
previous := make(output.InternalEvent)
|
||||
ctxArgs := contextargs.NewWithInput(tempDir)
|
||||
ctxArgs := contextargs.NewWithInput(context.Background(), tempDir)
|
||||
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
|
||||
finalEvent = event
|
||||
})
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package engine
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
@ -595,7 +596,7 @@ func testHeadless(t *testing.T, actions []*Action, timeout time.Duration, handle
|
||||
ts := httptest.NewServer(http.HandlerFunc(handler))
|
||||
defer ts.Close()
|
||||
|
||||
input := contextargs.NewWithInput(ts.URL)
|
||||
input := contextargs.NewWithInput(context.Background(), ts.URL)
|
||||
input.CookieJar, err = cookiejar.New(nil)
|
||||
require.Nil(t, err)
|
||||
|
||||
@ -674,7 +675,7 @@ func TestBlockedHeadlessURLS(t *testing.T) {
|
||||
{ActionType: ActionTypeHolder{ActionType: ActionWaitLoad}},
|
||||
}
|
||||
|
||||
data, page, err := instance.Run(contextargs.NewWithInput(ts.URL), actions, nil, &Options{Timeout: 20 * time.Second, Options: opts}) // allow file access in test
|
||||
data, page, err := instance.Run(contextargs.NewWithInput(context.Background(), ts.URL), actions, nil, &Options{Timeout: 20 * time.Second, Options: opts}) // allow file access in test
|
||||
require.Error(t, err, "expected error for url %s got %v", testcase, data)
|
||||
require.True(t, stringsutil.ContainsAny(err.Error(), "net::ERR_ACCESS_DENIED", "failed to parse url", "Cannot navigate to invalid URL", "net::ERR_ABORTED", "net::ERR_INVALID_URL"), "found different error %v for testcases %v", err, testcase)
|
||||
require.Len(t, data, 0, "expected no data for url %s got %v", testcase, data)
|
||||
|
||||
@ -44,7 +44,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, metadata,
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
input = contextargs.NewWithInput(url)
|
||||
input = contextargs.NewWithInput(input.Context(), url)
|
||||
}
|
||||
|
||||
if request.options.Browser.UserAgent() == "" {
|
||||
|
||||
@ -40,7 +40,7 @@ func TestMakeRequestFromModal(t *testing.T) {
|
||||
|
||||
generator := request.newGenerator(false)
|
||||
inputData, payloads, _ := generator.nextValue()
|
||||
req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{})
|
||||
req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
|
||||
require.Nil(t, err, "could not make http request")
|
||||
if req.request.URL == nil {
|
||||
t.Fatalf("url is nil in generator make")
|
||||
@ -70,13 +70,13 @@ func TestMakeRequestFromModalTrimSuffixSlash(t *testing.T) {
|
||||
|
||||
generator := request.newGenerator(false)
|
||||
inputData, payloads, _ := generator.nextValue()
|
||||
req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com/test.php"), inputData, payloads, map[string]interface{}{})
|
||||
req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com/test.php"), inputData, payloads, map[string]interface{}{})
|
||||
require.Nil(t, err, "could not make http request")
|
||||
require.Equal(t, "https://example.com/test.php?query=example", req.request.URL.String(), "could not get correct request path")
|
||||
|
||||
generator = request.newGenerator(false)
|
||||
inputData, payloads, _ = generator.nextValue()
|
||||
req, err = generator.Make(context.Background(), contextargs.NewWithInput("https://example.com/test/"), inputData, payloads, map[string]interface{}{})
|
||||
req, err = generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com/test/"), inputData, payloads, map[string]interface{}{})
|
||||
require.Nil(t, err, "could not make http request")
|
||||
require.Equal(t, "https://example.com/test/?query=example", req.request.URL.String(), "could not get correct request path")
|
||||
}
|
||||
@ -110,13 +110,13 @@ Accept-Encoding: gzip`},
|
||||
|
||||
generator := request.newGenerator(false)
|
||||
inputData, payloads, _ := generator.nextValue()
|
||||
req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{})
|
||||
req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
|
||||
require.Nil(t, err, "could not make http request")
|
||||
authorization := req.request.Header.Get("Authorization")
|
||||
require.Equal(t, "Basic admin:admin", authorization, "could not get correct authorization headers from raw")
|
||||
|
||||
inputData, payloads, _ = generator.nextValue()
|
||||
req, err = generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{})
|
||||
req, err = generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
|
||||
require.Nil(t, err, "could not make http request")
|
||||
authorization = req.request.Header.Get("Authorization")
|
||||
require.Equal(t, "Basic admin:guest", authorization, "could not get correct authorization headers from raw")
|
||||
@ -151,13 +151,13 @@ Accept-Encoding: gzip`},
|
||||
|
||||
generator := request.newGenerator(false)
|
||||
inputData, payloads, _ := generator.nextValue()
|
||||
req, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{})
|
||||
req, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
|
||||
require.Nil(t, err, "could not make http request")
|
||||
authorization := req.request.Header.Get("Authorization")
|
||||
require.Equal(t, "Basic YWRtaW46YWRtaW4=", authorization, "could not get correct authorization headers from raw")
|
||||
|
||||
inputData, payloads, _ = generator.nextValue()
|
||||
req, err = generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{})
|
||||
req, err = generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
|
||||
require.Nil(t, err, "could not make http request")
|
||||
authorization = req.request.Header.Get("Authorization")
|
||||
require.Equal(t, "Basic YWRtaW46Z3Vlc3Q=", authorization, "could not get correct authorization headers from raw")
|
||||
@ -195,7 +195,7 @@ func TestMakeRequestFromModelUniqueInteractsh(t *testing.T) {
|
||||
require.Nil(t, err, "could not create interactsh client")
|
||||
|
||||
inputData, payloads, _ := generator.nextValue()
|
||||
got, err := generator.Make(context.Background(), contextargs.NewWithInput("https://example.com"), inputData, payloads, map[string]interface{}{})
|
||||
got, err := generator.Make(context.Background(), contextargs.NewWithInput(context.Background(), "https://example.com"), inputData, payloads, map[string]interface{}{})
|
||||
require.Nil(t, err, "could not make http request")
|
||||
|
||||
// check if all the interactsh markers are replaced with unique urls
|
||||
|
||||
@ -35,8 +35,18 @@ var (
|
||||
forceMaxRedirects int
|
||||
normalClient *retryablehttp.Client
|
||||
clientPool *mapsutil.SyncLockMap[string, *retryablehttp.Client]
|
||||
// ResponseHeaderTimeout is the timeout for response headers
|
||||
// to be read from the server (this prevents infinite hang started by server if any)
|
||||
ResponseHeaderTimeout = time.Duration(5) * time.Second
|
||||
// HttpTimeoutMultiplier is the multiplier for the http timeout
|
||||
HttpTimeoutMultiplier = 3
|
||||
)
|
||||
|
||||
// GetHttpTimeout returns the http timeout for the client
|
||||
func GetHttpTimeout(opts *types.Options) time.Duration {
|
||||
return time.Duration(opts.Timeout*HttpTimeoutMultiplier) * time.Second
|
||||
}
|
||||
|
||||
// Init initializes the clientpool implementation
|
||||
func Init(options *types.Options) error {
|
||||
// Don't create clients if already created in the past.
|
||||
@ -139,7 +149,7 @@ func GetRawHTTP(options *types.Options) *rawhttp.Client {
|
||||
} else if Dialer != nil {
|
||||
rawHttpOptions.FastDialer = Dialer
|
||||
}
|
||||
rawHttpOptions.Timeout = time.Duration(options.Timeout) * time.Second
|
||||
rawHttpOptions.Timeout = GetHttpTimeout(options)
|
||||
rawHttpClient = rawhttp.NewClient(rawHttpOptions)
|
||||
}
|
||||
return rawHttpClient
|
||||
@ -237,11 +247,12 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl
|
||||
}
|
||||
return Dialer.DialTLS(ctx, network, addr)
|
||||
},
|
||||
MaxIdleConns: maxIdleConns,
|
||||
MaxIdleConnsPerHost: maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: maxConnsPerHost,
|
||||
TLSClientConfig: tlsConfig,
|
||||
DisableKeepAlives: disableKeepAlives,
|
||||
MaxIdleConns: maxIdleConns,
|
||||
MaxIdleConnsPerHost: maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: maxConnsPerHost,
|
||||
TLSClientConfig: tlsConfig,
|
||||
DisableKeepAlives: disableKeepAlives,
|
||||
ResponseHeaderTimeout: ResponseHeaderTimeout,
|
||||
}
|
||||
|
||||
if types.ProxyURL != "" {
|
||||
@ -288,7 +299,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl
|
||||
CheckRedirect: makeCheckRedirectFunc(redirectFlow, maxRedirects),
|
||||
}
|
||||
if !configuration.NoTimeout {
|
||||
httpclient.Timeout = time.Duration(options.Timeout) * time.Second
|
||||
httpclient.Timeout = GetHttpTimeout(options)
|
||||
}
|
||||
client := retryablehttp.NewWithHTTPClient(httpclient, retryableHttpOptions)
|
||||
if jar != nil {
|
||||
|
||||
@ -235,6 +235,12 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-input.Context().Done():
|
||||
return input.Context().Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// resize check point - nop if there are no changes
|
||||
if shouldFollowGlobal && spmHandler.Size() != request.options.Options.PayloadConcurrency {
|
||||
spmHandler.Resize(request.options.Options.PayloadConcurrency)
|
||||
@ -356,6 +362,13 @@ func (request *Request) executeTurboHTTP(input *contextargs.Context, dynamicValu
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-input.Context().Done():
|
||||
return input.Context().Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if spmHandler.FoundFirstMatch() || request.isUnresponsiveHost(input) || spmHandler.Cancelled() {
|
||||
// skip if first match is found
|
||||
break
|
||||
@ -433,8 +446,9 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
|
||||
request.options.RateLimitTake()
|
||||
|
||||
ctx := request.newContext(input)
|
||||
ctxWithTimeout, cancel := context.WithTimeout(ctx, time.Duration(request.options.Options.Timeout)*time.Second)
|
||||
ctxWithTimeout, cancel := context.WithTimeout(ctx, httpclientpool.GetHttpTimeout(request.options.Options))
|
||||
defer cancel()
|
||||
|
||||
generatedHttpRequest, err := generator.Make(ctxWithTimeout, input, data, payloads, dynamicValue)
|
||||
if err != nil {
|
||||
if err == types.ErrNoMoreRequests {
|
||||
@ -512,6 +526,13 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-input.Context().Done():
|
||||
return input.Context().Err()
|
||||
default:
|
||||
}
|
||||
|
||||
var gotErr error
|
||||
var skip bool
|
||||
if len(gotDynamicValues) > 0 {
|
||||
@ -1041,9 +1062,9 @@ func (request *Request) pruneSignatureInternalValues(maps ...map[string]interfac
|
||||
|
||||
func (request *Request) newContext(input *contextargs.Context) context.Context {
|
||||
if input.MetaInput.CustomIP != "" {
|
||||
return context.WithValue(context.Background(), fastdialer.IP, input.MetaInput.CustomIP)
|
||||
return context.WithValue(input.Context(), fastdialer.IP, input.MetaInput.CustomIP)
|
||||
}
|
||||
return context.Background()
|
||||
return input.Context()
|
||||
}
|
||||
|
||||
// markUnresponsiveHost checks if the error is a unreponsive host error and marks it
|
||||
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/projectdiscovery/fastdialer/fastdialer"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/http/httpclientpool"
|
||||
"github.com/projectdiscovery/retryablehttp-go"
|
||||
iputil "github.com/projectdiscovery/utils/ip"
|
||||
stringsutil "github.com/projectdiscovery/utils/strings"
|
||||
@ -124,7 +125,7 @@ func (r *Request) parseAnnotations(rawRequest string, request *retryablehttp.Req
|
||||
}
|
||||
} else {
|
||||
//nolint:govet // cancelled automatically by withTimeout
|
||||
ctx, overrides.cancelFunc = context.WithTimeout(context.Background(), time.Duration(r.options.Options.Timeout)*time.Second)
|
||||
ctx, overrides.cancelFunc = context.WithTimeout(context.Background(), httpclientpool.GetHttpTimeout(r.options.Options))
|
||||
request = request.Clone(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
@ -110,9 +110,21 @@ func (request *Request) executeFuzzingRule(input *contextargs.Context, previous
|
||||
func (request *Request) executeAllFuzzingRules(input *contextargs.Context, values map[string]interface{}, baseRequest *retryablehttp.Request, callback protocols.OutputEventCallback) error {
|
||||
applicable := false
|
||||
for _, rule := range request.Fuzzing {
|
||||
select {
|
||||
case <-input.Context().Done():
|
||||
return input.Context().Err()
|
||||
default:
|
||||
}
|
||||
|
||||
err := rule.Execute(&fuzz.ExecuteRuleInput{
|
||||
Input: input,
|
||||
Callback: func(gr fuzz.GeneratedRequest) bool {
|
||||
select {
|
||||
case <-input.Context().Done():
|
||||
return false
|
||||
default:
|
||||
}
|
||||
|
||||
// TODO: replace this after scanContext Refactor
|
||||
return request.executeGeneratedFuzzingRequest(gr, input, callback)
|
||||
},
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -83,7 +84,7 @@ Disallow: /c`))
|
||||
t.Run("test", func(t *testing.T) {
|
||||
metadata := make(output.InternalEvent)
|
||||
previous := make(output.InternalEvent)
|
||||
ctxArgs := contextargs.NewWithInput(ts.URL)
|
||||
ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
|
||||
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
|
||||
if event.OperatorsResult != nil && event.OperatorsResult.Matched {
|
||||
matchCount++
|
||||
@ -159,7 +160,7 @@ func TestDisableTE(t *testing.T) {
|
||||
t.Run("test", func(t *testing.T) {
|
||||
metadata := make(output.InternalEvent)
|
||||
previous := make(output.InternalEvent)
|
||||
ctxArgs := contextargs.NewWithInput(ts.URL)
|
||||
ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
|
||||
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
|
||||
if event.OperatorsResult != nil && event.OperatorsResult.Matched {
|
||||
matchCount++
|
||||
@ -172,7 +173,7 @@ func TestDisableTE(t *testing.T) {
|
||||
t.Run("test2", func(t *testing.T) {
|
||||
metadata := make(output.InternalEvent)
|
||||
previous := make(output.InternalEvent)
|
||||
ctxArgs := contextargs.NewWithInput(ts.URL)
|
||||
ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
|
||||
err := request2.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
|
||||
if event.OperatorsResult != nil && event.OperatorsResult.Matched {
|
||||
matchCount++
|
||||
@ -242,7 +243,7 @@ func TestReqURLPattern(t *testing.T) {
|
||||
t.Run("test", func(t *testing.T) {
|
||||
metadata := make(output.InternalEvent)
|
||||
previous := make(output.InternalEvent)
|
||||
ctxArgs := contextargs.NewWithInput(ts.URL)
|
||||
ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL)
|
||||
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
|
||||
if event.OperatorsResult != nil && event.OperatorsResult.Matched {
|
||||
matchCount++
|
||||
|
||||
@ -154,6 +154,7 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error {
|
||||
opts := &compiler.ExecuteOptions{
|
||||
Timeout: request.Timeout,
|
||||
Source: &request.Init,
|
||||
Context: context.Background(),
|
||||
}
|
||||
// register 'export' function to export variables from init code
|
||||
// these are saved in args and are available in pre-condition and request code
|
||||
@ -343,7 +344,7 @@ func (request *Request) ExecuteWithResults(target *contextargs.Context, dynamicV
|
||||
argsCopy.TemplateCtx = templateCtx.GetAll()
|
||||
|
||||
result, err := request.options.JsCompiler.ExecuteWithOptions(request.preConditionCompiled, argsCopy,
|
||||
&compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.PreCondition})
|
||||
&compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.PreCondition, Context: target.Context()})
|
||||
if err != nil {
|
||||
return errorutil.NewWithTag(request.TemplateID, "could not execute pre-condition: %s", err)
|
||||
}
|
||||
@ -373,6 +374,12 @@ func (request *Request) ExecuteWithResults(target *contextargs.Context, dynamicV
|
||||
return nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-input.Context().Done():
|
||||
return input.Context().Err()
|
||||
default:
|
||||
}
|
||||
|
||||
if err := request.executeRequestWithPayloads(hostPort, input, hostname, value, payloadValues, func(result *output.InternalWrappedEvent) {
|
||||
if result.OperatorsResult != nil && result.OperatorsResult.Matched {
|
||||
gotMatches = true
|
||||
@ -419,6 +426,12 @@ func (request *Request) executeRequestParallel(ctxParent context.Context, hostPo
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-input.Context().Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// resize check point - nop if there are no changes
|
||||
if shouldFollowGlobal && sg.Size != request.options.Options.PayloadConcurrency {
|
||||
sg.Resize(request.options.Options.PayloadConcurrency)
|
||||
@ -486,7 +499,7 @@ func (request *Request) executeRequestWithPayloads(hostPort string, input *conte
|
||||
}
|
||||
|
||||
results, err := request.options.JsCompiler.ExecuteWithOptions(request.scriptCompiled, argsCopy,
|
||||
&compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.Code})
|
||||
&compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.Code, Context: input.Context()})
|
||||
if err != nil {
|
||||
// shouldn't fail even if it returned error instead create a failure event
|
||||
results = compiler.ExecuteResult{"success": false, "error": err.Error()}
|
||||
|
||||
@ -139,6 +139,12 @@ func (request *Request) executeOnTarget(input *contextargs.Context, visited maps
|
||||
variables = generators.MergeMaps(variablesMap, variables, request.options.Constants)
|
||||
|
||||
for _, kv := range request.addresses {
|
||||
select {
|
||||
case <-input.Context().Done():
|
||||
return input.Context().Err()
|
||||
default:
|
||||
}
|
||||
|
||||
actualAddress := replacer.Replace(kv.address, variables)
|
||||
|
||||
if visited.Has(actualAddress) && !request.options.Options.DisableClustering {
|
||||
@ -186,6 +192,12 @@ func (request *Request) executeAddress(variables map[string]interface{}, actualA
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-input.Context().Done():
|
||||
return input.Context().Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// resize check point - nop if there are no changes
|
||||
if shouldFollowGlobal && swg.Size != request.options.Options.PayloadConcurrency {
|
||||
swg.Resize(request.options.Options.PayloadConcurrency)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package network
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@ -65,7 +66,7 @@ func TestNetworkExecuteWithResults(t *testing.T) {
|
||||
t.Run("domain-valid", func(t *testing.T) {
|
||||
metadata := make(output.InternalEvent)
|
||||
previous := make(output.InternalEvent)
|
||||
ctxArgs := contextargs.NewWithInput(parsed.Host)
|
||||
ctxArgs := contextargs.NewWithInput(context.Background(), parsed.Host)
|
||||
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
|
||||
finalEvent = event
|
||||
})
|
||||
@ -81,7 +82,7 @@ func TestNetworkExecuteWithResults(t *testing.T) {
|
||||
t.Run("invalid-port-override", func(t *testing.T) {
|
||||
metadata := make(output.InternalEvent)
|
||||
previous := make(output.InternalEvent)
|
||||
ctxArgs := contextargs.NewWithInput("127.0.0.1:11211")
|
||||
ctxArgs := contextargs.NewWithInput(context.Background(), "127.0.0.1:11211")
|
||||
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
|
||||
finalEvent = event
|
||||
})
|
||||
@ -95,7 +96,7 @@ func TestNetworkExecuteWithResults(t *testing.T) {
|
||||
t.Run("hex-to-string", func(t *testing.T) {
|
||||
metadata := make(output.InternalEvent)
|
||||
previous := make(output.InternalEvent)
|
||||
ctxArgs := contextargs.NewWithInput(parsed.Host)
|
||||
ctxArgs := contextargs.NewWithInput(context.Background(), parsed.Host)
|
||||
err := request.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) {
|
||||
finalEvent = event
|
||||
})
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package protocols
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"sync/atomic"
|
||||
|
||||
@ -173,7 +174,7 @@ func (e *ExecutorOptions) GetTemplateCtx(input *contextargs.MetaInput) *contexta
|
||||
templateCtx, ok := e.templateCtxStore.Get(scanId)
|
||||
if !ok {
|
||||
// if template context does not exist create new and add it to store and return it
|
||||
templateCtx = contextargs.New()
|
||||
templateCtx = contextargs.New(context.Background())
|
||||
templateCtx.MetaInput = input
|
||||
_ = e.templateCtxStore.Set(scanId, templateCtx)
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package ssl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -28,7 +29,7 @@ func TestSSLProtocol(t *testing.T) {
|
||||
require.Nil(t, err, "could not compile ssl request")
|
||||
|
||||
var gotEvent output.InternalEvent
|
||||
ctxArgs := contextargs.NewWithInput("scanme.sh:443")
|
||||
ctxArgs := contextargs.NewWithInput(context.Background(), "scanme.sh:443")
|
||||
err = request.ExecuteWithResults(ctxArgs, nil, nil, func(event *output.InternalWrappedEvent) {
|
||||
gotEvent = event.InternalEvent
|
||||
})
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
@ -49,7 +50,7 @@ func TestHTTPVariables(t *testing.T) {
|
||||
require.Equal(t, values["Hostname"], "foobar.com", "incorrect hostname")
|
||||
|
||||
baseURL = "http://scanme.sh"
|
||||
ctxArgs := contextargs.NewWithInput(baseURL)
|
||||
ctxArgs := contextargs.NewWithInput(context.Background(), baseURL)
|
||||
ctxArgs.MetaInput.CustomIP = "1.2.3.4"
|
||||
values = GenerateVariablesWithContextArgs(ctxArgs, true)
|
||||
|
||||
|
||||
@ -10,7 +10,6 @@ import (
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"
|
||||
)
|
||||
|
||||
|
||||
type ScanContextOption func(*ScanContext)
|
||||
|
||||
func WithEvents() ScanContextOption {
|
||||
@ -20,7 +19,8 @@ func WithEvents() ScanContextOption {
|
||||
}
|
||||
|
||||
type ScanContext struct {
|
||||
context.Context
|
||||
ctx context.Context
|
||||
|
||||
// exported / configurable fields
|
||||
Input *contextargs.Context
|
||||
|
||||
@ -43,8 +43,13 @@ type ScanContext struct {
|
||||
}
|
||||
|
||||
// NewScanContext creates a new scan context using input
|
||||
func NewScanContext(input *contextargs.Context) *ScanContext {
|
||||
return &ScanContext{Input: input}
|
||||
func NewScanContext(ctx context.Context, input *contextargs.Context) *ScanContext {
|
||||
return &ScanContext{ctx: ctx, Input: input}
|
||||
}
|
||||
|
||||
// Context returns the context of the scan
|
||||
func (s *ScanContext) Context() context.Context {
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
// GenerateResult returns final results slice from all events
|
||||
|
||||
@ -303,7 +303,7 @@ func (e *ClusterExecuter) Execute(ctx *scan.ScanContext) (bool, error) {
|
||||
|
||||
// ExecuteWithResults executes the protocol requests and returns results instead of writing them.
|
||||
func (e *ClusterExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) {
|
||||
scanCtx := scan.NewScanContext(ctx.Input)
|
||||
scanCtx := scan.NewScanContext(ctx.Context(), ctx.Input)
|
||||
dynamicValues := make(map[string]interface{})
|
||||
|
||||
inputItem := ctx.Input.Clone()
|
||||
|
||||
@ -174,6 +174,12 @@ func (f *FlowExecutor) Compile() error {
|
||||
|
||||
// ExecuteWithResults executes the flow and returns results
|
||||
func (f *FlowExecutor) ExecuteWithResults(ctx *scan.ScanContext) error {
|
||||
select {
|
||||
case <-ctx.Context().Done():
|
||||
return ctx.Context().Err()
|
||||
default:
|
||||
}
|
||||
|
||||
f.ctx.Input = ctx.Input
|
||||
// -----Load all types of variables-----
|
||||
// add all input args to template context
|
||||
|
||||
@ -55,8 +55,8 @@ func TestFlowTemplateWithIndex(t *testing.T) {
|
||||
err = Template.Executer.Compile()
|
||||
require.Nil(t, err, "could not compile template")
|
||||
|
||||
input := contextargs.NewWithInput("hackerone.com")
|
||||
ctx := scan.NewScanContext(input)
|
||||
input := contextargs.NewWithInput(context.Background(), "hackerone.com")
|
||||
ctx := scan.NewScanContext(context.Background(), input)
|
||||
gotresults, err := Template.Executer.Execute(ctx)
|
||||
require.Nil(t, err, "could not execute template")
|
||||
require.True(t, gotresults)
|
||||
@ -74,8 +74,8 @@ func TestFlowTemplateWithID(t *testing.T) {
|
||||
err = Template.Executer.Compile()
|
||||
require.Nil(t, err, "could not compile template")
|
||||
|
||||
target := contextargs.NewWithInput("hackerone.com")
|
||||
ctx := scan.NewScanContext(target)
|
||||
target := contextargs.NewWithInput(context.Background(), "hackerone.com")
|
||||
ctx := scan.NewScanContext(context.Background(), target)
|
||||
gotresults, err := Template.Executer.Execute(ctx)
|
||||
require.Nil(t, err, "could not execute template")
|
||||
require.True(t, gotresults)
|
||||
@ -96,8 +96,8 @@ func TestFlowWithProtoPrefix(t *testing.T) {
|
||||
err = Template.Executer.Compile()
|
||||
require.Nil(t, err, "could not compile template")
|
||||
|
||||
input := contextargs.NewWithInput("hackerone.com")
|
||||
ctx := scan.NewScanContext(input)
|
||||
input := contextargs.NewWithInput(context.Background(), "hackerone.com")
|
||||
ctx := scan.NewScanContext(context.Background(), input)
|
||||
gotresults, err := Template.Executer.Execute(ctx)
|
||||
require.Nil(t, err, "could not execute template")
|
||||
require.True(t, gotresults)
|
||||
@ -116,8 +116,8 @@ func TestFlowWithConditionNegative(t *testing.T) {
|
||||
err = Template.Executer.Compile()
|
||||
require.Nil(t, err, "could not compile template")
|
||||
|
||||
input := contextargs.NewWithInput("scanme.sh")
|
||||
ctx := scan.NewScanContext(input)
|
||||
input := contextargs.NewWithInput(context.Background(), "scanme.sh")
|
||||
ctx := scan.NewScanContext(context.Background(), input)
|
||||
// expect no results and verify thant dns request is executed and http is not
|
||||
gotresults, err := Template.Executer.Execute(ctx)
|
||||
require.Nil(t, err, "could not execute template")
|
||||
@ -137,8 +137,8 @@ func TestFlowWithConditionPositive(t *testing.T) {
|
||||
err = Template.Executer.Compile()
|
||||
require.Nil(t, err, "could not compile template")
|
||||
|
||||
input := contextargs.NewWithInput("blog.projectdiscovery.io")
|
||||
ctx := scan.NewScanContext(input)
|
||||
input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io")
|
||||
ctx := scan.NewScanContext(context.Background(), input)
|
||||
// positive match . expect results also verify that both dns() and http() were executed
|
||||
gotresults, err := Template.Executer.Execute(ctx)
|
||||
require.Nil(t, err, "could not execute template")
|
||||
@ -158,8 +158,8 @@ func TestFlowWithNoMatchers(t *testing.T) {
|
||||
err = Template.Executer.Compile()
|
||||
require.Nil(t, err, "could not compile template")
|
||||
|
||||
input := contextargs.NewWithInput("blog.projectdiscovery.io")
|
||||
ctx := scan.NewScanContext(input)
|
||||
input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io")
|
||||
ctx := scan.NewScanContext(context.Background(), input)
|
||||
// positive match . expect results also verify that both dns() and http() were executed
|
||||
gotresults, err := Template.Executer.Execute(ctx)
|
||||
require.Nil(t, err, "could not execute template")
|
||||
@ -174,8 +174,8 @@ func TestFlowWithNoMatchers(t *testing.T) {
|
||||
err = Template.Executer.Compile()
|
||||
require.Nil(t, err, "could not compile template")
|
||||
|
||||
anotherInput := contextargs.NewWithInput("blog.projectdiscovery.io")
|
||||
anotherCtx := scan.NewScanContext(anotherInput)
|
||||
anotherInput := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io")
|
||||
anotherCtx := scan.NewScanContext(context.Background(), anotherInput)
|
||||
// positive match . expect results also verify that both dns() and http() were executed
|
||||
gotresults, err = Template.Executer.Execute(anotherCtx)
|
||||
require.Nil(t, err, "could not execute template")
|
||||
|
||||
@ -45,6 +45,12 @@ func (g *Generic) ExecuteWithResults(ctx *scan.ScanContext) error {
|
||||
previous := mapsutil.NewSyncLockMap[string, any]()
|
||||
|
||||
for _, req := range g.requests {
|
||||
select {
|
||||
case <-ctx.Context().Done():
|
||||
return ctx.Context().Err()
|
||||
default:
|
||||
}
|
||||
|
||||
inputItem := ctx.Input.Clone()
|
||||
if g.options.InputHelper != nil && ctx.Input.MetaInput.Input != "" {
|
||||
if inputItem.MetaInput.Input = g.options.InputHelper.Transform(inputItem.MetaInput.Input, req.Type()); inputItem.MetaInput.Input == "" {
|
||||
|
||||
@ -44,6 +44,12 @@ func (m *MultiProtocol) Compile() error {
|
||||
|
||||
// ExecuteWithResults executes the template and returns results
|
||||
func (m *MultiProtocol) ExecuteWithResults(ctx *scan.ScanContext) error {
|
||||
select {
|
||||
case <-ctx.Context().Done():
|
||||
return ctx.Context().Err()
|
||||
default:
|
||||
}
|
||||
|
||||
// put all readonly args into template context
|
||||
m.options.GetTemplateCtx(ctx.Input.MetaInput).Merge(m.readOnlyArgs)
|
||||
|
||||
@ -96,6 +102,12 @@ func (m *MultiProtocol) ExecuteWithResults(ctx *scan.ScanContext) error {
|
||||
|
||||
// execute all protocols in the queue
|
||||
for _, req := range m.requests {
|
||||
select {
|
||||
case <-ctx.Context().Done():
|
||||
return ctx.Context().Err()
|
||||
default:
|
||||
}
|
||||
|
||||
values := m.options.GetTemplateCtx(ctx.Input.MetaInput).GetAll()
|
||||
err := req.ExecuteWithResults(ctx.Input, output.InternalEvent(values), nil, multiProtoCallback)
|
||||
// if error skip execution of next protocols
|
||||
|
||||
@ -54,8 +54,8 @@ func TestMultiProtoWithDynamicExtractor(t *testing.T) {
|
||||
err = Template.Executer.Compile()
|
||||
require.Nil(t, err, "could not compile template")
|
||||
|
||||
input := contextargs.NewWithInput("blog.projectdiscovery.io")
|
||||
ctx := scan.NewScanContext(input)
|
||||
input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io")
|
||||
ctx := scan.NewScanContext(context.Background(), input)
|
||||
gotresults, err := Template.Executer.Execute(ctx)
|
||||
require.Nil(t, err, "could not execute template")
|
||||
require.True(t, gotresults)
|
||||
@ -71,8 +71,8 @@ func TestMultiProtoWithProtoPrefix(t *testing.T) {
|
||||
err = Template.Executer.Compile()
|
||||
require.Nil(t, err, "could not compile template")
|
||||
|
||||
input := contextargs.NewWithInput("blog.projectdiscovery.io")
|
||||
ctx := scan.NewScanContext(input)
|
||||
input := contextargs.NewWithInput(context.Background(), "blog.projectdiscovery.io")
|
||||
ctx := scan.NewScanContext(context.Background(), input)
|
||||
gotresults, err := Template.Executer.Execute(ctx)
|
||||
require.Nil(t, err, "could not execute template")
|
||||
require.True(t, gotresults)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user