From c4fa2c74c1e5831a510d9ae8230ac439b7ebbb05 Mon Sep 17 00:00:00 2001 From: Nakul Bharti Date: Mon, 15 Sep 2025 23:48:02 +0530 Subject: [PATCH] cache, goroutine and unbounded workers management (#6420) * Enhance matcher compilation with caching for regex and DSL expressions to improve performance. Update template parsing to conditionally retain raw templates based on size constraints. * Implement caching for regex and DSL expressions in extractors and matchers to enhance performance. Introduce a buffer pool in raw requests to reduce memory allocations. Update template cache management for improved efficiency. * feat: improve concurrency to be bound * refactor: replace fmt.Sprintf with fmt.Fprintf for improved performance in header handling * feat: add regex matching tests and benchmarks for performance evaluation * feat: add prefix check in regex extraction to optimize matching process * feat: implement regex caching mechanism to enhance performance in extractors and matchers, along with tests and benchmarks for validation * feat: add unit tests for template execution in the core engine, enhancing test coverage and reliability * feat: enhance error handling in template execution and improve regex caching logic for better performance * Implement caching for regex and DSL expressions in the cache package, replacing previous sync.Map usage. Add unit tests for cache functionality, including eviction by capacity and retrieval of cached items. Update extractors and matchers to utilize the new cache system for improved performance and memory efficiency. * Add tests for SetCapacities in cache package to ensure cache behavior on capacity changes - Implemented TestSetCapacities_NoRebuildOnZero to verify that setting capacities to zero does not clear existing caches. - Added TestSetCapacities_BeforeFirstUse to confirm that initial cache settings are respected and not overridden by subsequent capacity changes. * Refactor matchers and update load test generator to use io package - Removed maxRegexScanBytes constant from match.go. - Replaced ioutil with io package in load_test.go for NopCloser usage. - Restored TestValidate_AllowsInlineMultiline in load_test.go to ensure inline validation functionality. * Add cancellation support in template execution and enhance test coverage - Updated executeTemplateWithTargets to respect context cancellation. - Introduced fakeTargetProvider and slowExecuter for testing. - Added Test_executeTemplateWithTargets_RespectsCancellation to validate cancellation behavior during template execution. --- pkg/core/executors.go | 129 +++++++----- pkg/core/executors_test.go | 148 +++++++++++++ pkg/operators/cache/cache.go | 62 ++++++ pkg/operators/cache/cache_test.go | 114 ++++++++++ pkg/operators/extractors/compile.go | 11 + pkg/operators/extractors/extract.go | 14 +- pkg/operators/matchers/compile.go | 16 +- pkg/operators/matchers/match.go | 38 +++- pkg/operators/matchers/match_test.go | 65 +++++- .../common/generators/attack_types_test.go | 26 +++ pkg/protocols/common/generators/env_test.go | 38 ++++ pkg/protocols/common/generators/load.go | 14 ++ pkg/protocols/common/generators/load_test.go | 198 ++++++++---------- pkg/protocols/common/generators/maps_test.go | 29 +++ pkg/protocols/common/generators/validate.go | 6 +- pkg/protocols/http/build_request.go | 5 +- pkg/protocols/http/raw/raw.go | 19 +- pkg/protocols/http/raw/raw_test.go | 15 ++ pkg/protocols/http/request.go | 81 +++++-- pkg/protocols/http/request_test.go | 114 ++++++++++ pkg/protocols/javascript/js.go | 5 +- pkg/reporting/reporting.go | 10 +- pkg/templates/cache.go | 30 ++- pkg/templates/templates.go | 8 +- pkg/templates/templates_test.go | 26 +++ 25 files changed, 1001 insertions(+), 220 deletions(-) create mode 100644 pkg/core/executors_test.go create mode 100644 pkg/operators/cache/cache.go create mode 100644 pkg/operators/cache/cache_test.go create mode 100644 pkg/protocols/common/generators/attack_types_test.go create mode 100644 pkg/protocols/common/generators/env_test.go diff --git a/pkg/core/executors.go b/pkg/core/executors.go index 8abcc9d78..aeb85ddfe 100644 --- a/pkg/core/executors.go +++ b/pkg/core/executors.go @@ -48,8 +48,15 @@ func (e *Engine) executeAllSelfContained(ctx context.Context, alltemplates []*te // executeTemplateWithTargets executes a given template on x targets (with a internal targetpool(i.e concurrency)) func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templates.Template, target provider.InputProvider, results *atomic.Bool) { - // this is target pool i.e max target to execute - wg := e.workPool.InputPool(template.Type()) + if e.workPool == nil { + e.workPool = e.GetWorkPool() + } + // Bounded worker pool using input concurrency + pool := e.workPool.InputPool(template.Type()) + workerCount := 1 + if pool != nil && pool.Size > 0 { + workerCount = pool.Size + } var ( index uint32 @@ -78,6 +85,41 @@ func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templ currentInfo.Unlock() } + // task represents a single target execution unit + type task struct { + index uint32 + skip bool + value *contextargs.MetaInput + } + + tasks := make(chan task) + var workersWg sync.WaitGroup + workersWg.Add(workerCount) + for i := 0; i < workerCount; i++ { + go func() { + defer workersWg.Done() + for t := range tasks { + func() { + defer cleanupInFlight(t.index) + select { + case <-ctx.Done(): + return + default: + } + if t.skip { + return + } + + match, err := e.executeTemplateOnInput(ctx, template, t.value) + if err != nil { + e.options.Logger.Warning().Msgf("[%s] Could not execute step on %s: %s\n", e.executerOpts.Colorizer.BrightBlue(template.ID), t.value.Input, err) + } + results.CompareAndSwap(false, match) + }() + } + }() + } + target.Iterate(func(scannedValue *contextargs.MetaInput) bool { select { case <-ctx.Done(): @@ -128,43 +170,13 @@ func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templ return true } - wg.Add() - go func(index uint32, skip bool, value *contextargs.MetaInput) { - defer wg.Done() - defer cleanupInFlight(index) - if skip { - return - } - - var match bool - var err error - ctxArgs := contextargs.New(ctx) - ctxArgs.MetaInput = value - ctx := scan.NewScanContext(ctx, ctxArgs) - switch template.Type() { - case types.WorkflowProtocol: - match = e.executeWorkflow(ctx, template.CompiledWorkflow) - default: - if e.Callback != nil { - if results, err := template.Executer.ExecuteWithResults(ctx); err == nil { - for _, result := range results { - e.Callback(result) - } - } - match = true - } else { - match, err = template.Executer.Execute(ctx) - } - } - if err != nil { - e.options.Logger.Warning().Msgf("[%s] Could not execute step on %s: %s\n", e.executerOpts.Colorizer.BrightBlue(template.ID), value.Input, err) - } - results.CompareAndSwap(false, match) - }(index, skip, scannedValue) + tasks <- task{index: index, skip: skip, value: scannedValue} index++ return true }) - wg.Wait() + + close(tasks) + workersWg.Wait() // on completion marks the template as completed currentInfo.Lock() @@ -202,26 +214,7 @@ func (e *Engine) executeTemplatesOnTarget(ctx context.Context, alltemplates []*t go func(template *templates.Template, value *contextargs.MetaInput, wg *syncutil.AdaptiveWaitGroup) { defer wg.Done() - var match bool - var err error - ctxArgs := contextargs.New(ctx) - ctxArgs.MetaInput = value - ctx := scan.NewScanContext(ctx, ctxArgs) - switch template.Type() { - case types.WorkflowProtocol: - match = e.executeWorkflow(ctx, template.CompiledWorkflow) - default: - if e.Callback != nil { - if results, err := template.Executer.ExecuteWithResults(ctx); err == nil { - for _, result := range results { - e.Callback(result) - } - } - match = true - } else { - match, err = template.Executer.Execute(ctx) - } - } + match, err := e.executeTemplateOnInput(ctx, template, value) if err != nil { e.options.Logger.Warning().Msgf("[%s] Could not execute step on %s: %s\n", e.executerOpts.Colorizer.BrightBlue(template.ID), value.Input, err) } @@ -229,3 +222,27 @@ func (e *Engine) executeTemplatesOnTarget(ctx context.Context, alltemplates []*t }(tpl, target, sg) } } + +// executeTemplateOnInput performs template execution for a single input and returns match status and error +func (e *Engine) executeTemplateOnInput(ctx context.Context, template *templates.Template, value *contextargs.MetaInput) (bool, error) { + ctxArgs := contextargs.New(ctx) + ctxArgs.MetaInput = value + scanCtx := scan.NewScanContext(ctx, ctxArgs) + + switch template.Type() { + case types.WorkflowProtocol: + return e.executeWorkflow(scanCtx, template.CompiledWorkflow), nil + default: + if e.Callback != nil { + results, err := template.Executer.ExecuteWithResults(scanCtx) + if err != nil { + return false, err + } + for _, result := range results { + e.Callback(result) + } + return len(results) > 0, nil + } + return template.Executer.Execute(scanCtx) + } +} diff --git a/pkg/core/executors_test.go b/pkg/core/executors_test.go new file mode 100644 index 000000000..394b2e6d9 --- /dev/null +++ b/pkg/core/executors_test.go @@ -0,0 +1,148 @@ +package core + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + "time" + + inputtypes "github.com/projectdiscovery/nuclei/v3/pkg/input/types" + "github.com/projectdiscovery/nuclei/v3/pkg/output" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs" + "github.com/projectdiscovery/nuclei/v3/pkg/scan" + "github.com/projectdiscovery/nuclei/v3/pkg/templates" + tmpltypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types" + "github.com/projectdiscovery/nuclei/v3/pkg/types" +) + +// fakeExecuter is a simple stub for protocols.Executer used to test executeTemplateOnInput +type fakeExecuter struct { + withResults bool +} + +func (f *fakeExecuter) Compile() error { return nil } +func (f *fakeExecuter) Requests() int { return 1 } +func (f *fakeExecuter) Execute(ctx *scan.ScanContext) (bool, error) { return !f.withResults, nil } +func (f *fakeExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) { + if !f.withResults { + return nil, nil + } + return []*output.ResultEvent{{Host: "h"}}, nil +} + +// newTestEngine creates a minimal Engine for tests +func newTestEngine() *Engine { + return New(&types.Options{}) +} + +func Test_executeTemplateOnInput_CallbackPath(t *testing.T) { + e := newTestEngine() + called := 0 + e.Callback = func(*output.ResultEvent) { called++ } + + tpl := &templates.Template{} + tpl.Executer = &fakeExecuter{withResults: true} + + ok, err := e.executeTemplateOnInput(context.Background(), tpl, &contextargs.MetaInput{Input: "x"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ok { + t.Fatalf("expected match true") + } + if called == 0 { + t.Fatalf("expected callback to be called") + } +} + +func Test_executeTemplateOnInput_ExecutePath(t *testing.T) { + e := newTestEngine() + tpl := &templates.Template{} + tpl.Executer = &fakeExecuter{withResults: false} + + ok, err := e.executeTemplateOnInput(context.Background(), tpl, &contextargs.MetaInput{Input: "x"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ok { + t.Fatalf("expected match true from Execute path") + } +} + +type fakeExecuterErr struct{} + +func (f *fakeExecuterErr) Compile() error { return nil } +func (f *fakeExecuterErr) Requests() int { return 1 } +func (f *fakeExecuterErr) Execute(ctx *scan.ScanContext) (bool, error) { return false, nil } +func (f *fakeExecuterErr) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) { + return nil, fmt.Errorf("boom") +} + +func Test_executeTemplateOnInput_CallbackErrorPropagates(t *testing.T) { + e := newTestEngine() + e.Callback = func(*output.ResultEvent) {} + tpl := &templates.Template{} + tpl.Executer = &fakeExecuterErr{} + + ok, err := e.executeTemplateOnInput(context.Background(), tpl, &contextargs.MetaInput{Input: "x"}) + if err == nil { + t.Fatalf("expected error to propagate") + } + if ok { + t.Fatalf("expected match to be false on error") + } +} + +type fakeTargetProvider struct { + values []*contextargs.MetaInput +} + +func (f *fakeTargetProvider) Count() int64 { return int64(len(f.values)) } +func (f *fakeTargetProvider) Iterate(cb func(value *contextargs.MetaInput) bool) { + for _, v := range f.values { + if !cb(v) { + return + } + } +} +func (f *fakeTargetProvider) Set(string, string) {} +func (f *fakeTargetProvider) SetWithProbe(string, string, inputtypes.InputLivenessProbe) error { + return nil +} +func (f *fakeTargetProvider) SetWithExclusions(string, string) error { return nil } +func (f *fakeTargetProvider) InputType() string { return "test" } +func (f *fakeTargetProvider) Close() {} + +type slowExecuter struct{} + +func (s *slowExecuter) Compile() error { return nil } +func (s *slowExecuter) Requests() int { return 1 } +func (s *slowExecuter) Execute(ctx *scan.ScanContext) (bool, error) { + select { + case <-ctx.Context().Done(): + return false, ctx.Context().Err() + case <-time.After(200 * time.Millisecond): + return true, nil + } +} +func (s *slowExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) { + return nil, nil +} + +func Test_executeTemplateWithTargets_RespectsCancellation(t *testing.T) { + e := newTestEngine() + e.SetExecuterOptions(&protocols.ExecutorOptions{Logger: e.Logger, ResumeCfg: types.NewResumeCfg(), ProtocolType: tmpltypes.HTTPProtocol}) + + tpl := &templates.Template{} + tpl.Executer = &slowExecuter{} + + targets := &fakeTargetProvider{values: []*contextargs.MetaInput{{Input: "a"}, {Input: "b"}, {Input: "c"}}} + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + var matched atomic.Bool + e.executeTemplateWithTargets(ctx, tpl, targets, &matched) +} diff --git a/pkg/operators/cache/cache.go b/pkg/operators/cache/cache.go new file mode 100644 index 000000000..ca486097c --- /dev/null +++ b/pkg/operators/cache/cache.go @@ -0,0 +1,62 @@ +package cache + +import ( + "regexp" + "sync" + + "github.com/Knetic/govaluate" + "github.com/projectdiscovery/gcache" +) + +var ( + initOnce sync.Once + mu sync.RWMutex + + regexCap = 4096 + dslCap = 4096 + + regexCache gcache.Cache[string, *regexp.Regexp] + dslCache gcache.Cache[string, *govaluate.EvaluableExpression] +) + +func initCaches() { + initOnce.Do(func() { + regexCache = gcache.New[string, *regexp.Regexp](regexCap).LRU().Build() + dslCache = gcache.New[string, *govaluate.EvaluableExpression](dslCap).LRU().Build() + }) +} + +func SetCapacities(regexCapacity, dslCapacity int) { + // ensure caches are initialized under initOnce, so later Regex()/DSL() won't re-init + initCaches() + + mu.Lock() + defer mu.Unlock() + + if regexCapacity > 0 { + regexCap = regexCapacity + } + if dslCapacity > 0 { + dslCap = dslCapacity + } + if regexCapacity <= 0 && dslCapacity <= 0 { + return + } + // rebuild caches with new capacities + regexCache = gcache.New[string, *regexp.Regexp](regexCap).LRU().Build() + dslCache = gcache.New[string, *govaluate.EvaluableExpression](dslCap).LRU().Build() +} + +func Regex() gcache.Cache[string, *regexp.Regexp] { + initCaches() + mu.RLock() + defer mu.RUnlock() + return regexCache +} + +func DSL() gcache.Cache[string, *govaluate.EvaluableExpression] { + initCaches() + mu.RLock() + defer mu.RUnlock() + return dslCache +} diff --git a/pkg/operators/cache/cache_test.go b/pkg/operators/cache/cache_test.go new file mode 100644 index 000000000..c44b72c84 --- /dev/null +++ b/pkg/operators/cache/cache_test.go @@ -0,0 +1,114 @@ +package cache + +import ( + "regexp" + "testing" + + "github.com/Knetic/govaluate" +) + +func TestRegexCache_SetGet(t *testing.T) { + // ensure init + c := Regex() + pattern := "abc(\n)?123" + re, err := regexp.Compile(pattern) + if err != nil { + t.Fatalf("compile: %v", err) + } + if err := c.Set(pattern, re); err != nil { + t.Fatalf("set: %v", err) + } + got, err := c.GetIFPresent(pattern) + if err != nil || got == nil { + t.Fatalf("get: %v got=%v", err, got) + } + if got.String() != re.String() { + t.Fatalf("mismatch: %s != %s", got.String(), re.String()) + } +} + +func TestDSLCache_SetGet(t *testing.T) { + c := DSL() + expr := "1 + 2 == 3" + ast, err := govaluate.NewEvaluableExpression(expr) + if err != nil { + t.Fatalf("dsl compile: %v", err) + } + if err := c.Set(expr, ast); err != nil { + t.Fatalf("set: %v", err) + } + got, err := c.GetIFPresent(expr) + if err != nil || got == nil { + t.Fatalf("get: %v got=%v", err, got) + } + if got.String() != ast.String() { + t.Fatalf("mismatch: %s != %s", got.String(), ast.String()) + } +} + +func TestRegexCache_EvictionByCapacity(t *testing.T) { + SetCapacities(3, 3) + c := Regex() + for i := 0; i < 5; i++ { + k := string(rune('a' + i)) + re := regexp.MustCompile(k) + _ = c.Set(k, re) + } + // last 3 keys expected to remain under LRU: 'c','d','e' + if _, err := c.GetIFPresent("a"); err == nil { + t.Fatalf("expected 'a' to be evicted") + } + if _, err := c.GetIFPresent("b"); err == nil { + t.Fatalf("expected 'b' to be evicted") + } + if _, err := c.GetIFPresent("c"); err != nil { + t.Fatalf("expected 'c' present") + } +} + +func TestSetCapacities_NoRebuildOnZero(t *testing.T) { + // init + SetCapacities(4, 4) + c1 := Regex() + _ = c1.Set("k", regexp.MustCompile("k")) + if _, err := c1.GetIFPresent("k"); err != nil { + t.Fatalf("expected key present: %v", err) + } + // zero changes should not rebuild/clear caches + SetCapacities(0, 0) + c2 := Regex() + if _, err := c2.GetIFPresent("k"); err != nil { + t.Fatalf("key lost after zero-capacity SetCapacities: %v", err) + } +} + +func TestSetCapacities_BeforeFirstUse(t *testing.T) { + // This should not be overridden by later initCaches + SetCapacities(2, 0) + c := Regex() + _ = c.Set("a", regexp.MustCompile("a")) + _ = c.Set("b", regexp.MustCompile("b")) + _ = c.Set("c", regexp.MustCompile("c")) + if _, err := c.GetIFPresent("a"); err == nil { + t.Fatalf("expected 'a' to be evicted under cap=2") + } +} + +func TestSetCapacities_ConcurrentAccess(t *testing.T) { + SetCapacities(64, 64) + stop := make(chan struct{}) + + go func() { + for i := 0; i < 5000; i++ { + _ = Regex().Set("k"+string(rune('a'+(i%26))), regexp.MustCompile("a")) + _, _ = Regex().GetIFPresent("k" + string(rune('a'+(i%26)))) + _, _ = DSL().GetIFPresent("1+2==3") + } + close(stop) + }() + + for i := 0; i < 200; i++ { + SetCapacities(64+(i%5), 64+((i+1)%5)) + } + <-stop +} diff --git a/pkg/operators/extractors/compile.go b/pkg/operators/extractors/compile.go index 2b55d374a..bcfd37eeb 100644 --- a/pkg/operators/extractors/compile.go +++ b/pkg/operators/extractors/compile.go @@ -7,6 +7,7 @@ import ( "github.com/Knetic/govaluate" "github.com/itchyny/gojq" + "github.com/projectdiscovery/nuclei/v3/pkg/operators/cache" "github.com/projectdiscovery/nuclei/v3/pkg/operators/common/dsl" ) @@ -20,10 +21,15 @@ func (e *Extractor) CompileExtractors() error { e.extractorType = computedType // Compile the regexes for _, regex := range e.Regex { + if cached, err := cache.Regex().GetIFPresent(regex); err == nil && cached != nil { + e.regexCompiled = append(e.regexCompiled, cached) + continue + } compiled, err := regexp.Compile(regex) if err != nil { return fmt.Errorf("could not compile regex: %s", regex) } + _ = cache.Regex().Set(regex, compiled) e.regexCompiled = append(e.regexCompiled, compiled) } for i, kval := range e.KVal { @@ -43,10 +49,15 @@ func (e *Extractor) CompileExtractors() error { } for _, dslExp := range e.DSL { + if cached, err := cache.DSL().GetIFPresent(dslExp); err == nil && cached != nil { + e.dslCompiled = append(e.dslCompiled, cached) + continue + } compiled, err := govaluate.NewEvaluableExpressionWithFunctions(dslExp, dsl.HelperFunctions) if err != nil { return &dsl.CompilationError{DslSignature: dslExp, WrappedError: err} } + _ = cache.DSL().Set(dslExp, compiled) e.dslCompiled = append(e.dslCompiled, compiled) } diff --git a/pkg/operators/extractors/extract.go b/pkg/operators/extractors/extract.go index 194b4648f..1a8ca63b6 100644 --- a/pkg/operators/extractors/extract.go +++ b/pkg/operators/extractors/extract.go @@ -17,9 +17,19 @@ func (e *Extractor) ExtractRegex(corpus string) map[string]struct{} { groupPlusOne := e.RegexGroup + 1 for _, regex := range e.regexCompiled { - matches := regex.FindAllStringSubmatch(corpus, -1) + // skip prefix short-circuit for case-insensitive patterns + rstr := regex.String() + if !strings.Contains(rstr, "(?i") { + if prefix, ok := regex.LiteralPrefix(); ok && prefix != "" { + if !strings.Contains(corpus, prefix) { + continue + } + } + } - for _, match := range matches { + submatches := regex.FindAllStringSubmatch(corpus, -1) + + for _, match := range submatches { if len(match) < groupPlusOne { continue } diff --git a/pkg/operators/matchers/compile.go b/pkg/operators/matchers/compile.go index 5a99347c5..4ae72be60 100644 --- a/pkg/operators/matchers/compile.go +++ b/pkg/operators/matchers/compile.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/Knetic/govaluate" - + "github.com/projectdiscovery/nuclei/v3/pkg/operators/cache" "github.com/projectdiscovery/nuclei/v3/pkg/operators/common/dsl" ) @@ -42,12 +42,17 @@ func (matcher *Matcher) CompileMatchers() error { matcher.Part = "body" } - // Compile the regexes + // Compile the regexes (with shared cache) for _, regex := range matcher.Regex { + if cached, err := cache.Regex().GetIFPresent(regex); err == nil && cached != nil { + matcher.regexCompiled = append(matcher.regexCompiled, cached) + continue + } compiled, err := regexp.Compile(regex) if err != nil { return fmt.Errorf("could not compile regex: %s", regex) } + _ = cache.Regex().Set(regex, compiled) matcher.regexCompiled = append(matcher.regexCompiled, compiled) } @@ -60,12 +65,17 @@ func (matcher *Matcher) CompileMatchers() error { } } - // Compile the dsl expressions + // Compile the dsl expressions (with shared cache) for _, dslExpression := range matcher.DSL { + if cached, err := cache.DSL().GetIFPresent(dslExpression); err == nil && cached != nil { + matcher.dslCompiled = append(matcher.dslCompiled, cached) + continue + } compiledExpression, err := govaluate.NewEvaluableExpressionWithFunctions(dslExpression, dsl.HelperFunctions) if err != nil { return &dsl.CompilationError{DslSignature: dslExpression, WrappedError: err} } + _ = cache.DSL().Set(dslExpression, compiledExpression) matcher.dslCompiled = append(matcher.dslCompiled, compiledExpression) } diff --git a/pkg/operators/matchers/match.go b/pkg/operators/matchers/match.go index 78bd40175..a914b5f98 100644 --- a/pkg/operators/matchers/match.go +++ b/pkg/operators/matchers/match.go @@ -106,10 +106,33 @@ func (matcher *Matcher) MatchRegex(corpus string) (bool, []string) { var matchedRegexes []string // Iterate over all the regexes accepted as valid for i, regex := range matcher.regexCompiled { - // Continue if the regex doesn't match - if !regex.MatchString(corpus) { - // If we are in an AND request and a match failed, - // return false as the AND condition fails on any single mismatch. + // Literal prefix short-circuit + rstr := regex.String() + if !strings.Contains(rstr, "(?i") { // covers (?i) and (?i: + if prefix, ok := regex.LiteralPrefix(); ok && prefix != "" { + if !strings.Contains(corpus, prefix) { + switch matcher.condition { + case ANDCondition: + return false, []string{} + case ORCondition: + continue + } + } + } + } + + // Fast OR-path: return first match without full scan + if matcher.condition == ORCondition && !matcher.MatchAll { + m := regex.FindAllString(corpus, 1) + if len(m) == 0 { + continue + } + return true, m + } + + // Single scan: get all matches directly + currentMatches := regex.FindAllString(corpus, -1) + if len(currentMatches) == 0 { switch matcher.condition { case ANDCondition: return false, []string{} @@ -118,12 +141,7 @@ func (matcher *Matcher) MatchRegex(corpus string) (bool, []string) { } } - currentMatches := regex.FindAllString(corpus, -1) - // If the condition was an OR, return on the first match. - if matcher.condition == ORCondition && !matcher.MatchAll { - return true, currentMatches - } - + // If the condition was an OR (and MatchAll true), we still need to gather all matchedRegexes = append(matchedRegexes, currentMatches...) // If we are at the end of the regex, return with true diff --git a/pkg/operators/matchers/match_test.go b/pkg/operators/matchers/match_test.go index ea6258ae0..8a073318b 100644 --- a/pkg/operators/matchers/match_test.go +++ b/pkg/operators/matchers/match_test.go @@ -84,7 +84,7 @@ func TestMatcher_MatchDSL(t *testing.T) { values := []string{"PING", "pong"} - for value := range values { + for _, value := range values { isMatched := m.MatchDSL(map[string]interface{}{"body": value, "VARIABLE": value}) require.True(t, isMatched) } @@ -209,3 +209,66 @@ func TestMatcher_MatchXPath_XML(t *testing.T) { isMatched = m.MatchXPath("

not right notvalid") require.False(t, isMatched, "Invalid xpath did not return false") } + +func TestMatchRegex_CaseInsensitivePrefixSkip(t *testing.T) { + m := &Matcher{Type: MatcherTypeHolder{MatcherType: RegexMatcher}, Condition: "or", Regex: []string{"(?i)abc"}} + err := m.CompileMatchers() + require.NoError(t, err) + ok, got := m.MatchRegex("zzz AbC yyy") + require.True(t, ok) + require.NotEmpty(t, got) +} + +func TestMatchStatusCodeAndSize(t *testing.T) { + mStatus := &Matcher{Status: []int{200, 302}} + require.True(t, mStatus.MatchStatusCode(200)) + require.True(t, mStatus.MatchStatusCode(302)) + require.False(t, mStatus.MatchStatusCode(404)) + + mSize := &Matcher{Size: []int{5, 10}} + require.True(t, mSize.MatchSize(5)) + require.False(t, mSize.MatchSize(7)) +} + +func TestMatchBinary_AND_OR(t *testing.T) { + // AND should fail if any binary not present + mAnd := &Matcher{Type: MatcherTypeHolder{MatcherType: BinaryMatcher}, Condition: "and", Binary: []string{"50494e47", "414141"}} // "PING", "AAA" + require.NoError(t, mAnd.CompileMatchers()) + ok, _ := mAnd.MatchBinary("PING") + require.False(t, ok) + // OR should succeed if any present + mOr := &Matcher{Type: MatcherTypeHolder{MatcherType: BinaryMatcher}, Condition: "or", Binary: []string{"414141", "50494e47"}} // "AAA", "PING" + require.NoError(t, mOr.CompileMatchers()) + ok, got := mOr.MatchBinary("xxPINGyy") + require.True(t, ok) + require.NotEmpty(t, got) +} + +func TestMatchRegex_LiteralPrefixShortCircuit(t *testing.T) { + // AND: first regex has literal prefix "abc"; corpus lacks it => early false + mAnd := &Matcher{Type: MatcherTypeHolder{MatcherType: RegexMatcher}, Condition: "and", Regex: []string{"abc[0-9]*", "[0-9]{2}"}} + require.NoError(t, mAnd.CompileMatchers()) + ok, matches := mAnd.MatchRegex("zzz 12 yyy") + require.False(t, ok) + require.Empty(t, matches) + + // OR: first regex skipped due to missing prefix, second matches => true + mOr := &Matcher{Type: MatcherTypeHolder{MatcherType: RegexMatcher}, Condition: "or", Regex: []string{"abc[0-9]*", "[0-9]{2}"}} + require.NoError(t, mOr.CompileMatchers()) + ok, matches = mOr.MatchRegex("zzz 12 yyy") + require.True(t, ok) + require.Equal(t, []string{"12"}, matches) +} + +func TestMatcher_MatchDSL_ErrorHandling(t *testing.T) { + // First expression errors (division by zero), second is true + bad, err := govaluate.NewEvaluableExpression("1 / 0") + require.NoError(t, err) + good, err := govaluate.NewEvaluableExpression("1 == 1") + require.NoError(t, err) + + m := &Matcher{Type: MatcherTypeHolder{MatcherType: DSLMatcher}, Condition: "or", dslCompiled: []*govaluate.EvaluableExpression{bad, good}} + require.NoError(t, m.CompileMatchers()) + ok := m.MatchDSL(map[string]interface{}{}) + require.True(t, ok) +} diff --git a/pkg/protocols/common/generators/attack_types_test.go b/pkg/protocols/common/generators/attack_types_test.go new file mode 100644 index 000000000..a1c808319 --- /dev/null +++ b/pkg/protocols/common/generators/attack_types_test.go @@ -0,0 +1,26 @@ +package generators + +import "testing" + +func TestAttackTypeHelpers(t *testing.T) { + // GetSupportedAttackTypes should include three values + types := GetSupportedAttackTypes() + if len(types) != 3 { + t.Fatalf("expected 3 types, got %d", len(types)) + } + // toAttackType valid + if got, err := toAttackType("pitchfork"); err != nil || got != PitchForkAttack { + t.Fatalf("toAttackType failed: %v %v", got, err) + } + // toAttackType invalid + if _, err := toAttackType("nope"); err == nil { + t.Fatalf("expected error for invalid attack type") + } + // normalizeValue and String + if normalizeValue(" ClusterBomb ") != "clusterbomb" { + t.Fatalf("normalizeValue failed") + } + if ClusterBombAttack.String() != "clusterbomb" { + t.Fatalf("String failed") + } +} diff --git a/pkg/protocols/common/generators/env_test.go b/pkg/protocols/common/generators/env_test.go new file mode 100644 index 000000000..88b8fa377 --- /dev/null +++ b/pkg/protocols/common/generators/env_test.go @@ -0,0 +1,38 @@ +package generators + +import ( + "os" + "testing" +) + +func TestParseEnvVars(t *testing.T) { + old := os.Environ() + // set a scoped env var + _ = os.Setenv("NUCLEI_TEST_K", "V1") + t.Cleanup(func() { + // restore + for _, kv := range old { + parts := kv + _ = parts // nothing, environment already has superset; best-effort cleanup below + } + _ = os.Unsetenv("NUCLEI_TEST_K") + }) + vars := parseEnvVars() + if vars["NUCLEI_TEST_K"] != "V1" { + t.Fatalf("expected V1, got %v", vars["NUCLEI_TEST_K"]) + } +} + +func TestEnvVarsMemoization(t *testing.T) { + // reset memoized map + envVars = nil + _ = os.Setenv("NUCLEI_TEST_MEMO", "A") + t.Cleanup(func() { _ = os.Unsetenv("NUCLEI_TEST_MEMO") }) + v1 := EnvVars()["NUCLEI_TEST_MEMO"] + // change env after memoization + _ = os.Setenv("NUCLEI_TEST_MEMO", "B") + v2 := EnvVars()["NUCLEI_TEST_MEMO"] + if v1 != "A" || v2 != "A" { + t.Fatalf("memoization failed: %v %v", v1, v2) + } +} diff --git a/pkg/protocols/common/generators/load.go b/pkg/protocols/common/generators/load.go index 892fe358a..91f631e4b 100644 --- a/pkg/protocols/common/generators/load.go +++ b/pkg/protocols/common/generators/load.go @@ -17,6 +17,20 @@ func (generator *PayloadGenerator) loadPayloads(payloads map[string]interface{}, for name, payload := range payloads { switch pt := payload.(type) { case string: + // Fast path: if no newline, treat as file path + if !strings.ContainsRune(pt, '\n') { + file, err := generator.options.LoadHelperFile(pt, templatePath, generator.catalog) + if err != nil { + return nil, errors.Wrap(err, "could not load payload file") + } + payloads, err := generator.loadPayloadsFromFile(file) + if err != nil { + return nil, errors.Wrap(err, "could not load payloads") + } + loadedPayloads[name] = payloads + break + } + // Multiline inline payloads elements := strings.Split(pt, "\n") //golint:gomnd // this is not a magic number if len(elements) >= 2 { diff --git a/pkg/protocols/common/generators/load_test.go b/pkg/protocols/common/generators/load_test.go index ebec9fd72..04d886270 100644 --- a/pkg/protocols/common/generators/load_test.go +++ b/pkg/protocols/common/generators/load_test.go @@ -1,120 +1,108 @@ package generators import ( - "os" - "os/exec" - "path/filepath" + "io" + "strings" "testing" - "github.com/projectdiscovery/nuclei/v3/pkg/catalog/config" - "github.com/projectdiscovery/nuclei/v3/pkg/catalog/disk" - osutils "github.com/projectdiscovery/utils/os" - "github.com/stretchr/testify/require" + "github.com/pkg/errors" + "github.com/projectdiscovery/nuclei/v3/pkg/catalog" + "github.com/projectdiscovery/nuclei/v3/pkg/types" ) -func TestLoadPayloads(t *testing.T) { - // since we are changing value of global variable i.e templates directory - // run this test as subprocess - if os.Getenv("LOAD_PAYLOAD_NO_ACCESS") != "1" { - cmd := exec.Command(os.Args[0], "-test.run=TestLoadPayloadsWithAccess") - cmd.Env = append(os.Environ(), "LOAD_PAYLOAD_NO_ACCESS=1") - err := cmd.Run() - if e, ok := err.(*exec.ExitError); ok && !e.Success() { - return - } - if err != nil { - t.Fatalf("process ran with err %v, want exit status 1", err) +type fakeCatalog struct{ catalog.Catalog } + +func (f *fakeCatalog) OpenFile(filename string) (io.ReadCloser, error) { + return nil, errors.New("not used") +} +func (f *fakeCatalog) GetTemplatePath(target string) ([]string, error) { return nil, nil } +func (f *fakeCatalog) GetTemplatesPath(definitions []string) ([]string, map[string]error) { + return nil, nil +} +func (f *fakeCatalog) ResolvePath(templateName, second string) (string, error) { + return templateName, nil +} + +func newTestGenerator() *PayloadGenerator { + opts := types.DefaultOptions() + // inject helper loader function + opts.LoadHelperFileFunction = func(path, templatePath string, _ catalog.Catalog) (io.ReadCloser, error) { + switch path { + case "fileA.txt": + return io.NopCloser(strings.NewReader("one\n two\n\nthree\n")), nil + default: + return io.NopCloser(strings.NewReader("x\ny\nz\n")), nil } } - templateDir := getTemplatesDir(t) - config.DefaultConfig.SetTemplatesDir(templateDir) - - generator := &PayloadGenerator{catalog: disk.NewCatalog(templateDir), options: getOptions(false)} - fullpath := filepath.Join(templateDir, "payloads.txt") - - // Test sandbox - t.Run("templates-directory", func(t *testing.T) { - // testcase when loading file from template directory and template file is in root - // expected to succeed - values, err := generator.loadPayloads(map[string]interface{}{ - "new": fullpath, - }, "/test") - require.NoError(t, err, "could not load payloads") - require.Equal(t, map[string][]string{"new": {"test", "another"}}, values, "could not get values") - }) - t.Run("templates-path-relative", func(t *testing.T) { - // testcase when loading file from template directory and template file is current working directory - // expected to fail since this is LFI - _, err := generator.loadPayloads(map[string]interface{}{ - "new": "../../../../../../../../../etc/passwd", - }, ".") - require.Error(t, err, "could load payloads") - }) - t.Run("template-directory", func(t *testing.T) { - // testcase when loading file from template directory and template file is inside template directory - // expected to succeed - values, err := generator.loadPayloads(map[string]interface{}{ - "new": fullpath, - }, filepath.Join(templateDir, "test.yaml")) - require.NoError(t, err, "could not load payloads") - require.Equal(t, map[string][]string{"new": {"test", "another"}}, values, "could not get values") - }) - - t.Run("invalid", func(t *testing.T) { - // testcase when loading file from /etc/passwd and template file is at root i.e / - // expected to fail since this is LFI - values, err := generator.loadPayloads(map[string]interface{}{ - "new": "/etc/passwd", - }, "/random") - require.Error(t, err, "could load payloads got %v", values) - require.Equal(t, 0, len(values), "could get values") - - // testcase when loading file from template directory and template file is at root i.e / - // expected to succeed - values, err = generator.loadPayloads(map[string]interface{}{ - "new": fullpath, - }, "/random") - require.NoError(t, err, "could load payloads %v", values) - require.Equal(t, 1, len(values), "could get values") - require.Equal(t, []string{"test", "another"}, values["new"], "could get values") - }) + return &PayloadGenerator{options: opts, catalog: &fakeCatalog{}} } -func TestLoadPayloadsWithAccess(t *testing.T) { - // since we are changing value of global variable i.e templates directory - // run this test as subprocess - if os.Getenv("LOAD_PAYLOAD_WITH_ACCESS") != "1" { - cmd := exec.Command(os.Args[0], "-test.run=TestLoadPayloadsWithAccess") - cmd.Env = append(os.Environ(), "LOAD_PAYLOAD_WITH_ACCESS=1") - err := cmd.Run() - if e, ok := err.(*exec.ExitError); ok && !e.Success() { - return - } - if err != nil { - t.Fatalf("process ran with err %v, want exit status 1", err) - } +func TestLoadPayloads_FastPathFile(t *testing.T) { + g := newTestGenerator() + out, err := g.loadPayloads(map[string]interface{}{"A": "fileA.txt"}, "") + if err != nil { + t.Fatalf("err: %v", err) + } + got := out["A"] + if len(got) != 3 || got[0] != "one" || got[1] != " two" || got[2] != "three" { + t.Fatalf("unexpected: %#v", got) } - templateDir := getTemplatesDir(t) - config.DefaultConfig.SetTemplatesDir(templateDir) - - generator := &PayloadGenerator{catalog: disk.NewCatalog(templateDir), options: getOptions(true)} - - t.Run("no-sandbox-unix", func(t *testing.T) { - if osutils.IsWindows() { - return - } - _, err := generator.loadPayloads(map[string]interface{}{ - "new": "/etc/passwd", - }, "/random") - require.NoError(t, err, "could load payloads") - }) } -func getTemplatesDir(t *testing.T) string { - tempdir, err := os.MkdirTemp("", "templates-*") - require.NoError(t, err, "could not create temp dir") - fullpath := filepath.Join(tempdir, "payloads.txt") - err = os.WriteFile(fullpath, []byte("test\nanother"), 0777) - require.NoError(t, err, "could not write payload") - return tempdir +func TestLoadPayloads_InlineMultiline(t *testing.T) { + g := newTestGenerator() + inline := "a\nb\n" + out, err := g.loadPayloads(map[string]interface{}{"B": inline}, "") + if err != nil { + t.Fatalf("err: %v", err) + } + got := out["B"] + if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "" { + t.Fatalf("unexpected: %#v", got) + } +} + +func TestLoadPayloads_SingleLineFallsBackToFile(t *testing.T) { + g := newTestGenerator() + inline := "fileA.txt" // single line, should be treated as file path + out, err := g.loadPayloads(map[string]interface{}{"C": inline}, "") + if err != nil { + t.Fatalf("err: %v", err) + } + got := out["C"] + if len(got) != 3 { + t.Fatalf("unexpected len: %d", len(got)) + } +} + +func TestLoadPayloads_InterfaceSlice(t *testing.T) { + g := newTestGenerator() + out, err := g.loadPayloads(map[string]interface{}{"D": []interface{}{"p", "q"}}, "") + if err != nil { + t.Fatalf("err: %v", err) + } + got := out["D"] + if len(got) != 2 || got[0] != "p" || got[1] != "q" { + t.Fatalf("unexpected: %#v", got) + } +} + +func TestLoadPayloadsFromFile_SkipsEmpty(t *testing.T) { + g := newTestGenerator() + rc := io.NopCloser(strings.NewReader("a\n\n\n b \n")) + lines, err := g.loadPayloadsFromFile(rc) + if err != nil { + t.Fatalf("err: %v", err) + } + if len(lines) != 2 || lines[0] != "a" || lines[1] != " b " { + t.Fatalf("unexpected: %#v", lines) + } +} + +func TestValidate_AllowsInlineMultiline(t *testing.T) { + g := newTestGenerator() + inline := "x\ny\n" + if err := g.validate(map[string]interface{}{"E": inline}, ""); err != nil { + t.Fatalf("validate rejected inline multiline: %v", err) + } } diff --git a/pkg/protocols/common/generators/maps_test.go b/pkg/protocols/common/generators/maps_test.go index ca75bb655..8b73539f8 100644 --- a/pkg/protocols/common/generators/maps_test.go +++ b/pkg/protocols/common/generators/maps_test.go @@ -14,3 +14,32 @@ func TestMergeMapsMany(t *testing.T) { "c": {"5"}, }, got, "could not get correct merged map") } + +func TestMergeMapsAndExpand(t *testing.T) { + m1 := map[string]interface{}{"a": "1"} + m2 := map[string]interface{}{"b": "2"} + out := MergeMaps(m1, m2) + if out["a"].(string) != "1" || out["b"].(string) != "2" { + t.Fatalf("unexpected merge: %#v", out) + } + flat := map[string]string{"x": "y"} + exp := ExpandMapValues(flat) + if len(exp["x"]) != 1 || exp["x"][0] != "y" { + t.Fatalf("unexpected expand: %#v", exp) + } +} + +func TestIteratorRemaining(t *testing.T) { + g, err := New(map[string]interface{}{"k": []interface{}{"a", "b"}}, BatteringRamAttack, "", nil, "", nil) + if err != nil { + t.Fatalf("new: %v", err) + } + it := g.NewIterator() + if it.Total() != 2 || it.Remaining() != 2 { + t.Fatalf("unexpected totals: %d %d", it.Total(), it.Remaining()) + } + _, _ = it.Value() + if it.Remaining() != 1 { + t.Fatalf("unexpected remaining after one: %d", it.Remaining()) + } +} diff --git a/pkg/protocols/common/generators/validate.go b/pkg/protocols/common/generators/validate.go index 0aa073714..48365e0db 100644 --- a/pkg/protocols/common/generators/validate.go +++ b/pkg/protocols/common/generators/validate.go @@ -1,7 +1,6 @@ package generators import ( - "errors" "fmt" "path/filepath" "strings" @@ -17,9 +16,8 @@ func (g *PayloadGenerator) validate(payloads map[string]interface{}, templatePat for name, payload := range payloads { switch payloadType := payload.(type) { case string: - // check if it's a multiline string list - if len(strings.Split(payloadType, "\n")) != 1 { - return errors.New("invalid number of lines in payload") + if strings.ContainsRune(payloadType, '\n') { + continue } // For historical reasons, "validate" checks to see if the payload file exist. diff --git a/pkg/protocols/http/build_request.go b/pkg/protocols/http/build_request.go index 0fa8d1502..980573c96 100644 --- a/pkg/protocols/http/build_request.go +++ b/pkg/protocols/http/build_request.go @@ -92,9 +92,8 @@ type generatedRequest struct { // setReqURLPattern sets the url request pattern for the generated request func (gr *generatedRequest) setReqURLPattern(reqURLPattern string) { - data := strings.Split(reqURLPattern, "\n") - if len(data) > 1 { - reqURLPattern = strings.TrimSpace(data[0]) + if idx := strings.IndexByte(reqURLPattern, '\n'); idx >= 0 { + reqURLPattern = strings.TrimSpace(reqURLPattern[:idx]) // this is raw request (if it has 3 parts after strings.Fields then its valid only use 2nd part) parts := strings.Fields(reqURLPattern) if len(parts) >= 3 { diff --git a/pkg/protocols/http/raw/raw.go b/pkg/protocols/http/raw/raw.go index c2a2121b6..7b1457afa 100644 --- a/pkg/protocols/http/raw/raw.go +++ b/pkg/protocols/http/raw/raw.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "strings" + "sync" "github.com/projectdiscovery/gologger" "github.com/projectdiscovery/nuclei/v3/pkg/authprovider/authx" @@ -17,6 +18,8 @@ import ( urlutil "github.com/projectdiscovery/utils/url" ) +var bufferPool = sync.Pool{New: func() any { return new(bytes.Buffer) }} + // Request defines a basic HTTP raw request type Request struct { FullURL string @@ -270,13 +273,17 @@ func (r *Request) TryFillCustomHeaders(headers []string) error { if newLineIndex > 0 { newLineIndex += hostHeaderIndex + 2 // insert custom headers - var buf bytes.Buffer + buf := bufferPool.Get().(*bytes.Buffer) + buf.Reset() buf.Write(r.UnsafeRawBytes[:newLineIndex]) for _, header := range headers { - buf.WriteString(fmt.Sprintf("%s\r\n", header)) + buf.WriteString(header) + buf.WriteString("\r\n") } buf.Write(r.UnsafeRawBytes[newLineIndex:]) - r.UnsafeRawBytes = buf.Bytes() + r.UnsafeRawBytes = append([]byte(nil), buf.Bytes()...) + buf.Reset() + bufferPool.Put(buf) return nil } return errors.New("no new line found at the end of host header") @@ -301,9 +308,10 @@ func (r *Request) ApplyAuthStrategy(strategy authx.AuthStrategy) { parsed.Params.Add(p.Key, p.Value) } case *authx.CookiesAuthStrategy: - var buff bytes.Buffer + buff := bufferPool.Get().(*bytes.Buffer) + buff.Reset() for _, cookie := range s.Data.Cookies { - buff.WriteString(fmt.Sprintf("%s=%s; ", cookie.Key, cookie.Value)) + fmt.Fprintf(buff, "%s=%s; ", cookie.Key, cookie.Value) } if buff.Len() > 0 { if val, ok := r.Headers["Cookie"]; ok { @@ -312,6 +320,7 @@ func (r *Request) ApplyAuthStrategy(strategy authx.AuthStrategy) { r.Headers["Cookie"] = buff.String() } } + bufferPool.Put(buff) case *authx.HeadersAuthStrategy: for _, header := range s.Data.Headers { r.Headers[header.Key] = header.Value diff --git a/pkg/protocols/http/raw/raw_test.go b/pkg/protocols/http/raw/raw_test.go index a44664d48..80fefff7f 100644 --- a/pkg/protocols/http/raw/raw_test.go +++ b/pkg/protocols/http/raw/raw_test.go @@ -7,6 +7,21 @@ import ( "github.com/stretchr/testify/require" ) +func TestTryFillCustomHeaders_BufferDetached(t *testing.T) { + r := &Request{ + UnsafeRawBytes: []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\nBody"), + } + // first fill + err := r.TryFillCustomHeaders([]string{"X-Test: 1"}) + require.NoError(t, err, "unexpected error on first call") + prev := r.UnsafeRawBytes + prevStr := string(prev) // content snapshot + err = r.TryFillCustomHeaders([]string{"X-Another: 2"}) + require.NoError(t, err, "unexpected error on second call") + require.Equal(t, prevStr, string(prev), "first slice mutated after second call; buffer not detached") + require.NotEqual(t, prevStr, string(r.UnsafeRawBytes), "request bytes did not change after second call") +} + func TestParseRawRequestWithPort(t *testing.T) { request, err := Parse(`GET /gg/phpinfo.php HTTP/1.1 Host: {{Hostname}}:123 diff --git a/pkg/protocols/http/request.go b/pkg/protocols/http/request.go index 381f4d385..c6e3de4d9 100644 --- a/pkg/protocols/http/request.go +++ b/pkg/protocols/http/request.go @@ -240,6 +240,48 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV } }) + // bounded worker-pool to avoid spawning one goroutine per payload + type task struct { + req *generatedRequest + updatedInput *contextargs.Context + } + + var workersWg sync.WaitGroup + currentWorkers := maxWorkers + tasks := make(chan task, maxWorkers) + spawnWorker := func(ctx context.Context) { + workersWg.Add(1) + go func() { + defer workersWg.Done() + for t := range tasks { + select { + case <-ctx.Done(): + return + default: + } + if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(t.updatedInput) || spmHandler.Cancelled() { + continue + } + spmHandler.Acquire() + if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(t.updatedInput) || spmHandler.Cancelled() { + spmHandler.Release() + continue + } + request.options.RateLimitTake() + select { + case <-spmHandler.Done(): + spmHandler.Release() + continue + case spmHandler.ResultChan <- request.executeRequest(t.updatedInput, t.req, make(map[string]interface{}), false, wrappedCallback, 0): + spmHandler.Release() + } + } + }() + } + for i := 0; i < currentWorkers; i++ { + spawnWorker(ctx) + } + // iterate payloads and make requests generator := request.newGenerator(false) for { @@ -259,6 +301,13 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV if err := spmHandler.Resize(input.Context(), request.options.Options.PayloadConcurrency); err != nil { return err } + // if payload concurrency increased, add more workers + if spmHandler.Size() > currentWorkers { + for i := 0; i < spmHandler.Size()-currentWorkers; i++ { + spawnWorker(ctx) + } + currentWorkers = spmHandler.Size() + } } // break if stop at first match is found or host is unresponsive @@ -284,29 +333,21 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV spmHandler.Cancel() return nil } - spmHandler.Acquire() - go func(httpRequest *generatedRequest) { - defer spmHandler.Release() - if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(updatedInput) || spmHandler.Cancelled() { - return + select { + case <-spmHandler.Done(): + close(tasks) + workersWg.Wait() + spmHandler.Wait() + if spmHandler.FoundFirstMatch() { + return nil } - // putting ratelimiter here prevents any unnecessary waiting if any - request.options.RateLimitTake() - - // after ratelimit take, check if we need to stop - if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(updatedInput) || spmHandler.Cancelled() { - return - } - - select { - case <-spmHandler.Done(): - return - case spmHandler.ResultChan <- request.executeRequest(input, httpRequest, make(map[string]interface{}), false, wrappedCallback, 0): - return - } - }(generatedHttpRequest) + return multierr.Combine(spmHandler.CombinedResults()...) + case tasks <- task{req: generatedHttpRequest, updatedInput: updatedInput}: + } request.options.Progress.IncrementRequests() } + close(tasks) + workersWg.Wait() spmHandler.Wait() if spmHandler.FoundFirstMatch() { // ignore any context cancellation and in-transit execution errors diff --git a/pkg/protocols/http/request_test.go b/pkg/protocols/http/request_test.go index a6314ae5a..b547d91b5 100644 --- a/pkg/protocols/http/request_test.go +++ b/pkg/protocols/http/request_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync/atomic" "testing" "github.com/stretchr/testify/require" @@ -257,3 +258,116 @@ func TestReqURLPattern(t *testing.T) { require.NotEmpty(t, finalEvent.Results[0].ReqURLPattern, "could not get req url pattern") require.Equal(t, `/{{rand_char("abc")}}/{{interactsh-url}}/123?query={{rand_int(1, 10)}}&data={{randstr}}`, finalEvent.Results[0].ReqURLPattern) } + +// fakeHostErrorsCache implements hosterrorscache.CacheInterface minimally for tests +type fakeHostErrorsCache struct{} + +func (f *fakeHostErrorsCache) SetVerbose(bool) {} +func (f *fakeHostErrorsCache) Close() {} +func (f *fakeHostErrorsCache) Remove(*contextargs.Context) {} +func (f *fakeHostErrorsCache) MarkFailed(string, *contextargs.Context, error) {} +func (f *fakeHostErrorsCache) MarkFailedOrRemove(string, *contextargs.Context, error) { +} + +// Check always returns true to simulate an already unresponsive host +func (f *fakeHostErrorsCache) Check(string, *contextargs.Context) bool { return true } + +func TestExecuteParallelHTTP_StopAtFirstMatch(t *testing.T) { + options := testutils.DefaultOptions + testutils.Init(options) + + // server that always matches + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintf(w, "match") + })) + defer ts.Close() + + templateID := "parallel-stop-first" + req := &Request{ + ID: templateID, + Method: HTTPMethodTypeHolder{MethodType: HTTPGet}, + Path: []string{"{{BaseURL}}/p?x={{v}}"}, + Threads: 2, + Payloads: map[string]interface{}{ + "v": []string{"1", "2"}, + }, + Operators: operators.Operators{ + Matchers: []*matchers.Matcher{{ + Part: "body", + Type: matchers.MatcherTypeHolder{MatcherType: matchers.WordsMatcher}, + Words: []string{"match"}, + }}, + }, + StopAtFirstMatch: true, + } + + executerOpts := testutils.NewMockExecuterOptions(options, &testutils.TemplateInfo{ + ID: templateID, + Info: model.Info{SeverityHolder: severity.Holder{Severity: severity.Low}, Name: "test"}, + }) + err := req.Compile(executerOpts) + require.NoError(t, err) + + var matches int32 + metadata := make(output.InternalEvent) + previous := make(output.InternalEvent) + ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL) + err = req.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { + if event.OperatorsResult != nil && event.OperatorsResult.Matched { + atomic.AddInt32(&matches, 1) + } + }) + require.NoError(t, err) + require.Equal(t, int32(1), atomic.LoadInt32(&matches), "expected only first match to be processed") +} + +func TestExecuteParallelHTTP_SkipOnUnresponsiveFromCache(t *testing.T) { + options := testutils.DefaultOptions + testutils.Init(options) + + // server that would match if reached + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintf(w, "match") + })) + defer ts.Close() + + templateID := "parallel-skip-unresponsive" + req := &Request{ + ID: templateID, + Method: HTTPMethodTypeHolder{MethodType: HTTPGet}, + Path: []string{"{{BaseURL}}/p?x={{v}}"}, + Threads: 2, + Payloads: map[string]interface{}{ + "v": []string{"1", "2"}, + }, + Operators: operators.Operators{ + Matchers: []*matchers.Matcher{{ + Part: "body", + Type: matchers.MatcherTypeHolder{MatcherType: matchers.WordsMatcher}, + Words: []string{"match"}, + }}, + }, + } + + executerOpts := testutils.NewMockExecuterOptions(options, &testutils.TemplateInfo{ + ID: templateID, + Info: model.Info{SeverityHolder: severity.Holder{Severity: severity.Low}, Name: "test"}, + }) + // inject fake host errors cache that forces skip + executerOpts.HostErrorsCache = &fakeHostErrorsCache{} + + err := req.Compile(executerOpts) + require.NoError(t, err) + + var matches int32 + metadata := make(output.InternalEvent) + previous := make(output.InternalEvent) + ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL) + err = req.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { + if event.OperatorsResult != nil && event.OperatorsResult.Matched { + atomic.AddInt32(&matches, 1) + } + }) + require.NoError(t, err) + require.Equal(t, int32(0), atomic.LoadInt32(&matches), "expected no matches when host is marked unresponsive") +} diff --git a/pkg/protocols/javascript/js.go b/pkg/protocols/javascript/js.go index 9925c3103..8b872d84a 100644 --- a/pkg/protocols/javascript/js.go +++ b/pkg/protocols/javascript/js.go @@ -811,8 +811,11 @@ func beautifyJavascript(code string) string { } func prettyPrint(templateId string, buff string) { + if buff == "" { + return + } lines := strings.Split(buff, "\n") - final := []string{} + final := make([]string, 0, len(lines)) for _, v := range lines { if v != "" { final = append(final, "\t"+v) diff --git a/pkg/reporting/reporting.go b/pkg/reporting/reporting.go index c3706c970..58d7f61fb 100644 --- a/pkg/reporting/reporting.go +++ b/pkg/reporting/reporting.go @@ -192,11 +192,13 @@ func New(options *Options, db string, doNotDedupe bool) (Client, error) { } } - storage, err := dedupe.New(db) - if err != nil { - return nil, err + if db != "" || len(client.trackers) > 0 || len(client.exporters) > 0 { + storage, err := dedupe.New(db) + if err != nil { + return nil, err + } + client.dedupe = storage } - client.dedupe = storage return client, nil } diff --git a/pkg/templates/cache.go b/pkg/templates/cache.go index b4c00fb36..ae7124772 100644 --- a/pkg/templates/cache.go +++ b/pkg/templates/cache.go @@ -12,7 +12,9 @@ type Cache struct { // New returns a new templates cache func NewCache() *Cache { - return &Cache{items: mapsutil.NewSyncLockMap[string, parsedTemplate]()} + return &Cache{ + items: mapsutil.NewSyncLockMap[string, parsedTemplate](), + } } type parsedTemplate struct { @@ -33,7 +35,31 @@ func (t *Cache) Has(template string) (*Template, []byte, error) { // Store stores a template with data and error func (t *Cache) Store(id string, tpl *Template, raw []byte, err error) { - _ = t.items.Set(id, parsedTemplate{template: tpl, raw: conversion.String(raw), err: err}) + entry := parsedTemplate{ + template: tpl, + err: err, + raw: conversion.String(raw), + } + _ = t.items.Set(id, entry) +} + +// StoreWithoutRaw stores a template without raw data for memory efficiency +func (t *Cache) StoreWithoutRaw(id string, tpl *Template, err error) { + entry := parsedTemplate{ + template: tpl, + err: err, + raw: "", + } + _ = t.items.Set(id, entry) +} + +// Get returns only the template without raw bytes +func (t *Cache) Get(id string) (*Template, error) { + value, ok := t.items.Get(id) + if !ok { + return nil, nil + } + return value.template, value.err } // Purge the cache diff --git a/pkg/templates/templates.go b/pkg/templates/templates.go index 74818557e..9907e2711 100644 --- a/pkg/templates/templates.go +++ b/pkg/templates/templates.go @@ -404,7 +404,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err // for code protocol requests for _, request := range template.RequestsCode { // simple test to check if source is a file or a snippet - if len(strings.Split(request.Source, "\n")) == 1 && fileutil.FileExists(request.Source) { + if !strings.ContainsRune(request.Source, '\n') && fileutil.FileExists(request.Source) { if val, ok := loadFile(request.Source); ok { template.ImportedFiles = append(template.ImportedFiles, request.Source) request.Source = val @@ -415,7 +415,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err // for javascript protocol code references for _, request := range template.RequestsJavascript { // simple test to check if source is a file or a snippet - if len(strings.Split(request.Code, "\n")) == 1 && fileutil.FileExists(request.Code) { + if !strings.ContainsRune(request.Code, '\n') && fileutil.FileExists(request.Code) { if val, ok := loadFile(request.Code); ok { template.ImportedFiles = append(template.ImportedFiles, request.Code) request.Code = val @@ -442,7 +442,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err if req.Type() == types.CodeProtocol { request := req.(*code.Request) // simple test to check if source is a file or a snippet - if len(strings.Split(request.Source, "\n")) == 1 && fileutil.FileExists(request.Source) { + if !strings.ContainsRune(request.Source, '\n') && fileutil.FileExists(request.Source) { if val, ok := loadFile(request.Source); ok { template.ImportedFiles = append(template.ImportedFiles, request.Source) request.Source = val @@ -456,7 +456,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err if req.Type() == types.JavascriptProtocol { request := req.(*javascript.Request) // simple test to check if source is a file or a snippet - if len(strings.Split(request.Code, "\n")) == 1 && fileutil.FileExists(request.Code) { + if !strings.ContainsRune(request.Code, '\n') && fileutil.FileExists(request.Code) { if val, ok := loadFile(request.Code); ok { template.ImportedFiles = append(template.ImportedFiles, request.Code) request.Code = val diff --git a/pkg/templates/templates_test.go b/pkg/templates/templates_test.go index 7621f1c57..a94ebf656 100644 --- a/pkg/templates/templates_test.go +++ b/pkg/templates/templates_test.go @@ -9,6 +9,32 @@ import ( "gopkg.in/yaml.v2" ) +func TestCachePoolZeroing(t *testing.T) { + c := NewCache() + + tpl := &Template{ID: "x"} + raw := []byte("SOME BIG RAW") + + c.Store("id1", tpl, raw, nil) + gotTpl, gotErr := c.Get("id1") + if gotErr != nil { + t.Fatalf("unexpected err: %v", gotErr) + } + if gotTpl == nil || gotTpl.ID != "x" { + t.Fatalf("unexpected tpl: %#v", gotTpl) + } + + // StoreWithoutRaw should not retain raw + c.StoreWithoutRaw("id2", tpl, nil) + gotTpl2, gotErr2 := c.Get("id2") + if gotErr2 != nil { + t.Fatalf("unexpected err: %v", gotErr2) + } + if gotTpl2 == nil || gotTpl2.ID != "x" { + t.Fatalf("unexpected tpl2: %#v", gotTpl2) + } +} + func TestTemplateStruct(t *testing.T) { templatePath := "./tests/match-1.yaml" bin, err := os.ReadFile(templatePath)