diff --git a/pkg/core/executors.go b/pkg/core/executors.go index 8abcc9d78..aeb85ddfe 100644 --- a/pkg/core/executors.go +++ b/pkg/core/executors.go @@ -48,8 +48,15 @@ func (e *Engine) executeAllSelfContained(ctx context.Context, alltemplates []*te // executeTemplateWithTargets executes a given template on x targets (with a internal targetpool(i.e concurrency)) func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templates.Template, target provider.InputProvider, results *atomic.Bool) { - // this is target pool i.e max target to execute - wg := e.workPool.InputPool(template.Type()) + if e.workPool == nil { + e.workPool = e.GetWorkPool() + } + // Bounded worker pool using input concurrency + pool := e.workPool.InputPool(template.Type()) + workerCount := 1 + if pool != nil && pool.Size > 0 { + workerCount = pool.Size + } var ( index uint32 @@ -78,6 +85,41 @@ func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templ currentInfo.Unlock() } + // task represents a single target execution unit + type task struct { + index uint32 + skip bool + value *contextargs.MetaInput + } + + tasks := make(chan task) + var workersWg sync.WaitGroup + workersWg.Add(workerCount) + for i := 0; i < workerCount; i++ { + go func() { + defer workersWg.Done() + for t := range tasks { + func() { + defer cleanupInFlight(t.index) + select { + case <-ctx.Done(): + return + default: + } + if t.skip { + return + } + + match, err := e.executeTemplateOnInput(ctx, template, t.value) + if err != nil { + e.options.Logger.Warning().Msgf("[%s] Could not execute step on %s: %s\n", e.executerOpts.Colorizer.BrightBlue(template.ID), t.value.Input, err) + } + results.CompareAndSwap(false, match) + }() + } + }() + } + target.Iterate(func(scannedValue *contextargs.MetaInput) bool { select { case <-ctx.Done(): @@ -128,43 +170,13 @@ func (e *Engine) executeTemplateWithTargets(ctx context.Context, template *templ return true } - wg.Add() - go func(index uint32, skip bool, value *contextargs.MetaInput) { - defer wg.Done() - defer cleanupInFlight(index) - if skip { - return - } - - var match bool - var err error - ctxArgs := contextargs.New(ctx) - ctxArgs.MetaInput = value - ctx := scan.NewScanContext(ctx, ctxArgs) - switch template.Type() { - case types.WorkflowProtocol: - match = e.executeWorkflow(ctx, template.CompiledWorkflow) - default: - if e.Callback != nil { - if results, err := template.Executer.ExecuteWithResults(ctx); err == nil { - for _, result := range results { - e.Callback(result) - } - } - match = true - } else { - match, err = template.Executer.Execute(ctx) - } - } - if err != nil { - e.options.Logger.Warning().Msgf("[%s] Could not execute step on %s: %s\n", e.executerOpts.Colorizer.BrightBlue(template.ID), value.Input, err) - } - results.CompareAndSwap(false, match) - }(index, skip, scannedValue) + tasks <- task{index: index, skip: skip, value: scannedValue} index++ return true }) - wg.Wait() + + close(tasks) + workersWg.Wait() // on completion marks the template as completed currentInfo.Lock() @@ -202,26 +214,7 @@ func (e *Engine) executeTemplatesOnTarget(ctx context.Context, alltemplates []*t go func(template *templates.Template, value *contextargs.MetaInput, wg *syncutil.AdaptiveWaitGroup) { defer wg.Done() - var match bool - var err error - ctxArgs := contextargs.New(ctx) - ctxArgs.MetaInput = value - ctx := scan.NewScanContext(ctx, ctxArgs) - switch template.Type() { - case types.WorkflowProtocol: - match = e.executeWorkflow(ctx, template.CompiledWorkflow) - default: - if e.Callback != nil { - if results, err := template.Executer.ExecuteWithResults(ctx); err == nil { - for _, result := range results { - e.Callback(result) - } - } - match = true - } else { - match, err = template.Executer.Execute(ctx) - } - } + match, err := e.executeTemplateOnInput(ctx, template, value) if err != nil { e.options.Logger.Warning().Msgf("[%s] Could not execute step on %s: %s\n", e.executerOpts.Colorizer.BrightBlue(template.ID), value.Input, err) } @@ -229,3 +222,27 @@ func (e *Engine) executeTemplatesOnTarget(ctx context.Context, alltemplates []*t }(tpl, target, sg) } } + +// executeTemplateOnInput performs template execution for a single input and returns match status and error +func (e *Engine) executeTemplateOnInput(ctx context.Context, template *templates.Template, value *contextargs.MetaInput) (bool, error) { + ctxArgs := contextargs.New(ctx) + ctxArgs.MetaInput = value + scanCtx := scan.NewScanContext(ctx, ctxArgs) + + switch template.Type() { + case types.WorkflowProtocol: + return e.executeWorkflow(scanCtx, template.CompiledWorkflow), nil + default: + if e.Callback != nil { + results, err := template.Executer.ExecuteWithResults(scanCtx) + if err != nil { + return false, err + } + for _, result := range results { + e.Callback(result) + } + return len(results) > 0, nil + } + return template.Executer.Execute(scanCtx) + } +} diff --git a/pkg/core/executors_test.go b/pkg/core/executors_test.go new file mode 100644 index 000000000..394b2e6d9 --- /dev/null +++ b/pkg/core/executors_test.go @@ -0,0 +1,148 @@ +package core + +import ( + "context" + "fmt" + "sync/atomic" + "testing" + "time" + + inputtypes "github.com/projectdiscovery/nuclei/v3/pkg/input/types" + "github.com/projectdiscovery/nuclei/v3/pkg/output" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols" + "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs" + "github.com/projectdiscovery/nuclei/v3/pkg/scan" + "github.com/projectdiscovery/nuclei/v3/pkg/templates" + tmpltypes "github.com/projectdiscovery/nuclei/v3/pkg/templates/types" + "github.com/projectdiscovery/nuclei/v3/pkg/types" +) + +// fakeExecuter is a simple stub for protocols.Executer used to test executeTemplateOnInput +type fakeExecuter struct { + withResults bool +} + +func (f *fakeExecuter) Compile() error { return nil } +func (f *fakeExecuter) Requests() int { return 1 } +func (f *fakeExecuter) Execute(ctx *scan.ScanContext) (bool, error) { return !f.withResults, nil } +func (f *fakeExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) { + if !f.withResults { + return nil, nil + } + return []*output.ResultEvent{{Host: "h"}}, nil +} + +// newTestEngine creates a minimal Engine for tests +func newTestEngine() *Engine { + return New(&types.Options{}) +} + +func Test_executeTemplateOnInput_CallbackPath(t *testing.T) { + e := newTestEngine() + called := 0 + e.Callback = func(*output.ResultEvent) { called++ } + + tpl := &templates.Template{} + tpl.Executer = &fakeExecuter{withResults: true} + + ok, err := e.executeTemplateOnInput(context.Background(), tpl, &contextargs.MetaInput{Input: "x"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ok { + t.Fatalf("expected match true") + } + if called == 0 { + t.Fatalf("expected callback to be called") + } +} + +func Test_executeTemplateOnInput_ExecutePath(t *testing.T) { + e := newTestEngine() + tpl := &templates.Template{} + tpl.Executer = &fakeExecuter{withResults: false} + + ok, err := e.executeTemplateOnInput(context.Background(), tpl, &contextargs.MetaInput{Input: "x"}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if !ok { + t.Fatalf("expected match true from Execute path") + } +} + +type fakeExecuterErr struct{} + +func (f *fakeExecuterErr) Compile() error { return nil } +func (f *fakeExecuterErr) Requests() int { return 1 } +func (f *fakeExecuterErr) Execute(ctx *scan.ScanContext) (bool, error) { return false, nil } +func (f *fakeExecuterErr) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) { + return nil, fmt.Errorf("boom") +} + +func Test_executeTemplateOnInput_CallbackErrorPropagates(t *testing.T) { + e := newTestEngine() + e.Callback = func(*output.ResultEvent) {} + tpl := &templates.Template{} + tpl.Executer = &fakeExecuterErr{} + + ok, err := e.executeTemplateOnInput(context.Background(), tpl, &contextargs.MetaInput{Input: "x"}) + if err == nil { + t.Fatalf("expected error to propagate") + } + if ok { + t.Fatalf("expected match to be false on error") + } +} + +type fakeTargetProvider struct { + values []*contextargs.MetaInput +} + +func (f *fakeTargetProvider) Count() int64 { return int64(len(f.values)) } +func (f *fakeTargetProvider) Iterate(cb func(value *contextargs.MetaInput) bool) { + for _, v := range f.values { + if !cb(v) { + return + } + } +} +func (f *fakeTargetProvider) Set(string, string) {} +func (f *fakeTargetProvider) SetWithProbe(string, string, inputtypes.InputLivenessProbe) error { + return nil +} +func (f *fakeTargetProvider) SetWithExclusions(string, string) error { return nil } +func (f *fakeTargetProvider) InputType() string { return "test" } +func (f *fakeTargetProvider) Close() {} + +type slowExecuter struct{} + +func (s *slowExecuter) Compile() error { return nil } +func (s *slowExecuter) Requests() int { return 1 } +func (s *slowExecuter) Execute(ctx *scan.ScanContext) (bool, error) { + select { + case <-ctx.Context().Done(): + return false, ctx.Context().Err() + case <-time.After(200 * time.Millisecond): + return true, nil + } +} +func (s *slowExecuter) ExecuteWithResults(ctx *scan.ScanContext) ([]*output.ResultEvent, error) { + return nil, nil +} + +func Test_executeTemplateWithTargets_RespectsCancellation(t *testing.T) { + e := newTestEngine() + e.SetExecuterOptions(&protocols.ExecutorOptions{Logger: e.Logger, ResumeCfg: types.NewResumeCfg(), ProtocolType: tmpltypes.HTTPProtocol}) + + tpl := &templates.Template{} + tpl.Executer = &slowExecuter{} + + targets := &fakeTargetProvider{values: []*contextargs.MetaInput{{Input: "a"}, {Input: "b"}, {Input: "c"}}} + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + var matched atomic.Bool + e.executeTemplateWithTargets(ctx, tpl, targets, &matched) +} diff --git a/pkg/operators/cache/cache.go b/pkg/operators/cache/cache.go new file mode 100644 index 000000000..ca486097c --- /dev/null +++ b/pkg/operators/cache/cache.go @@ -0,0 +1,62 @@ +package cache + +import ( + "regexp" + "sync" + + "github.com/Knetic/govaluate" + "github.com/projectdiscovery/gcache" +) + +var ( + initOnce sync.Once + mu sync.RWMutex + + regexCap = 4096 + dslCap = 4096 + + regexCache gcache.Cache[string, *regexp.Regexp] + dslCache gcache.Cache[string, *govaluate.EvaluableExpression] +) + +func initCaches() { + initOnce.Do(func() { + regexCache = gcache.New[string, *regexp.Regexp](regexCap).LRU().Build() + dslCache = gcache.New[string, *govaluate.EvaluableExpression](dslCap).LRU().Build() + }) +} + +func SetCapacities(regexCapacity, dslCapacity int) { + // ensure caches are initialized under initOnce, so later Regex()/DSL() won't re-init + initCaches() + + mu.Lock() + defer mu.Unlock() + + if regexCapacity > 0 { + regexCap = regexCapacity + } + if dslCapacity > 0 { + dslCap = dslCapacity + } + if regexCapacity <= 0 && dslCapacity <= 0 { + return + } + // rebuild caches with new capacities + regexCache = gcache.New[string, *regexp.Regexp](regexCap).LRU().Build() + dslCache = gcache.New[string, *govaluate.EvaluableExpression](dslCap).LRU().Build() +} + +func Regex() gcache.Cache[string, *regexp.Regexp] { + initCaches() + mu.RLock() + defer mu.RUnlock() + return regexCache +} + +func DSL() gcache.Cache[string, *govaluate.EvaluableExpression] { + initCaches() + mu.RLock() + defer mu.RUnlock() + return dslCache +} diff --git a/pkg/operators/cache/cache_test.go b/pkg/operators/cache/cache_test.go new file mode 100644 index 000000000..c44b72c84 --- /dev/null +++ b/pkg/operators/cache/cache_test.go @@ -0,0 +1,114 @@ +package cache + +import ( + "regexp" + "testing" + + "github.com/Knetic/govaluate" +) + +func TestRegexCache_SetGet(t *testing.T) { + // ensure init + c := Regex() + pattern := "abc(\n)?123" + re, err := regexp.Compile(pattern) + if err != nil { + t.Fatalf("compile: %v", err) + } + if err := c.Set(pattern, re); err != nil { + t.Fatalf("set: %v", err) + } + got, err := c.GetIFPresent(pattern) + if err != nil || got == nil { + t.Fatalf("get: %v got=%v", err, got) + } + if got.String() != re.String() { + t.Fatalf("mismatch: %s != %s", got.String(), re.String()) + } +} + +func TestDSLCache_SetGet(t *testing.T) { + c := DSL() + expr := "1 + 2 == 3" + ast, err := govaluate.NewEvaluableExpression(expr) + if err != nil { + t.Fatalf("dsl compile: %v", err) + } + if err := c.Set(expr, ast); err != nil { + t.Fatalf("set: %v", err) + } + got, err := c.GetIFPresent(expr) + if err != nil || got == nil { + t.Fatalf("get: %v got=%v", err, got) + } + if got.String() != ast.String() { + t.Fatalf("mismatch: %s != %s", got.String(), ast.String()) + } +} + +func TestRegexCache_EvictionByCapacity(t *testing.T) { + SetCapacities(3, 3) + c := Regex() + for i := 0; i < 5; i++ { + k := string(rune('a' + i)) + re := regexp.MustCompile(k) + _ = c.Set(k, re) + } + // last 3 keys expected to remain under LRU: 'c','d','e' + if _, err := c.GetIFPresent("a"); err == nil { + t.Fatalf("expected 'a' to be evicted") + } + if _, err := c.GetIFPresent("b"); err == nil { + t.Fatalf("expected 'b' to be evicted") + } + if _, err := c.GetIFPresent("c"); err != nil { + t.Fatalf("expected 'c' present") + } +} + +func TestSetCapacities_NoRebuildOnZero(t *testing.T) { + // init + SetCapacities(4, 4) + c1 := Regex() + _ = c1.Set("k", regexp.MustCompile("k")) + if _, err := c1.GetIFPresent("k"); err != nil { + t.Fatalf("expected key present: %v", err) + } + // zero changes should not rebuild/clear caches + SetCapacities(0, 0) + c2 := Regex() + if _, err := c2.GetIFPresent("k"); err != nil { + t.Fatalf("key lost after zero-capacity SetCapacities: %v", err) + } +} + +func TestSetCapacities_BeforeFirstUse(t *testing.T) { + // This should not be overridden by later initCaches + SetCapacities(2, 0) + c := Regex() + _ = c.Set("a", regexp.MustCompile("a")) + _ = c.Set("b", regexp.MustCompile("b")) + _ = c.Set("c", regexp.MustCompile("c")) + if _, err := c.GetIFPresent("a"); err == nil { + t.Fatalf("expected 'a' to be evicted under cap=2") + } +} + +func TestSetCapacities_ConcurrentAccess(t *testing.T) { + SetCapacities(64, 64) + stop := make(chan struct{}) + + go func() { + for i := 0; i < 5000; i++ { + _ = Regex().Set("k"+string(rune('a'+(i%26))), regexp.MustCompile("a")) + _, _ = Regex().GetIFPresent("k" + string(rune('a'+(i%26)))) + _, _ = DSL().GetIFPresent("1+2==3") + } + close(stop) + }() + + for i := 0; i < 200; i++ { + SetCapacities(64+(i%5), 64+((i+1)%5)) + } + <-stop +} diff --git a/pkg/operators/extractors/compile.go b/pkg/operators/extractors/compile.go index 2b55d374a..bcfd37eeb 100644 --- a/pkg/operators/extractors/compile.go +++ b/pkg/operators/extractors/compile.go @@ -7,6 +7,7 @@ import ( "github.com/Knetic/govaluate" "github.com/itchyny/gojq" + "github.com/projectdiscovery/nuclei/v3/pkg/operators/cache" "github.com/projectdiscovery/nuclei/v3/pkg/operators/common/dsl" ) @@ -20,10 +21,15 @@ func (e *Extractor) CompileExtractors() error { e.extractorType = computedType // Compile the regexes for _, regex := range e.Regex { + if cached, err := cache.Regex().GetIFPresent(regex); err == nil && cached != nil { + e.regexCompiled = append(e.regexCompiled, cached) + continue + } compiled, err := regexp.Compile(regex) if err != nil { return fmt.Errorf("could not compile regex: %s", regex) } + _ = cache.Regex().Set(regex, compiled) e.regexCompiled = append(e.regexCompiled, compiled) } for i, kval := range e.KVal { @@ -43,10 +49,15 @@ func (e *Extractor) CompileExtractors() error { } for _, dslExp := range e.DSL { + if cached, err := cache.DSL().GetIFPresent(dslExp); err == nil && cached != nil { + e.dslCompiled = append(e.dslCompiled, cached) + continue + } compiled, err := govaluate.NewEvaluableExpressionWithFunctions(dslExp, dsl.HelperFunctions) if err != nil { return &dsl.CompilationError{DslSignature: dslExp, WrappedError: err} } + _ = cache.DSL().Set(dslExp, compiled) e.dslCompiled = append(e.dslCompiled, compiled) } diff --git a/pkg/operators/extractors/extract.go b/pkg/operators/extractors/extract.go index 194b4648f..1a8ca63b6 100644 --- a/pkg/operators/extractors/extract.go +++ b/pkg/operators/extractors/extract.go @@ -17,9 +17,19 @@ func (e *Extractor) ExtractRegex(corpus string) map[string]struct{} { groupPlusOne := e.RegexGroup + 1 for _, regex := range e.regexCompiled { - matches := regex.FindAllStringSubmatch(corpus, -1) + // skip prefix short-circuit for case-insensitive patterns + rstr := regex.String() + if !strings.Contains(rstr, "(?i") { + if prefix, ok := regex.LiteralPrefix(); ok && prefix != "" { + if !strings.Contains(corpus, prefix) { + continue + } + } + } - for _, match := range matches { + submatches := regex.FindAllStringSubmatch(corpus, -1) + + for _, match := range submatches { if len(match) < groupPlusOne { continue } diff --git a/pkg/operators/matchers/compile.go b/pkg/operators/matchers/compile.go index 5a99347c5..4ae72be60 100644 --- a/pkg/operators/matchers/compile.go +++ b/pkg/operators/matchers/compile.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/Knetic/govaluate" - + "github.com/projectdiscovery/nuclei/v3/pkg/operators/cache" "github.com/projectdiscovery/nuclei/v3/pkg/operators/common/dsl" ) @@ -42,12 +42,17 @@ func (matcher *Matcher) CompileMatchers() error { matcher.Part = "body" } - // Compile the regexes + // Compile the regexes (with shared cache) for _, regex := range matcher.Regex { + if cached, err := cache.Regex().GetIFPresent(regex); err == nil && cached != nil { + matcher.regexCompiled = append(matcher.regexCompiled, cached) + continue + } compiled, err := regexp.Compile(regex) if err != nil { return fmt.Errorf("could not compile regex: %s", regex) } + _ = cache.Regex().Set(regex, compiled) matcher.regexCompiled = append(matcher.regexCompiled, compiled) } @@ -60,12 +65,17 @@ func (matcher *Matcher) CompileMatchers() error { } } - // Compile the dsl expressions + // Compile the dsl expressions (with shared cache) for _, dslExpression := range matcher.DSL { + if cached, err := cache.DSL().GetIFPresent(dslExpression); err == nil && cached != nil { + matcher.dslCompiled = append(matcher.dslCompiled, cached) + continue + } compiledExpression, err := govaluate.NewEvaluableExpressionWithFunctions(dslExpression, dsl.HelperFunctions) if err != nil { return &dsl.CompilationError{DslSignature: dslExpression, WrappedError: err} } + _ = cache.DSL().Set(dslExpression, compiledExpression) matcher.dslCompiled = append(matcher.dslCompiled, compiledExpression) } diff --git a/pkg/operators/matchers/match.go b/pkg/operators/matchers/match.go index 78bd40175..a914b5f98 100644 --- a/pkg/operators/matchers/match.go +++ b/pkg/operators/matchers/match.go @@ -106,10 +106,33 @@ func (matcher *Matcher) MatchRegex(corpus string) (bool, []string) { var matchedRegexes []string // Iterate over all the regexes accepted as valid for i, regex := range matcher.regexCompiled { - // Continue if the regex doesn't match - if !regex.MatchString(corpus) { - // If we are in an AND request and a match failed, - // return false as the AND condition fails on any single mismatch. + // Literal prefix short-circuit + rstr := regex.String() + if !strings.Contains(rstr, "(?i") { // covers (?i) and (?i: + if prefix, ok := regex.LiteralPrefix(); ok && prefix != "" { + if !strings.Contains(corpus, prefix) { + switch matcher.condition { + case ANDCondition: + return false, []string{} + case ORCondition: + continue + } + } + } + } + + // Fast OR-path: return first match without full scan + if matcher.condition == ORCondition && !matcher.MatchAll { + m := regex.FindAllString(corpus, 1) + if len(m) == 0 { + continue + } + return true, m + } + + // Single scan: get all matches directly + currentMatches := regex.FindAllString(corpus, -1) + if len(currentMatches) == 0 { switch matcher.condition { case ANDCondition: return false, []string{} @@ -118,12 +141,7 @@ func (matcher *Matcher) MatchRegex(corpus string) (bool, []string) { } } - currentMatches := regex.FindAllString(corpus, -1) - // If the condition was an OR, return on the first match. - if matcher.condition == ORCondition && !matcher.MatchAll { - return true, currentMatches - } - + // If the condition was an OR (and MatchAll true), we still need to gather all matchedRegexes = append(matchedRegexes, currentMatches...) // If we are at the end of the regex, return with true diff --git a/pkg/operators/matchers/match_test.go b/pkg/operators/matchers/match_test.go index ea6258ae0..8a073318b 100644 --- a/pkg/operators/matchers/match_test.go +++ b/pkg/operators/matchers/match_test.go @@ -84,7 +84,7 @@ func TestMatcher_MatchDSL(t *testing.T) { values := []string{"PING", "pong"} - for value := range values { + for _, value := range values { isMatched := m.MatchDSL(map[string]interface{}{"body": value, "VARIABLE": value}) require.True(t, isMatched) } @@ -209,3 +209,66 @@ func TestMatcher_MatchXPath_XML(t *testing.T) { isMatched = m.MatchXPath("
notvalid") require.False(t, isMatched, "Invalid xpath did not return false") } + +func TestMatchRegex_CaseInsensitivePrefixSkip(t *testing.T) { + m := &Matcher{Type: MatcherTypeHolder{MatcherType: RegexMatcher}, Condition: "or", Regex: []string{"(?i)abc"}} + err := m.CompileMatchers() + require.NoError(t, err) + ok, got := m.MatchRegex("zzz AbC yyy") + require.True(t, ok) + require.NotEmpty(t, got) +} + +func TestMatchStatusCodeAndSize(t *testing.T) { + mStatus := &Matcher{Status: []int{200, 302}} + require.True(t, mStatus.MatchStatusCode(200)) + require.True(t, mStatus.MatchStatusCode(302)) + require.False(t, mStatus.MatchStatusCode(404)) + + mSize := &Matcher{Size: []int{5, 10}} + require.True(t, mSize.MatchSize(5)) + require.False(t, mSize.MatchSize(7)) +} + +func TestMatchBinary_AND_OR(t *testing.T) { + // AND should fail if any binary not present + mAnd := &Matcher{Type: MatcherTypeHolder{MatcherType: BinaryMatcher}, Condition: "and", Binary: []string{"50494e47", "414141"}} // "PING", "AAA" + require.NoError(t, mAnd.CompileMatchers()) + ok, _ := mAnd.MatchBinary("PING") + require.False(t, ok) + // OR should succeed if any present + mOr := &Matcher{Type: MatcherTypeHolder{MatcherType: BinaryMatcher}, Condition: "or", Binary: []string{"414141", "50494e47"}} // "AAA", "PING" + require.NoError(t, mOr.CompileMatchers()) + ok, got := mOr.MatchBinary("xxPINGyy") + require.True(t, ok) + require.NotEmpty(t, got) +} + +func TestMatchRegex_LiteralPrefixShortCircuit(t *testing.T) { + // AND: first regex has literal prefix "abc"; corpus lacks it => early false + mAnd := &Matcher{Type: MatcherTypeHolder{MatcherType: RegexMatcher}, Condition: "and", Regex: []string{"abc[0-9]*", "[0-9]{2}"}} + require.NoError(t, mAnd.CompileMatchers()) + ok, matches := mAnd.MatchRegex("zzz 12 yyy") + require.False(t, ok) + require.Empty(t, matches) + + // OR: first regex skipped due to missing prefix, second matches => true + mOr := &Matcher{Type: MatcherTypeHolder{MatcherType: RegexMatcher}, Condition: "or", Regex: []string{"abc[0-9]*", "[0-9]{2}"}} + require.NoError(t, mOr.CompileMatchers()) + ok, matches = mOr.MatchRegex("zzz 12 yyy") + require.True(t, ok) + require.Equal(t, []string{"12"}, matches) +} + +func TestMatcher_MatchDSL_ErrorHandling(t *testing.T) { + // First expression errors (division by zero), second is true + bad, err := govaluate.NewEvaluableExpression("1 / 0") + require.NoError(t, err) + good, err := govaluate.NewEvaluableExpression("1 == 1") + require.NoError(t, err) + + m := &Matcher{Type: MatcherTypeHolder{MatcherType: DSLMatcher}, Condition: "or", dslCompiled: []*govaluate.EvaluableExpression{bad, good}} + require.NoError(t, m.CompileMatchers()) + ok := m.MatchDSL(map[string]interface{}{}) + require.True(t, ok) +} diff --git a/pkg/protocols/common/generators/attack_types_test.go b/pkg/protocols/common/generators/attack_types_test.go new file mode 100644 index 000000000..a1c808319 --- /dev/null +++ b/pkg/protocols/common/generators/attack_types_test.go @@ -0,0 +1,26 @@ +package generators + +import "testing" + +func TestAttackTypeHelpers(t *testing.T) { + // GetSupportedAttackTypes should include three values + types := GetSupportedAttackTypes() + if len(types) != 3 { + t.Fatalf("expected 3 types, got %d", len(types)) + } + // toAttackType valid + if got, err := toAttackType("pitchfork"); err != nil || got != PitchForkAttack { + t.Fatalf("toAttackType failed: %v %v", got, err) + } + // toAttackType invalid + if _, err := toAttackType("nope"); err == nil { + t.Fatalf("expected error for invalid attack type") + } + // normalizeValue and String + if normalizeValue(" ClusterBomb ") != "clusterbomb" { + t.Fatalf("normalizeValue failed") + } + if ClusterBombAttack.String() != "clusterbomb" { + t.Fatalf("String failed") + } +} diff --git a/pkg/protocols/common/generators/env_test.go b/pkg/protocols/common/generators/env_test.go new file mode 100644 index 000000000..88b8fa377 --- /dev/null +++ b/pkg/protocols/common/generators/env_test.go @@ -0,0 +1,38 @@ +package generators + +import ( + "os" + "testing" +) + +func TestParseEnvVars(t *testing.T) { + old := os.Environ() + // set a scoped env var + _ = os.Setenv("NUCLEI_TEST_K", "V1") + t.Cleanup(func() { + // restore + for _, kv := range old { + parts := kv + _ = parts // nothing, environment already has superset; best-effort cleanup below + } + _ = os.Unsetenv("NUCLEI_TEST_K") + }) + vars := parseEnvVars() + if vars["NUCLEI_TEST_K"] != "V1" { + t.Fatalf("expected V1, got %v", vars["NUCLEI_TEST_K"]) + } +} + +func TestEnvVarsMemoization(t *testing.T) { + // reset memoized map + envVars = nil + _ = os.Setenv("NUCLEI_TEST_MEMO", "A") + t.Cleanup(func() { _ = os.Unsetenv("NUCLEI_TEST_MEMO") }) + v1 := EnvVars()["NUCLEI_TEST_MEMO"] + // change env after memoization + _ = os.Setenv("NUCLEI_TEST_MEMO", "B") + v2 := EnvVars()["NUCLEI_TEST_MEMO"] + if v1 != "A" || v2 != "A" { + t.Fatalf("memoization failed: %v %v", v1, v2) + } +} diff --git a/pkg/protocols/common/generators/load.go b/pkg/protocols/common/generators/load.go index 892fe358a..91f631e4b 100644 --- a/pkg/protocols/common/generators/load.go +++ b/pkg/protocols/common/generators/load.go @@ -17,6 +17,20 @@ func (generator *PayloadGenerator) loadPayloads(payloads map[string]interface{}, for name, payload := range payloads { switch pt := payload.(type) { case string: + // Fast path: if no newline, treat as file path + if !strings.ContainsRune(pt, '\n') { + file, err := generator.options.LoadHelperFile(pt, templatePath, generator.catalog) + if err != nil { + return nil, errors.Wrap(err, "could not load payload file") + } + payloads, err := generator.loadPayloadsFromFile(file) + if err != nil { + return nil, errors.Wrap(err, "could not load payloads") + } + loadedPayloads[name] = payloads + break + } + // Multiline inline payloads elements := strings.Split(pt, "\n") //golint:gomnd // this is not a magic number if len(elements) >= 2 { diff --git a/pkg/protocols/common/generators/load_test.go b/pkg/protocols/common/generators/load_test.go index ebec9fd72..04d886270 100644 --- a/pkg/protocols/common/generators/load_test.go +++ b/pkg/protocols/common/generators/load_test.go @@ -1,120 +1,108 @@ package generators import ( - "os" - "os/exec" - "path/filepath" + "io" + "strings" "testing" - "github.com/projectdiscovery/nuclei/v3/pkg/catalog/config" - "github.com/projectdiscovery/nuclei/v3/pkg/catalog/disk" - osutils "github.com/projectdiscovery/utils/os" - "github.com/stretchr/testify/require" + "github.com/pkg/errors" + "github.com/projectdiscovery/nuclei/v3/pkg/catalog" + "github.com/projectdiscovery/nuclei/v3/pkg/types" ) -func TestLoadPayloads(t *testing.T) { - // since we are changing value of global variable i.e templates directory - // run this test as subprocess - if os.Getenv("LOAD_PAYLOAD_NO_ACCESS") != "1" { - cmd := exec.Command(os.Args[0], "-test.run=TestLoadPayloadsWithAccess") - cmd.Env = append(os.Environ(), "LOAD_PAYLOAD_NO_ACCESS=1") - err := cmd.Run() - if e, ok := err.(*exec.ExitError); ok && !e.Success() { - return - } - if err != nil { - t.Fatalf("process ran with err %v, want exit status 1", err) +type fakeCatalog struct{ catalog.Catalog } + +func (f *fakeCatalog) OpenFile(filename string) (io.ReadCloser, error) { + return nil, errors.New("not used") +} +func (f *fakeCatalog) GetTemplatePath(target string) ([]string, error) { return nil, nil } +func (f *fakeCatalog) GetTemplatesPath(definitions []string) ([]string, map[string]error) { + return nil, nil +} +func (f *fakeCatalog) ResolvePath(templateName, second string) (string, error) { + return templateName, nil +} + +func newTestGenerator() *PayloadGenerator { + opts := types.DefaultOptions() + // inject helper loader function + opts.LoadHelperFileFunction = func(path, templatePath string, _ catalog.Catalog) (io.ReadCloser, error) { + switch path { + case "fileA.txt": + return io.NopCloser(strings.NewReader("one\n two\n\nthree\n")), nil + default: + return io.NopCloser(strings.NewReader("x\ny\nz\n")), nil } } - templateDir := getTemplatesDir(t) - config.DefaultConfig.SetTemplatesDir(templateDir) - - generator := &PayloadGenerator{catalog: disk.NewCatalog(templateDir), options: getOptions(false)} - fullpath := filepath.Join(templateDir, "payloads.txt") - - // Test sandbox - t.Run("templates-directory", func(t *testing.T) { - // testcase when loading file from template directory and template file is in root - // expected to succeed - values, err := generator.loadPayloads(map[string]interface{}{ - "new": fullpath, - }, "/test") - require.NoError(t, err, "could not load payloads") - require.Equal(t, map[string][]string{"new": {"test", "another"}}, values, "could not get values") - }) - t.Run("templates-path-relative", func(t *testing.T) { - // testcase when loading file from template directory and template file is current working directory - // expected to fail since this is LFI - _, err := generator.loadPayloads(map[string]interface{}{ - "new": "../../../../../../../../../etc/passwd", - }, ".") - require.Error(t, err, "could load payloads") - }) - t.Run("template-directory", func(t *testing.T) { - // testcase when loading file from template directory and template file is inside template directory - // expected to succeed - values, err := generator.loadPayloads(map[string]interface{}{ - "new": fullpath, - }, filepath.Join(templateDir, "test.yaml")) - require.NoError(t, err, "could not load payloads") - require.Equal(t, map[string][]string{"new": {"test", "another"}}, values, "could not get values") - }) - - t.Run("invalid", func(t *testing.T) { - // testcase when loading file from /etc/passwd and template file is at root i.e / - // expected to fail since this is LFI - values, err := generator.loadPayloads(map[string]interface{}{ - "new": "/etc/passwd", - }, "/random") - require.Error(t, err, "could load payloads got %v", values) - require.Equal(t, 0, len(values), "could get values") - - // testcase when loading file from template directory and template file is at root i.e / - // expected to succeed - values, err = generator.loadPayloads(map[string]interface{}{ - "new": fullpath, - }, "/random") - require.NoError(t, err, "could load payloads %v", values) - require.Equal(t, 1, len(values), "could get values") - require.Equal(t, []string{"test", "another"}, values["new"], "could get values") - }) + return &PayloadGenerator{options: opts, catalog: &fakeCatalog{}} } -func TestLoadPayloadsWithAccess(t *testing.T) { - // since we are changing value of global variable i.e templates directory - // run this test as subprocess - if os.Getenv("LOAD_PAYLOAD_WITH_ACCESS") != "1" { - cmd := exec.Command(os.Args[0], "-test.run=TestLoadPayloadsWithAccess") - cmd.Env = append(os.Environ(), "LOAD_PAYLOAD_WITH_ACCESS=1") - err := cmd.Run() - if e, ok := err.(*exec.ExitError); ok && !e.Success() { - return - } - if err != nil { - t.Fatalf("process ran with err %v, want exit status 1", err) - } +func TestLoadPayloads_FastPathFile(t *testing.T) { + g := newTestGenerator() + out, err := g.loadPayloads(map[string]interface{}{"A": "fileA.txt"}, "") + if err != nil { + t.Fatalf("err: %v", err) + } + got := out["A"] + if len(got) != 3 || got[0] != "one" || got[1] != " two" || got[2] != "three" { + t.Fatalf("unexpected: %#v", got) } - templateDir := getTemplatesDir(t) - config.DefaultConfig.SetTemplatesDir(templateDir) - - generator := &PayloadGenerator{catalog: disk.NewCatalog(templateDir), options: getOptions(true)} - - t.Run("no-sandbox-unix", func(t *testing.T) { - if osutils.IsWindows() { - return - } - _, err := generator.loadPayloads(map[string]interface{}{ - "new": "/etc/passwd", - }, "/random") - require.NoError(t, err, "could load payloads") - }) } -func getTemplatesDir(t *testing.T) string { - tempdir, err := os.MkdirTemp("", "templates-*") - require.NoError(t, err, "could not create temp dir") - fullpath := filepath.Join(tempdir, "payloads.txt") - err = os.WriteFile(fullpath, []byte("test\nanother"), 0777) - require.NoError(t, err, "could not write payload") - return tempdir +func TestLoadPayloads_InlineMultiline(t *testing.T) { + g := newTestGenerator() + inline := "a\nb\n" + out, err := g.loadPayloads(map[string]interface{}{"B": inline}, "") + if err != nil { + t.Fatalf("err: %v", err) + } + got := out["B"] + if len(got) != 3 || got[0] != "a" || got[1] != "b" || got[2] != "" { + t.Fatalf("unexpected: %#v", got) + } +} + +func TestLoadPayloads_SingleLineFallsBackToFile(t *testing.T) { + g := newTestGenerator() + inline := "fileA.txt" // single line, should be treated as file path + out, err := g.loadPayloads(map[string]interface{}{"C": inline}, "") + if err != nil { + t.Fatalf("err: %v", err) + } + got := out["C"] + if len(got) != 3 { + t.Fatalf("unexpected len: %d", len(got)) + } +} + +func TestLoadPayloads_InterfaceSlice(t *testing.T) { + g := newTestGenerator() + out, err := g.loadPayloads(map[string]interface{}{"D": []interface{}{"p", "q"}}, "") + if err != nil { + t.Fatalf("err: %v", err) + } + got := out["D"] + if len(got) != 2 || got[0] != "p" || got[1] != "q" { + t.Fatalf("unexpected: %#v", got) + } +} + +func TestLoadPayloadsFromFile_SkipsEmpty(t *testing.T) { + g := newTestGenerator() + rc := io.NopCloser(strings.NewReader("a\n\n\n b \n")) + lines, err := g.loadPayloadsFromFile(rc) + if err != nil { + t.Fatalf("err: %v", err) + } + if len(lines) != 2 || lines[0] != "a" || lines[1] != " b " { + t.Fatalf("unexpected: %#v", lines) + } +} + +func TestValidate_AllowsInlineMultiline(t *testing.T) { + g := newTestGenerator() + inline := "x\ny\n" + if err := g.validate(map[string]interface{}{"E": inline}, ""); err != nil { + t.Fatalf("validate rejected inline multiline: %v", err) + } } diff --git a/pkg/protocols/common/generators/maps_test.go b/pkg/protocols/common/generators/maps_test.go index ca75bb655..8b73539f8 100644 --- a/pkg/protocols/common/generators/maps_test.go +++ b/pkg/protocols/common/generators/maps_test.go @@ -14,3 +14,32 @@ func TestMergeMapsMany(t *testing.T) { "c": {"5"}, }, got, "could not get correct merged map") } + +func TestMergeMapsAndExpand(t *testing.T) { + m1 := map[string]interface{}{"a": "1"} + m2 := map[string]interface{}{"b": "2"} + out := MergeMaps(m1, m2) + if out["a"].(string) != "1" || out["b"].(string) != "2" { + t.Fatalf("unexpected merge: %#v", out) + } + flat := map[string]string{"x": "y"} + exp := ExpandMapValues(flat) + if len(exp["x"]) != 1 || exp["x"][0] != "y" { + t.Fatalf("unexpected expand: %#v", exp) + } +} + +func TestIteratorRemaining(t *testing.T) { + g, err := New(map[string]interface{}{"k": []interface{}{"a", "b"}}, BatteringRamAttack, "", nil, "", nil) + if err != nil { + t.Fatalf("new: %v", err) + } + it := g.NewIterator() + if it.Total() != 2 || it.Remaining() != 2 { + t.Fatalf("unexpected totals: %d %d", it.Total(), it.Remaining()) + } + _, _ = it.Value() + if it.Remaining() != 1 { + t.Fatalf("unexpected remaining after one: %d", it.Remaining()) + } +} diff --git a/pkg/protocols/common/generators/validate.go b/pkg/protocols/common/generators/validate.go index 0aa073714..48365e0db 100644 --- a/pkg/protocols/common/generators/validate.go +++ b/pkg/protocols/common/generators/validate.go @@ -1,7 +1,6 @@ package generators import ( - "errors" "fmt" "path/filepath" "strings" @@ -17,9 +16,8 @@ func (g *PayloadGenerator) validate(payloads map[string]interface{}, templatePat for name, payload := range payloads { switch payloadType := payload.(type) { case string: - // check if it's a multiline string list - if len(strings.Split(payloadType, "\n")) != 1 { - return errors.New("invalid number of lines in payload") + if strings.ContainsRune(payloadType, '\n') { + continue } // For historical reasons, "validate" checks to see if the payload file exist. diff --git a/pkg/protocols/http/build_request.go b/pkg/protocols/http/build_request.go index 0fa8d1502..980573c96 100644 --- a/pkg/protocols/http/build_request.go +++ b/pkg/protocols/http/build_request.go @@ -92,9 +92,8 @@ type generatedRequest struct { // setReqURLPattern sets the url request pattern for the generated request func (gr *generatedRequest) setReqURLPattern(reqURLPattern string) { - data := strings.Split(reqURLPattern, "\n") - if len(data) > 1 { - reqURLPattern = strings.TrimSpace(data[0]) + if idx := strings.IndexByte(reqURLPattern, '\n'); idx >= 0 { + reqURLPattern = strings.TrimSpace(reqURLPattern[:idx]) // this is raw request (if it has 3 parts after strings.Fields then its valid only use 2nd part) parts := strings.Fields(reqURLPattern) if len(parts) >= 3 { diff --git a/pkg/protocols/http/raw/raw.go b/pkg/protocols/http/raw/raw.go index c2a2121b6..7b1457afa 100644 --- a/pkg/protocols/http/raw/raw.go +++ b/pkg/protocols/http/raw/raw.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "strings" + "sync" "github.com/projectdiscovery/gologger" "github.com/projectdiscovery/nuclei/v3/pkg/authprovider/authx" @@ -17,6 +18,8 @@ import ( urlutil "github.com/projectdiscovery/utils/url" ) +var bufferPool = sync.Pool{New: func() any { return new(bytes.Buffer) }} + // Request defines a basic HTTP raw request type Request struct { FullURL string @@ -270,13 +273,17 @@ func (r *Request) TryFillCustomHeaders(headers []string) error { if newLineIndex > 0 { newLineIndex += hostHeaderIndex + 2 // insert custom headers - var buf bytes.Buffer + buf := bufferPool.Get().(*bytes.Buffer) + buf.Reset() buf.Write(r.UnsafeRawBytes[:newLineIndex]) for _, header := range headers { - buf.WriteString(fmt.Sprintf("%s\r\n", header)) + buf.WriteString(header) + buf.WriteString("\r\n") } buf.Write(r.UnsafeRawBytes[newLineIndex:]) - r.UnsafeRawBytes = buf.Bytes() + r.UnsafeRawBytes = append([]byte(nil), buf.Bytes()...) + buf.Reset() + bufferPool.Put(buf) return nil } return errors.New("no new line found at the end of host header") @@ -301,9 +308,10 @@ func (r *Request) ApplyAuthStrategy(strategy authx.AuthStrategy) { parsed.Params.Add(p.Key, p.Value) } case *authx.CookiesAuthStrategy: - var buff bytes.Buffer + buff := bufferPool.Get().(*bytes.Buffer) + buff.Reset() for _, cookie := range s.Data.Cookies { - buff.WriteString(fmt.Sprintf("%s=%s; ", cookie.Key, cookie.Value)) + fmt.Fprintf(buff, "%s=%s; ", cookie.Key, cookie.Value) } if buff.Len() > 0 { if val, ok := r.Headers["Cookie"]; ok { @@ -312,6 +320,7 @@ func (r *Request) ApplyAuthStrategy(strategy authx.AuthStrategy) { r.Headers["Cookie"] = buff.String() } } + bufferPool.Put(buff) case *authx.HeadersAuthStrategy: for _, header := range s.Data.Headers { r.Headers[header.Key] = header.Value diff --git a/pkg/protocols/http/raw/raw_test.go b/pkg/protocols/http/raw/raw_test.go index a44664d48..80fefff7f 100644 --- a/pkg/protocols/http/raw/raw_test.go +++ b/pkg/protocols/http/raw/raw_test.go @@ -7,6 +7,21 @@ import ( "github.com/stretchr/testify/require" ) +func TestTryFillCustomHeaders_BufferDetached(t *testing.T) { + r := &Request{ + UnsafeRawBytes: []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\nBody"), + } + // first fill + err := r.TryFillCustomHeaders([]string{"X-Test: 1"}) + require.NoError(t, err, "unexpected error on first call") + prev := r.UnsafeRawBytes + prevStr := string(prev) // content snapshot + err = r.TryFillCustomHeaders([]string{"X-Another: 2"}) + require.NoError(t, err, "unexpected error on second call") + require.Equal(t, prevStr, string(prev), "first slice mutated after second call; buffer not detached") + require.NotEqual(t, prevStr, string(r.UnsafeRawBytes), "request bytes did not change after second call") +} + func TestParseRawRequestWithPort(t *testing.T) { request, err := Parse(`GET /gg/phpinfo.php HTTP/1.1 Host: {{Hostname}}:123 diff --git a/pkg/protocols/http/request.go b/pkg/protocols/http/request.go index 381f4d385..c6e3de4d9 100644 --- a/pkg/protocols/http/request.go +++ b/pkg/protocols/http/request.go @@ -240,6 +240,48 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV } }) + // bounded worker-pool to avoid spawning one goroutine per payload + type task struct { + req *generatedRequest + updatedInput *contextargs.Context + } + + var workersWg sync.WaitGroup + currentWorkers := maxWorkers + tasks := make(chan task, maxWorkers) + spawnWorker := func(ctx context.Context) { + workersWg.Add(1) + go func() { + defer workersWg.Done() + for t := range tasks { + select { + case <-ctx.Done(): + return + default: + } + if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(t.updatedInput) || spmHandler.Cancelled() { + continue + } + spmHandler.Acquire() + if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(t.updatedInput) || spmHandler.Cancelled() { + spmHandler.Release() + continue + } + request.options.RateLimitTake() + select { + case <-spmHandler.Done(): + spmHandler.Release() + continue + case spmHandler.ResultChan <- request.executeRequest(t.updatedInput, t.req, make(map[string]interface{}), false, wrappedCallback, 0): + spmHandler.Release() + } + } + }() + } + for i := 0; i < currentWorkers; i++ { + spawnWorker(ctx) + } + // iterate payloads and make requests generator := request.newGenerator(false) for { @@ -259,6 +301,13 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV if err := spmHandler.Resize(input.Context(), request.options.Options.PayloadConcurrency); err != nil { return err } + // if payload concurrency increased, add more workers + if spmHandler.Size() > currentWorkers { + for i := 0; i < spmHandler.Size()-currentWorkers; i++ { + spawnWorker(ctx) + } + currentWorkers = spmHandler.Size() + } } // break if stop at first match is found or host is unresponsive @@ -284,29 +333,21 @@ func (request *Request) executeParallelHTTP(input *contextargs.Context, dynamicV spmHandler.Cancel() return nil } - spmHandler.Acquire() - go func(httpRequest *generatedRequest) { - defer spmHandler.Release() - if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(updatedInput) || spmHandler.Cancelled() { - return + select { + case <-spmHandler.Done(): + close(tasks) + workersWg.Wait() + spmHandler.Wait() + if spmHandler.FoundFirstMatch() { + return nil } - // putting ratelimiter here prevents any unnecessary waiting if any - request.options.RateLimitTake() - - // after ratelimit take, check if we need to stop - if spmHandler.FoundFirstMatch() || request.isUnresponsiveAddress(updatedInput) || spmHandler.Cancelled() { - return - } - - select { - case <-spmHandler.Done(): - return - case spmHandler.ResultChan <- request.executeRequest(input, httpRequest, make(map[string]interface{}), false, wrappedCallback, 0): - return - } - }(generatedHttpRequest) + return multierr.Combine(spmHandler.CombinedResults()...) + case tasks <- task{req: generatedHttpRequest, updatedInput: updatedInput}: + } request.options.Progress.IncrementRequests() } + close(tasks) + workersWg.Wait() spmHandler.Wait() if spmHandler.FoundFirstMatch() { // ignore any context cancellation and in-transit execution errors diff --git a/pkg/protocols/http/request_test.go b/pkg/protocols/http/request_test.go index a6314ae5a..b547d91b5 100644 --- a/pkg/protocols/http/request_test.go +++ b/pkg/protocols/http/request_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" "net/http/httptest" + "sync/atomic" "testing" "github.com/stretchr/testify/require" @@ -257,3 +258,116 @@ func TestReqURLPattern(t *testing.T) { require.NotEmpty(t, finalEvent.Results[0].ReqURLPattern, "could not get req url pattern") require.Equal(t, `/{{rand_char("abc")}}/{{interactsh-url}}/123?query={{rand_int(1, 10)}}&data={{randstr}}`, finalEvent.Results[0].ReqURLPattern) } + +// fakeHostErrorsCache implements hosterrorscache.CacheInterface minimally for tests +type fakeHostErrorsCache struct{} + +func (f *fakeHostErrorsCache) SetVerbose(bool) {} +func (f *fakeHostErrorsCache) Close() {} +func (f *fakeHostErrorsCache) Remove(*contextargs.Context) {} +func (f *fakeHostErrorsCache) MarkFailed(string, *contextargs.Context, error) {} +func (f *fakeHostErrorsCache) MarkFailedOrRemove(string, *contextargs.Context, error) { +} + +// Check always returns true to simulate an already unresponsive host +func (f *fakeHostErrorsCache) Check(string, *contextargs.Context) bool { return true } + +func TestExecuteParallelHTTP_StopAtFirstMatch(t *testing.T) { + options := testutils.DefaultOptions + testutils.Init(options) + + // server that always matches + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintf(w, "match") + })) + defer ts.Close() + + templateID := "parallel-stop-first" + req := &Request{ + ID: templateID, + Method: HTTPMethodTypeHolder{MethodType: HTTPGet}, + Path: []string{"{{BaseURL}}/p?x={{v}}"}, + Threads: 2, + Payloads: map[string]interface{}{ + "v": []string{"1", "2"}, + }, + Operators: operators.Operators{ + Matchers: []*matchers.Matcher{{ + Part: "body", + Type: matchers.MatcherTypeHolder{MatcherType: matchers.WordsMatcher}, + Words: []string{"match"}, + }}, + }, + StopAtFirstMatch: true, + } + + executerOpts := testutils.NewMockExecuterOptions(options, &testutils.TemplateInfo{ + ID: templateID, + Info: model.Info{SeverityHolder: severity.Holder{Severity: severity.Low}, Name: "test"}, + }) + err := req.Compile(executerOpts) + require.NoError(t, err) + + var matches int32 + metadata := make(output.InternalEvent) + previous := make(output.InternalEvent) + ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL) + err = req.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { + if event.OperatorsResult != nil && event.OperatorsResult.Matched { + atomic.AddInt32(&matches, 1) + } + }) + require.NoError(t, err) + require.Equal(t, int32(1), atomic.LoadInt32(&matches), "expected only first match to be processed") +} + +func TestExecuteParallelHTTP_SkipOnUnresponsiveFromCache(t *testing.T) { + options := testutils.DefaultOptions + testutils.Init(options) + + // server that would match if reached + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = fmt.Fprintf(w, "match") + })) + defer ts.Close() + + templateID := "parallel-skip-unresponsive" + req := &Request{ + ID: templateID, + Method: HTTPMethodTypeHolder{MethodType: HTTPGet}, + Path: []string{"{{BaseURL}}/p?x={{v}}"}, + Threads: 2, + Payloads: map[string]interface{}{ + "v": []string{"1", "2"}, + }, + Operators: operators.Operators{ + Matchers: []*matchers.Matcher{{ + Part: "body", + Type: matchers.MatcherTypeHolder{MatcherType: matchers.WordsMatcher}, + Words: []string{"match"}, + }}, + }, + } + + executerOpts := testutils.NewMockExecuterOptions(options, &testutils.TemplateInfo{ + ID: templateID, + Info: model.Info{SeverityHolder: severity.Holder{Severity: severity.Low}, Name: "test"}, + }) + // inject fake host errors cache that forces skip + executerOpts.HostErrorsCache = &fakeHostErrorsCache{} + + err := req.Compile(executerOpts) + require.NoError(t, err) + + var matches int32 + metadata := make(output.InternalEvent) + previous := make(output.InternalEvent) + ctxArgs := contextargs.NewWithInput(context.Background(), ts.URL) + err = req.ExecuteWithResults(ctxArgs, metadata, previous, func(event *output.InternalWrappedEvent) { + if event.OperatorsResult != nil && event.OperatorsResult.Matched { + atomic.AddInt32(&matches, 1) + } + }) + require.NoError(t, err) + require.Equal(t, int32(0), atomic.LoadInt32(&matches), "expected no matches when host is marked unresponsive") +} diff --git a/pkg/protocols/javascript/js.go b/pkg/protocols/javascript/js.go index 9925c3103..8b872d84a 100644 --- a/pkg/protocols/javascript/js.go +++ b/pkg/protocols/javascript/js.go @@ -811,8 +811,11 @@ func beautifyJavascript(code string) string { } func prettyPrint(templateId string, buff string) { + if buff == "" { + return + } lines := strings.Split(buff, "\n") - final := []string{} + final := make([]string, 0, len(lines)) for _, v := range lines { if v != "" { final = append(final, "\t"+v) diff --git a/pkg/reporting/reporting.go b/pkg/reporting/reporting.go index c3706c970..58d7f61fb 100644 --- a/pkg/reporting/reporting.go +++ b/pkg/reporting/reporting.go @@ -192,11 +192,13 @@ func New(options *Options, db string, doNotDedupe bool) (Client, error) { } } - storage, err := dedupe.New(db) - if err != nil { - return nil, err + if db != "" || len(client.trackers) > 0 || len(client.exporters) > 0 { + storage, err := dedupe.New(db) + if err != nil { + return nil, err + } + client.dedupe = storage } - client.dedupe = storage return client, nil } diff --git a/pkg/templates/cache.go b/pkg/templates/cache.go index b4c00fb36..ae7124772 100644 --- a/pkg/templates/cache.go +++ b/pkg/templates/cache.go @@ -12,7 +12,9 @@ type Cache struct { // New returns a new templates cache func NewCache() *Cache { - return &Cache{items: mapsutil.NewSyncLockMap[string, parsedTemplate]()} + return &Cache{ + items: mapsutil.NewSyncLockMap[string, parsedTemplate](), + } } type parsedTemplate struct { @@ -33,7 +35,31 @@ func (t *Cache) Has(template string) (*Template, []byte, error) { // Store stores a template with data and error func (t *Cache) Store(id string, tpl *Template, raw []byte, err error) { - _ = t.items.Set(id, parsedTemplate{template: tpl, raw: conversion.String(raw), err: err}) + entry := parsedTemplate{ + template: tpl, + err: err, + raw: conversion.String(raw), + } + _ = t.items.Set(id, entry) +} + +// StoreWithoutRaw stores a template without raw data for memory efficiency +func (t *Cache) StoreWithoutRaw(id string, tpl *Template, err error) { + entry := parsedTemplate{ + template: tpl, + err: err, + raw: "", + } + _ = t.items.Set(id, entry) +} + +// Get returns only the template without raw bytes +func (t *Cache) Get(id string) (*Template, error) { + value, ok := t.items.Get(id) + if !ok { + return nil, nil + } + return value.template, value.err } // Purge the cache diff --git a/pkg/templates/templates.go b/pkg/templates/templates.go index 74818557e..9907e2711 100644 --- a/pkg/templates/templates.go +++ b/pkg/templates/templates.go @@ -404,7 +404,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err // for code protocol requests for _, request := range template.RequestsCode { // simple test to check if source is a file or a snippet - if len(strings.Split(request.Source, "\n")) == 1 && fileutil.FileExists(request.Source) { + if !strings.ContainsRune(request.Source, '\n') && fileutil.FileExists(request.Source) { if val, ok := loadFile(request.Source); ok { template.ImportedFiles = append(template.ImportedFiles, request.Source) request.Source = val @@ -415,7 +415,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err // for javascript protocol code references for _, request := range template.RequestsJavascript { // simple test to check if source is a file or a snippet - if len(strings.Split(request.Code, "\n")) == 1 && fileutil.FileExists(request.Code) { + if !strings.ContainsRune(request.Code, '\n') && fileutil.FileExists(request.Code) { if val, ok := loadFile(request.Code); ok { template.ImportedFiles = append(template.ImportedFiles, request.Code) request.Code = val @@ -442,7 +442,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err if req.Type() == types.CodeProtocol { request := req.(*code.Request) // simple test to check if source is a file or a snippet - if len(strings.Split(request.Source, "\n")) == 1 && fileutil.FileExists(request.Source) { + if !strings.ContainsRune(request.Source, '\n') && fileutil.FileExists(request.Source) { if val, ok := loadFile(request.Source); ok { template.ImportedFiles = append(template.ImportedFiles, request.Source) request.Source = val @@ -456,7 +456,7 @@ func (template *Template) ImportFileRefs(options *protocols.ExecutorOptions) err if req.Type() == types.JavascriptProtocol { request := req.(*javascript.Request) // simple test to check if source is a file or a snippet - if len(strings.Split(request.Code, "\n")) == 1 && fileutil.FileExists(request.Code) { + if !strings.ContainsRune(request.Code, '\n') && fileutil.FileExists(request.Code) { if val, ok := loadFile(request.Code); ok { template.ImportedFiles = append(template.ImportedFiles, request.Code) request.Code = val diff --git a/pkg/templates/templates_test.go b/pkg/templates/templates_test.go index 7621f1c57..a94ebf656 100644 --- a/pkg/templates/templates_test.go +++ b/pkg/templates/templates_test.go @@ -9,6 +9,32 @@ import ( "gopkg.in/yaml.v2" ) +func TestCachePoolZeroing(t *testing.T) { + c := NewCache() + + tpl := &Template{ID: "x"} + raw := []byte("SOME BIG RAW") + + c.Store("id1", tpl, raw, nil) + gotTpl, gotErr := c.Get("id1") + if gotErr != nil { + t.Fatalf("unexpected err: %v", gotErr) + } + if gotTpl == nil || gotTpl.ID != "x" { + t.Fatalf("unexpected tpl: %#v", gotTpl) + } + + // StoreWithoutRaw should not retain raw + c.StoreWithoutRaw("id2", tpl, nil) + gotTpl2, gotErr2 := c.Get("id2") + if gotErr2 != nil { + t.Fatalf("unexpected err: %v", gotErr2) + } + if gotTpl2 == nil || gotTpl2.ID != "x" { + t.Fatalf("unexpected tpl2: %#v", gotTpl2) + } +} + func TestTemplateStruct(t *testing.T) { templatePath := "./tests/match-1.yaml" bin, err := os.ReadFile(templatePath)